Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions cads_processing_api_service/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,11 @@ def post_process_execution(
"""
structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid)
_ = limits.check_rate_limits(
SETTINGS.rate_limits.process_execution.post,
SETTINGS.rate_limits,
"processes_processid_execution",
"post",
auth_info,
process_id,
)
request_body = execution_content.model_dump()
catalogue_sessionmaker = db_utils.get_catalogue_sessionmaker(
Expand Down Expand Up @@ -396,7 +399,9 @@ def get_jobs(
"""
structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid)
_ = limits.check_rate_limits(
SETTINGS.rate_limits.jobs.get,
SETTINGS.rate_limits,
"jobs",
"get",
auth_info,
)
job_filters = {
Expand Down Expand Up @@ -526,7 +531,9 @@ def get_job(
"""
structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid)
_ = limits.check_rate_limits(
SETTINGS.rate_limits.job.get,
SETTINGS.rate_limits,
"jobs_jobid",
"get",
auth_info,
)
compute_connection_mode = (
Expand Down Expand Up @@ -646,7 +653,9 @@ def get_job_results(
"""
structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid)
_ = limits.check_rate_limits(
SETTINGS.rate_limits.job_results.get,
SETTINGS.rate_limits,
"jobs_jobsid_results",
"get",
auth_info,
)
compute_connection_mode = (
Expand Down Expand Up @@ -711,7 +720,9 @@ def delete_job(
"""
structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid)
_ = limits.check_rate_limits(
SETTINGS.rate_limits.job.delete,
SETTINGS.rate_limits,
"jobs_jobsid",
"delete",
auth_info,
)
compute_sessionmaker = db_utils.get_compute_sessionmaker(
Expand Down
44 changes: 19 additions & 25 deletions cads_processing_api_service/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,45 +108,39 @@ class RateLimitsRouteConfig(pydantic.BaseModel):
delete: RateLimitsMethodConfig = pydantic.Field(default=RateLimitsMethodConfig())


class RateLimitsProcessExecutionRouteConfig(pydantic.BaseModel):
__pydantic_extra__: dict[str, RateLimitsRouteConfig] = pydantic.Field(init=False)

default: RateLimitsRouteConfig = pydantic.Field(default=RateLimitsRouteConfig())

model_config = pydantic.ConfigDict(extra="allow")


class RateLimitsConfig(pydantic.BaseModel):
default: RateLimitsRouteConfig = pydantic.Field(
default=RateLimitsRouteConfig(), validate_default=True
)
process_execution: RateLimitsRouteConfig = pydantic.Field(
default=RateLimitsRouteConfig(),
alias="/processes/{process_id}/execution",
validate_default=True,
processes_processid_execution: RateLimitsProcessExecutionRouteConfig = (
pydantic.Field(
alias="/processes/{process_id}/execution",
default=RateLimitsProcessExecutionRouteConfig(),
validate_default=True,
)
)
jobs: RateLimitsRouteConfig = pydantic.Field(
default=RateLimitsRouteConfig(), alias="/jobs", validate_default=True
)
job: RateLimitsRouteConfig = pydantic.Field(
jobs_jobsid: RateLimitsRouteConfig = pydantic.Field(
default=RateLimitsRouteConfig(), alias="/jobs/{job_id}", validate_default=True
)
job_results: RateLimitsRouteConfig = pydantic.Field(
jobs_jobsid_results: RateLimitsRouteConfig = pydantic.Field(
default=RateLimitsRouteConfig(),
alias="/jobs/{job_id}/results",
validate_default=True,
)

@pydantic.model_validator(mode="after") # type: ignore
def populate_fields_with_default(self) -> pydantic.BaseModel:
default = self.default
if default is RateLimitsRouteConfig():
return self
routes = self.model_fields
for route in routes:
if route == "default":
continue
route_config: RateLimitsRouteConfig = getattr(self, route)
for method in route_config.model_fields:
method_config: RateLimitsMethodConfig = getattr(route_config, method)
for origin in method_config.model_fields:
set_value = getattr(getattr(getattr(self, route), method), origin)
if not set_value:
default_value = getattr(getattr(default, method), origin)
setattr(getattr(route_config, method), origin, default_value)
return self
jobs_delete: RateLimitsRouteConfig = pydantic.Field(
default=RateLimitsRouteConfig(), alias="/jobs/delete", validate_default=True
)


def load_rate_limits(rate_limits_file: str | None) -> RateLimitsConfig:
Expand Down
4 changes: 3 additions & 1 deletion cads_processing_api_service/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def delete_jobs(
"""
structlog.contextvars.bind_contextvars(user_uid=auth_info.user_uid)
limits.check_rate_limits(
SETTINGS.rate_limits.jobs.delete,
SETTINGS.rate_limits,
"jobs_delete",
"post",
auth_info,
)
job_ids = request.job_ids
Expand Down
58 changes: 53 additions & 5 deletions cads_processing_api_service/limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License

from typing import Any

import limits
import structlog

Expand All @@ -27,6 +29,47 @@
limiter = config.RATE_LIMITS_LIMITER


def get_rate_limits(
rate_limits_config: config.RateLimitsConfig,
route: str,
method: str,
request_origin: str,
route_param: str | None = None,
) -> list[str]:
"""Get the rate limits for a specific route and method."""
rate_limits = rate_limits_config.model_dump()
route_rate_limits: dict[str, Any] = rate_limits.get(route, {})
if route_param is not None:
route_param_rate_limits: dict[str, Any] = route_rate_limits.get(route_param, {})
else:
route_param_rate_limits = route_rate_limits
method_rate_limits: dict[str, Any] = route_param_rate_limits.get(method, {})
rate_limit_ids: list[str] = method_rate_limits.get(request_origin, [])
return rate_limit_ids


def get_rate_limits_defaulted(
rate_limits_config: config.RateLimitsConfig,
route: str,
method: str,
request_origin: str,
route_param: str | None = None,
) -> list[str]:
"""Get the rate limits for a specific route and method, with defaults."""
rate_limits = get_rate_limits(
rate_limits_config, route, method, request_origin, route_param
)
if not rate_limits:
rate_limits = get_rate_limits(
rate_limits_config, route, method, request_origin, "default"
)
if not rate_limits:
rate_limits = get_rate_limits(
rate_limits_config, "default", method, request_origin
)
return rate_limits


def check_rate_limits_for_user(
user_uid: str, rate_limits: list[limits.RateLimitItem]
) -> None:
Expand All @@ -52,13 +95,18 @@ def check_rate_limits_for_user(


def check_rate_limits(
method_rate_limits: config.RateLimitsMethodConfig,
rate_limits_config: config.RateLimitsConfig,
route: str,
method: str,
auth_info: models.AuthInfo,
route_param: str | None = None,
) -> None:
"""Check if the rate limits are exceeded."""
user_uid = auth_info.user_uid
request_origin = auth_info.request_origin
rate_limit_ids = getattr(method_rate_limits, request_origin)
rate_limits = [limits.parse(rate_limit_id) for rate_limit_id in rate_limit_ids]
check_rate_limits_for_user(user_uid, rate_limits)
user_uid = auth_info.user_uid
rate_limits = get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin, route_param
)
rate_limits_parsed = [limits.parse(rate_limit) for rate_limit in rate_limits]
check_rate_limits_for_user(user_uid, rate_limits_parsed)
return None
60 changes: 11 additions & 49 deletions tests/test_10_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,20 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None:
loaded_rate_limits = config.load_rate_limits(rate_limits_file)
assert loaded_rate_limits == config.RateLimitsConfig()

rate_limits_file = str(tmp_path / "rate-limits.yaml")
rate_limits = {
"/processes/{process_id}/execution": {
"post": {"api": ["1/second"], "ui": ["2/second"]}
},
}
with open(rate_limits_file, "w") as file:
yaml.dump(rate_limits, file)
loaded_rate_limits = config.load_rate_limits(rate_limits_file)
assert loaded_rate_limits == config.RateLimitsConfig(**rate_limits)
# rate_limits_file = str(tmp_path / "rate-limits.yaml")
# rate_limits = {
# "/jobs/{job_id}": {
# "get": {"api": ["1/second"], "ui": ["2/second"]}
# },
# }
# with open(rate_limits_file, "w") as file:
# yaml.dump(rate_limits, file)
# loaded_rate_limits = config.load_rate_limits(rate_limits_file)
# assert loaded_rate_limits == config.RateLimitsConfig(**rate_limits)

rate_limits_file = str(tmp_path / "invalid-rate-limits.yaml")
rate_limits = {
"/processes/{process_id}/execution": {"post": {"api": ["invalid_limit"]}},
"/jobs/{job_id}": {"get": {"api": ["invalid_limit"]}},
}
with open(rate_limits_file, "w") as file:
yaml.dump(rate_limits, file)
Expand All @@ -94,41 +94,3 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None:
rate_limits_file = str(tmp_path / "not-found-rate-limits.yaml")
loaded_rate_limits = config.load_rate_limits(rate_limits_file)
assert loaded_rate_limits == config.RateLimitsConfig()


def test_rate_limits_config_populate_with_default() -> None:
rate_limits_config = config.RateLimitsConfig(
**{
"default": {
"post": {"api": ["1/second"], "ui": ["2/second"]},
"get": {"api": ["2/second"]},
},
"/processes/{process_id}/execution": {"post": {"api": ["1/minute"]}},
}
)
exp_populated_rate_limits_config = {
"default": {
"post": {"api": ["1/second"], "ui": ["2/second"]},
"get": {"api": ["2/second"]},
},
"process_execution": {
"post": {"api": ["1/minute"], "ui": ["2/second"]},
"get": {"api": ["2/second"]},
},
"jobs": {
"post": {"api": ["1/second"], "ui": ["2/second"]},
"get": {"api": ["2/second"]},
},
"job": {
"post": {"api": ["1/second"], "ui": ["2/second"]},
"get": {"api": ["2/second"]},
},
"job_results": {
"post": {"api": ["1/second"], "ui": ["2/second"]},
"get": {"api": ["2/second"]},
},
}
assert (
rate_limits_config.model_dump(exclude_defaults=True)
== exp_populated_rate_limits_config
)
Loading
Loading