diff --git a/cads_processing_api_service/clients.py b/cads_processing_api_service/clients.py index 6d93a80..1681eb4 100644 --- a/cads_processing_api_service/clients.py +++ b/cads_processing_api_service/clients.py @@ -230,8 +230,11 @@ def post_process_execution( """ structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid) _ = limits.check_rate_limits( - SETTINGS.rate_limits.process_execution.post, + SETTINGS.rate_limits, + "processes_processid_execution", + "post", auth_info, + process_id, ) request_body = execution_content.model_dump() catalogue_sessionmaker = db_utils.get_catalogue_sessionmaker( @@ -396,7 +399,9 @@ def get_jobs( """ structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid) _ = limits.check_rate_limits( - SETTINGS.rate_limits.jobs.get, + SETTINGS.rate_limits, + "jobs", + "get", auth_info, ) job_filters = { @@ -526,7 +531,9 @@ def get_job( """ structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid) _ = limits.check_rate_limits( - SETTINGS.rate_limits.job.get, + SETTINGS.rate_limits, + "jobs_jobid", + "get", auth_info, ) compute_connection_mode = ( @@ -646,7 +653,9 @@ def get_job_results( """ structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid) _ = limits.check_rate_limits( - SETTINGS.rate_limits.job_results.get, + SETTINGS.rate_limits, + "jobs_jobsid_results", + "get", auth_info, ) compute_connection_mode = ( @@ -711,7 +720,9 @@ def delete_job( """ structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid) _ = limits.check_rate_limits( - SETTINGS.rate_limits.job.delete, + SETTINGS.rate_limits, + "jobs_jobsid", + "delete", auth_info, ) compute_sessionmaker = db_utils.get_compute_sessionmaker( diff --git a/cads_processing_api_service/config.py b/cads_processing_api_service/config.py index a47017f..233b423 100644 --- a/cads_processing_api_service/config.py +++ b/cads_processing_api_service/config.py @@ -108,45 +108,47 @@ class RateLimitsRouteConfig(pydantic.BaseModel): delete: RateLimitsMethodConfig = pydantic.Field(default=RateLimitsMethodConfig()) +class RateLimitsRouteParamConfig(pydantic.BaseModel): + __pydantic_extra__: dict[str, RateLimitsRouteConfig] = pydantic.Field(init=False) + + default: RateLimitsRouteConfig = pydantic.Field(default=RateLimitsRouteConfig()) + + model_config = pydantic.ConfigDict(extra="allow") + + class RateLimitsConfig(pydantic.BaseModel): default: RateLimitsRouteConfig = pydantic.Field( default=RateLimitsRouteConfig(), validate_default=True ) - process_execution: RateLimitsRouteConfig = pydantic.Field( - default=RateLimitsRouteConfig(), + processes_processid_execution: RateLimitsRouteParamConfig = pydantic.Field( alias="/processes/{process_id}/execution", + default=RateLimitsRouteParamConfig(), + validate_default=True, + ) + processes_processid_constraints: RateLimitsRouteParamConfig = pydantic.Field( + alias="/processes/{process_id}/constraints", + default=RateLimitsRouteParamConfig(), + validate_default=True, + ) + processes_processid_costing: RateLimitsRouteParamConfig = pydantic.Field( + alias="/processes/{process_id}/costing", + default=RateLimitsRouteParamConfig(), validate_default=True, ) jobs: RateLimitsRouteConfig = pydantic.Field( default=RateLimitsRouteConfig(), alias="/jobs", validate_default=True ) - job: RateLimitsRouteConfig = pydantic.Field( + jobs_jobsid: RateLimitsRouteConfig = pydantic.Field( default=RateLimitsRouteConfig(), alias="/jobs/{job_id}", validate_default=True ) - job_results: RateLimitsRouteConfig = pydantic.Field( + jobs_jobsid_results: RateLimitsRouteConfig = pydantic.Field( default=RateLimitsRouteConfig(), alias="/jobs/{job_id}/results", validate_default=True, ) - - @pydantic.model_validator(mode="after") # type: ignore - def populate_fields_with_default(self) -> pydantic.BaseModel: - default = self.default - if default is RateLimitsRouteConfig(): - return self - routes = self.model_fields - for route in routes: - if route == "default": - continue - route_config: RateLimitsRouteConfig = getattr(self, route) - for method in route_config.model_fields: - method_config: RateLimitsMethodConfig = getattr(route_config, method) - for origin in method_config.model_fields: - set_value = getattr(getattr(getattr(self, route), method), origin) - if not set_value: - default_value = getattr(getattr(default, method), origin) - setattr(getattr(route_config, method), origin, default_value) - return self + jobs_delete: RateLimitsRouteConfig = pydantic.Field( + default=RateLimitsRouteConfig(), alias="/jobs/delete", validate_default=True + ) def load_rate_limits(rate_limits_file: str | None) -> RateLimitsConfig: diff --git a/cads_processing_api_service/endpoints.py b/cads_processing_api_service/endpoints.py index cc6d571..155aee9 100644 --- a/cads_processing_api_service/endpoints.py +++ b/cads_processing_api_service/endpoints.py @@ -26,8 +26,6 @@ SETTINGS = config.settings -logger: structlog.stdlib.BoundLogger = structlog.get_logger(__name__) - @exceptions.exception_logger def apply_constraints( @@ -39,6 +37,13 @@ def apply_constraints( ) ), ) -> dict[str, Any]: + limits.check_rate_limits( + SETTINGS.rate_limits, + "processes_processid_constraints", + "post", + auth_info, + process_id, + ) request = execution_content.model_dump() table = cads_catalogue.database.Resource catalogue_sessionmaker = db_utils.get_catalogue_sessionmaker( @@ -93,6 +98,13 @@ def estimate_cost( models.RequestCost Info on the cost with the highest cost/limit ratio. """ + limits.check_rate_limits( + SETTINGS.rate_limits, + "processes_processid_costing", + "post", + auth_info, + process_id, + ) request = execution_content.model_dump() table = cads_catalogue.database.Resource catalogue_sessionmaker = db_utils.get_catalogue_sessionmaker( @@ -172,7 +184,9 @@ def delete_jobs( """ structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid) limits.check_rate_limits( - SETTINGS.rate_limits.jobs.delete, + SETTINGS.rate_limits, + "jobs_delete", + "post", auth_info, ) job_ids = request.job_ids diff --git a/cads_processing_api_service/limits.py b/cads_processing_api_service/limits.py index 13e4ac6..20e7b46 100644 --- a/cads_processing_api_service/limits.py +++ b/cads_processing_api_service/limits.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License +from typing import Any + import limits import structlog @@ -27,6 +29,47 @@ limiter = config.RATE_LIMITS_LIMITER +def get_rate_limits( + rate_limits_config: config.RateLimitsConfig, + route: str, + method: str, + request_origin: str, + route_param: str | None = None, +) -> list[str]: + """Get the rate limits for a specific route and method.""" + rate_limits = rate_limits_config.model_dump() + route_rate_limits: dict[str, Any] = rate_limits.get(route, {}) + if route_param is not None: + route_param_rate_limits: dict[str, Any] = route_rate_limits.get(route_param, {}) + else: + route_param_rate_limits = route_rate_limits + method_rate_limits: dict[str, Any] = route_param_rate_limits.get(method, {}) + rate_limit_ids: list[str] = method_rate_limits.get(request_origin, []) + return rate_limit_ids + + +def get_rate_limits_defaulted( + rate_limits_config: config.RateLimitsConfig, + route: str, + method: str, + request_origin: str, + route_param: str | None = None, +) -> list[str]: + """Get the rate limits for a specific route and method, with defaults.""" + rate_limits = get_rate_limits( + rate_limits_config, route, method, request_origin, route_param + ) + if not rate_limits: + rate_limits = get_rate_limits( + rate_limits_config, route, method, request_origin, "default" + ) + if not rate_limits: + rate_limits = get_rate_limits( + rate_limits_config, "default", method, request_origin + ) + return rate_limits + + def check_rate_limits_for_user( user_uid: str, rate_limits: list[limits.RateLimitItem] ) -> None: @@ -52,13 +95,18 @@ def check_rate_limits_for_user( def check_rate_limits( - method_rate_limits: config.RateLimitsMethodConfig, + rate_limits_config: config.RateLimitsConfig, + route: str, + method: str, auth_info: models.AuthInfo, + route_param: str | None = None, ) -> None: """Check if the rate limits are exceeded.""" - user_uid = auth_info.user_uid request_origin = auth_info.request_origin - rate_limit_ids = getattr(method_rate_limits, request_origin) - rate_limits = [limits.parse(rate_limit_id) for rate_limit_id in rate_limit_ids] - check_rate_limits_for_user(user_uid, rate_limits) + user_uid = auth_info.user_uid + rate_limits = get_rate_limits_defaulted( + rate_limits_config, route, method, request_origin, route_param + ) + rate_limits_parsed = [limits.parse(rate_limit) for rate_limit in rate_limits] + check_rate_limits_for_user(user_uid, rate_limits_parsed) return None diff --git a/tests/test_10_config.py b/tests/test_10_config.py index ba7260c..3ded98d 100644 --- a/tests/test_10_config.py +++ b/tests/test_10_config.py @@ -73,18 +73,41 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None: rate_limits_file = str(tmp_path / "rate-limits.yaml") rate_limits = { - "/processes/{process_id}/execution": { - "post": {"api": ["1/second"], "ui": ["2/second"]} + "/jobs/{job_id}": {"get": {"api": ["1/second"], "ui": ["2/second"]}}, + "/processes/{process_id}/constraints": { + "default": {"get": {"api": ["1/second"], "ui": ["2/second"]}}, + "process-id": {"post": {"api": ["1/second"], "ui": ["2/second"]}}, }, } with open(rate_limits_file, "w") as file: yaml.dump(rate_limits, file) - loaded_rate_limits = config.load_rate_limits(rate_limits_file) - assert loaded_rate_limits == config.RateLimitsConfig(**rate_limits) + loaded_rate_limits = config.load_rate_limits(rate_limits_file).model_dump() + expected_jobs_limits = { + "get": {"api": ["1/second"], "ui": ["2/second"]}, + "post": {"api": [], "ui": []}, + "delete": {"api": [], "ui": []}, + } + assert loaded_rate_limits["jobs_jobsid"] == expected_jobs_limits + expected_process_constraints_limits = { + "default": { + "get": {"api": ["1/second"], "ui": ["2/second"]}, + "post": {"api": [], "ui": []}, + "delete": {"api": [], "ui": []}, + }, + "process-id": { + "get": {"api": [], "ui": []}, + "post": {"api": ["1/second"], "ui": ["2/second"]}, + "delete": {"api": [], "ui": []}, + }, + } + assert ( + loaded_rate_limits["processes_processid_constraints"] + == expected_process_constraints_limits + ) rate_limits_file = str(tmp_path / "invalid-rate-limits.yaml") rate_limits = { - "/processes/{process_id}/execution": {"post": {"api": ["invalid_limit"]}}, + "/jobs/{job_id}": {"get": {"api": ["invalid_limit"]}}, } with open(rate_limits_file, "w") as file: yaml.dump(rate_limits, file) @@ -94,41 +117,3 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None: rate_limits_file = str(tmp_path / "not-found-rate-limits.yaml") loaded_rate_limits = config.load_rate_limits(rate_limits_file) assert loaded_rate_limits == config.RateLimitsConfig() - - -def test_rate_limits_config_populate_with_default() -> None: - rate_limits_config = config.RateLimitsConfig( - **{ - "default": { - "post": {"api": ["1/second"], "ui": ["2/second"]}, - "get": {"api": ["2/second"]}, - }, - "/processes/{process_id}/execution": {"post": {"api": ["1/minute"]}}, - } - ) - exp_populated_rate_limits_config = { - "default": { - "post": {"api": ["1/second"], "ui": ["2/second"]}, - "get": {"api": ["2/second"]}, - }, - "process_execution": { - "post": {"api": ["1/minute"], "ui": ["2/second"]}, - "get": {"api": ["2/second"]}, - }, - "jobs": { - "post": {"api": ["1/second"], "ui": ["2/second"]}, - "get": {"api": ["2/second"]}, - }, - "job": { - "post": {"api": ["1/second"], "ui": ["2/second"]}, - "get": {"api": ["2/second"]}, - }, - "job_results": { - "post": {"api": ["1/second"], "ui": ["2/second"]}, - "get": {"api": ["2/second"]}, - }, - } - assert ( - rate_limits_config.model_dump(exclude_defaults=True) - == exp_populated_rate_limits_config - ) diff --git a/tests/test_30_limits.py b/tests/test_30_limits.py index 11997f1..43a5d35 100644 --- a/tests/test_30_limits.py +++ b/tests/test_30_limits.py @@ -18,7 +18,176 @@ import pytest import cads_processing_api_service.limits -from cads_processing_api_service import exceptions +from cads_processing_api_service import config, exceptions + + +def test_get_rate_limits() -> None: + rate_limits = {"/jobs/{job_id}": {"get": {"api": ["2/second"]}}} + rate_limits_config = config.RateLimitsConfig(**rate_limits) + + route = "jobs_jobsid" + method = "get" + request_origin = "api" + rate_limits = cads_processing_api_service.limits.get_rate_limits( + rate_limits_config, route, method, request_origin + ) + exp_rate_limits = ["2/second"] + assert rate_limits == exp_rate_limits + + +def test_get_rate_limits_route_param() -> None: + rate_limits = { + "/processes/{process_id}/execution": { + "process_id": {"post": {"api": ["2/second"]}} + } + } + rate_limits_config = config.RateLimitsConfig(**rate_limits) + + route = "processes_processid_execution" + route_param = "process_id" + method = "post" + request_origin = "api" + rate_limits = cads_processing_api_service.limits.get_rate_limits( + rate_limits_config, route, method, request_origin, route_param + ) + exp_rate_limits = ["2/second"] + assert rate_limits == exp_rate_limits + + +def test_get_rate_limits_defaulted_actual_value() -> None: + rate_limits = { + "/jobs/{job_id}": {"get": {"api": ["2/second"]}}, + "default": {"get": {"api": ["1/second"]}}, + } + rate_limits_config = config.RateLimitsConfig(**rate_limits) + + route = "jobs_jobsid" + method = "get" + request_origin = "api" + rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( + rate_limits_config, route, method, request_origin + ) + exp_rate_limits = ["2/second"] + assert rate_limits == exp_rate_limits + + +def test_get_rate_limits_defaulted_default_value() -> None: + rate_limits = { + "/jobs/{job_id}": {"post": {"api": ["2/second"]}}, + "/jobs": {"get": {"api": ["2/second"]}}, + "default": {"post": {"ui": ["1/second"]}}, + } + rate_limits_config = config.RateLimitsConfig(**rate_limits) + + route = "jobs_jobsid" + method = "post" + request_origin = "ui" + rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( + rate_limits_config, route, method, request_origin + ) + exp_rate_limits = ["1/second"] + assert rate_limits == exp_rate_limits + + route = "jobs" + method = "post" + request_origin = "ui" + rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( + rate_limits_config, route, method, request_origin + ) + exp_rate_limits = ["1/second"] + assert rate_limits == exp_rate_limits + + route = "processes_processid_execute" + method = "post" + request_origin = "ui" + rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( + rate_limits_config, route, method, request_origin + ) + exp_rate_limits = ["1/second"] + assert rate_limits == exp_rate_limits + + +def test_get_rate_limits_defaulted_route_param_actual_value() -> None: + rate_limits = { + "/processes/{process_id}/execution": { + "test_process_id": {"post": {"api": ["2/second"]}} + }, + "default": {"post": {"ui": ["1/second"]}}, + } + rate_limits_config = config.RateLimitsConfig(**rate_limits) + + route = "processes_processid_execution" + method = "post" + request_origin = "api" + route_param = "test_process_id" + rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( + rate_limits_config, route, method, request_origin, route_param + ) + exp_rate_limits = ["2/second"] + assert rate_limits == exp_rate_limits + + +def test_get_rate_limits_defaulted_route_param_default_value() -> None: + rate_limits = { + "/processes/{process_id}/execution": { + "test_process_id": {"post": {"api": ["2/second"]}}, + "default": {"post": {"api": ["1/second"]}}, + }, + "default": {"post": {"ui": ["1/minute"]}}, + } + rate_limits_config = config.RateLimitsConfig(**rate_limits) + + route = "processes_processid_execution" + method = "post" + request_origin = "api" + route_param = "missing_test_process_id" + rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( + rate_limits_config, route, method, request_origin, route_param + ) + exp_rate_limits = ["1/second"] + assert rate_limits == exp_rate_limits + + route = "processes_processid_execution" + method = "post" + request_origin = "ui" + route_param = "missing_test_process_id" + rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( + rate_limits_config, route, method, request_origin, route_param + ) + exp_rate_limits = ["1/minute"] + assert rate_limits == exp_rate_limits + + +def test_get_rate_limits_undefined() -> None: + rate_limits = {"/jobs": {"get": {"api": ["2/second"]}}} + rate_limits_config = config.RateLimitsConfig.model_validate(rate_limits) + + route = "jobs" + method = "get" + request_origin = "ui" + rate_limits = cads_processing_api_service.limits.get_rate_limits( + rate_limits_config, route, method, request_origin + ) + exp_rate_limits = [] + assert rate_limits == exp_rate_limits + + route = "jobs" + method = "post" + request_origin = "ui" + rate_limits = cads_processing_api_service.limits.get_rate_limits( + rate_limits_config, route, method, request_origin + ) + exp_rate_limits = [] + assert rate_limits == exp_rate_limits + + route = "job" + method = "get" + request_origin = "ui" + rate_limits = cads_processing_api_service.limits.get_rate_limits( + rate_limits_config, route, method, request_origin + ) + exp_rate_limits = [] + assert rate_limits == exp_rate_limits def test_check_rate_limits_for_user() -> None: