From bdf5b13240d07fc3cad5ca6767333d2220b46640 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Fri, 13 Jun 2025 13:41:30 +0200 Subject: [PATCH 01/33] feat: Add pilot management: create/delete/patch and query --- .../src/diracx/client/_generated/_client.py | 5 +- .../diracx/client/_generated/aio/_client.py | 5 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 572 ++++++++++++++- .../client/_generated/models/__init__.py | 16 +- .../client/_generated/models/_models.py | 307 ++++++-- .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 665 +++++++++++++++++- diracx-core/src/diracx/core/exceptions.py | 53 +- diracx-core/src/diracx/core/models.py | 29 +- diracx-db/src/diracx/db/sql/__init__.py | 2 +- diracx-db/src/diracx/db/sql/job/db.py | 20 +- .../src/diracx/db/sql/pilot_agents/db.py | 45 -- .../sql/{pilot_agents => pilots}/__init__.py | 0 diracx-db/src/diracx/db/sql/pilots/db.py | 300 ++++++++ .../db/sql/{pilot_agents => pilots}/schema.py | 1 + diracx-db/src/diracx/db/sql/utils/__init__.py | 22 +- .../src/diracx/db/sql/utils/functions.py | 90 ++- .../pilot_agents/test_pilot_agents_db.py | 30 - .../{pilot_agents => pilots}/__init__.py | 0 .../tests/pilots/test_pilot_management.py | 405 +++++++++++ diracx-db/tests/pilots/test_query.py | 301 ++++++++ diracx-logic/src/diracx/logic/jobs/query.py | 6 +- .../src/diracx/logic/pilots/management.py | 95 +++ diracx-logic/src/diracx/logic/pilots/query.py | 39 + diracx-routers/pyproject.toml | 2 + .../src/diracx/routers/jobs/query.py | 4 +- .../src/diracx/routers/pilots/__init__.py | 13 + .../diracx/routers/pilots/access_policies.py | 120 ++++ .../src/diracx/routers/pilots/management.py | 232 ++++++ .../src/diracx/routers/pilots/query.py | 149 ++++ .../tests/pilots/test_pilot_creation.py | 183 +++++ diracx-routers/tests/pilots/test_query.py | 372 ++++++++++ .../src/gubbins/client/_generated/_client.py | 12 +- .../gubbins/client/_generated/aio/_client.py | 12 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 572 ++++++++++++++- .../client/_generated/models/__init__.py | 16 +- .../client/_generated/models/_models.py | 307 ++++++-- .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 665 +++++++++++++++++- 41 files changed, 5427 insertions(+), 248 deletions(-) delete mode 100644 diracx-db/src/diracx/db/sql/pilot_agents/db.py rename diracx-db/src/diracx/db/sql/{pilot_agents => pilots}/__init__.py (100%) create mode 100644 diracx-db/src/diracx/db/sql/pilots/db.py rename diracx-db/src/diracx/db/sql/{pilot_agents => pilots}/schema.py (98%) delete mode 100644 diracx-db/tests/pilot_agents/test_pilot_agents_db.py rename diracx-db/tests/{pilot_agents => pilots}/__init__.py (100%) create mode 100644 diracx-db/tests/pilots/test_pilot_management.py create mode 100644 diracx-db/tests/pilots/test_query.py create mode 100644 diracx-logic/src/diracx/logic/pilots/management.py create mode 100644 diracx-logic/src/diracx/logic/pilots/query.py create mode 100644 diracx-routers/src/diracx/routers/pilots/__init__.py create mode 100644 diracx-routers/src/diracx/routers/pilots/access_policies.py create mode 100644 diracx-routers/src/diracx/routers/pilots/management.py create mode 100644 diracx-routers/src/diracx/routers/pilots/query.py create mode 100644 diracx-routers/tests/pilots/test_pilot_creation.py create mode 100644 diracx-routers/tests/pilots/test_query.py diff --git a/diracx-client/src/diracx/client/_generated/_client.py b/diracx-client/src/diracx/client/_generated/_client.py index aa558f636..9e37d5081 100644 --- a/diracx-client/src/diracx/client/_generated/_client.py +++ b/diracx-client/src/diracx/client/_generated/_client.py @@ -15,7 +15,7 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/diracx-client/src/diracx/client/_generated/aio/_client.py b/diracx-client/src/diracx/client/_generated/aio/_client.py index 10cfad884..397b7f989 100644 --- a/diracx-client/src/diracx/client/_generated/aio/_client.py +++ b/diracx-client/src/diracx/client/_generated/aio/_client.py @@ -15,7 +15,7 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.aio.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.aio.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py index 10db0c7a9..be02776fc 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 0916d8a28..85b8cc406 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -52,6 +52,12 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, + build_pilots_add_pilot_stamps_request, + build_pilots_associate_pilot_with_jobs_request, + build_pilots_clear_pilots_request, + build_pilots_delete_pilots_request, + build_pilots_search_request, + build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -1826,7 +1832,7 @@ async def patch_metadata(self, body: Union[Dict[str, Dict[str, Any]], IO[bytes]] @overload async def search( self, - body: Optional[_models.JobSearchParams] = None, + body: Optional[_models.SearchParams] = None, *, page: int = 1, per_page: int = 100, @@ -1840,7 +1846,7 @@ async def search( **TODO: Add more docs**. :param body: Default value is None. - :type body: ~_generated.models.JobSearchParams + :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -1886,7 +1892,7 @@ async def search( @distributed_trace_async async def search( self, - body: Optional[Union[_models.JobSearchParams, IO[bytes]]] = None, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, *, page: int = 1, per_page: int = 100, @@ -1898,8 +1904,8 @@ async def search( **TODO: Add more docs**. - :param body: Is either a JobSearchParams type or a IO[bytes] type. Default value is None. - :type body: ~_generated.models.JobSearchParams or IO[bytes] + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -1929,7 +1935,7 @@ async def search( _content = body else: if body is not None: - _json = self._serialize.body(body, "JobSearchParams") + _json = self._serialize.body(body, "SearchParams") else: _json = None @@ -2157,3 +2163,557 @@ async def submit_jdl_jobs(self, body: Union[List[str], IO[bytes]], **kwargs: Any return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots(self, *, pilot_stamps: List[str], **kwargs: Any) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + If at least one pilot is not found, it WILL rollback. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Required. + :paramtype pilot_stamps: list[str] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_fields( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_fields( + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace_async + async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any) -> None: + """Clear Pilots. + + Endpoint for DIRAC to delete all pilots that lived more than age_in_days. + + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Required. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake. Default value is True. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_clear_pilots_request( + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def associate_pilot_with_jobs( + self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def associate_pilot_with_jobs( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def associate_pilot_with_jobs( + self, body: Union[_models.BodyPilotsAssociatePilotWithJobs, IO[bytes]], **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Is either a BodyPilotsAssociatePilotWithJobs type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAssociatePilotWithJobs") + + _request = build_pilots_associate_pilot_with_jobs_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 7343700e4..c6f8fb19a 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -14,24 +14,28 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, + BodyPilotsAddPilotStamps, + BodyPilotsAssociatePilotWithJobs, + BodyPilotsUpdatePilotFields, GroupInfo, HTTPValidationError, HeartbeatData, InitiateDeviceFlowResponse, InsertedJob, JobCommand, - JobSearchParams, - JobSearchParamsSearchItem, JobStatusUpdate, JobSummaryParams, JobSummaryParamsSearchItem, Metadata, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, ScalarSearchSpec, ScalarSearchSpecValue, + SearchParams, + SearchParamsSearchItem, SetJobStatusReturn, SetJobStatusReturnSuccess, SortSpec, @@ -61,24 +65,28 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", + "BodyPilotsAddPilotStamps", + "BodyPilotsAssociatePilotWithJobs", + "BodyPilotsUpdatePilotFields", "GroupInfo", "HTTPValidationError", "HeartbeatData", "InitiateDeviceFlowResponse", "InsertedJob", "JobCommand", - "JobSearchParams", - "JobSearchParamsSearchItem", "JobStatusUpdate", "JobSummaryParams", "JobSummaryParamsSearchItem", "Metadata", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", "ScalarSearchSpec", "ScalarSearchSpecValue", + "SearchParams", + "SearchParamsSearchItem", "SetJobStatusReturn", "SetJobStatusReturnSuccess", "SortSpec", diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 14045211b..9a224f824 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -94,6 +94,119 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Virtual Organisation associated with the inserted pilots. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + pilot_references: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Virtual Organisation associated with the inserted pilots. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.pilot_references = pilot_references + + +class BodyPilotsAssociatePilotWithJobs(_serialization.Model): + """Body_pilots_associate_pilot_with_jobs. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: The stamp of the pilot. Required. + :vartype pilot_stamp: str + :ivar pilot_jobs_ids: The jobs we want to add to the pilot. Required. + :vartype pilot_jobs_ids: list[int] + """ + + _validation = { + "pilot_stamp": {"required": True}, + "pilot_jobs_ids": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "pilot_jobs_ids": {"key": "pilot_jobs_ids", "type": "[int]"}, + } + + def __init__(self, *, pilot_stamp: str, pilot_jobs_ids: List[int], **kwargs: Any) -> None: + """ + :keyword pilot_stamp: The stamp of the pilot. Required. + :paramtype pilot_stamp: str + :keyword pilot_jobs_ids: The jobs we want to add to the pilot. Required. + :paramtype pilot_jobs_ids: list[int] + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.pilot_jobs_ids = pilot_jobs_ids + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class GroupInfo(_serialization.Model): """GroupInfo. @@ -358,56 +471,6 @@ def __init__(self, *, job_id: int, command: str, arguments: Optional[str] = None self.arguments = arguments -class JobSearchParams(_serialization.Model): - """JobSearchParams. - - :ivar parameters: Parameters. - :vartype parameters: list[str] - :ivar search: Search. - :vartype search: list[~_generated.models.JobSearchParamsSearchItem] - :ivar sort: Sort. - :vartype sort: list[~_generated.models.SortSpec] - :ivar distinct: Distinct. - :vartype distinct: bool - """ - - _attribute_map = { - "parameters": {"key": "parameters", "type": "[str]"}, - "search": {"key": "search", "type": "[JobSearchParamsSearchItem]"}, - "sort": {"key": "sort", "type": "[SortSpec]"}, - "distinct": {"key": "distinct", "type": "bool"}, - } - - def __init__( - self, - *, - parameters: Optional[List[str]] = None, - search: List["_models.JobSearchParamsSearchItem"] = [], - sort: List["_models.SortSpec"] = [], - distinct: bool = False, - **kwargs: Any - ) -> None: - """ - :keyword parameters: Parameters. - :paramtype parameters: list[str] - :keyword search: Search. - :paramtype search: list[~_generated.models.JobSearchParamsSearchItem] - :keyword sort: Sort. - :paramtype sort: list[~_generated.models.SortSpec] - :keyword distinct: Distinct. - :paramtype distinct: bool - """ - super().__init__(**kwargs) - self.parameters = parameters - self.search = search - self.sort = sort - self.distinct = distinct - - -class JobSearchParamsSearchItem(_serialization.Model): - """JobSearchParamsSearchItem.""" - - class JobStatusUpdate(_serialization.Model): """JobStatusUpdate. @@ -655,6 +718,100 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: Status. + :vartype status: str + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[str] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: Status. + :paramtype status: str + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. @@ -836,6 +993,56 @@ class ScalarSearchSpecValue(_serialization.Model): """Value.""" +class SearchParams(_serialization.Model): + """SearchParams. + + :ivar parameters: Parameters. + :vartype parameters: list[str] + :ivar search: Search. + :vartype search: list[~_generated.models.SearchParamsSearchItem] + :ivar sort: Sort. + :vartype sort: list[~_generated.models.SortSpec] + :ivar distinct: Distinct. + :vartype distinct: bool + """ + + _attribute_map = { + "parameters": {"key": "parameters", "type": "[str]"}, + "search": {"key": "search", "type": "[SearchParamsSearchItem]"}, + "sort": {"key": "sort", "type": "[SortSpec]"}, + "distinct": {"key": "distinct", "type": "bool"}, + } + + def __init__( + self, + *, + parameters: Optional[List[str]] = None, + search: List["_models.SearchParamsSearchItem"] = [], + sort: List["_models.SortSpec"] = [], + distinct: bool = False, + **kwargs: Any + ) -> None: + """ + :keyword parameters: Parameters. + :paramtype parameters: list[str] + :keyword search: Search. + :paramtype search: list[~_generated.models.SearchParamsSearchItem] + :keyword sort: Sort. + :paramtype sort: list[~_generated.models.SortSpec] + :keyword distinct: Distinct. + :paramtype distinct: bool + """ + super().__init__(**kwargs) + self.parameters = parameters + self.search = search + self.sort = sort + self.distinct = distinct + + +class SearchParamsSearchItem(_serialization.Model): + """SearchParamsSearchItem.""" + + class SetJobStatusReturn(_serialization.Model): """SetJobStatusReturn. diff --git a/diracx-client/src/diracx/client/_generated/operations/__init__.py b/diracx-client/src/diracx/client/_generated/operations/__init__.py index 10db0c7a9..be02776fc 100644 --- a/diracx-client/src/diracx/client/_generated/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 0259e5aaf..26353b973 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -590,6 +590,103 @@ def build_jobs_submit_jdl_jobs_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/management/pilot" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request(*, pilot_stamps: List[str], **kwargs: Any) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/management/pilot" + + # Construct parameters + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/management/pilot" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_clear_pilots_request( + *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any +) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/management/pilot/interval" + + # Construct parameters + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_associate_pilot_with_jobs_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/management/jobs" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/management/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2351,7 +2448,7 @@ def patch_metadata( # pylint: disable=inconsistent-return-statements @overload def search( self, - body: Optional[_models.JobSearchParams] = None, + body: Optional[_models.SearchParams] = None, *, page: int = 1, per_page: int = 100, @@ -2365,7 +2462,7 @@ def search( **TODO: Add more docs**. :param body: Default value is None. - :type body: ~_generated.models.JobSearchParams + :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -2411,7 +2508,7 @@ def search( @distributed_trace def search( self, - body: Optional[Union[_models.JobSearchParams, IO[bytes]]] = None, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, *, page: int = 1, per_page: int = 100, @@ -2423,8 +2520,8 @@ def search( **TODO: Add more docs**. - :param body: Is either a JobSearchParams type or a IO[bytes] type. Default value is None. - :type body: ~_generated.models.JobSearchParams or IO[bytes] + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -2454,7 +2551,7 @@ def search( _content = body else: if body is not None: - _json = self._serialize.body(body, "JobSearchParams") + _json = self._serialize.body(body, "SearchParams") else: _json = None @@ -2680,3 +2777,559 @@ def submit_jdl_jobs(self, body: Union[List[str], IO[bytes]], **kwargs: Any) -> L return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, *, pilot_stamps: List[str], **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + If at least one pilot is not found, it WILL rollback. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Required. + :paramtype pilot_stamps: list[str] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_fields( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace + def clear_pilots( # pylint: disable=inconsistent-return-statements + self, *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any + ) -> None: + """Clear Pilots. + + Endpoint for DIRAC to delete all pilots that lived more than age_in_days. + + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Required. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake. Default value is True. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_clear_pilots_request( + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def associate_pilot_with_jobs( + self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def associate_pilot_with_jobs( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def associate_pilot_with_jobs( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsAssociatePilotWithJobs, IO[bytes]], **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Is either a BodyPilotsAssociatePilotWithJobs type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAssociatePilotWithJobs") + + _request = build_pilots_associate_pilot_with_jobs_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 54d7c240d..a9a571795 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -15,6 +15,7 @@ class DiracError(RuntimeError): def __init__(self, detail: str = "Unknown"): self.detail = detail + super().__init__(detail) class AuthorizationError(DiracError): ... @@ -49,19 +50,19 @@ class InvalidQueryError(DiracError): class TokenNotFoundError(DiracError): - def __init__(self, jti: str, detail: str | None = None): + def __init__(self, jti: str, detail: str = ""): self.jti: str = jti super().__init__(f"Token {jti} not found" + (f" ({detail})" if detail else "")) class JobNotFoundError(DiracError): - def __init__(self, job_id: int, detail: str | None = None): + def __init__(self, job_id: int, detail: str = ""): self.job_id: int = job_id super().__init__(f"Job {job_id} not found" + (f" ({detail})" if detail else "")) class SandboxNotFoundError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -71,7 +72,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyAssignedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -81,7 +82,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyInsertedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -91,7 +92,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class JobError(DiracError): - def __init__(self, job_id, detail: str | None = None): + def __init__(self, job_id, detail: str = ""): self.job_id: int = job_id super().__init__( f"Error concerning job {job_id}" + (f" ({detail})" if detail else "") @@ -100,3 +101,43 @@ def __init__(self, job_id, detail: str | None = None): class NotReadyError(DiracError): """Tried to access a value which is asynchronously loaded but not yet available.""" + + +class DiracFormattedError(DiracError): + # TODO: Refactor? + pattern = "Error %s" + + def __init__(self, data: dict[str, str], detail: str = ""): + self.data = data + + parts = [f"({key}: {value})" for key, value in data.items()] + message = type(self).pattern % (" ".join(parts)) + if detail: + message += f": {detail}" + + super().__init__(message) + + +class PilotNotFoundError(DiracFormattedError): + pattern = "Pilot %s not found" + + def __init__( + self, + data: dict[str, str], + detail: str = "", + non_existing_pilots: set = set(), + ): + super().__init__(data, detail) + self.non_existing_pilots = non_existing_pilots + + +class PilotAlreadyExistsError(DiracFormattedError): + pattern = "Pilot %s already exists" + + +class PilotJobsNotFoundError(DiracFormattedError): + pattern = "Pilots or Jobs %s not found" + + +class PilotAlreadyAssociatedWithJobError(DiracFormattedError): + pattern = "Pilot is already associated with a job %s " diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index 415e36295..93dba188b 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -7,7 +7,7 @@ from datetime import datetime from enum import StrEnum -from typing import Literal +from typing import Literal, Optional from pydantic import BaseModel, Field from typing_extensions import TypedDict @@ -65,7 +65,7 @@ class JobSummaryParams(BaseModel): # TODO: Add more validation -class JobSearchParams(BaseModel): +class SearchParams(BaseModel): parameters: list[str] | None = None search: list[SearchSpec] = [] sort: list[SortSpec] = [] @@ -272,3 +272,28 @@ class JobCommand(BaseModel): job_id: int command: Literal["Kill"] arguments: str | None = None + + +class PilotInfo(BaseModel): + sub: str + pilot_stamp: str + vo: str + + +class PilotStampInfo(BaseModel): + pilot_stamp: str + + +class PilotFieldsMapping(BaseModel): + """All the fields that a user can modify on a Pilot (except PilotStamp).""" + + PilotStamp: str + StatusReason: Optional[str] = None + Status: Optional[str] = None + BenchMark: Optional[float] = None + DestinationSite: Optional[str] = None + Queue: Optional[str] = None + GridSite: Optional[str] = None + GridType: Optional[str] = None + AccountingSent: Optional[bool] = None + CurrentJobID: Optional[int] = None diff --git a/diracx-db/src/diracx/db/sql/__init__.py b/diracx-db/src/diracx/db/sql/__init__.py index 3be3af8a3..e2f141ad5 100644 --- a/diracx-db/src/diracx/db/sql/__init__.py +++ b/diracx-db/src/diracx/db/sql/__init__.py @@ -12,6 +12,6 @@ from .auth.db import AuthDB from .job.db import JobDB from .job_logging.db import JobLoggingDB -from .pilot_agents.db import PilotAgentsDB +from .pilots.db import PilotAgentsDB from .sandbox_metadata.db import SandboxMetadataDB from .task_queue.db import TaskQueueDB diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 89f2bb49d..809fed97e 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -13,8 +13,13 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobCommand, SearchSpec, SortSpec -from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints -from ..utils.functions import utcnow +from ..utils import ( + BaseSQLDB, + _get_columns, + apply_search_filters, + apply_sort_constraints, + utcnow, +) from .schema import ( HeartBeatLoggingInfo, InputData, @@ -25,17 +30,6 @@ ) -def _get_columns(table, parameters): - columns = [x for x in table.columns] - if parameters: - if unrecognised_parameters := set(parameters) - set(table.columns.keys()): - raise InvalidQueryError( - f"Unrecognised parameters requested {unrecognised_parameters}" - ) - columns = [c for c in columns if c.name in parameters] - return columns - - class JobDB(BaseSQLDB): metadata = JobDBBase.metadata diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py deleted file mode 100644 index 954f081b1..000000000 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timezone - -from sqlalchemy import insert - -from ..utils import BaseSQLDB -from .schema import PilotAgents, PilotAgentsDBBase - - -class PilotAgentsDB(BaseSQLDB): - """PilotAgentsDB class is a front-end to the PilotAgents Database.""" - - metadata = PilotAgentsDBBase.metadata - - async def add_pilot_references( - self, - pilot_ref: list[str], - vo: str, - grid_type: str = "DIRAC", - pilot_stamps: dict | None = None, - ) -> None: - if pilot_stamps is None: - pilot_stamps = {} - - now = datetime.now(tz=timezone.utc) - - # Prepare the list of dictionaries for bulk insertion - values = [ - { - "PilotJobReference": ref, - "VO": vo, - "GridType": grid_type, - "SubmissionTime": now, - "LastUpdateTime": now, - "Status": "Submitted", - "PilotStamp": pilot_stamps.get(ref, ""), - } - for ref in pilot_ref - ] - - # Insert multiple rows in a single execute call - stmt = insert(PilotAgents).values(values) - await self.conn.execute(stmt) - return diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/__init__.py b/diracx-db/src/diracx/db/sql/pilots/__init__.py similarity index 100% rename from diracx-db/src/diracx/db/sql/pilot_agents/__init__.py rename to diracx-db/src/diracx/db/sql/pilots/__init__.py diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py new file mode 100644 index 000000000..279b227e7 --- /dev/null +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any, Sequence + +from sqlalchemy import RowMapping, bindparam, func +from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql import delete, insert, select, update + +from diracx.core.exceptions import ( + InvalidQueryError, + PilotAlreadyAssociatedWithJobError, + PilotJobsNotFoundError, + PilotNotFoundError, +) +from diracx.core.models import ( + PilotFieldsMapping, + SearchSpec, + SortSpec, +) + +from ..utils import ( + BaseSQLDB, + _get_columns, + apply_search_filters, + apply_sort_constraints, + fetch_records_bulk_or_raises, +) +from .schema import ( + JobToPilotMapping, + PilotAgents, + PilotAgentsDBBase, +) + + +class PilotAgentsDB(BaseSQLDB): + """PilotAgentsDB class is a front-end to the PilotAgents Database.""" + + metadata = PilotAgentsDBBase.metadata + + async def add_pilots_bulk( + self, + pilot_stamps: list[str], + vo: str, + grid_type: str = "DIRAC", + pilot_references: dict | None = None, + ): + """Bulk add pilots in the DB. + + If we can't find a pilot_reference associated with a stamp, we take the stamp by default. + """ + if pilot_references is None: + pilot_references = {} + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + values = [ + { + "PilotJobReference": pilot_references.get(stamp, stamp), + "VO": vo, + "GridType": grid_type, + "SubmissionTime": now, + "LastUpdateTime": now, + "Status": "Submitted", + "PilotStamp": stamp, + } + for stamp in pilot_stamps + ] + + # Insert multiple rows in a single execute call and use 'returning' to get primary keys + stmt = insert(PilotAgents).values(values) # Assuming 'id' is the primary key + + await self.conn.execute(stmt) + + async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]): + """Bulk delete pilots. + + Raises PilotNotFound if one of the pilot was not found. + """ + stmt = delete(PilotAgents).where(PilotAgents.pilot_stamp.in_(pilot_stamps)) + + res = await self.conn.execute(stmt) + + if res.rowcount != len(pilot_stamps): + raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)}) + + async def associate_pilot_with_jobs(self, job_to_pilot_mapping: list[dict]): + """Associate a pilot with jobs. + + job_to_pilot_mapping format: + ```py + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + ] + ``` + + Raises: + - PilotNotFoundError if a pilot_id is not associated with a pilot. + - PilotAlreadyAssociatedWithJobError if the pilot is already associated with a job. + - NotImplementedError if the integrity error is not caught. + + **Important note**: We assume that a job exists. + + """ + # Insert multiple rows in a single execute call + stmt = insert(JobToPilotMapping).values(job_to_pilot_mapping) + + try: + await self.conn.execute(stmt) + except IntegrityError as e: + if "foreign key" in str(e.orig).lower(): + raise PilotNotFoundError( + data={"pilot_stamps": str(job_to_pilot_mapping)}, + detail="at least one of these pilots does not exist", + ) from e + + if ( + "duplicate entry" in str(e.orig).lower() + or "unique constraint" in str(e.orig).lower() + ): + raise PilotAlreadyAssociatedWithJobError( + data={"job_to_pilot_mapping": str(job_to_pilot_mapping)} + ) from e + + # Other errors to catch + raise NotImplementedError( + "Engine Specific error not caught" + str(e) + ) from e + + async def update_pilot_fields_bulk( + self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + ): + """Bulk update pilots with a mapping. + + pilot_stamps_to_fields_mapping format: + ```py + [ + { + "PilotStamp": pilot_stamp, + "BenchMark": bench_mark, + "StatusReason": pilot_reason, + "AccountingSent": accounting_sent, + "Status": status, + "CurrentJobID": current_job_id, + "Queue": queue, + ... + } + ] + ``` + + The mapping helps to update multiple fields at a time. + + Raises PilotNotFoundError if one of the pilots is not found. + """ + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp == bindparam("b_pilot_stamp")) + .values( + { + key: bindparam(key) + for key in pilot_stamps_to_fields_mapping[0] + .model_dump(exclude_none=True) + .keys() + if key != "PilotStamp" + } + ) + ) + + values = [ + { + **{"b_pilot_stamp": mapping.PilotStamp}, + **mapping.model_dump(exclude={"PilotStamp"}, exclude_none=True), + } + for mapping in pilot_stamps_to_fields_mapping + ] + + res = await self.conn.execute(stmt, values) + + if res.rowcount != len(pilot_stamps_to_fields_mapping): + raise PilotNotFoundError( + data={"mapping": str(pilot_stamps_to_fields_mapping)} + ) + + async def get_pilots_by_stamp_bulk( + self, pilot_stamps: list[str] + ) -> Sequence[RowMapping]: + """Bulk fetch pilots. + + Raises PilotNotFoundError if one of the stamp is not associated with a pilot. + + """ + results = await fetch_records_bulk_or_raises( + self.conn, + PilotAgents, + PilotNotFoundError, + "pilot_stamp", + "PilotStamp", + pilot_stamps, + allow_no_result=True, + ) + + # Custom handling, to see which pilot_stamp does not exist (if so, say which one) + found_keys = {row["PilotStamp"] for row in results} + missing = set(pilot_stamps) - found_keys + + if missing: + raise PilotNotFoundError( + data={"pilot_stamp": str(missing)}, + detail=str(missing), + non_existing_pilots=missing, + ) + + return results + + async def get_pilot_jobs_ids_by_pilot_id(self, pilot_id: int) -> list[int]: + """Fetch pilot jobs.""" + job_to_pilot_mapping = await fetch_records_bulk_or_raises( + self.conn, + JobToPilotMapping, + PilotJobsNotFoundError, + "pilot_id", + "PilotID", + [pilot_id], + allow_more_than_one_result_per_input=True, + allow_no_result=True, + ) + + return [mapping["JobID"] for mapping in job_to_pilot_mapping] + + async def get_pilot_ids_by_stamps(self, pilot_stamps: list[str]) -> list[int]: + """Get pilot ids.""" + # This function is currently needed while we are relying on pilot_ids instead of pilot_stamps + # (Ex: JobToPilotMapping) + pilots = await self.get_pilots_by_stamp_bulk(pilot_stamps) + + return [pilot["PilotID"] for pilot in pilots] + + async def search( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for pilots in the database.""" + # TODO: Refactorize with the search function for jobs. + # Find which columns to select + columns = _get_columns(PilotAgents.__table__, parameters) + + stmt = select(*columns) + + stmt = apply_search_filters( + PilotAgents.__table__.columns.__getitem__, stmt, search + ) + stmt = apply_sort_constraints( + PilotAgents.__table__.columns.__getitem__, stmt, sorts + ) + + if distinct: + stmt = stmt.distinct() + + # Calculate total count before applying pagination + total_count_subquery = stmt.alias() + total_count_stmt = select(func.count()).select_from(total_count_subquery) + total = (await self.conn.execute(total_count_stmt)).scalar_one() + + # Apply pagination + if page is not None: + if page < 1: + raise InvalidQueryError("Page must be a positive integer") + if per_page < 1: + raise InvalidQueryError("Per page must be a positive integer") + stmt = stmt.offset((page - 1) * per_page).limit(per_page) + + # Execute the query + return total, [ + dict(row._mapping) async for row in (await self.conn.stream(stmt)) + ] + + async def clear_pilots_bulk( + self, cutoff_date: datetime, delete_only_aborted: bool + ) -> int: + """Bulk delete pilots that have SubmissionTime before the 'cutoff_date'. + Returns the number of deletion. + """ + # TODO: Add test (Millisec?) + stmt = delete(PilotAgents).where(PilotAgents.submission_time < cutoff_date) + + # If delete_only_aborted is True, add the condition for 'Status' being 'Aborted' + if delete_only_aborted: + stmt = stmt.where(PilotAgents.status == "Aborted") + + # Execute the statement + res = await self.conn.execute(stmt) + + return res.rowcount diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py similarity index 98% rename from diracx-db/src/diracx/db/sql/pilot_agents/schema.py rename to diracx-db/src/diracx/db/sql/pilots/schema.py index bff7c460c..032e36510 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -37,6 +37,7 @@ class PilotAgents(PilotAgentsDBBase): __table_args__ = ( Index("PilotJobReference", "PilotJobReference"), + Index("PilotStamp", "PilotStamp"), Index("Status", "Status"), Index("Statuskey", "GridSite", "DestinationSite", "Status"), ) diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 69b78b4bf..e3d0747a1 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -6,20 +6,28 @@ apply_search_filters, apply_sort_constraints, ) -from .functions import hash, substract_date, utcnow +from .functions import ( + _get_columns, + fetch_records_bulk_or_raises, + hash, + substract_date, + utcnow, +) from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn __all__ = ( - "utcnow", + "_get_columns", + "apply_search_filters", + "apply_sort_constraints", + "BaseSQLDB", "Column", - "NullColumn", "DateNowColumn", - "BaseSQLDB", "EnumBackedBool", "EnumColumn", - "apply_search_filters", - "apply_sort_constraints", - "substract_date", + "fetch_records_bulk_or_raises", "hash", + "NullColumn", + "substract_date", "SQLDBUnavailableError", + "utcnow", ) diff --git a/diracx-db/src/diracx/db/sql/utils/functions.py b/diracx-db/src/diracx/db/sql/utils/functions.py index 34cb2a0da..536412406 100644 --- a/diracx-db/src/diracx/db/sql/utils/functions.py +++ b/diracx-db/src/diracx/db/sql/utils/functions.py @@ -2,16 +2,30 @@ import hashlib from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Sequence, Type -from sqlalchemy import DateTime, func +from sqlalchemy import DateTime, RowMapping, asc, desc, func, select +from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql import expression +from sqlalchemy.sql import ColumnElement, expression + +from diracx.core.exceptions import DiracFormattedError, InvalidQueryError if TYPE_CHECKING: from sqlalchemy.types import TypeEngine +def _get_columns(table, parameters): + columns = [x for x in table.columns] + if parameters: + if unrecognised_parameters := set(parameters) - set(table.columns.keys()): + raise InvalidQueryError( + f"Unrecognised parameters requested {unrecognised_parameters}" + ) + columns = [c for c in columns if c.name in parameters] + return columns + + class utcnow(expression.FunctionElement): # noqa: N801 type: TypeEngine = DateTime() inherit_cache: bool = True @@ -140,3 +154,73 @@ def substract_date(**kwargs: float) -> datetime: def hash(code: str): return hashlib.sha256(code.encode()).hexdigest() + + +def raw_hash(code: str): + return hashlib.sha256(code.encode()).digest() + + +async def fetch_records_bulk_or_raises( + conn: AsyncConnection, + model: Any, # Here, we currently must use `Any` because `declarative_base()` returns any + missing_elements_error_cls: Type[DiracFormattedError], + column_attribute_name: str, + column_name: str, + elements_to_fetch: list, + order_by: tuple[str, str] | None = None, + allow_more_than_one_result_per_input: bool = False, + allow_no_result: bool = False, +) -> Sequence[RowMapping]: + """Fetches a list of elements in a table, returns a list of elements. + All elements from the `element_to_fetch` **must** be present. + Raises the specified error if at least one is missing. + + Example: + fetch_records_bulk_or_raises( + self.conn, + PilotAgents, + PilotNotFound, + "pilot_id", + "PilotID", + [1,2,3] + ) + + """ + assert elements_to_fetch + + # Get the column that needs to be in elements_to_fetch + column = getattr(model, column_attribute_name) + + # Create the request + stmt = select(model).with_for_update().where(column.in_(elements_to_fetch)) + + if order_by: + column_name_to_order_by, direction = order_by + column_to_order_by = getattr(model, column_name_to_order_by) + + operator: ColumnElement = ( + asc(column_to_order_by) if direction == "asc" else desc(column_to_order_by) + ) + + stmt = stmt.order_by(operator) + + # Transform into dictionaries + raw_results = await conn.execute(stmt) + results = raw_results.mappings().all() + + # Detects duplicates + if not allow_more_than_one_result_per_input: + if len(results) > len(elements_to_fetch): + raise RuntimeError("Seems to have duplicates in the database.") + + if not allow_no_result: + # Checks if we have every elements we wanted + found_keys = {row[column_name] for row in results} + missing = set(elements_to_fetch) - found_keys + + if missing: + raise missing_elements_error_cls( + data={column_name: str(missing)}, detail=str(missing) + ) + + return results diff --git a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py b/diracx-db/tests/pilot_agents/test_pilot_agents_db.py deleted file mode 100644 index 3ca989885..000000000 --- a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import pytest - -from diracx.db.sql.pilot_agents.db import PilotAgentsDB - - -@pytest.fixture -async def pilot_agents_db(tmp_path) -> PilotAgentsDB: - agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") - async with agents_db.engine_context(): - async with agents_db.engine.begin() as conn: - await conn.run_sync(agents_db.metadata.create_all) - yield agents_db - - -async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): - async with pilot_agents_db as pilot_agents_db: - # Add a pilot reference - refs = [f"ref_{i}" for i in range(10)] - stamps = [f"stamp_{i}" for i in range(10)] - stamp_dict = dict(zip(refs, stamps)) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict - ) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=None - ) diff --git a/diracx-db/tests/pilot_agents/__init__.py b/diracx-db/tests/pilots/__init__.py similarity index 100% rename from diracx-db/tests/pilot_agents/__init__.py rename to diracx-db/tests/pilots/__init__.py diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py new file mode 100644 index 000000000..18fa1119c --- /dev/null +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -0,0 +1,405 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from sqlalchemy.sql import update + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, + PilotNotFoundError, +) +from diracx.core.models import PilotFieldsMapping +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +@pytest.fixture +async def add_stamps(pilot_db): + async def _add_stamps(start_n=0): + async with pilot_db as db: + # Add pilots + refs = [f"ref_{i}" for i in range(start_n, start_n + N)] + stamps = [f"stamp_{i}" for i in range(start_n, start_n + N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await db.add_pilots_bulk( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + pilots = await db.get_pilots_by_stamp_bulk(stamps) + + return pilots + + return _add_stamps + + +@pytest.fixture +async def create_timed_pilots(pilot_db, add_stamps): + async def _create_timed_pilots( + old_date: datetime, aborted: bool = False, start_n=0 + ): + # Get pilots + pilots = await add_stamps(start_n) + + async with pilot_db as db: + # Update manually their age + # Collect PilotStamps + pilot_stamps = [pilot["PilotStamp"] for pilot in pilots] + + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(pilot_stamps)) + .values(SubmissionTime=old_date) + ) + + if aborted: + stmt = stmt.values(Status="Aborted") + + res = await db.conn.execute(stmt) + assert res.rowcount == len(pilot_stamps) + + pilots = await db.get_pilots_by_stamp_bulk(pilot_stamps) + return pilots + + return _create_timed_pilots + + +@pytest.fixture +async def create_old_pilots_environment(pilot_db, create_timed_pilots): + non_aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), False, N + ) + aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), True, 2 * N + ) + + aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), True, 3 * N + ) + non_aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), False, 4 * N + ) + + pilot_number = 4 * N + + assert pilot_number == ( + len(non_aborted_recent) + + len(aborted_recent) + + len(aborted_very_old) + + len(non_aborted_very_old) + ) + + # Phase 0. Verify that we have the right environment + async with pilot_db as pilot_db: + # Ensure that we can get every pilot (only get first of each group) + await pilot_db.get_pilots_by_stamp_bulk([non_aborted_recent[0]["PilotStamp"]]) + await pilot_db.get_pilots_by_stamp_bulk([aborted_recent[0]["PilotStamp"]]) + await pilot_db.get_pilots_by_stamp_bulk([aborted_very_old[0]["PilotStamp"]]) + await pilot_db.get_pilots_by_stamp_bulk([non_aborted_very_old[0]["PilotStamp"]]) + + return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old + + +@pytest.mark.asyncio +async def test_insert_and_select(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(10)] + stamps = [f"stamp_{i}" for i in range(10)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots_bulk( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Accept duplicates because it is checked by the logic + await pilot_db.add_pilots_bulk( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=None + ) + + +@pytest.mark.asyncio +async def test_insert_and_delete(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(2)] + stamps = [f"stamp_{i}" for i in range(2)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots_bulk( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Works, the pilots exists + await pilot_db.get_pilots_by_stamp_bulk([stamps[0]]) + await pilot_db.get_pilots_by_stamp_bulk([stamps[0]]) + + # We delete the first pilot + await pilot_db.delete_pilots_by_stamps_bulk([stamps[0]]) + + # We get the 2nd pilot that is not delete (no error) + await pilot_db.get_pilots_by_stamp_bulk([stamps[1]]) + # We get the 1st pilot that is delete (error) + with pytest.raises(PilotNotFoundError): + await pilot_db.get_pilots_by_stamp_bulk([stamps[0]]) + + +@pytest.mark.asyncio +async def test_insert_and_delete_only_old_aborted( + pilot_db: PilotAgentsDB, create_old_pilots_environment +): + non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( + create_old_pilots_environment + ) + + async with pilot_db as pilot_db: + # Delete all aborted that were born before 2020 + # Every aborted that are old may be delete + await pilot_db.clear_pilots_bulk( + datetime(2020, 1, 1, tzinfo=timezone.utc), True + ) + + # Assert who still live + for normally_exiting_pilot_list in [ + non_aborted_recent, + aborted_recent, + non_aborted_very_old, + ]: + stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] + + await pilot_db.get_pilots_by_stamp_bulk(stamps) + + # Assert who normally does not live + for normally_deleted_pilot_list in [aborted_very_old]: + stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] + + with pytest.raises(PilotNotFoundError): + await pilot_db.get_pilots_by_stamp_bulk(stamps) + + +@pytest.mark.asyncio +async def test_insert_and_delete_old( + pilot_db: PilotAgentsDB, create_old_pilots_environment +): + non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( + create_old_pilots_environment + ) + + async with pilot_db as pilot_db: + # Delete all aborted that were born before 2020 + # Every aborted that are old may be delete + await pilot_db.clear_pilots_bulk( + datetime(2020, 1, 1, tzinfo=timezone.utc), False + ) + + # Assert who still live + for normally_exiting_pilot_list in [ + non_aborted_recent, + aborted_recent, + ]: + stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] + + await pilot_db.get_pilots_by_stamp_bulk(stamps) + + # Assert who normally does not live + for normally_deleted_pilot_list in [ + aborted_very_old, + non_aborted_very_old, + ]: + stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] + + with pytest.raises(PilotNotFoundError): + await pilot_db.get_pilots_by_stamp_bulk(stamps) + + +@pytest.mark.asyncio +async def test_insert_and_delete_recent_only_aborted( + pilot_db: PilotAgentsDB, create_old_pilots_environment +): + non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( + create_old_pilots_environment + ) + + async with pilot_db as pilot_db: + # Delete all aborted that were born before 2020 + # Every aborted that are old may be delete + await pilot_db.clear_pilots_bulk( + datetime(2025, 3, 10, tzinfo=timezone.utc), True + ) + + # Assert who still live + for normally_exiting_pilot_list in [non_aborted_recent, non_aborted_very_old]: + stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] + + await pilot_db.get_pilots_by_stamp_bulk(stamps) + + # Assert who normally does not live + for normally_deleted_pilot_list in [ + aborted_very_old, + aborted_recent, + ]: + stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] + + with pytest.raises(PilotNotFoundError): + await pilot_db.get_pilots_by_stamp_bulk(stamps) + + +@pytest.mark.asyncio +async def test_insert_and_delete_recent( + pilot_db: PilotAgentsDB, create_old_pilots_environment +): + non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( + create_old_pilots_environment + ) + + async with pilot_db as pilot_db: + # Delete all aborted that were born before 2020 + # Every aborted that are old may be delete + await pilot_db.clear_pilots_bulk( + datetime(2025, 3, 10, tzinfo=timezone.utc), False + ) + + # Assert who normally does not live + for normally_deleted_pilot_list in [ + aborted_very_old, + aborted_recent, + non_aborted_recent, + non_aborted_very_old, + ]: + stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] + + with pytest.raises(PilotNotFoundError): + await pilot_db.get_pilots_by_stamp_bulk(stamps) + + +@pytest.mark.asyncio +async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + await pilot_db.add_pilots_bulk( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await pilot_db.get_pilots_by_stamp_bulk([pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Assert values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 0.0 + assert pilot["Status"] == "Submitted" + assert pilot["StatusReason"] == "Unknown" + assert not pilot["AccountingSent"] + + # + # Modify a pilot, then check if every change is done + # + await pilot_db.update_pilot_fields_bulk( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status="WAITING", + ) + ] + ) + + res = await pilot_db.get_pilots_by_stamp_bulk([pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Set values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 1.0 + assert pilot["Status"] == "WAITING" + assert pilot["StatusReason"] == "NewReason" + assert pilot["AccountingSent"] + + +@pytest.mark.asyncio +async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): + """We will proceed in few steps. + + 1. Create a pilot + 2. Verify that he is not associated with any job + 3. Associate with jobs + 4. Verify that he is associate with this job + 5. Associate with jobs that he already has and two that he has not + 6. Associate with jobs that he has not, but were involved in a crash + """ + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + # Add pilot + await pilot_db.add_pilots_bulk( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await pilot_db.get_pilots_by_stamp_bulk([pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + pilot_id = pilot["PilotID"] + + # Verify that he has no jobs + assert len(await pilot_db.get_pilot_jobs_ids_by_pilot_id(pilot_id)) == 0 + + now = datetime.now(tz=timezone.utc) + + # Associate pilot with jobs + pilot_jobs = [1, 2, 3] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.associate_pilot_with_jobs(job_to_pilot_mapping) + + # Verify that he has all jobs + db_jobs = await pilot_db.get_pilot_jobs_ids_by_pilot_id(pilot_id) + # We test both length and if every job is included if for any reason we have duplicates + assert all(job in db_jobs for job in pilot_jobs) + assert len(pilot_jobs) == len(db_jobs) + + # Associate pilot with a job that he already has, and one that he has not + pilot_jobs = [10, 1, 5] + with pytest.raises(PilotAlreadyAssociatedWithJobError): + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.associate_pilot_with_jobs(job_to_pilot_mapping) + + # Associate pilot with jobs that he has not, but was previously in an error + # To test that the rollback worked + pilot_jobs = [5, 10] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.associate_pilot_with_jobs(job_to_pilot_mapping) diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py new file mode 100644 index 000000000..e2511e169 --- /dev/null +++ b/diracx-db/tests/pilots/test_query.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +import pytest + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + PilotFieldsMapping, + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql.pilots.db import PilotAgentsDB + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = ["Started", "Stopped", "Waiting"] + + +@pytest.fixture +async def populated_pilot_db(pilot_db): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i + 1}" for i in range(N)] + stamps = [f"stamp_{i + 1}" for i in range(N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await pilot_db.add_pilots_bulk( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + await pilot_db.get_pilots_by_stamp_bulk(stamps) + + await pilot_db.update_pilot_fields_bulk( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ) + for i, pilot_stamp in enumerate(stamps) + ] + ) + + yield pilot_db + + +async def test_search_parameters(populated_pilot_db): + """Test that we can search specific parameters for pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific parameter: PilotID + total, result = await pilot_db.search(["PilotID"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + + # Search a specific parameter: Status + total, result = await pilot_db.search(["Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"Status"} + + # Search for multiple parameters: PilotID, Status + total, result = await pilot_db.search(["PilotID", "Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + + # Search for a specific parameter but use distinct: Status + total, result = await pilot_db.search(["Status"], [], [], distinct=True) + assert total == len(PILOT_STATUSES) + assert result + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + total, result = await pilot_db.search(["Dummy"], [], []) + + +async def test_search_conditions(populated_pilot_db): + """Test that we can search for specific pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + total, result = await pilot_db.search([], [condition], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + total, result = await pilot_db.search([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + total, result = await pilot_db.search([], [condition], []) + assert total == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + total, result = await pilot_db.search([], [condition], []) + assert not result + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search([], [condition], []) + assert total == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + total, result = await pilot_db.search([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search([], [condition], []) + assert total == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + total, result = await pilot_db.search([], [condition], []) + assert total == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search([], [condition1, condition2], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search([], [condition1, condition2], []) + assert total == 0 + assert not result + + +async def test_search_sorts(populated_pilot_db): + """Test that we can search for pilots in the database and sort the results.""" + async with populated_pilot_db as pilot_db: + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + total, result = await pilot_db.search([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + total, result = await pilot_db.search([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + total, result = await pilot_db.search([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search([], [], [sort1, sort2]) + assert total == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + + +@pytest.mark.parametrize( + "per_page, page, expected_len, expected_first_id, expect_exception", + [ + (10, 1, 10, 1, None), # Page 1 + (10, 2, 10, 11, None), # Page 2 + (10, 10, 10, 91, None), # Page 10 + (50, 2, 50, 51, None), # Page 2 with 50 per page + (10, 11, 0, None, None), # Page beyond range, should return empty + (10, 0, None, None, InvalidQueryError), # Invalid page + (0, 1, None, None, InvalidQueryError), # Invalid per_page + ], +) +async def test_search_pagination( + populated_pilot_db, + per_page, + page, + expected_len, + expected_first_id, + expect_exception, +): + """Test pagination logic in pilot search.""" + async with populated_pilot_db as pilot_db: + if expect_exception: + with pytest.raises(expect_exception): + await pilot_db.search([], [], [], per_page=per_page, page=page) + else: + total, result = await pilot_db.search( + [], [], [], per_page=per_page, page=page + ) + assert total == N + if expected_len == 0: + assert not result + else: + assert result + assert len(result) == expected_len + assert result[0]["PilotID"] == expected_first_id diff --git a/diracx-logic/src/diracx/logic/jobs/query.py b/diracx-logic/src/diracx/logic/jobs/query.py index efb4b2fc5..ba3e6269b 100644 --- a/diracx-logic/src/diracx/logic/jobs/query.py +++ b/diracx-logic/src/diracx/logic/jobs/query.py @@ -5,9 +5,9 @@ from diracx.core.config.schema import Config from diracx.core.models import ( - JobSearchParams, JobSummaryParams, ScalarSearchOperator, + SearchParams, ) from diracx.db.os.job_parameters import JobParametersDB from diracx.db.sql.job.db import JobDB @@ -27,7 +27,7 @@ async def search( preferred_username: str | None, page: int = 1, per_page: int = 100, - body: JobSearchParams | None = None, + body: SearchParams | None = None, ) -> tuple[int, list[dict[str, Any]]]: """Retrieve information about jobs.""" # Apply a limit to per_page to prevent abuse of the API @@ -35,7 +35,7 @@ async def search( per_page = MAX_PER_PAGE if body is None: - body = JobSearchParams() + body = SearchParams() if query_logging_info := ("LoggingInfo" in (body.parameters or [])): if body.parameters: diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py new file mode 100644 index 000000000..f1c0ee3c8 --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from diracx.core.exceptions import PilotAlreadyExistsError, PilotNotFoundError +from diracx.core.models import PilotFieldsMapping +from diracx.db.sql import PilotAgentsDB + + +async def register_new_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + vo: str, + grid_type: str = "Dirac", + pilot_job_references: dict[str, str] | None = None, +): + # [IMPORTANT] Check unicity of pilot references + # If a pilot already exists, it will undo everything and raise an error + try: + await pilot_db.get_pilots_by_stamp_bulk(pilot_stamps=pilot_stamps) + raise PilotAlreadyExistsError(data={"pilot_stamps": str(pilot_stamps)}) + except PilotNotFoundError as e: + # e.non_existing_pilots is set of the pilot that are not found + # We can compare it with the pilot references that want to add + # If both sets are the same, it means that every pilots is new, and so we can add them to the db + # If not, it means that at least one is already in the db + + non_existing_pilots = e.non_existing_pilots + pilots_that_already_exist = set(pilot_stamps) - non_existing_pilots + + if pilots_that_already_exist: + raise PilotAlreadyExistsError( + data={"pilot_stamps": str(pilots_that_already_exist)} + ) from e + + await pilot_db.add_pilots_bulk( + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + pilot_references=pilot_job_references, + ) + + +async def clear_pilots_bulk( + pilot_db: PilotAgentsDB, age_in_days: int, delete_only_aborted: bool +): + """Delete pilots that have been submitted before interval_in_days.""" + cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) + + await pilot_db.clear_pilots_bulk( + cutoff_date=cutoff_date, delete_only_aborted=delete_only_aborted + ) + + +async def delete_pilots_by_stamps_bulk( + pilot_db: PilotAgentsDB, pilot_stamps: list[str] +): + await pilot_db.delete_pilots_by_stamps_bulk(pilot_stamps) + + +async def update_pilots_fields( + pilot_db: PilotAgentsDB, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] +): + await pilot_db.update_pilot_fields_bulk(pilot_stamps_to_fields_mapping) + + +async def associate_pilot_with_jobs( + pilot_db: PilotAgentsDB, pilot_stamp: str, pilot_jobs_ids: list[int] +): + pilot_ids = await pilot_db.get_pilot_ids_by_stamps([pilot_stamp]) + # Semantic assured by fetch_records_bulk_or_raises + pilot_id = pilot_ids[0] + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs_ids + ] + + await pilot_db.associate_pilot_with_jobs( + job_to_pilot_mapping=job_to_pilot_mapping, + ) + + +async def get_pilot_jobs_ids_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamp: str +) -> list[int]: + """Fetch pilot jobs by stamp.""" + pilot_ids = await pilot_db.get_pilot_ids_by_stamps([pilot_stamp]) + # Semantic assured by fetch_records_bulk_or_raises + pilot_id = pilot_ids[0] + + return await pilot_db.get_pilot_jobs_ids_by_pilot_id(pilot_id) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py new file mode 100644 index 000000000..dbdb686dc --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Any + +from diracx.core.models import ScalarSearchOperator, SearchParams +from diracx.db.sql import PilotAgentsDB + +MAX_PER_PAGE = 10000 + + +async def search( + pilot_db: PilotAgentsDB, + user_vo: str, + page: int = 1, + per_page: int = 100, + body: SearchParams | None = None, +) -> tuple[int, list[dict[str, Any]]]: + """Retrieve information about jobs.""" + # Apply a limit to per_page to prevent abuse of the API + if per_page > MAX_PER_PAGE: + per_page = MAX_PER_PAGE + + if body is None: + body = SearchParams() + + body.search.append( + {"parameter": "VO", "operator": ScalarSearchOperator.EQUAL, "value": user_vo} + ) + + total, pilots = await pilot_db.search( + body.parameters, + body.search, + body.sort, + distinct=body.distinct, + page=page, + per_page=per_page, + ) + + return total, pilots diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 6f554c74e..2038223ce 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -46,10 +46,12 @@ auth = "diracx.routers.auth:router" config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" +pilots = "diracx.routers.pilots:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +PilotManagementAccessPolicy = "diracx.routers.pilots.access_policies:PilotManagementAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/jobs/query.py b/diracx-routers/src/diracx/routers/jobs/query.py index a8667b7dd..f2f8dd323 100644 --- a/diracx-routers/src/diracx/routers/jobs/query.py +++ b/diracx-routers/src/diracx/routers/jobs/query.py @@ -6,8 +6,8 @@ from fastapi import Body, Depends, Response from diracx.core.models import ( - JobSearchParams, JobSummaryParams, + SearchParams, ) from diracx.core.properties import JOB_ADMINISTRATOR from diracx.logic.jobs.query import search as search_bl @@ -135,7 +135,7 @@ async def search( page: int = 1, per_page: int = 100, body: Annotated[ - JobSearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) ] = None, ) -> list[dict[str, Any]]: """Retrieve information about jobs. diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py new file mode 100644 index 000000000..03f9b8422 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import logging + +from ..fastapi_classes import DiracxRouter +from .management import router as management_router +from .query import router as query_router + +logger = logging.getLogger(__name__) + +router = DiracxRouter() +router.include_router(management_router) +router.include_router(query_router) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py new file mode 100644 index 000000000..02d6de0e8 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from collections.abc import Callable +from enum import StrEnum, auto +from typing import Annotated + +from fastapi import Depends, HTTPException, status + +from diracx.core.exceptions import PilotNotFoundError +from diracx.core.properties import NORMAL_USER, TRUSTED_HOST +from diracx.db.sql import PilotAgentsDB +from diracx.routers.access_policies import BaseAccessPolicy +from diracx.routers.utils.users import AuthorizedUserInfo + + +class ActionType(StrEnum): + # Create a pilot + CREATE_PILOT = auto() + # Change some pilot fields + CHANGE_PILOT_FIELD = auto() + # Read some pilot info + READ_PILOT_FIELDS = auto() + + +class PilotManagementAccessPolicy(BaseAccessPolicy): + """Rules: + * You need either NORMAL_USER in your properties + * A NORMAL_USER can create a pilot. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + pilot_db: PilotAgentsDB | None = None, + pilot_stamps: list[str] | None = None, + vo: str | None = None, + action: ActionType | None = None, + ): + assert action, "action is a mandatory parameter" + + if action == ActionType.READ_PILOT_FIELDS: + if NORMAL_USER in user_info.properties: + return + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You have to be logged on to see pilots.", + ) + + if not vo: + assert pilot_stamps and pilot_db, ( + "if vo is not provided, " + "pilot_stamp and pilot_db are mandatory to determine the vo" + ) + + try: + pilots = await pilot_db.get_pilots_by_stamp_bulk(pilot_stamps) + except PilotNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The given stamp is not associated with a pilot", + ) from e + + # Semantic assured by get_pilots_by_stamp_bulk + first_vo = pilots[0]["VO"] + + if not all(pilot["VO"] == first_vo for pilot in pilots): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You gave pilots with different VOs.", + ) + + vo = first_vo + + if not vo == user_info.vo: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the right VO for this resource.", + ) + + if NORMAL_USER not in user_info.properties: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the rights to create pilots.", + ) + + if action == ActionType.CREATE_PILOT: + return + + if action == ActionType.CHANGE_PILOT_FIELD: + return + + raise ValueError("Unknown action.") + + +CheckPilotManagementPolicyCallable = Annotated[ + Callable, Depends(PilotManagementAccessPolicy.check) +] + + +class DiracServicesAccessPolicy(BaseAccessPolicy): + """This access policy is used by DIRAC services (ex: Matcher).""" + + @staticmethod + async def policy(policy_name: str, user_info: AuthorizedUserInfo): + if TRUSTED_HOST in user_info.properties: + return + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="This endpoint is reserved only for DIRAC services.", + ) + + +CheckDiracServicesPolicyCallable = Annotated[ + Callable, Depends(DiracServicesAccessPolicy.check) +] diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py new file mode 100644 index 000000000..2d959e378 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import logging +from http import HTTPStatus +from typing import Annotated + +from fastapi import Body, Depends, HTTPException, Query, status + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, + PilotAlreadyExistsError, + PilotNotFoundError, +) +from diracx.core.models import ( + PilotFieldsMapping, +) +from diracx.logic.pilots.management import ( + associate_pilot_with_jobs as associate_pilot_with_jobs_bl, +) +from diracx.logic.pilots.management import ( + clear_pilots_bulk, + delete_pilots_by_stamps_bulk, + register_new_pilots, + update_pilots_fields, +) + +from ..dependencies import PilotAgentsDB +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ( + ActionType, + CheckDiracServicesPolicyCallable, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + +logger = logging.getLogger(__name__) + + +@router.post("/management/pilot") +async def add_pilot_stamps( + pilot_db: PilotAgentsDB, + pilot_stamps: Annotated[ + list[str], + Body(description="List of the pilot stamps we want to add to the db."), + ], + vo: Annotated[ + str, + Body(description="Virtual Organisation associated with the inserted pilots."), + ], + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckPilotManagementPolicyCallable, + grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", + pilot_references: Annotated[ + dict[str, str] | None, + Body(description="Association of a pilot reference with a pilot stamp."), + ] = None, +): + """Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + """ + await check_permissions(action=ActionType.CREATE_PILOT, vo=vo) + + try: + await register_new_pilots( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + pilot_job_references=pilot_references, + ) + + # Logs credentials creation + logger.debug( + f"{user_info.preferred_username} added {len(pilot_stamps)} pilots." + ) + except PilotAlreadyExistsError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.delete("/management/pilot", status_code=HTTPStatus.NO_CONTENT) +async def delete_pilots( + pilot_stamps: Annotated[ + list[str], Query(description="Stamps of the pilots we want to delete.") + ], + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, +): + """Endpoint to delete a pilot. + + If at least one pilot is not found, it WILL rollback. + """ + await check_permissions( + action=ActionType.CHANGE_PILOT_FIELD, + pilot_stamps=pilot_stamps, + pilot_db=pilot_db, + ) + + try: + await delete_pilots_by_stamps_bulk(pilot_db=pilot_db, pilot_stamps=pilot_stamps) + except PilotNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one pilot has not been found.", + ) from e + + +@router.delete("/management/pilot/interval", status_code=HTTPStatus.NO_CONTENT) +async def clear_pilots( + pilot_db: PilotAgentsDB, + age_in_days: Annotated[ + int, + Query( + description=( + "The number of days that define the maximum age of pilots to be deleted." + "Pilots older than this age will be considered for deletion." + ) + ), + ], + check_permissions: CheckDiracServicesPolicyCallable, + delete_only_aborted: Annotated[ + bool, + Query( + description=( + "Flag indicating whether to only delete pilots whose status is 'Aborted'." + "If set to True, only pilots with the 'Aborted' status will be deleted." + "It is set by default as True to avoid any mistake." + ) + ), + ] = True, +): + """Endpoint for DIRAC to delete all pilots that lived more than age_in_days.""" + await check_permissions() + + if age_in_days < 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="age_in_days must be positive.", + ) + + await clear_pilots_bulk( + pilot_db=pilot_db, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + ) + + +EXAMPLE_UPDATE_FIELDS = { + "Update the BenchMark field": { + "summary": "Update BenchMark", + "description": "Update only the BenchMark for one pilot.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_pilot_stamp", "BenchMark": 1.0} + ] + }, + }, + "Update multiple statuses": { + "summary": "Update multiple pilots", + "description": "Update multiple pilots statuses.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_first_pilot_stamp", "Status": "Waiting"}, + {"PilotStamp": "the_second_pilot_stamp", "Status": "Waiting"}, + ] + }, + }, +} + + +@router.patch("/management/pilot", status_code=HTTPStatus.NO_CONTENT) +async def update_pilot_fields( + pilot_stamps_to_fields_mapping: Annotated[ + list[PilotFieldsMapping], + Body( + description="(pilot_stamp, pilot_fields) mapping to change.", + embed=True, + openapi_examples=EXAMPLE_UPDATE_FIELDS, + ), + ], + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, +): + """Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + """ + # TODO: Add an example for openapi + pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] + + # Ensures stamps validity + await check_permissions( + action=ActionType.CHANGE_PILOT_FIELD, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + ) + + await update_pilots_fields( + pilot_db=pilot_db, + pilot_stamps_to_fields_mapping=pilot_stamps_to_fields_mapping, + ) + + +@router.patch("/management/jobs", status_code=HTTPStatus.NO_CONTENT) +async def associate_pilot_with_jobs( + pilot_db: PilotAgentsDB, + pilot_stamp: Annotated[str, Body(description="The stamp of the pilot.")], + pilot_jobs_ids: Annotated[ + list[int], Body(description="The jobs we want to add to the pilot.") + ], + check_permissions: CheckDiracServicesPolicyCallable, +): + """Endpoint only for DIRAC services, to associate a pilot with a job.""" + await check_permissions() + + try: + await associate_pilot_with_jobs_bl( + pilot_db=pilot_db, + pilot_stamp=pilot_stamp, + pilot_jobs_ids=pilot_jobs_ids, + ) + except PilotNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="This pilot does not exist." + ) from e + except PilotAlreadyAssociatedWithJobError as e: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="This pilot is already associated with this job.", + ) from e diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py new file mode 100644 index 000000000..f494ce1f6 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import Body, Depends, Response + +from diracx.core.models import SearchParams +from diracx.logic.pilots.query import search as search_bl + +from ..dependencies import PilotAgentsDB +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ( + ActionType, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + +EXAMPLE_SEARCHES = { + "Show all": { + "summary": "Show all", + "description": "Shows all pilots the current user has access to.", + "value": {}, + }, + "A specific pilot": { + "summary": "A specific pilot", + "description": "Search for a specific pilot by ID", + "value": {"search": [{"parameter": "PilotID", "operator": "eq", "value": "5"}]}, + }, + "Get ordered pilot statuses": { + "summary": "Get ordered pilot statuses", + "description": "Get only pilot statuses for specific pilots, ordered by status", + "value": { + "parameters": ["PilotID", "Status"], + "search": [ + {"parameter": "PilotID", "operator": "in", "values": ["6", "2", "3"]} + ], + "sort": [{"parameter": "PilotID", "direction": "asc"}], + }, + }, +} + + +EXAMPLE_RESPONSES: dict[int | str, dict[str, Any]] = { + 200: { + "description": "List of matching results", + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, + 206: { + "description": "Partial Content. Only a part of the requested range could be served.", + "headers": { + "Content-Range": { + "description": "The range of pilots returned in this response", + "schema": {"type": "string", "example": "pilots 0-1/4"}, + } + }, + "model": list[dict[str, Any]], + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, +} + + +@router.post("/management/search", responses=EXAMPLE_RESPONSES) +async def search( + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + page: int = 1, + per_page: int = 100, + body: Annotated[ + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) + ] = None, +) -> list[dict[str, Any]]: + """Retrieve information about pilots.""" + # Inspired by /api/jobs/query + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + user_vo = user_info.vo + + total, pilots = await search_bl( + pilot_db=pilot_db, + user_vo=user_vo, + page=page, + per_page=per_page, + body=body, + ) + + # Set the Content-Range header if needed + # https://datatracker.ietf.org/doc/html/rfc7233#section-4 + + # No pilots found but there are pilots for the requested search + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 + if len(pilots) == 0 and total > 0: + response.headers["Content-Range"] = f"pilots */{total}" + response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE + + # The total number of pilots is greater than the number of pilots returned + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 + elif len(pilots) < total: + first_idx = per_page * (page - 1) + last_idx = min(first_idx + len(pilots), total) - 1 if total > 0 else 0 + response.headers["Content-Range"] = f"pilots {first_idx}-{last_idx}/{total}" + response.status_code = HTTPStatus.PARTIAL_CONTENT + return pilots diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py new file mode 100644 index 000000000..b4e5d2eec --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import pytest + +from diracx.core.models import PilotFieldsMapping + +pytestmark = pytest.mark.enabled_dependencies( + [ + "PilotCredentialsAccessPolicy", + "DevelopmentSettings", + "AuthDB", + "AuthSettings", + "ConfigSource", + "BaseAccessPolicy", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + ] +) + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +def test_client(client_factory): + with client_factory.unauthenticated() as client: + yield client + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +async def test_create_pilots(normal_test_client): + # Lots of request, to validate that it returns the credentials in the same order as the input references + pilot_stamps = [f"stamps_{i}" for i in range(N)] + + # -------------- Bulk insert -------------- + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + + r = normal_test_client.post( + "/api/pilots/management/pilot", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Register a pilot that already exists, and one that does not -------------- + + body = { + "vo": MAIN_VO, + "pilot_stamps": [pilot_stamps[0], pilot_stamps[0] + "_new_one"], + } + + r = normal_test_client.post( + "/api/pilots/management/pilot", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 409 + assert ( + r.json()["detail"] + == f"Pilot (pilot_stamps: {{'{pilot_stamps[0]}'}}) already exists" + ) + + # -------------- Register a pilot that does not exists **but** was called before in an error -------------- + # To prove that, if I tried to register a pilot that does not exist with one that already exists, + # i can normally add the one that did not exist before (it should not have added it before) + body = {"vo": MAIN_VO, "pilot_stamps": [pilot_stamps[0] + "_new_one"]} + + r = normal_test_client.post( + "/api/pilots/management/pilot", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 200 + + +async def test_create_pilot_and_delete_it(normal_test_client): + pilot_stamp = "stamps_1" + + # -------------- Insert -------------- + body = {"vo": MAIN_VO, "pilot_stamps": [pilot_stamp]} + + # Create a pilot + r = normal_test_client.post( + "/api/pilots/management/pilot", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Duplicate -------------- + # Duplicate because it exists, should have 409 + r = normal_test_client.post( + "/api/pilots/management/pilot", + json=body, + ) + + assert r.status_code == 409, r.json() + + # -------------- Delete -------------- + params = {"pilot_stamps": [pilot_stamp]} + + # We delete the pilot + r = normal_test_client.delete( + "/api/pilots/management/pilot", + params=params, + ) + + assert r.status_code == 204 + + # -------------- Insert -------------- + # Create a the same pilot, but works because it does not exist anymore + r = normal_test_client.post( + "/api/pilots/management/pilot", + json=body, + ) + + assert r.status_code == 200, r.json() + + +async def test_create_pilot_and_modify_it(normal_test_client): + pilot_stamps = ["stamps_1", "stamp_2"] + + # -------------- Insert -------------- + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + + # Create pilots + r = normal_test_client.post( + "/api/pilots/management/pilot", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Modify -------------- + # We modify only the first pilot + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamps[0], + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status="Waiting", + ).model_dump(exclude_unset=True) + ] + } + + r = normal_test_client.patch("/api/pilots/management/pilot", json=body) + + assert r.status_code == 204 + + body = { + "parameters": [], + "search": [], + "sort": [], + "distinct": True, + } + + r = normal_test_client.post("/api/pilots/management/search", json=body) + assert r.status_code == 200, r.json() + pilot1 = r.json()[0] + pilot2 = r.json()[1] + + assert pilot1["BenchMark"] == 1.0 + assert pilot1["StatusReason"] == "NewReason" + assert pilot1["AccountingSent"] + assert pilot1["Status"] == "Waiting" + + assert pilot2["BenchMark"] != pilot1["BenchMark"] + assert pilot2["StatusReason"] != pilot1["StatusReason"] + assert pilot2["AccountingSent"] != pilot1["AccountingSent"] + assert pilot2["Status"] != pilot1["Status"] diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py new file mode 100644 index 000000000..fe92ef0b9 --- /dev/null +++ b/diracx-routers/tests/pilots/test_query.py @@ -0,0 +1,372 @@ +"""Inspired by pilots and jobs db search tests.""" + +from __future__ import annotations + +import pytest + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + PilotFieldsMapping, + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthSettings", + "ConfigSource", + "DevelopmentSettings", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + ] +) + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +MAIN_VO = "lhcb" +N = 100 + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = ["Started", "Stopped", "Waiting"] + + +@pytest.fixture +async def populated_pilot_client(normal_test_client): + pilot_stamps = [f"stamp_{i}" for i in range(1, N + 1)] + + # -------------- Bulk insert -------------- + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + + r = normal_test_client.post( + "/api/pilots/management/pilot", + json=body, + ) + + assert r.status_code == 200, r.json() + + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ).model_dump(exclude_unset=True) + for i, pilot_stamp in enumerate(pilot_stamps) + ] + } + + r = normal_test_client.patch("/api/pilots/management/pilot", json=body) + + assert r.status_code == 204 + + yield normal_test_client + + +@pytest.fixture +async def search(populated_pilot_client): + async def _search( + parameters, conditions, sorts, distinct=False, page=1, per_page=100 + ): + body = { + "parameters": parameters, + "search": conditions, + "sort": sorts, + "distinct": distinct, + } + + params = {"per_page": per_page, "page": page} + + r = populated_pilot_client.post( + "/api/pilots/management/search", json=body, params=params + ) + + if r.status_code == 400: + # If we have a status_code 400, that means that the query failed + raise InvalidQueryError() + + return r.json(), r.headers + + return _search + + +async def test_search_parameters(search): + """Test that we can search specific parameters for pilots.""" + # Search a specific parameter: PilotID + result, headers = await search(["PilotID"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + assert "Content-Range" not in headers + + # Search a specific parameter: Status + result, headers = await search(["Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"Status"} + assert "Content-Range" not in headers + + # Search for multiple parameters: PilotID, Status + result, headers = await search(["PilotID", "Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + assert "Content-Range" not in headers + + # Search for a specific parameter but use distinct: Status + result, headers = await search(["Status"], [], [], distinct=True) + assert len(result) == len(PILOT_STATUSES) + assert result + assert "Content-Range" not in headers + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + result, headers = await search(["Dummy"], [], []) + + +async def test_search_conditions(search): + """Test that we can search for specific pilots.""" + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + result, headers = await search([], [condition], []) + assert not result + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + result, headers = await search([], [condition], []) + assert len(result) == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + assert "Content-Range" not in headers + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + assert len(result) == 0 + assert not result + assert "Content-Range" not in headers + + +async def test_search_sorts(search): + """Test that we can search for pilots and sort the results.""" + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + assert "Content-Range" not in headers + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort1, sort2]) + assert len(result) == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + assert "Content-Range" not in headers + + +async def test_search_pagination(search): + """Test that we can search for pilots.""" + # Search for the first 10 pilots + result, headers = await search([], [], [], per_page=10, page=1) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 1 + + # Search for the second 10 pilots + result, headers = await search([], [], [], per_page=10, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 11 + + # Search for the last 10 pilots + result, headers = await search([], [], [], per_page=10, page=10) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 91 + + # Search for the second 50 pilots + result, headers = await search([], [], [], per_page=50, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 50 + assert result[0]["PilotID"] == 51 + + # Invalid page number + result, headers = await search([], [], [], per_page=10, page=11) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert not result + + # Invalid page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=10, page=0) + + # Invalid per_page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=0, page=1) diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py index 65282efb6..fdf17b6a3 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py @@ -15,7 +15,14 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, LollygagOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + LollygagOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.operations.JobsOperations :ivar lollygag: LollygagOperations operations :vartype lollygag: _generated.operations.LollygagOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py index d67986dae..76280797e 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py @@ -15,7 +15,14 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, LollygagOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + LollygagOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.aio.operations.JobsOperations :ivar lollygag: LollygagOperations operations :vartype lollygag: _generated.aio.operations.LollygagOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py index 572930a93..3408891fc 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "LollygagOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 30d2e1c17..23082b70c 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -55,6 +55,12 @@ build_lollygag_get_gubbins_secrets_request, build_lollygag_get_owner_object_request, build_lollygag_insert_owner_object_request, + build_pilots_add_pilot_stamps_request, + build_pilots_associate_pilot_with_jobs_request, + build_pilots_clear_pilots_request, + build_pilots_delete_pilots_request, + build_pilots_search_request, + build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -1829,7 +1835,7 @@ async def patch_metadata(self, body: Union[Dict[str, Dict[str, Any]], IO[bytes]] @overload async def search( self, - body: Optional[_models.JobSearchParams] = None, + body: Optional[_models.SearchParams] = None, *, page: int = 1, per_page: int = 100, @@ -1843,7 +1849,7 @@ async def search( **TODO: Add more docs**. :param body: Default value is None. - :type body: ~_generated.models.JobSearchParams + :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -1889,7 +1895,7 @@ async def search( @distributed_trace_async async def search( self, - body: Optional[Union[_models.JobSearchParams, IO[bytes]]] = None, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, *, page: int = 1, per_page: int = 100, @@ -1901,8 +1907,8 @@ async def search( **TODO: Add more docs**. - :param body: Is either a JobSearchParams type or a IO[bytes] type. Default value is None. - :type body: ~_generated.models.JobSearchParams or IO[bytes] + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -1932,7 +1938,7 @@ async def search( _content = body else: if body is not None: - _json = self._serialize.body(body, "JobSearchParams") + _json = self._serialize.body(body, "SearchParams") else: _json = None @@ -2324,3 +2330,557 @@ async def get_gubbins_secrets(self, **kwargs: Any) -> Any: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots(self, *, pilot_stamps: List[str], **kwargs: Any) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + If at least one pilot is not found, it WILL rollback. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Required. + :paramtype pilot_stamps: list[str] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_fields( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_fields( + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace_async + async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any) -> None: + """Clear Pilots. + + Endpoint for DIRAC to delete all pilots that lived more than age_in_days. + + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Required. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake. Default value is True. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_clear_pilots_request( + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def associate_pilot_with_jobs( + self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def associate_pilot_with_jobs( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def associate_pilot_with_jobs( + self, body: Union[_models.BodyPilotsAssociatePilotWithJobs, IO[bytes]], **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Is either a BodyPilotsAssociatePilotWithJobs type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAssociatePilotWithJobs") + + _request = build_pilots_associate_pilot_with_jobs_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index 2c1fc99e9..60f09c531 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -14,6 +14,9 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, + BodyPilotsAddPilotStamps, + BodyPilotsAssociatePilotWithJobs, + BodyPilotsUpdatePilotFields, ExtendedMetadata, GroupInfo, HTTPValidationError, @@ -21,17 +24,18 @@ InitiateDeviceFlowResponse, InsertedJob, JobCommand, - JobSearchParams, - JobSearchParamsSearchItem, JobStatusUpdate, JobSummaryParams, JobSummaryParamsSearchItem, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, ScalarSearchSpec, ScalarSearchSpecValue, + SearchParams, + SearchParamsSearchItem, SetJobStatusReturn, SetJobStatusReturnSuccess, SortSpec, @@ -61,6 +65,9 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", + "BodyPilotsAddPilotStamps", + "BodyPilotsAssociatePilotWithJobs", + "BodyPilotsUpdatePilotFields", "ExtendedMetadata", "GroupInfo", "HTTPValidationError", @@ -68,17 +75,18 @@ "InitiateDeviceFlowResponse", "InsertedJob", "JobCommand", - "JobSearchParams", - "JobSearchParamsSearchItem", "JobStatusUpdate", "JobSummaryParams", "JobSummaryParamsSearchItem", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", "ScalarSearchSpec", "ScalarSearchSpecValue", + "SearchParams", + "SearchParamsSearchItem", "SetJobStatusReturn", "SetJobStatusReturnSuccess", "SortSpec", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 714f0317a..2400791e0 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -94,6 +94,119 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Virtual Organisation associated with the inserted pilots. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + pilot_references: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Virtual Organisation associated with the inserted pilots. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.pilot_references = pilot_references + + +class BodyPilotsAssociatePilotWithJobs(_serialization.Model): + """Body_pilots_associate_pilot_with_jobs. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: The stamp of the pilot. Required. + :vartype pilot_stamp: str + :ivar pilot_jobs_ids: The jobs we want to add to the pilot. Required. + :vartype pilot_jobs_ids: list[int] + """ + + _validation = { + "pilot_stamp": {"required": True}, + "pilot_jobs_ids": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "pilot_jobs_ids": {"key": "pilot_jobs_ids", "type": "[int]"}, + } + + def __init__(self, *, pilot_stamp: str, pilot_jobs_ids: List[int], **kwargs: Any) -> None: + """ + :keyword pilot_stamp: The stamp of the pilot. Required. + :paramtype pilot_stamp: str + :keyword pilot_jobs_ids: The jobs we want to add to the pilot. Required. + :paramtype pilot_jobs_ids: list[int] + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.pilot_jobs_ids = pilot_jobs_ids + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class ExtendedMetadata(_serialization.Model): """ExtendedMetadata. @@ -405,56 +518,6 @@ def __init__(self, *, job_id: int, command: str, arguments: Optional[str] = None self.arguments = arguments -class JobSearchParams(_serialization.Model): - """JobSearchParams. - - :ivar parameters: Parameters. - :vartype parameters: list[str] - :ivar search: Search. - :vartype search: list[~_generated.models.JobSearchParamsSearchItem] - :ivar sort: Sort. - :vartype sort: list[~_generated.models.SortSpec] - :ivar distinct: Distinct. - :vartype distinct: bool - """ - - _attribute_map = { - "parameters": {"key": "parameters", "type": "[str]"}, - "search": {"key": "search", "type": "[JobSearchParamsSearchItem]"}, - "sort": {"key": "sort", "type": "[SortSpec]"}, - "distinct": {"key": "distinct", "type": "bool"}, - } - - def __init__( - self, - *, - parameters: Optional[List[str]] = None, - search: List["_models.JobSearchParamsSearchItem"] = [], - sort: List["_models.SortSpec"] = [], - distinct: bool = False, - **kwargs: Any - ) -> None: - """ - :keyword parameters: Parameters. - :paramtype parameters: list[str] - :keyword search: Search. - :paramtype search: list[~_generated.models.JobSearchParamsSearchItem] - :keyword sort: Sort. - :paramtype sort: list[~_generated.models.SortSpec] - :keyword distinct: Distinct. - :paramtype distinct: bool - """ - super().__init__(**kwargs) - self.parameters = parameters - self.search = search - self.sort = sort - self.distinct = distinct - - -class JobSearchParamsSearchItem(_serialization.Model): - """JobSearchParamsSearchItem.""" - - class JobStatusUpdate(_serialization.Model): """JobStatusUpdate. @@ -676,6 +739,100 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: Status. + :vartype status: str + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[str] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: Status. + :paramtype status: str + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. @@ -857,6 +1014,56 @@ class ScalarSearchSpecValue(_serialization.Model): """Value.""" +class SearchParams(_serialization.Model): + """SearchParams. + + :ivar parameters: Parameters. + :vartype parameters: list[str] + :ivar search: Search. + :vartype search: list[~_generated.models.SearchParamsSearchItem] + :ivar sort: Sort. + :vartype sort: list[~_generated.models.SortSpec] + :ivar distinct: Distinct. + :vartype distinct: bool + """ + + _attribute_map = { + "parameters": {"key": "parameters", "type": "[str]"}, + "search": {"key": "search", "type": "[SearchParamsSearchItem]"}, + "sort": {"key": "sort", "type": "[SortSpec]"}, + "distinct": {"key": "distinct", "type": "bool"}, + } + + def __init__( + self, + *, + parameters: Optional[List[str]] = None, + search: List["_models.SearchParamsSearchItem"] = [], + sort: List["_models.SortSpec"] = [], + distinct: bool = False, + **kwargs: Any + ) -> None: + """ + :keyword parameters: Parameters. + :paramtype parameters: list[str] + :keyword search: Search. + :paramtype search: list[~_generated.models.SearchParamsSearchItem] + :keyword sort: Sort. + :paramtype sort: list[~_generated.models.SortSpec] + :keyword distinct: Distinct. + :paramtype distinct: bool + """ + super().__init__(**kwargs) + self.parameters = parameters + self.search = search + self.sort = sort + self.distinct = distinct + + +class SearchParamsSearchItem(_serialization.Model): + """SearchParamsSearchItem.""" + + class SetJobStatusReturn(_serialization.Model): """SetJobStatusReturn. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py index 572930a93..3408891fc 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "LollygagOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 4e429a056..c594dd602 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -639,6 +639,103 @@ def build_lollygag_get_gubbins_secrets_request(**kwargs: Any) -> HttpRequest: # return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/management/pilot" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request(*, pilot_stamps: List[str], **kwargs: Any) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/management/pilot" + + # Construct parameters + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/management/pilot" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_clear_pilots_request( + *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any +) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/management/pilot/interval" + + # Construct parameters + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_associate_pilot_with_jobs_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/management/jobs" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/management/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2400,7 +2497,7 @@ def patch_metadata( # pylint: disable=inconsistent-return-statements @overload def search( self, - body: Optional[_models.JobSearchParams] = None, + body: Optional[_models.SearchParams] = None, *, page: int = 1, per_page: int = 100, @@ -2414,7 +2511,7 @@ def search( **TODO: Add more docs**. :param body: Default value is None. - :type body: ~_generated.models.JobSearchParams + :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -2460,7 +2557,7 @@ def search( @distributed_trace def search( self, - body: Optional[Union[_models.JobSearchParams, IO[bytes]]] = None, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, *, page: int = 1, per_page: int = 100, @@ -2472,8 +2569,8 @@ def search( **TODO: Add more docs**. - :param body: Is either a JobSearchParams type or a IO[bytes] type. Default value is None. - :type body: ~_generated.models.JobSearchParams or IO[bytes] + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -2503,7 +2600,7 @@ def search( _content = body else: if body is not None: - _json = self._serialize.body(body, "JobSearchParams") + _json = self._serialize.body(body, "SearchParams") else: _json = None @@ -2893,3 +2990,559 @@ def get_gubbins_secrets(self, **kwargs: Any) -> Any: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, *, pilot_stamps: List[str], **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + If at least one pilot is not found, it WILL rollback. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Required. + :paramtype pilot_stamps: list[str] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_fields( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace + def clear_pilots( # pylint: disable=inconsistent-return-statements + self, *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any + ) -> None: + """Clear Pilots. + + Endpoint for DIRAC to delete all pilots that lived more than age_in_days. + + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Required. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake. Default value is True. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_clear_pilots_request( + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def associate_pilot_with_jobs( + self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def associate_pilot_with_jobs( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def associate_pilot_with_jobs( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsAssociatePilotWithJobs, IO[bytes]], **kwargs: Any + ) -> None: + """Associate Pilot With Jobs. + + Endpoint only for DIRAC services, to associate a pilot with a job. + + :param body: Is either a BodyPilotsAssociatePilotWithJobs type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAssociatePilotWithJobs") + + _request = build_pilots_associate_pilot_with_jobs_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore From 65894156d5db0a3cf06dc17318818f747f00ef19 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Sun, 15 Jun 2025 08:10:13 +0200 Subject: [PATCH 02/33] feat: Fixes and now use of search engine instead of DIY fetch records --- diracx-db/src/diracx/db/sql/job/db.py | 47 ++----- diracx-db/src/diracx/db/sql/pilots/db.py | 131 +++++------------- diracx-db/src/diracx/db/sql/utils/__init__.py | 4 +- diracx-db/src/diracx/db/sql/utils/base.py | 58 +++++++- .../src/diracx/db/sql/utils/functions.py | 90 +----------- diracx-db/tests/jobs/test_job_db.py | 54 ++++---- .../tests/pilots/test_pilot_management.py | 106 +++++++++++--- diracx-db/tests/pilots/test_query.py | 46 +++--- diracx-logic/src/diracx/logic/jobs/query.py | 2 +- diracx-logic/src/diracx/logic/jobs/status.py | 6 +- .../src/diracx/logic/pilots/__init__.py | 0 .../src/diracx/logic/pilots/management.py | 20 ++- diracx-logic/src/diracx/logic/pilots/query.py | 76 +++++++++- .../diracx/routers/pilots/access_policies.py | 7 +- 14 files changed, 334 insertions(+), 313 deletions(-) create mode 100644 diracx-logic/src/diracx/logic/pilots/__init__.py diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 809fed97e..682475b9b 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -13,13 +13,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobCommand, SearchSpec, SortSpec -from ..utils import ( - BaseSQLDB, - _get_columns, - apply_search_filters, - apply_sort_constraints, - utcnow, -) +from ..utils import BaseSQLDB, _get_columns, apply_search_filters, utcnow from .schema import ( HeartBeatLoggingInfo, InputData, @@ -63,7 +57,7 @@ async def summary(self, group_by, search) -> list[dict[str, str | int]]: if row.count > 0 # type: ignore ] - async def search( + async def search_jobs( self, parameters: list[str] | None, search: list[SearchSpec], @@ -74,34 +68,15 @@ async def search( page: int | None = None, ) -> tuple[int, list[dict[Any, Any]]]: """Search for jobs in the database.""" - # Find which columns to select - columns = _get_columns(Jobs.__table__, parameters) - - stmt = select(*columns) - - stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search) - stmt = apply_sort_constraints(Jobs.__table__.columns.__getitem__, stmt, sorts) - - if distinct: - stmt = stmt.distinct() - - # Calculate total count before applying pagination - total_count_subquery = stmt.alias() - total_count_stmt = select(func.count()).select_from(total_count_subquery) - total = (await self.conn.execute(total_count_stmt)).scalar_one() - - # Apply pagination - if page is not None: - if page < 1: - raise InvalidQueryError("Page must be a positive integer") - if per_page < 1: - raise InvalidQueryError("Per page must be a positive integer") - stmt = stmt.offset((page - 1) * per_page).limit(per_page) - - # Execute the query - return total, [ - dict(row._mapping) async for row in (await self.conn.stream(stmt)) - ] + return await self.search( + model=Jobs, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) async def create_job(self, compressed_original_jdl: str): """Used to insert a new job with original JDL. Returns inserted job id.""" diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index 279b227e7..cb86864c7 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -1,16 +1,14 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Any, Sequence +from typing import Any -from sqlalchemy import RowMapping, bindparam, func +from sqlalchemy import bindparam from sqlalchemy.exc import IntegrityError -from sqlalchemy.sql import delete, insert, select, update +from sqlalchemy.sql import delete, insert, update from diracx.core.exceptions import ( - InvalidQueryError, PilotAlreadyAssociatedWithJobError, - PilotJobsNotFoundError, PilotNotFoundError, ) from diracx.core.models import ( @@ -21,10 +19,6 @@ from ..utils import ( BaseSQLDB, - _get_columns, - apply_search_filters, - apply_sort_constraints, - fetch_records_bulk_or_raises, ) from .schema import ( JobToPilotMapping, @@ -43,7 +37,7 @@ async def add_pilots_bulk( pilot_stamps: list[str], vo: str, grid_type: str = "DIRAC", - pilot_references: dict | None = None, + pilot_references: dict[str, str] | None = None, ): """Bulk add pilots in the DB. @@ -85,7 +79,9 @@ async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]): if res.rowcount != len(pilot_stamps): raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)}) - async def associate_pilot_with_jobs(self, job_to_pilot_mapping: list[dict]): + async def associate_pilot_with_jobs( + self, job_to_pilot_mapping: list[dict[str, Any]] + ): """Associate a pilot with jobs. job_to_pilot_mapping format: @@ -182,61 +178,28 @@ async def update_pilot_fields_bulk( data={"mapping": str(pilot_stamps_to_fields_mapping)} ) - async def get_pilots_by_stamp_bulk( - self, pilot_stamps: list[str] - ) -> Sequence[RowMapping]: - """Bulk fetch pilots. - - Raises PilotNotFoundError if one of the stamp is not associated with a pilot. - - """ - results = await fetch_records_bulk_or_raises( - self.conn, - PilotAgents, - PilotNotFoundError, - "pilot_stamp", - "PilotStamp", - pilot_stamps, - allow_no_result=True, - ) - - # Custom handling, to see which pilot_stamp does not exist (if so, say which one) - found_keys = {row["PilotStamp"] for row in results} - missing = set(pilot_stamps) - found_keys - - if missing: - raise PilotNotFoundError( - data={"pilot_stamp": str(missing)}, - detail=str(missing), - non_existing_pilots=missing, - ) - - return results - - async def get_pilot_jobs_ids_by_pilot_id(self, pilot_id: int) -> list[int]: - """Fetch pilot jobs.""" - job_to_pilot_mapping = await fetch_records_bulk_or_raises( - self.conn, - JobToPilotMapping, - PilotJobsNotFoundError, - "pilot_id", - "PilotID", - [pilot_id], - allow_more_than_one_result_per_input=True, - allow_no_result=True, + async def search_pilots( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for pilots in the database.""" + return await self.search( + model=PilotAgents, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, ) - return [mapping["JobID"] for mapping in job_to_pilot_mapping] - - async def get_pilot_ids_by_stamps(self, pilot_stamps: list[str]) -> list[int]: - """Get pilot ids.""" - # This function is currently needed while we are relying on pilot_ids instead of pilot_stamps - # (Ex: JobToPilotMapping) - pilots = await self.get_pilots_by_stamp_bulk(pilot_stamps) - - return [pilot["PilotID"] for pilot in pilots] - - async def search( + async def search_pilot_to_job_mapping( self, parameters: list[str] | None, search: list[SearchSpec], @@ -247,39 +210,15 @@ async def search( page: int | None = None, ) -> tuple[int, list[dict[Any, Any]]]: """Search for pilots in the database.""" - # TODO: Refactorize with the search function for jobs. - # Find which columns to select - columns = _get_columns(PilotAgents.__table__, parameters) - - stmt = select(*columns) - - stmt = apply_search_filters( - PilotAgents.__table__.columns.__getitem__, stmt, search + return await self.search( + model=JobToPilotMapping, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, ) - stmt = apply_sort_constraints( - PilotAgents.__table__.columns.__getitem__, stmt, sorts - ) - - if distinct: - stmt = stmt.distinct() - - # Calculate total count before applying pagination - total_count_subquery = stmt.alias() - total_count_stmt = select(func.count()).select_from(total_count_subquery) - total = (await self.conn.execute(total_count_stmt)).scalar_one() - - # Apply pagination - if page is not None: - if page < 1: - raise InvalidQueryError("Page must be a positive integer") - if per_page < 1: - raise InvalidQueryError("Per page must be a positive integer") - stmt = stmt.offset((page - 1) * per_page).limit(per_page) - - # Execute the query - return total, [ - dict(row._mapping) async for row in (await self.conn.stream(stmt)) - ] async def clear_pilots_bulk( self, cutoff_date: datetime, delete_only_aborted: bool diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index e3d0747a1..53b3f3c96 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -3,12 +3,11 @@ from .base import ( BaseSQLDB, SQLDBUnavailableError, + _get_columns, apply_search_filters, apply_sort_constraints, ) from .functions import ( - _get_columns, - fetch_records_bulk_or_raises, hash, substract_date, utcnow, @@ -24,7 +23,6 @@ "DateNowColumn", "EnumBackedBool", "EnumColumn", - "fetch_records_bulk_or_raises", "hash", "NullColumn", "substract_date", diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index 6286364af..dc10754fc 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -8,16 +8,16 @@ from collections.abc import AsyncIterator from contextvars import ContextVar from datetime import datetime -from typing import Self, cast +from typing import Any, Self, cast from pydantic import TypeAdapter -from sqlalchemy import DateTime, MetaData, select +from sqlalchemy import DateTime, MetaData, func, select from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension -from diracx.core.models import SortDirection +from diracx.core.models import SearchSpec, SortDirection, SortSpec from diracx.core.settings import SqlalchemyDsn from diracx.db.exceptions import DBUnavailableError @@ -227,6 +227,47 @@ async def ping(self): except OperationalError as e: raise SQLDBUnavailableError("Cannot ping the DB") from e + async def search( + self, + model: Any, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for pilots in the database.""" + # Find which columns to select + columns = _get_columns(model.__table__, parameters) + + stmt = select(*columns) + + stmt = apply_search_filters(model.__table__.columns.__getitem__, stmt, search) + stmt = apply_sort_constraints(model.__table__.columns.__getitem__, stmt, sorts) + + if distinct: + stmt = stmt.distinct() + + # Calculate total count before applying pagination + total_count_subquery = stmt.alias() + total_count_stmt = select(func.count()).select_from(total_count_subquery) + total = (await self.conn.execute(total_count_stmt)).scalar_one() + + # Apply pagination + if page is not None: + if page < 1: + raise InvalidQueryError("Page must be a positive integer") + if per_page < 1: + raise InvalidQueryError("Per page must be a positive integer") + stmt = stmt.offset((page - 1) * per_page).limit(per_page) + + # Execute the query + return total, [ + dict(row._mapping) async for row in (await self.conn.stream(stmt)) + ] + def find_time_resolution(value): if isinstance(value, datetime): @@ -258,6 +299,17 @@ def find_time_resolution(value): raise InvalidQueryError(f"Cannot parse {value=}") +def _get_columns(table, parameters): + columns = [x for x in table.columns] + if parameters: + if unrecognised_parameters := set(parameters) - set(table.columns.keys()): + raise InvalidQueryError( + f"Unrecognised parameters requested {unrecognised_parameters}" + ) + columns = [c for c in columns if c.name in parameters] + return columns + + def apply_search_filters(column_mapping, stmt, search): for query in search: try: diff --git a/diracx-db/src/diracx/db/sql/utils/functions.py b/diracx-db/src/diracx/db/sql/utils/functions.py index 536412406..34cb2a0da 100644 --- a/diracx-db/src/diracx/db/sql/utils/functions.py +++ b/diracx-db/src/diracx/db/sql/utils/functions.py @@ -2,30 +2,16 @@ import hashlib from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Sequence, Type +from typing import TYPE_CHECKING -from sqlalchemy import DateTime, RowMapping, asc, desc, func, select -from sqlalchemy.ext.asyncio import AsyncConnection +from sqlalchemy import DateTime, func from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql import ColumnElement, expression - -from diracx.core.exceptions import DiracFormattedError, InvalidQueryError +from sqlalchemy.sql import expression if TYPE_CHECKING: from sqlalchemy.types import TypeEngine -def _get_columns(table, parameters): - columns = [x for x in table.columns] - if parameters: - if unrecognised_parameters := set(parameters) - set(table.columns.keys()): - raise InvalidQueryError( - f"Unrecognised parameters requested {unrecognised_parameters}" - ) - columns = [c for c in columns if c.name in parameters] - return columns - - class utcnow(expression.FunctionElement): # noqa: N801 type: TypeEngine = DateTime() inherit_cache: bool = True @@ -154,73 +140,3 @@ def substract_date(**kwargs: float) -> datetime: def hash(code: str): return hashlib.sha256(code.encode()).hexdigest() - - -def raw_hash(code: str): - return hashlib.sha256(code.encode()).digest() - - -async def fetch_records_bulk_or_raises( - conn: AsyncConnection, - model: Any, # Here, we currently must use `Any` because `declarative_base()` returns any - missing_elements_error_cls: Type[DiracFormattedError], - column_attribute_name: str, - column_name: str, - elements_to_fetch: list, - order_by: tuple[str, str] | None = None, - allow_more_than_one_result_per_input: bool = False, - allow_no_result: bool = False, -) -> Sequence[RowMapping]: - """Fetches a list of elements in a table, returns a list of elements. - All elements from the `element_to_fetch` **must** be present. - Raises the specified error if at least one is missing. - - Example: - fetch_records_bulk_or_raises( - self.conn, - PilotAgents, - PilotNotFound, - "pilot_id", - "PilotID", - [1,2,3] - ) - - """ - assert elements_to_fetch - - # Get the column that needs to be in elements_to_fetch - column = getattr(model, column_attribute_name) - - # Create the request - stmt = select(model).with_for_update().where(column.in_(elements_to_fetch)) - - if order_by: - column_name_to_order_by, direction = order_by - column_to_order_by = getattr(model, column_name_to_order_by) - - operator: ColumnElement = ( - asc(column_to_order_by) if direction == "asc" else desc(column_to_order_by) - ) - - stmt = stmt.order_by(operator) - - # Transform into dictionaries - raw_results = await conn.execute(stmt) - results = raw_results.mappings().all() - - # Detects duplicates - if not allow_more_than_one_result_per_input: - if len(results) > len(elements_to_fetch): - raise RuntimeError("Seems to have duplicates in the database.") - - if not allow_no_result: - # Checks if we have every elements we wanted - found_keys = {row[column_name] for row in results} - missing = set(elements_to_fetch) - found_keys - - if missing: - raise missing_elements_error_cls( - data={column_name: str(missing)}, detail=str(missing) - ) - - return results diff --git a/diracx-db/tests/jobs/test_job_db.py b/diracx-db/tests/jobs/test_job_db.py index e6ca58ce9..5ae49ad10 100644 --- a/diracx-db/tests/jobs/test_job_db.py +++ b/diracx-db/tests/jobs/test_job_db.py @@ -51,34 +51,34 @@ async def test_search_parameters(populated_job_db): """Test that we can search specific parameters for jobs in the database.""" async with populated_job_db as job_db: # Search a specific parameter: JobID - total, result = await job_db.search(["JobID"], [], []) + total, result = await job_db.search_jobs(["JobID"], [], []) assert total == 100 assert result for r in result: assert r.keys() == {"JobID"} # Search a specific parameter: Status - total, result = await job_db.search(["Status"], [], []) + total, result = await job_db.search_jobs(["Status"], [], []) assert total == 100 assert result for r in result: assert r.keys() == {"Status"} # Search for multiple parameters: JobID, Status - total, result = await job_db.search(["JobID", "Status"], [], []) + total, result = await job_db.search_jobs(["JobID", "Status"], [], []) assert total == 100 assert result for r in result: assert r.keys() == {"JobID", "Status"} # Search for a specific parameter but use distinct: Status - total, result = await job_db.search(["Status"], [], [], distinct=True) + total, result = await job_db.search_jobs(["Status"], [], [], distinct=True) assert total == 1 assert result # Search for a non-existent parameter: Dummy with pytest.raises(InvalidQueryError): - total, result = await job_db.search(["Dummy"], [], []) + total, result = await job_db.search_jobs(["Dummy"], [], []) async def test_search_conditions(populated_job_db): @@ -88,7 +88,7 @@ async def test_search_conditions(populated_job_db): condition = ScalarSearchSpec( parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=3 ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 1 assert result assert len(result) == 1 @@ -98,7 +98,7 @@ async def test_search_conditions(populated_job_db): condition = ScalarSearchSpec( parameter="JobID", operator=ScalarSearchOperator.LESS_THAN, value=3 ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 2 assert result assert len(result) == 2 @@ -109,7 +109,7 @@ async def test_search_conditions(populated_job_db): condition = ScalarSearchSpec( parameter="JobID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 99 assert result assert len(result) == 99 @@ -119,14 +119,14 @@ async def test_search_conditions(populated_job_db): condition = ScalarSearchSpec( parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=5873 ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert not result # Search a specific vector condition: JobID in 1,2,3 condition = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.IN, values=[1, 2, 3] ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 3 assert result assert len(result) == 3 @@ -136,7 +136,7 @@ async def test_search_conditions(populated_job_db): condition = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 2 assert result assert len(result) == 2 @@ -146,7 +146,7 @@ async def test_search_conditions(populated_job_db): condition = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 97 assert result assert len(result) == 97 @@ -156,7 +156,7 @@ async def test_search_conditions(populated_job_db): condition = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 5873] ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 98 assert result assert len(result) == 98 @@ -169,7 +169,7 @@ async def test_search_conditions(populated_job_db): condition2 = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.IN, values=[4, 5, 6] ) - total, result = await job_db.search([], [condition1, condition2], []) + total, result = await job_db.search_jobs([], [condition1, condition2], []) assert total == 1 assert result assert len(result) == 1 @@ -183,7 +183,7 @@ async def test_search_conditions(populated_job_db): condition2 = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.IN, values=[4, 5, 6] ) - total, result = await job_db.search([], [condition1, condition2], []) + total, result = await job_db.search_jobs([], [condition1, condition2], []) assert total == 0 assert not result @@ -193,7 +193,7 @@ async def test_search_sorts(populated_job_db): async with populated_job_db as job_db: # Search and sort by JobID in ascending order sort = SortSpec(parameter="JobID", direction=SortDirection.ASC) - total, result = await job_db.search([], [], [sort]) + total, result = await job_db.search_jobs([], [], [sort]) assert total == 100 assert result for i, r in enumerate(result): @@ -201,7 +201,7 @@ async def test_search_sorts(populated_job_db): # Search and sort by JobID in descending order sort = SortSpec(parameter="JobID", direction=SortDirection.DESC) - total, result = await job_db.search([], [], [sort]) + total, result = await job_db.search_jobs([], [], [sort]) assert total == 100 assert result for i, r in enumerate(result): @@ -209,7 +209,7 @@ async def test_search_sorts(populated_job_db): # Search and sort by Owner in ascending order sort = SortSpec(parameter="Owner", direction=SortDirection.ASC) - total, result = await job_db.search([], [], [sort]) + total, result = await job_db.search_jobs([], [], [sort]) assert total == 100 assert result # Assert that owner10 is before owner2 because of the lexicographical order @@ -218,7 +218,7 @@ async def test_search_sorts(populated_job_db): # Search and sort by Owner in descending order sort = SortSpec(parameter="Owner", direction=SortDirection.DESC) - total, result = await job_db.search([], [], [sort]) + total, result = await job_db.search_jobs([], [], [sort]) assert total == 100 assert result # Assert that owner10 is before owner2 because of the lexicographical order @@ -228,7 +228,7 @@ async def test_search_sorts(populated_job_db): # Search and sort by OwnerGroup in ascending order and JobID in descending order sort1 = SortSpec(parameter="OwnerGroup", direction=SortDirection.ASC) sort2 = SortSpec(parameter="JobID", direction=SortDirection.DESC) - total, result = await job_db.search([], [], [sort1, sort2]) + total, result = await job_db.search_jobs([], [], [sort1, sort2]) assert total == 100 assert result assert result[0]["OwnerGroup"] == "owner_group1" @@ -241,45 +241,45 @@ async def test_search_pagination(populated_job_db): """Test that we can search for jobs in the database.""" async with populated_job_db as job_db: # Search for the first 10 jobs - total, result = await job_db.search([], [], [], per_page=10, page=1) + total, result = await job_db.search_jobs([], [], [], per_page=10, page=1) assert total == 100 assert result assert len(result) == 10 assert result[0]["JobID"] == 1 # Search for the second 10 jobs - total, result = await job_db.search([], [], [], per_page=10, page=2) + total, result = await job_db.search_jobs([], [], [], per_page=10, page=2) assert total == 100 assert result assert len(result) == 10 assert result[0]["JobID"] == 11 # Search for the last 10 jobs - total, result = await job_db.search([], [], [], per_page=10, page=10) + total, result = await job_db.search_jobs([], [], [], per_page=10, page=10) assert total == 100 assert result assert len(result) == 10 assert result[0]["JobID"] == 91 # Search for the second 50 jobs - total, result = await job_db.search([], [], [], per_page=50, page=2) + total, result = await job_db.search_jobs([], [], [], per_page=50, page=2) assert total == 100 assert result assert len(result) == 50 assert result[0]["JobID"] == 51 # Invalid page number - total, result = await job_db.search([], [], [], per_page=10, page=11) + total, result = await job_db.search_jobs([], [], [], per_page=10, page=11) assert total == 100 assert not result # Invalid page number with pytest.raises(InvalidQueryError): - result = await job_db.search([], [], [], per_page=10, page=0) + result = await job_db.search_jobs([], [], [], per_page=10, page=0) # Invalid per_page number with pytest.raises(InvalidQueryError): - result = await job_db.search([], [], [], per_page=0, page=1) + result = await job_db.search_jobs([], [], [], per_page=0, page=1) async def test_set_job_commands_invalid_job_id(job_db: JobDB): diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py index 18fa1119c..f73b39bdb 100644 --- a/diracx-db/tests/pilots/test_pilot_management.py +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import datetime, timezone +from typing import Any import pytest from sqlalchemy.sql import update @@ -9,7 +10,13 @@ PilotAlreadyAssociatedWithJobError, PilotNotFoundError, ) -from diracx.core.models import PilotFieldsMapping +from diracx.core.models import ( + PilotFieldsMapping, + ScalarSearchOperator, + ScalarSearchSpec, + VectorSearchOperator, + VectorSearchSpec, +) from diracx.db.sql.pilots.db import PilotAgentsDB from diracx.db.sql.pilots.schema import PilotAgents @@ -26,6 +33,57 @@ async def pilot_db(tmp_path): yield agents_db +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=10000, + ) + + return [job["JobID"] for job in jobs] + + +async def get_pilots_by_stamp_bulk( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=1000, + ) + + # Custom handling, to see which pilot_stamp does not exist (if so, say which one) + found_keys = {row["PilotStamp"] for row in pilots} + missing = set(pilot_stamps) - found_keys + + if missing: + raise PilotNotFoundError( + data={"pilot_stamp": str(missing)}, + detail=str(missing), + non_existing_pilots=missing, + ) + + return pilots + + @pytest.fixture async def add_stamps(pilot_db): async def _add_stamps(start_n=0): @@ -41,7 +99,7 @@ async def _add_stamps(start_n=0): stamps, vo, grid_type="DIRAC", pilot_references=pilot_references ) - pilots = await db.get_pilots_by_stamp_bulk(stamps) + pilots = await get_pilots_by_stamp_bulk(db, stamps) return pilots @@ -73,7 +131,7 @@ async def _create_timed_pilots( res = await db.conn.execute(stmt) assert res.rowcount == len(pilot_stamps) - pilots = await db.get_pilots_by_stamp_bulk(pilot_stamps) + pilots = await get_pilots_by_stamp_bulk(db, pilot_stamps) return pilots return _create_timed_pilots @@ -107,10 +165,12 @@ async def create_old_pilots_environment(pilot_db, create_timed_pilots): # Phase 0. Verify that we have the right environment async with pilot_db as pilot_db: # Ensure that we can get every pilot (only get first of each group) - await pilot_db.get_pilots_by_stamp_bulk([non_aborted_recent[0]["PilotStamp"]]) - await pilot_db.get_pilots_by_stamp_bulk([aborted_recent[0]["PilotStamp"]]) - await pilot_db.get_pilots_by_stamp_bulk([aborted_very_old[0]["PilotStamp"]]) - await pilot_db.get_pilots_by_stamp_bulk([non_aborted_very_old[0]["PilotStamp"]]) + await get_pilots_by_stamp_bulk(pilot_db, [non_aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp_bulk(pilot_db, [aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp_bulk(pilot_db, [aborted_very_old[0]["PilotStamp"]]) + await get_pilots_by_stamp_bulk( + pilot_db, [non_aborted_very_old[0]["PilotStamp"]] + ) return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old @@ -146,17 +206,17 @@ async def test_insert_and_delete(pilot_db: PilotAgentsDB): ) # Works, the pilots exists - await pilot_db.get_pilots_by_stamp_bulk([stamps[0]]) - await pilot_db.get_pilots_by_stamp_bulk([stamps[0]]) + await get_pilots_by_stamp_bulk(pilot_db, [stamps[0]]) + await get_pilots_by_stamp_bulk(pilot_db, [stamps[0]]) # We delete the first pilot await pilot_db.delete_pilots_by_stamps_bulk([stamps[0]]) # We get the 2nd pilot that is not delete (no error) - await pilot_db.get_pilots_by_stamp_bulk([stamps[1]]) + await get_pilots_by_stamp_bulk(pilot_db, [stamps[1]]) # We get the 1st pilot that is delete (error) with pytest.raises(PilotNotFoundError): - await pilot_db.get_pilots_by_stamp_bulk([stamps[0]]) + await get_pilots_by_stamp_bulk(pilot_db, [stamps[0]]) @pytest.mark.asyncio @@ -182,14 +242,14 @@ async def test_insert_and_delete_only_old_aborted( ]: stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] - await pilot_db.get_pilots_by_stamp_bulk(stamps) + await get_pilots_by_stamp_bulk(pilot_db, stamps) # Assert who normally does not live for normally_deleted_pilot_list in [aborted_very_old]: stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] with pytest.raises(PilotNotFoundError): - await pilot_db.get_pilots_by_stamp_bulk(stamps) + await get_pilots_by_stamp_bulk(pilot_db, stamps) @pytest.mark.asyncio @@ -214,7 +274,7 @@ async def test_insert_and_delete_old( ]: stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] - await pilot_db.get_pilots_by_stamp_bulk(stamps) + await get_pilots_by_stamp_bulk(pilot_db, stamps) # Assert who normally does not live for normally_deleted_pilot_list in [ @@ -224,7 +284,7 @@ async def test_insert_and_delete_old( stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] with pytest.raises(PilotNotFoundError): - await pilot_db.get_pilots_by_stamp_bulk(stamps) + await get_pilots_by_stamp_bulk(pilot_db, stamps) @pytest.mark.asyncio @@ -246,7 +306,7 @@ async def test_insert_and_delete_recent_only_aborted( for normally_exiting_pilot_list in [non_aborted_recent, non_aborted_very_old]: stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] - await pilot_db.get_pilots_by_stamp_bulk(stamps) + await get_pilots_by_stamp_bulk(pilot_db, stamps) # Assert who normally does not live for normally_deleted_pilot_list in [ @@ -256,7 +316,7 @@ async def test_insert_and_delete_recent_only_aborted( stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] with pytest.raises(PilotNotFoundError): - await pilot_db.get_pilots_by_stamp_bulk(stamps) + await get_pilots_by_stamp_bulk(pilot_db, stamps) @pytest.mark.asyncio @@ -284,7 +344,7 @@ async def test_insert_and_delete_recent( stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] with pytest.raises(PilotNotFoundError): - await pilot_db.get_pilots_by_stamp_bulk(stamps) + await get_pilots_by_stamp_bulk(pilot_db, stamps) @pytest.mark.asyncio @@ -297,7 +357,7 @@ async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): grid_type="grid-type", ) - res = await pilot_db.get_pilots_by_stamp_bulk([pilot_stamp]) + res = await get_pilots_by_stamp_bulk(pilot_db, [pilot_stamp]) assert len(res) == 1 pilot = res[0] @@ -325,7 +385,7 @@ async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): ] ) - res = await pilot_db.get_pilots_by_stamp_bulk([pilot_stamp]) + res = await get_pilots_by_stamp_bulk(pilot_db, [pilot_stamp]) assert len(res) == 1 pilot = res[0] @@ -359,13 +419,13 @@ async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): grid_type="grid-type", ) - res = await pilot_db.get_pilots_by_stamp_bulk([pilot_stamp]) + res = await get_pilots_by_stamp_bulk(pilot_db, [pilot_stamp]) assert len(res) == 1 pilot = res[0] pilot_id = pilot["PilotID"] # Verify that he has no jobs - assert len(await pilot_db.get_pilot_jobs_ids_by_pilot_id(pilot_id)) == 0 + assert len(await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id)) == 0 now = datetime.now(tz=timezone.utc) @@ -379,7 +439,7 @@ async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): await pilot_db.associate_pilot_with_jobs(job_to_pilot_mapping) # Verify that he has all jobs - db_jobs = await pilot_db.get_pilot_jobs_ids_by_pilot_id(pilot_id) + db_jobs = await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id) # We test both length and if every job is included if for any reason we have duplicates assert all(job in db_jobs for job in pilot_jobs) assert len(pilot_jobs) == len(db_jobs) diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py index e2511e169..c594017fe 100644 --- a/diracx-db/tests/pilots/test_query.py +++ b/diracx-db/tests/pilots/test_query.py @@ -51,8 +51,6 @@ async def populated_pilot_db(pilot_db): stamps, vo, grid_type="DIRAC", pilot_references=pilot_references ) - await pilot_db.get_pilots_by_stamp_bulk(stamps) - await pilot_db.update_pilot_fields_bulk( [ PilotFieldsMapping( @@ -75,34 +73,34 @@ async def test_search_parameters(populated_pilot_db): """Test that we can search specific parameters for pilots in the database.""" async with populated_pilot_db as pilot_db: # Search a specific parameter: PilotID - total, result = await pilot_db.search(["PilotID"], [], []) + total, result = await pilot_db.search_pilots(["PilotID"], [], []) assert total == N assert result for r in result: assert r.keys() == {"PilotID"} # Search a specific parameter: Status - total, result = await pilot_db.search(["Status"], [], []) + total, result = await pilot_db.search_pilots(["Status"], [], []) assert total == N assert result for r in result: assert r.keys() == {"Status"} # Search for multiple parameters: PilotID, Status - total, result = await pilot_db.search(["PilotID", "Status"], [], []) + total, result = await pilot_db.search_pilots(["PilotID", "Status"], [], []) assert total == N assert result for r in result: assert r.keys() == {"PilotID", "Status"} # Search for a specific parameter but use distinct: Status - total, result = await pilot_db.search(["Status"], [], [], distinct=True) + total, result = await pilot_db.search_pilots(["Status"], [], [], distinct=True) assert total == len(PILOT_STATUSES) assert result # Search for a non-existent parameter: Dummy with pytest.raises(InvalidQueryError): - total, result = await pilot_db.search(["Dummy"], [], []) + total, result = await pilot_db.search_pilots(["Dummy"], [], []) async def test_search_conditions(populated_pilot_db): @@ -112,7 +110,7 @@ async def test_search_conditions(populated_pilot_db): condition = ScalarSearchSpec( parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 ) - total, result = await pilot_db.search([], [condition], []) + total, result = await pilot_db.search_pilots([], [condition], []) assert total == 1 assert result assert len(result) == 1 @@ -122,7 +120,7 @@ async def test_search_conditions(populated_pilot_db): condition = ScalarSearchSpec( parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 ) - total, result = await pilot_db.search([], [condition], []) + total, result = await pilot_db.search_pilots([], [condition], []) assert total == 2 assert result assert len(result) == 2 @@ -133,7 +131,7 @@ async def test_search_conditions(populated_pilot_db): condition = ScalarSearchSpec( parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 ) - total, result = await pilot_db.search([], [condition], []) + total, result = await pilot_db.search_pilots([], [condition], []) assert total == 99 assert result assert len(result) == 99 @@ -143,14 +141,14 @@ async def test_search_conditions(populated_pilot_db): condition = ScalarSearchSpec( parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 ) - total, result = await pilot_db.search([], [condition], []) + total, result = await pilot_db.search_pilots([], [condition], []) assert not result # Search a specific vector condition: PilotID in 1,2,3 condition = VectorSearchSpec( parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] ) - total, result = await pilot_db.search([], [condition], []) + total, result = await pilot_db.search_pilots([], [condition], []) assert total == 3 assert result assert len(result) == 3 @@ -160,7 +158,7 @@ async def test_search_conditions(populated_pilot_db): condition = VectorSearchSpec( parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] ) - total, result = await pilot_db.search([], [condition], []) + total, result = await pilot_db.search_pilots([], [condition], []) assert total == 2 assert result assert len(result) == 2 @@ -170,7 +168,7 @@ async def test_search_conditions(populated_pilot_db): condition = VectorSearchSpec( parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] ) - total, result = await pilot_db.search([], [condition], []) + total, result = await pilot_db.search_pilots([], [condition], []) assert total == 97 assert result assert len(result) == 97 @@ -182,7 +180,7 @@ async def test_search_conditions(populated_pilot_db): operator=VectorSearchOperator.NOT_IN, values=[1, 2, 5873], ) - total, result = await pilot_db.search([], [condition], []) + total, result = await pilot_db.search_pilots([], [condition], []) assert total == 98 assert result assert len(result) == 98 @@ -195,7 +193,7 @@ async def test_search_conditions(populated_pilot_db): condition2 = VectorSearchSpec( parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] ) - total, result = await pilot_db.search([], [condition1, condition2], []) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) assert total == 1 assert result assert len(result) == 1 @@ -209,7 +207,7 @@ async def test_search_conditions(populated_pilot_db): condition2 = VectorSearchSpec( parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] ) - total, result = await pilot_db.search([], [condition1, condition2], []) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) assert total == 0 assert not result @@ -219,7 +217,7 @@ async def test_search_sorts(populated_pilot_db): async with populated_pilot_db as pilot_db: # Search and sort by PilotID in ascending order sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) - total, result = await pilot_db.search([], [], [sort]) + total, result = await pilot_db.search_pilots([], [], [sort]) assert total == N assert result for i, r in enumerate(result): @@ -227,7 +225,7 @@ async def test_search_sorts(populated_pilot_db): # Search and sort by PilotID in descending order sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) - total, result = await pilot_db.search([], [], [sort]) + total, result = await pilot_db.search_pilots([], [], [sort]) assert total == N assert result for i, r in enumerate(result): @@ -235,7 +233,7 @@ async def test_search_sorts(populated_pilot_db): # Search and sort by PilotStamp in ascending order sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) - total, result = await pilot_db.search([], [], [sort]) + total, result = await pilot_db.search_pilots([], [], [sort]) assert total == N assert result # Assert that stamp_10 is before stamp_2 because of the lexicographical order @@ -244,7 +242,7 @@ async def test_search_sorts(populated_pilot_db): # Search and sort by PilotStamp in descending order sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) - total, result = await pilot_db.search([], [], [sort]) + total, result = await pilot_db.search_pilots([], [], [sort]) assert total == N assert result # Assert that stamp_10 is before stamp_2 because of the lexicographical order @@ -254,7 +252,7 @@ async def test_search_sorts(populated_pilot_db): # Search and sort by PilotStamp in ascending order and PilotID in descending order sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) - total, result = await pilot_db.search([], [], [sort1, sort2]) + total, result = await pilot_db.search_pilots([], [], [sort1, sort2]) assert total == N assert result assert result[0]["PilotStamp"] == "stamp_1" @@ -287,9 +285,9 @@ async def test_search_pagination( async with populated_pilot_db as pilot_db: if expect_exception: with pytest.raises(expect_exception): - await pilot_db.search([], [], [], per_page=per_page, page=page) + await pilot_db.search_pilots([], [], [], per_page=per_page, page=page) else: - total, result = await pilot_db.search( + total, result = await pilot_db.search_pilots( [], [], [], per_page=per_page, page=page ) assert total == N diff --git a/diracx-logic/src/diracx/logic/jobs/query.py b/diracx-logic/src/diracx/logic/jobs/query.py index ba3e6269b..0ec9738cf 100644 --- a/diracx-logic/src/diracx/logic/jobs/query.py +++ b/diracx-logic/src/diracx/logic/jobs/query.py @@ -62,7 +62,7 @@ async def search( } ) - total, jobs = await job_db.search( + total, jobs = await job_db.search_jobs( body.parameters, body.search, body.sort, diff --git a/diracx-logic/src/diracx/logic/jobs/status.py b/diracx-logic/src/diracx/logic/jobs/status.py index 82b670137..90d62f819 100644 --- a/diracx-logic/src/diracx/logic/jobs/status.py +++ b/diracx-logic/src/diracx/logic/jobs/status.py @@ -124,7 +124,7 @@ async def set_job_statuses( } # search all jobs at once - _, results = await job_db.search( + _, results = await job_db.search_jobs( parameters=["Status", "StartExecTime", "EndExecTime", "JobID", "VO"], search=[ { @@ -291,7 +291,7 @@ async def reschedule_jobs( attribute_changes: defaultdict[int, dict[str, str]] = defaultdict(dict) jdl_changes = {} - _, results = await job_db.search( + _, results = await job_db.search_jobs( parameters=[ "Status", "MinorStatus", @@ -558,7 +558,7 @@ async def add_heartbeat( "operator": VectorSearchOperator.IN, "values": list(data), } - _, results = await job_db.search( + _, results = await job_db.search_jobs( parameters=["Status", "JobID"], search=[search_query], sorts=[] ) if len(results) != len(data): diff --git a/diracx-logic/src/diracx/logic/pilots/__init__.py b/diracx-logic/src/diracx/logic/pilots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index f1c0ee3c8..b51fe3b1f 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -6,6 +6,12 @@ from diracx.core.models import PilotFieldsMapping from diracx.db.sql import PilotAgentsDB +from .query import ( + get_pilot_ids_by_stamps, + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp_bulk, +) + async def register_new_pilots( pilot_db: PilotAgentsDB, @@ -17,7 +23,7 @@ async def register_new_pilots( # [IMPORTANT] Check unicity of pilot references # If a pilot already exists, it will undo everything and raise an error try: - await pilot_db.get_pilots_by_stamp_bulk(pilot_stamps=pilot_stamps) + await get_pilots_by_stamp_bulk(pilot_db=pilot_db, pilot_stamps=pilot_stamps) raise PilotAlreadyExistsError(data={"pilot_stamps": str(pilot_stamps)}) except PilotNotFoundError as e: # e.non_existing_pilots is set of the pilot that are not found @@ -67,8 +73,9 @@ async def update_pilots_fields( async def associate_pilot_with_jobs( pilot_db: PilotAgentsDB, pilot_stamp: str, pilot_jobs_ids: list[int] ): - pilot_ids = await pilot_db.get_pilot_ids_by_stamps([pilot_stamp]) - # Semantic assured by fetch_records_bulk_or_raises + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) pilot_id = pilot_ids[0] now = datetime.now(tz=timezone.utc) @@ -88,8 +95,9 @@ async def get_pilot_jobs_ids_by_stamp( pilot_db: PilotAgentsDB, pilot_stamp: str ) -> list[int]: """Fetch pilot jobs by stamp.""" - pilot_ids = await pilot_db.get_pilot_ids_by_stamps([pilot_stamp]) - # Semantic assured by fetch_records_bulk_or_raises + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) pilot_id = pilot_ids[0] - return await pilot_db.get_pilot_jobs_ids_by_pilot_id(pilot_id) + return await get_pilot_jobs_ids_by_pilot_id(pilot_db=pilot_db, pilot_id=pilot_id) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index dbdb686dc..578b12280 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -2,7 +2,14 @@ from typing import Any -from diracx.core.models import ScalarSearchOperator, SearchParams +from diracx.core.exceptions import PilotNotFoundError +from diracx.core.models import ( + ScalarSearchOperator, + ScalarSearchSpec, + SearchParams, + VectorSearchOperator, + VectorSearchSpec, +) from diracx.db.sql import PilotAgentsDB MAX_PER_PAGE = 10000 @@ -24,10 +31,12 @@ async def search( body = SearchParams() body.search.append( - {"parameter": "VO", "operator": ScalarSearchOperator.EQUAL, "value": user_vo} + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=user_vo + ) ) - total, pilots = await pilot_db.search( + total, pilots = await pilot_db.search_pilots( body.parameters, body.search, body.sort, @@ -37,3 +46,64 @@ async def search( ) return total, pilots + + +async def get_pilots_by_stamp_bulk( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + # Custom handling, to see which pilot_stamp does not exist (if so, say which one) + found_keys = {row["PilotStamp"] for row in pilots} + missing = set(pilot_stamps) - found_keys + + if missing: + raise PilotNotFoundError( + data={"pilot_stamp": str(missing)}, + detail=str(missing), + non_existing_pilots=missing, + ) + + return pilots + + +async def get_pilot_ids_by_stamps( + pilot_db: PilotAgentsDB, pilot_stamps: list[str] +) -> list[int]: + pilots = await get_pilots_by_stamp_bulk( + pilot_db=pilot_db, pilot_stamps=pilot_stamps, parameters=["PilotID"] + ) + + return [pilot["PilotID"] for pilot in pilots] + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [job["JobID"] for job in jobs] diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index 02d6de0e8..52e00e74c 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -9,6 +9,7 @@ from diracx.core.exceptions import PilotNotFoundError from diracx.core.properties import NORMAL_USER, TRUSTED_HOST from diracx.db.sql import PilotAgentsDB +from diracx.logic.pilots.query import get_pilots_by_stamp_bulk from diracx.routers.access_policies import BaseAccessPolicy from diracx.routers.utils.users import AuthorizedUserInfo @@ -57,7 +58,11 @@ async def policy( ) try: - pilots = await pilot_db.get_pilots_by_stamp_bulk(pilot_stamps) + pilots = await get_pilots_by_stamp_bulk( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["VO"], # For efficiency + ) except PilotNotFoundError as e: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, From a49b3830d05078d284ded9e78cea4b80f7babdd0 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Sun, 15 Jun 2025 21:09:40 +0200 Subject: [PATCH 03/33] fix: Fixed models and reoranizing db functions --- diracx-core/src/diracx/core/models.py | 12 +---- diracx-db/src/diracx/db/sql/pilots/db.py | 68 +++++++++++++----------- 2 files changed, 39 insertions(+), 41 deletions(-) diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index 93dba188b..4fddc90c3 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -274,17 +274,7 @@ class JobCommand(BaseModel): arguments: str | None = None -class PilotInfo(BaseModel): - sub: str - pilot_stamp: str - vo: str - - -class PilotStampInfo(BaseModel): - pilot_stamp: str - - -class PilotFieldsMapping(BaseModel): +class PilotFieldsMapping(BaseModel, extra="forbid"): """All the fields that a user can modify on a Pilot (except PilotStamp).""" PilotStamp: str diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index cb86864c7..d66d6a443 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -32,6 +32,8 @@ class PilotAgentsDB(BaseSQLDB): metadata = PilotAgentsDBBase.metadata + # ----------------------------- Insert Functions ----------------------------- + async def add_pilots_bulk( self, pilot_stamps: list[str], @@ -67,18 +69,6 @@ async def add_pilots_bulk( await self.conn.execute(stmt) - async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]): - """Bulk delete pilots. - - Raises PilotNotFound if one of the pilot was not found. - """ - stmt = delete(PilotAgents).where(PilotAgents.pilot_stamp.in_(pilot_stamps)) - - res = await self.conn.execute(stmt) - - if res.rowcount != len(pilot_stamps): - raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)}) - async def associate_pilot_with_jobs( self, job_to_pilot_mapping: list[dict[str, Any]] ): @@ -124,6 +114,40 @@ async def associate_pilot_with_jobs( "Engine Specific error not caught" + str(e) ) from e + # ----------------------------- Delete Functions ----------------------------- + + async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]): + """Bulk delete pilots. + + Raises PilotNotFound if one of the pilot was not found. + """ + stmt = delete(PilotAgents).where(PilotAgents.pilot_stamp.in_(pilot_stamps)) + + res = await self.conn.execute(stmt) + + if res.rowcount != len(pilot_stamps): + raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)}) + + async def clear_pilots_bulk( + self, cutoff_date: datetime, delete_only_aborted: bool + ) -> int: + """Bulk delete pilots that have SubmissionTime before the 'cutoff_date'. + Returns the number of deletion. + """ + # TODO: Add test (Millisec?) + stmt = delete(PilotAgents).where(PilotAgents.submission_time < cutoff_date) + + # If delete_only_aborted is True, add the condition for 'Status' being 'Aborted' + if delete_only_aborted: + stmt = stmt.where(PilotAgents.status == "Aborted") + + # Execute the statement + res = await self.conn.execute(stmt) + + return res.rowcount + + # ----------------------------- Update Functions ----------------------------- + async def update_pilot_fields_bulk( self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] ): @@ -178,6 +202,8 @@ async def update_pilot_fields_bulk( data={"mapping": str(pilot_stamps_to_fields_mapping)} ) + # ----------------------------- Search Functions ----------------------------- + async def search_pilots( self, parameters: list[str] | None, @@ -219,21 +245,3 @@ async def search_pilot_to_job_mapping( per_page=per_page, page=page, ) - - async def clear_pilots_bulk( - self, cutoff_date: datetime, delete_only_aborted: bool - ) -> int: - """Bulk delete pilots that have SubmissionTime before the 'cutoff_date'. - Returns the number of deletion. - """ - # TODO: Add test (Millisec?) - stmt = delete(PilotAgents).where(PilotAgents.submission_time < cutoff_date) - - # If delete_only_aborted is True, add the condition for 'Status' being 'Aborted' - if delete_only_aborted: - stmt = stmt.where(PilotAgents.status == "Aborted") - - # Execute the statement - res = await self.conn.execute(stmt) - - return res.rowcount From 76d4e7168e17e78cde82dd69a668e43ef3b4da1d Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Mon, 16 Jun 2025 11:22:14 +0200 Subject: [PATCH 04/33] fix: Fixed possible bad behaviour --- diracx-logic/src/diracx/logic/pilots/query.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index 578b12280..7f54cc1d9 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -51,6 +51,9 @@ async def search( async def get_pilots_by_stamp_bulk( pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] ) -> list[dict[Any, Any]]: + if parameters: + parameters.append("PilotStamp") + _, pilots = await pilot_db.search_pilots( parameters=parameters, search=[ From b15e6d8b8fe72410acb8191123bf81b375f55196 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Wed, 18 Jun 2025 10:36:59 +0200 Subject: [PATCH 05/33] fix: Small typo fix --- diracx-db/src/diracx/db/sql/utils/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index dc10754fc..143c0675c 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -238,7 +238,7 @@ async def search( per_page: int = 100, page: int | None = None, ) -> tuple[int, list[dict[Any, Any]]]: - """Search for pilots in the database.""" + """Search in a SQL database, with filters.""" # Find which columns to select columns = _get_columns(model.__table__, parameters) From e20a15c9378faae74fe74108a47021a446ee6586 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Thu, 19 Jun 2025 17:29:01 +0200 Subject: [PATCH 06/33] fix: Some minor fixes including adding PILOT_STATUS --- diracx-core/src/diracx/core/models.py | 21 ++++++++++++++++++- diracx-db/src/diracx/db/sql/pilots/db.py | 4 ++-- .../tests/pilots/test_pilot_management.py | 9 ++++---- diracx-db/tests/pilots/test_query.py | 3 ++- .../src/diracx/routers/pilots/management.py | 2 +- .../tests/pilots/test_pilot_creation.py | 6 +++--- diracx-routers/tests/pilots/test_query.py | 3 ++- 7 files changed, 35 insertions(+), 13 deletions(-) diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index 4fddc90c3..6ed7cd9ff 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -279,7 +279,7 @@ class PilotFieldsMapping(BaseModel, extra="forbid"): PilotStamp: str StatusReason: Optional[str] = None - Status: Optional[str] = None + Status: Optional[PilotStatus] = None BenchMark: Optional[float] = None DestinationSite: Optional[str] = None Queue: Optional[str] = None @@ -287,3 +287,22 @@ class PilotFieldsMapping(BaseModel, extra="forbid"): GridType: Optional[str] = None AccountingSent: Optional[bool] = None CurrentJobID: Optional[int] = None + + +class PilotStatus(StrEnum): + #: The pilot has been generated and is transferred to a remote site: + SUBMITTED = "Submitted" + #: The pilot is waiting for a computing resource in a batch queue: + WAITING = "Waiting" + #: The pilot is running a payload on a worker node: + RUNNING = "Running" + #: The pilot finished its execution: + DONE = "Done" + #: The pilot execution failed: + FAILED = "Failed" + #: The pilot was deleted: + DELETED = "Deleted" + #: The pilot execution was aborted: + ABORTED = "Aborted" + #: Cannot get information about the pilot status: + UNKNOWN = "Unknown" diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index d66d6a443..860ef9c8e 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -98,7 +98,7 @@ async def associate_pilot_with_jobs( if "foreign key" in str(e.orig).lower(): raise PilotNotFoundError( data={"pilot_stamps": str(job_to_pilot_mapping)}, - detail="at least one of these pilots does not exist", + detail="at least one of these pilots do not exist", ) from e if ( @@ -129,7 +129,7 @@ async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]): raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)}) async def clear_pilots_bulk( - self, cutoff_date: datetime, delete_only_aborted: bool + self, cutoff_date: datetime, delete_only_aborted: bool = False ) -> int: """Bulk delete pilots that have SubmissionTime before the 'cutoff_date'. Returns the number of deletion. diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py index f73b39bdb..5f92a0485 100644 --- a/diracx-db/tests/pilots/test_pilot_management.py +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -12,6 +12,7 @@ ) from diracx.core.models import ( PilotFieldsMapping, + PilotStatus, ScalarSearchOperator, ScalarSearchSpec, VectorSearchOperator, @@ -126,7 +127,7 @@ async def _create_timed_pilots( ) if aborted: - stmt = stmt.values(Status="Aborted") + stmt = stmt.values(Status=PilotStatus.ABORTED) res = await db.conn.execute(stmt) assert res.rowcount == len(pilot_stamps) @@ -366,7 +367,7 @@ async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): assert pilot["PilotStamp"] == pilot_stamp assert pilot["GridType"] == "grid-type" assert pilot["BenchMark"] == 0.0 - assert pilot["Status"] == "Submitted" + assert pilot["Status"] == PilotStatus.SUBMITTED assert pilot["StatusReason"] == "Unknown" assert not pilot["AccountingSent"] @@ -380,7 +381,7 @@ async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): BenchMark=1.0, StatusReason="NewReason", AccountingSent=True, - Status="WAITING", + Status=PilotStatus.WAITING, ) ] ) @@ -394,7 +395,7 @@ async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): assert pilot["PilotStamp"] == pilot_stamp assert pilot["GridType"] == "grid-type" assert pilot["BenchMark"] == 1.0 - assert pilot["Status"] == "WAITING" + assert pilot["Status"] == PilotStatus.WAITING assert pilot["StatusReason"] == "NewReason" assert pilot["AccountingSent"] diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py index c594017fe..4b36e6ec5 100644 --- a/diracx-db/tests/pilots/test_query.py +++ b/diracx-db/tests/pilots/test_query.py @@ -5,6 +5,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import ( PilotFieldsMapping, + PilotStatus, ScalarSearchOperator, ScalarSearchSpec, SortDirection, @@ -34,7 +35,7 @@ async def pilot_db(tmp_path): "I was sleeping", ] -PILOT_STATUSES = ["Started", "Stopped", "Waiting"] +PILOT_STATUSES = list(PilotStatus) @pytest.fixture diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 2d959e378..174a6ce46 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -129,7 +129,7 @@ async def clear_pilots( "It is set by default as True to avoid any mistake." ) ), - ] = True, + ] = False, ): """Endpoint for DIRAC to delete all pilots that lived more than age_in_days.""" await check_permissions() diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index b4e5d2eec..7a42be399 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -2,7 +2,7 @@ import pytest -from diracx.core.models import PilotFieldsMapping +from diracx.core.models import PilotFieldsMapping, PilotStatus pytestmark = pytest.mark.enabled_dependencies( [ @@ -151,7 +151,7 @@ async def test_create_pilot_and_modify_it(normal_test_client): BenchMark=1.0, StatusReason="NewReason", AccountingSent=True, - Status="Waiting", + Status=PilotStatus.WAITING, ).model_dump(exclude_unset=True) ] } @@ -175,7 +175,7 @@ async def test_create_pilot_and_modify_it(normal_test_client): assert pilot1["BenchMark"] == 1.0 assert pilot1["StatusReason"] == "NewReason" assert pilot1["AccountingSent"] - assert pilot1["Status"] == "Waiting" + assert pilot1["Status"] == PilotStatus.WAITING assert pilot2["BenchMark"] != pilot1["BenchMark"] assert pilot2["StatusReason"] != pilot1["StatusReason"] diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py index fe92ef0b9..981126798 100644 --- a/diracx-routers/tests/pilots/test_query.py +++ b/diracx-routers/tests/pilots/test_query.py @@ -7,6 +7,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import ( PilotFieldsMapping, + PilotStatus, ScalarSearchOperator, ScalarSearchSpec, SortDirection, @@ -42,7 +43,7 @@ def normal_test_client(client_factory): "I was sleeping", ] -PILOT_STATUSES = ["Started", "Stopped", "Waiting"] +PILOT_STATUSES = list(PilotStatus) @pytest.fixture From 07274f29ac5268f3e67d671ffbfb906b6928753d Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Mon, 23 Jun 2025 10:45:14 +0200 Subject: [PATCH 07/33] fix: Small fixes and rename function about pilot and job association --- .../_generated/aio/operations/_operations.py | 32 ++++---- .../client/_generated/models/__init__.py | 6 +- .../diracx/client/_generated/models/_enums.py | 13 ++++ .../client/_generated/models/_models.py | 78 ++++++++++--------- .../_generated/operations/_operations.py | 36 ++++----- diracx-db/src/diracx/db/sql/pilots/db.py | 9 +-- diracx-db/src/diracx/db/sql/pilots/schema.py | 4 +- .../src/diracx/logic/pilots/management.py | 4 +- .../src/diracx/routers/pilots/management.py | 6 +- .../_generated/aio/operations/_operations.py | 32 ++++---- .../client/_generated/models/__init__.py | 6 +- .../client/_generated/models/_enums.py | 13 ++++ .../client/_generated/models/_models.py | 78 ++++++++++--------- .../_generated/operations/_operations.py | 36 ++++----- 14 files changed, 190 insertions(+), 163 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 85b8cc406..bfa1bf1de 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -52,8 +52,8 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, + build_pilots_add_jobs_to_pilot_request, build_pilots_add_pilot_stamps_request, - build_pilots_associate_pilot_with_jobs_request, build_pilots_clear_pilots_request, build_pilots_delete_pilots_request, build_pilots_search_request, @@ -2435,7 +2435,7 @@ async def update_pilot_fields( return cls(pipeline_response, None, {}) # type: ignore @distributed_trace_async - async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any) -> None: + async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any) -> None: """Clear Pilots. Endpoint for DIRAC to delete all pilots that lived more than age_in_days. @@ -2445,7 +2445,7 @@ async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = Tr :paramtype age_in_days: int :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake. Default value is True. + default as True to avoid any mistake. Default value is False. :paramtype delete_only_aborted: bool :return: None :rtype: None @@ -2487,15 +2487,15 @@ async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = Tr return cls(pipeline_response, None, {}) # type: ignore @overload - async def associate_pilot_with_jobs( - self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any + async def add_jobs_to_pilot( + self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any ) -> None: - """Associate Pilot With Jobs. + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. :param body: Required. - :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs + :type body: ~_generated.models.BodyPilotsAddJobsToPilot :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2505,10 +2505,10 @@ async def associate_pilot_with_jobs( """ @overload - async def associate_pilot_with_jobs( + async def add_jobs_to_pilot( self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any ) -> None: - """Associate Pilot With Jobs. + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. @@ -2523,15 +2523,13 @@ async def associate_pilot_with_jobs( """ @distributed_trace_async - async def associate_pilot_with_jobs( - self, body: Union[_models.BodyPilotsAssociatePilotWithJobs, IO[bytes]], **kwargs: Any - ) -> None: - """Associate Pilot With Jobs. + async def add_jobs_to_pilot(self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any) -> None: + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. - :param body: Is either a BodyPilotsAssociatePilotWithJobs type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs or IO[bytes] + :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -2556,9 +2554,9 @@ async def associate_pilot_with_jobs( if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "BodyPilotsAssociatePilotWithJobs") + _json = self._serialize.body(body, "BodyPilotsAddJobsToPilot") - _request = build_pilots_associate_pilot_with_jobs_request( + _request = build_pilots_add_jobs_to_pilot_request( content_type=content_type, json=_json, content=_content, diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index c6f8fb19a..5a6c6e047 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -14,8 +14,8 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, + BodyPilotsAddJobsToPilot, BodyPilotsAddPilotStamps, - BodyPilotsAssociatePilotWithJobs, BodyPilotsUpdatePilotFields, GroupInfo, HTTPValidationError, @@ -52,6 +52,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -65,8 +66,8 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", + "BodyPilotsAddJobsToPilot", "BodyPilotsAddPilotStamps", - "BodyPilotsAssociatePilotWithJobs", "BodyPilotsUpdatePilotFields", "GroupInfo", "HTTPValidationError", @@ -100,6 +101,7 @@ "VectorSearchSpecValues", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/diracx-client/src/diracx/client/_generated/models/_enums.py b/diracx-client/src/diracx/client/_generated/models/_enums.py index 8098c62f4..44da9887d 100644 --- a/diracx-client/src/diracx/client/_generated/models/_enums.py +++ b/diracx-client/src/diracx/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 9a224f824..81a5760bf 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -94,6 +94,39 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" +class BodyPilotsAddJobsToPilot(_serialization.Model): + """Body_pilots_add_jobs_to_pilot. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: The stamp of the pilot. Required. + :vartype pilot_stamp: str + :ivar pilot_jobs_ids: The jobs we want to add to the pilot. Required. + :vartype pilot_jobs_ids: list[int] + """ + + _validation = { + "pilot_stamp": {"required": True}, + "pilot_jobs_ids": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "pilot_jobs_ids": {"key": "pilot_jobs_ids", "type": "[int]"}, + } + + def __init__(self, *, pilot_stamp: str, pilot_jobs_ids: List[int], **kwargs: Any) -> None: + """ + :keyword pilot_stamp: The stamp of the pilot. Required. + :paramtype pilot_stamp: str + :keyword pilot_jobs_ids: The jobs we want to add to the pilot. Required. + :paramtype pilot_jobs_ids: list[int] + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.pilot_jobs_ids = pilot_jobs_ids + + class BodyPilotsAddPilotStamps(_serialization.Model): """Body_pilots_add_pilot_stamps. @@ -147,39 +180,6 @@ def __init__( self.pilot_references = pilot_references -class BodyPilotsAssociatePilotWithJobs(_serialization.Model): - """Body_pilots_associate_pilot_with_jobs. - - All required parameters must be populated in order to send to server. - - :ivar pilot_stamp: The stamp of the pilot. Required. - :vartype pilot_stamp: str - :ivar pilot_jobs_ids: The jobs we want to add to the pilot. Required. - :vartype pilot_jobs_ids: list[int] - """ - - _validation = { - "pilot_stamp": {"required": True}, - "pilot_jobs_ids": {"required": True}, - } - - _attribute_map = { - "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, - "pilot_jobs_ids": {"key": "pilot_jobs_ids", "type": "[int]"}, - } - - def __init__(self, *, pilot_stamp: str, pilot_jobs_ids: List[int], **kwargs: Any) -> None: - """ - :keyword pilot_stamp: The stamp of the pilot. Required. - :paramtype pilot_stamp: str - :keyword pilot_jobs_ids: The jobs we want to add to the pilot. Required. - :paramtype pilot_jobs_ids: list[int] - """ - super().__init__(**kwargs) - self.pilot_stamp = pilot_stamp - self.pilot_jobs_ids = pilot_jobs_ids - - class BodyPilotsUpdatePilotFields(_serialization.Model): """Body_pilots_update_pilot_fields. @@ -727,8 +727,9 @@ class PilotFieldsMapping(_serialization.Model): :vartype pilot_stamp: str :ivar status_reason: Statusreason. :vartype status_reason: str - :ivar status: Status. - :vartype status: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus :ivar bench_mark: Benchmark. :vartype bench_mark: float :ivar destination_site: Destinationsite. @@ -767,7 +768,7 @@ def __init__( *, pilot_stamp: str, status_reason: Optional[str] = None, - status: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, bench_mark: Optional[float] = None, destination_site: Optional[str] = None, queue: Optional[str] = None, @@ -782,8 +783,9 @@ def __init__( :paramtype pilot_stamp: str :keyword status_reason: Statusreason. :paramtype status_reason: str - :keyword status: Status. - :paramtype status: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus :keyword bench_mark: Benchmark. :paramtype bench_mark: float :keyword destination_site: Destinationsite. diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 26353b973..9554ba7fe 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -634,7 +634,7 @@ def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: def build_pilots_clear_pilots_request( - *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any + *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -649,7 +649,7 @@ def build_pilots_clear_pilots_request( return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) -def build_pilots_associate_pilot_with_jobs_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long +def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) @@ -3050,7 +3050,7 @@ def update_pilot_fields( # pylint: disable=inconsistent-return-statements @distributed_trace def clear_pilots( # pylint: disable=inconsistent-return-statements - self, *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any + self, *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any ) -> None: """Clear Pilots. @@ -3061,7 +3061,7 @@ def clear_pilots( # pylint: disable=inconsistent-return-statements :paramtype age_in_days: int :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake. Default value is True. + default as True to avoid any mistake. Default value is False. :paramtype delete_only_aborted: bool :return: None :rtype: None @@ -3103,15 +3103,15 @@ def clear_pilots( # pylint: disable=inconsistent-return-statements return cls(pipeline_response, None, {}) # type: ignore @overload - def associate_pilot_with_jobs( - self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any + def add_jobs_to_pilot( + self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any ) -> None: - """Associate Pilot With Jobs. + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. :param body: Required. - :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs + :type body: ~_generated.models.BodyPilotsAddJobsToPilot :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -3121,10 +3121,8 @@ def associate_pilot_with_jobs( """ @overload - def associate_pilot_with_jobs( - self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any - ) -> None: - """Associate Pilot With Jobs. + def add_jobs_to_pilot(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. @@ -3139,15 +3137,15 @@ def associate_pilot_with_jobs( """ @distributed_trace - def associate_pilot_with_jobs( # pylint: disable=inconsistent-return-statements - self, body: Union[_models.BodyPilotsAssociatePilotWithJobs, IO[bytes]], **kwargs: Any + def add_jobs_to_pilot( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any ) -> None: - """Associate Pilot With Jobs. + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. - :param body: Is either a BodyPilotsAssociatePilotWithJobs type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs or IO[bytes] + :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -3172,9 +3170,9 @@ def associate_pilot_with_jobs( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "BodyPilotsAssociatePilotWithJobs") + _json = self._serialize.body(body, "BodyPilotsAddJobsToPilot") - _request = build_pilots_associate_pilot_with_jobs_request( + _request = build_pilots_add_jobs_to_pilot_request( content_type=content_type, json=_json, content=_content, diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index 860ef9c8e..b788458c0 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -13,6 +13,7 @@ ) from diracx.core.models import ( PilotFieldsMapping, + PilotStatus, SearchSpec, SortSpec, ) @@ -69,9 +70,7 @@ async def add_pilots_bulk( await self.conn.execute(stmt) - async def associate_pilot_with_jobs( - self, job_to_pilot_mapping: list[dict[str, Any]] - ): + async def add_jobs_to_pilot_bulk(self, job_to_pilot_mapping: list[dict[str, Any]]): """Associate a pilot with jobs. job_to_pilot_mapping format: @@ -83,7 +82,7 @@ async def associate_pilot_with_jobs( Raises: - PilotNotFoundError if a pilot_id is not associated with a pilot. - - PilotAlreadyAssociatedWithJobError if the pilot is already associated with a job. + - PilotAlreadyAssociatedWithJobError if the pilot is already associated with one of the given jobs. - NotImplementedError if the integrity error is not caught. **Important note**: We assume that a job exists. @@ -139,7 +138,7 @@ async def clear_pilots_bulk( # If delete_only_aborted is True, add the condition for 'Status' being 'Aborted' if delete_only_aborted: - stmt = stmt.where(PilotAgents.status == "Aborted") + stmt = stmt.where(PilotAgents.status == PilotStatus.ABORTED) # Execute the statement res = await self.conn.execute(stmt) diff --git a/diracx-db/src/diracx/db/sql/pilots/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py index 032e36510..af087f1f8 100644 --- a/diracx-db/src/diracx/db/sql/pilots/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -10,6 +10,8 @@ ) from sqlalchemy.orm import declarative_base +from diracx.core.models import PilotStatus + from ..utils import Column, EnumBackedBool, NullColumn PilotAgentsDBBase = declarative_base() @@ -31,7 +33,7 @@ class PilotAgents(PilotAgentsDBBase): benchmark = Column("BenchMark", Double, default=0.0) submission_time = NullColumn("SubmissionTime", DateTime) last_update_time = NullColumn("LastUpdateTime", DateTime) - status = Column("Status", String(32), default="Unknown") + status = Column("Status", String(32), default=PilotStatus.UNKNOWN) status_reason = Column("StatusReason", String(255), default="Unknown") accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False) diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index b51fe3b1f..1e098d122 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -70,7 +70,7 @@ async def update_pilots_fields( await pilot_db.update_pilot_fields_bulk(pilot_stamps_to_fields_mapping) -async def associate_pilot_with_jobs( +async def add_jobs_to_pilot( pilot_db: PilotAgentsDB, pilot_stamp: str, pilot_jobs_ids: list[int] ): pilot_ids = await get_pilot_ids_by_stamps( @@ -86,7 +86,7 @@ async def associate_pilot_with_jobs( for job_id in pilot_jobs_ids ] - await pilot_db.associate_pilot_with_jobs( + await pilot_db.add_jobs_to_pilot_bulk( job_to_pilot_mapping=job_to_pilot_mapping, ) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 174a6ce46..19d78d4d6 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -15,7 +15,7 @@ PilotFieldsMapping, ) from diracx.logic.pilots.management import ( - associate_pilot_with_jobs as associate_pilot_with_jobs_bl, + add_jobs_to_pilot as add_jobs_to_pilot_bl, ) from diracx.logic.pilots.management import ( clear_pilots_bulk, @@ -204,7 +204,7 @@ async def update_pilot_fields( @router.patch("/management/jobs", status_code=HTTPStatus.NO_CONTENT) -async def associate_pilot_with_jobs( +async def add_jobs_to_pilot( pilot_db: PilotAgentsDB, pilot_stamp: Annotated[str, Body(description="The stamp of the pilot.")], pilot_jobs_ids: Annotated[ @@ -216,7 +216,7 @@ async def associate_pilot_with_jobs( await check_permissions() try: - await associate_pilot_with_jobs_bl( + await add_jobs_to_pilot_bl( pilot_db=pilot_db, pilot_stamp=pilot_stamp, pilot_jobs_ids=pilot_jobs_ids, diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 23082b70c..4c32f530c 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -55,8 +55,8 @@ build_lollygag_get_gubbins_secrets_request, build_lollygag_get_owner_object_request, build_lollygag_insert_owner_object_request, + build_pilots_add_jobs_to_pilot_request, build_pilots_add_pilot_stamps_request, - build_pilots_associate_pilot_with_jobs_request, build_pilots_clear_pilots_request, build_pilots_delete_pilots_request, build_pilots_search_request, @@ -2602,7 +2602,7 @@ async def update_pilot_fields( return cls(pipeline_response, None, {}) # type: ignore @distributed_trace_async - async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any) -> None: + async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any) -> None: """Clear Pilots. Endpoint for DIRAC to delete all pilots that lived more than age_in_days. @@ -2612,7 +2612,7 @@ async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = Tr :paramtype age_in_days: int :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake. Default value is True. + default as True to avoid any mistake. Default value is False. :paramtype delete_only_aborted: bool :return: None :rtype: None @@ -2654,15 +2654,15 @@ async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = Tr return cls(pipeline_response, None, {}) # type: ignore @overload - async def associate_pilot_with_jobs( - self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any + async def add_jobs_to_pilot( + self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any ) -> None: - """Associate Pilot With Jobs. + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. :param body: Required. - :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs + :type body: ~_generated.models.BodyPilotsAddJobsToPilot :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2672,10 +2672,10 @@ async def associate_pilot_with_jobs( """ @overload - async def associate_pilot_with_jobs( + async def add_jobs_to_pilot( self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any ) -> None: - """Associate Pilot With Jobs. + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. @@ -2690,15 +2690,13 @@ async def associate_pilot_with_jobs( """ @distributed_trace_async - async def associate_pilot_with_jobs( - self, body: Union[_models.BodyPilotsAssociatePilotWithJobs, IO[bytes]], **kwargs: Any - ) -> None: - """Associate Pilot With Jobs. + async def add_jobs_to_pilot(self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any) -> None: + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. - :param body: Is either a BodyPilotsAssociatePilotWithJobs type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs or IO[bytes] + :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -2723,9 +2721,9 @@ async def associate_pilot_with_jobs( if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "BodyPilotsAssociatePilotWithJobs") + _json = self._serialize.body(body, "BodyPilotsAddJobsToPilot") - _request = build_pilots_associate_pilot_with_jobs_request( + _request = build_pilots_add_jobs_to_pilot_request( content_type=content_type, json=_json, content=_content, diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index 60f09c531..7f6b0f274 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -14,8 +14,8 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, + BodyPilotsAddJobsToPilot, BodyPilotsAddPilotStamps, - BodyPilotsAssociatePilotWithJobs, BodyPilotsUpdatePilotFields, ExtendedMetadata, GroupInfo, @@ -52,6 +52,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -65,8 +66,8 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", + "BodyPilotsAddJobsToPilot", "BodyPilotsAddPilotStamps", - "BodyPilotsAssociatePilotWithJobs", "BodyPilotsUpdatePilotFields", "ExtendedMetadata", "GroupInfo", @@ -100,6 +101,7 @@ "VectorSearchSpecValues", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py index 8098c62f4..44da9887d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 2400791e0..7724f2823 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -94,6 +94,39 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" +class BodyPilotsAddJobsToPilot(_serialization.Model): + """Body_pilots_add_jobs_to_pilot. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: The stamp of the pilot. Required. + :vartype pilot_stamp: str + :ivar pilot_jobs_ids: The jobs we want to add to the pilot. Required. + :vartype pilot_jobs_ids: list[int] + """ + + _validation = { + "pilot_stamp": {"required": True}, + "pilot_jobs_ids": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "pilot_jobs_ids": {"key": "pilot_jobs_ids", "type": "[int]"}, + } + + def __init__(self, *, pilot_stamp: str, pilot_jobs_ids: List[int], **kwargs: Any) -> None: + """ + :keyword pilot_stamp: The stamp of the pilot. Required. + :paramtype pilot_stamp: str + :keyword pilot_jobs_ids: The jobs we want to add to the pilot. Required. + :paramtype pilot_jobs_ids: list[int] + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.pilot_jobs_ids = pilot_jobs_ids + + class BodyPilotsAddPilotStamps(_serialization.Model): """Body_pilots_add_pilot_stamps. @@ -147,39 +180,6 @@ def __init__( self.pilot_references = pilot_references -class BodyPilotsAssociatePilotWithJobs(_serialization.Model): - """Body_pilots_associate_pilot_with_jobs. - - All required parameters must be populated in order to send to server. - - :ivar pilot_stamp: The stamp of the pilot. Required. - :vartype pilot_stamp: str - :ivar pilot_jobs_ids: The jobs we want to add to the pilot. Required. - :vartype pilot_jobs_ids: list[int] - """ - - _validation = { - "pilot_stamp": {"required": True}, - "pilot_jobs_ids": {"required": True}, - } - - _attribute_map = { - "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, - "pilot_jobs_ids": {"key": "pilot_jobs_ids", "type": "[int]"}, - } - - def __init__(self, *, pilot_stamp: str, pilot_jobs_ids: List[int], **kwargs: Any) -> None: - """ - :keyword pilot_stamp: The stamp of the pilot. Required. - :paramtype pilot_stamp: str - :keyword pilot_jobs_ids: The jobs we want to add to the pilot. Required. - :paramtype pilot_jobs_ids: list[int] - """ - super().__init__(**kwargs) - self.pilot_stamp = pilot_stamp - self.pilot_jobs_ids = pilot_jobs_ids - - class BodyPilotsUpdatePilotFields(_serialization.Model): """Body_pilots_update_pilot_fields. @@ -748,8 +748,9 @@ class PilotFieldsMapping(_serialization.Model): :vartype pilot_stamp: str :ivar status_reason: Statusreason. :vartype status_reason: str - :ivar status: Status. - :vartype status: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus :ivar bench_mark: Benchmark. :vartype bench_mark: float :ivar destination_site: Destinationsite. @@ -788,7 +789,7 @@ def __init__( *, pilot_stamp: str, status_reason: Optional[str] = None, - status: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, bench_mark: Optional[float] = None, destination_site: Optional[str] = None, queue: Optional[str] = None, @@ -803,8 +804,9 @@ def __init__( :paramtype pilot_stamp: str :keyword status_reason: Statusreason. :paramtype status_reason: str - :keyword status: Status. - :paramtype status: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus :keyword bench_mark: Benchmark. :paramtype bench_mark: float :keyword destination_site: Destinationsite. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index c594dd602..b677a98cb 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -683,7 +683,7 @@ def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: def build_pilots_clear_pilots_request( - *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any + *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -698,7 +698,7 @@ def build_pilots_clear_pilots_request( return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) -def build_pilots_associate_pilot_with_jobs_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long +def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) @@ -3263,7 +3263,7 @@ def update_pilot_fields( # pylint: disable=inconsistent-return-statements @distributed_trace def clear_pilots( # pylint: disable=inconsistent-return-statements - self, *, age_in_days: int, delete_only_aborted: bool = True, **kwargs: Any + self, *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any ) -> None: """Clear Pilots. @@ -3274,7 +3274,7 @@ def clear_pilots( # pylint: disable=inconsistent-return-statements :paramtype age_in_days: int :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake. Default value is True. + default as True to avoid any mistake. Default value is False. :paramtype delete_only_aborted: bool :return: None :rtype: None @@ -3316,15 +3316,15 @@ def clear_pilots( # pylint: disable=inconsistent-return-statements return cls(pipeline_response, None, {}) # type: ignore @overload - def associate_pilot_with_jobs( - self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any + def add_jobs_to_pilot( + self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any ) -> None: - """Associate Pilot With Jobs. + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. :param body: Required. - :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs + :type body: ~_generated.models.BodyPilotsAddJobsToPilot :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -3334,10 +3334,8 @@ def associate_pilot_with_jobs( """ @overload - def associate_pilot_with_jobs( - self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any - ) -> None: - """Associate Pilot With Jobs. + def add_jobs_to_pilot(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. @@ -3352,15 +3350,15 @@ def associate_pilot_with_jobs( """ @distributed_trace - def associate_pilot_with_jobs( # pylint: disable=inconsistent-return-statements - self, body: Union[_models.BodyPilotsAssociatePilotWithJobs, IO[bytes]], **kwargs: Any + def add_jobs_to_pilot( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any ) -> None: - """Associate Pilot With Jobs. + """Add Jobs To Pilot. Endpoint only for DIRAC services, to associate a pilot with a job. - :param body: Is either a BodyPilotsAssociatePilotWithJobs type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAssociatePilotWithJobs or IO[bytes] + :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -3385,9 +3383,9 @@ def associate_pilot_with_jobs( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "BodyPilotsAssociatePilotWithJobs") + _json = self._serialize.body(body, "BodyPilotsAddJobsToPilot") - _request = build_pilots_associate_pilot_with_jobs_request( + _request = build_pilots_add_jobs_to_pilot_request( content_type=content_type, json=_json, content=_content, From cb840de0772c76adcc380f6f00bbe8e0fd332dd7 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Mon, 23 Jun 2025 10:51:31 +0200 Subject: [PATCH 08/33] fix: Small fix in a test --- diracx-db/tests/pilots/test_pilot_management.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py index 5f92a0485..0312913e5 100644 --- a/diracx-db/tests/pilots/test_pilot_management.py +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -437,7 +437,7 @@ async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in pilot_jobs ] - await pilot_db.associate_pilot_with_jobs(job_to_pilot_mapping) + await pilot_db.add_jobs_to_pilot_bulk(job_to_pilot_mapping) # Verify that he has all jobs db_jobs = await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id) @@ -453,7 +453,7 @@ async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in pilot_jobs ] - await pilot_db.associate_pilot_with_jobs(job_to_pilot_mapping) + await pilot_db.add_jobs_to_pilot_bulk(job_to_pilot_mapping) # Associate pilot with jobs that he has not, but was previously in an error # To test that the rollback worked @@ -463,4 +463,4 @@ async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in pilot_jobs ] - await pilot_db.associate_pilot_with_jobs(job_to_pilot_mapping) + await pilot_db.add_jobs_to_pilot_bulk(job_to_pilot_mapping) From a851e710498ab7862d11d282d09db7e8245041ac Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Wed, 25 Jun 2025 11:32:00 +0200 Subject: [PATCH 09/33] fix: Lot of fixes, refactoring tests, and modifying access policy to support DIRAC --- diracx-db/src/diracx/db/sql/pilots/db.py | 20 +- .../tests/pilots/test_pilot_management.py | 240 ++++-------------- diracx-db/tests/pilots/test_query.py | 4 +- diracx-db/tests/pilots/utils.py | 151 +++++++++++ .../src/diracx/logic/pilots/management.py | 59 ++--- diracx-logic/src/diracx/logic/pilots/query.py | 40 ++- .../diracx/routers/pilots/access_policies.py | 95 +------ .../src/diracx/routers/pilots/management.py | 81 ++++-- .../tests/pilots/test_pilot_creation.py | 41 +-- diracx-routers/tests/pilots/test_query.py | 4 +- 10 files changed, 362 insertions(+), 373 deletions(-) create mode 100644 diracx-db/tests/pilots/utils.py diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index b788458c0..a175ab866 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -35,12 +35,15 @@ class PilotAgentsDB(BaseSQLDB): # ----------------------------- Insert Functions ----------------------------- - async def add_pilots_bulk( + async def add_pilots( self, pilot_stamps: list[str], vo: str, grid_type: str = "DIRAC", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", pilot_references: dict[str, str] | None = None, + status_reason: str = "Unknown", ): """Bulk add pilots in the DB. @@ -57,9 +60,12 @@ async def add_pilots_bulk( "PilotJobReference": pilot_references.get(stamp, stamp), "VO": vo, "GridType": grid_type, + "GridSite": grid_site, + "DestinationSite": destination_site, "SubmissionTime": now, "LastUpdateTime": now, "Status": "Submitted", + "StatusReason": status_reason, "PilotStamp": stamp, } for stamp in pilot_stamps @@ -70,7 +76,7 @@ async def add_pilots_bulk( await self.conn.execute(stmt) - async def add_jobs_to_pilot_bulk(self, job_to_pilot_mapping: list[dict[str, Any]]): + async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): """Associate a pilot with jobs. job_to_pilot_mapping format: @@ -115,7 +121,7 @@ async def add_jobs_to_pilot_bulk(self, job_to_pilot_mapping: list[dict[str, Any] # ----------------------------- Delete Functions ----------------------------- - async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]): + async def delete_pilots_by_stamps(self, pilot_stamps: list[str]): """Bulk delete pilots. Raises PilotNotFound if one of the pilot was not found. @@ -127,7 +133,7 @@ async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]): if res.rowcount != len(pilot_stamps): raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)}) - async def clear_pilots_bulk( + async def clear_pilots( self, cutoff_date: datetime, delete_only_aborted: bool = False ) -> int: """Bulk delete pilots that have SubmissionTime before the 'cutoff_date'. @@ -147,7 +153,7 @@ async def clear_pilots_bulk( # ----------------------------- Update Functions ----------------------------- - async def update_pilot_fields_bulk( + async def update_pilot_fields( self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] ): """Bulk update pilots with a mapping. @@ -213,7 +219,7 @@ async def search_pilots( per_page: int = 100, page: int | None = None, ) -> tuple[int, list[dict[Any, Any]]]: - """Search for pilots in the database.""" + """Search for pilot information in the database.""" return await self.search( model=PilotAgents, parameters=parameters, @@ -234,7 +240,7 @@ async def search_pilot_to_job_mapping( per_page: int = 100, page: int | None = None, ) -> tuple[int, list[dict[Any, Any]]]: - """Search for pilots in the database.""" + """Search for jobs that are associated with pilots.""" return await self.search( model=JobToPilotMapping, parameters=parameters, diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py index 0312913e5..ff09fdb74 100644 --- a/diracx-db/tests/pilots/test_pilot_management.py +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -1,25 +1,25 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Any import pytest -from sqlalchemy.sql import update from diracx.core.exceptions import ( PilotAlreadyAssociatedWithJobError, - PilotNotFoundError, ) from diracx.core.models import ( PilotFieldsMapping, PilotStatus, - ScalarSearchOperator, - ScalarSearchSpec, - VectorSearchOperator, - VectorSearchSpec, ) from diracx.db.sql.pilots.db import PilotAgentsDB -from diracx.db.sql.pilots.schema import PilotAgents + +from .utils import ( + add_stamps, # noqa: F401 + create_old_pilots_environment, # noqa: F401 + create_timed_pilots, # noqa: F401 + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp, +) MAIN_VO = "lhcb" N = 100 @@ -34,148 +34,6 @@ async def pilot_db(tmp_path): yield agents_db -async def get_pilot_jobs_ids_by_pilot_id( - pilot_db: PilotAgentsDB, pilot_id: int -) -> list[int]: - _, jobs = await pilot_db.search_pilot_to_job_mapping( - parameters=["JobID"], - search=[ - ScalarSearchSpec( - parameter="PilotID", - operator=ScalarSearchOperator.EQUAL, - value=pilot_id, - ) - ], - sorts=[], - distinct=True, - per_page=10000, - ) - - return [job["JobID"] for job in jobs] - - -async def get_pilots_by_stamp_bulk( - pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] -) -> list[dict[Any, Any]]: - _, pilots = await pilot_db.search_pilots( - parameters=parameters, - search=[ - VectorSearchSpec( - parameter="PilotStamp", - operator=VectorSearchOperator.IN, - values=pilot_stamps, - ) - ], - sorts=[], - distinct=True, - per_page=1000, - ) - - # Custom handling, to see which pilot_stamp does not exist (if so, say which one) - found_keys = {row["PilotStamp"] for row in pilots} - missing = set(pilot_stamps) - found_keys - - if missing: - raise PilotNotFoundError( - data={"pilot_stamp": str(missing)}, - detail=str(missing), - non_existing_pilots=missing, - ) - - return pilots - - -@pytest.fixture -async def add_stamps(pilot_db): - async def _add_stamps(start_n=0): - async with pilot_db as db: - # Add pilots - refs = [f"ref_{i}" for i in range(start_n, start_n + N)] - stamps = [f"stamp_{i}" for i in range(start_n, start_n + N)] - pilot_references = dict(zip(stamps, refs)) - - vo = MAIN_VO - - await db.add_pilots_bulk( - stamps, vo, grid_type="DIRAC", pilot_references=pilot_references - ) - - pilots = await get_pilots_by_stamp_bulk(db, stamps) - - return pilots - - return _add_stamps - - -@pytest.fixture -async def create_timed_pilots(pilot_db, add_stamps): - async def _create_timed_pilots( - old_date: datetime, aborted: bool = False, start_n=0 - ): - # Get pilots - pilots = await add_stamps(start_n) - - async with pilot_db as db: - # Update manually their age - # Collect PilotStamps - pilot_stamps = [pilot["PilotStamp"] for pilot in pilots] - - stmt = ( - update(PilotAgents) - .where(PilotAgents.pilot_stamp.in_(pilot_stamps)) - .values(SubmissionTime=old_date) - ) - - if aborted: - stmt = stmt.values(Status=PilotStatus.ABORTED) - - res = await db.conn.execute(stmt) - assert res.rowcount == len(pilot_stamps) - - pilots = await get_pilots_by_stamp_bulk(db, pilot_stamps) - return pilots - - return _create_timed_pilots - - -@pytest.fixture -async def create_old_pilots_environment(pilot_db, create_timed_pilots): - non_aborted_recent = await create_timed_pilots( - datetime(2025, 1, 1, tzinfo=timezone.utc), False, N - ) - aborted_recent = await create_timed_pilots( - datetime(2025, 1, 1, tzinfo=timezone.utc), True, 2 * N - ) - - aborted_very_old = await create_timed_pilots( - datetime(2003, 3, 10, tzinfo=timezone.utc), True, 3 * N - ) - non_aborted_very_old = await create_timed_pilots( - datetime(2003, 3, 10, tzinfo=timezone.utc), False, 4 * N - ) - - pilot_number = 4 * N - - assert pilot_number == ( - len(non_aborted_recent) - + len(aborted_recent) - + len(aborted_very_old) - + len(non_aborted_very_old) - ) - - # Phase 0. Verify that we have the right environment - async with pilot_db as pilot_db: - # Ensure that we can get every pilot (only get first of each group) - await get_pilots_by_stamp_bulk(pilot_db, [non_aborted_recent[0]["PilotStamp"]]) - await get_pilots_by_stamp_bulk(pilot_db, [aborted_recent[0]["PilotStamp"]]) - await get_pilots_by_stamp_bulk(pilot_db, [aborted_very_old[0]["PilotStamp"]]) - await get_pilots_by_stamp_bulk( - pilot_db, [non_aborted_very_old[0]["PilotStamp"]] - ) - - return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old - - @pytest.mark.asyncio async def test_insert_and_select(pilot_db: PilotAgentsDB): async with pilot_db as pilot_db: @@ -184,12 +42,12 @@ async def test_insert_and_select(pilot_db: PilotAgentsDB): stamps = [f"stamp_{i}" for i in range(10)] pilot_references = dict(zip(stamps, refs)) - await pilot_db.add_pilots_bulk( + await pilot_db.add_pilots( stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references ) # Accept duplicates because it is checked by the logic - await pilot_db.add_pilots_bulk( + await pilot_db.add_pilots( stamps, MAIN_VO, grid_type="DIRAC", pilot_references=None ) @@ -202,27 +60,28 @@ async def test_insert_and_delete(pilot_db: PilotAgentsDB): stamps = [f"stamp_{i}" for i in range(2)] pilot_references = dict(zip(stamps, refs)) - await pilot_db.add_pilots_bulk( + await pilot_db.add_pilots( stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references ) # Works, the pilots exists - await get_pilots_by_stamp_bulk(pilot_db, [stamps[0]]) - await get_pilots_by_stamp_bulk(pilot_db, [stamps[0]]) + await get_pilots_by_stamp(pilot_db, [stamps[0]]) + await get_pilots_by_stamp(pilot_db, [stamps[0]]) # We delete the first pilot - await pilot_db.delete_pilots_by_stamps_bulk([stamps[0]]) + await pilot_db.delete_pilots_by_stamps([stamps[0]]) # We get the 2nd pilot that is not delete (no error) - await get_pilots_by_stamp_bulk(pilot_db, [stamps[1]]) + await get_pilots_by_stamp(pilot_db, [stamps[1]]) # We get the 1st pilot that is delete (error) - with pytest.raises(PilotNotFoundError): - await get_pilots_by_stamp_bulk(pilot_db, [stamps[0]]) + + assert not await get_pilots_by_stamp(pilot_db, [stamps[0]]) @pytest.mark.asyncio async def test_insert_and_delete_only_old_aborted( - pilot_db: PilotAgentsDB, create_old_pilots_environment + pilot_db: PilotAgentsDB, + create_old_pilots_environment, # noqa: F811 ): non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( create_old_pilots_environment @@ -231,9 +90,7 @@ async def test_insert_and_delete_only_old_aborted( async with pilot_db as pilot_db: # Delete all aborted that were born before 2020 # Every aborted that are old may be delete - await pilot_db.clear_pilots_bulk( - datetime(2020, 1, 1, tzinfo=timezone.utc), True - ) + await pilot_db.clear_pilots(datetime(2020, 1, 1, tzinfo=timezone.utc), True) # Assert who still live for normally_exiting_pilot_list in [ @@ -243,19 +100,19 @@ async def test_insert_and_delete_only_old_aborted( ]: stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] - await get_pilots_by_stamp_bulk(pilot_db, stamps) + await get_pilots_by_stamp(pilot_db, stamps) # Assert who normally does not live for normally_deleted_pilot_list in [aborted_very_old]: stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] - with pytest.raises(PilotNotFoundError): - await get_pilots_by_stamp_bulk(pilot_db, stamps) + assert not await get_pilots_by_stamp(pilot_db, stamps) @pytest.mark.asyncio async def test_insert_and_delete_old( - pilot_db: PilotAgentsDB, create_old_pilots_environment + pilot_db: PilotAgentsDB, + create_old_pilots_environment, # noqa: F811 ): non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( create_old_pilots_environment @@ -264,9 +121,7 @@ async def test_insert_and_delete_old( async with pilot_db as pilot_db: # Delete all aborted that were born before 2020 # Every aborted that are old may be delete - await pilot_db.clear_pilots_bulk( - datetime(2020, 1, 1, tzinfo=timezone.utc), False - ) + await pilot_db.clear_pilots(datetime(2020, 1, 1, tzinfo=timezone.utc), False) # Assert who still live for normally_exiting_pilot_list in [ @@ -275,7 +130,7 @@ async def test_insert_and_delete_old( ]: stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] - await get_pilots_by_stamp_bulk(pilot_db, stamps) + await get_pilots_by_stamp(pilot_db, stamps) # Assert who normally does not live for normally_deleted_pilot_list in [ @@ -284,13 +139,13 @@ async def test_insert_and_delete_old( ]: stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] - with pytest.raises(PilotNotFoundError): - await get_pilots_by_stamp_bulk(pilot_db, stamps) + assert not await get_pilots_by_stamp(pilot_db, stamps) @pytest.mark.asyncio async def test_insert_and_delete_recent_only_aborted( - pilot_db: PilotAgentsDB, create_old_pilots_environment + pilot_db: PilotAgentsDB, + create_old_pilots_environment, # noqa: F811 ): non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( create_old_pilots_environment @@ -299,15 +154,13 @@ async def test_insert_and_delete_recent_only_aborted( async with pilot_db as pilot_db: # Delete all aborted that were born before 2020 # Every aborted that are old may be delete - await pilot_db.clear_pilots_bulk( - datetime(2025, 3, 10, tzinfo=timezone.utc), True - ) + await pilot_db.clear_pilots(datetime(2025, 3, 10, tzinfo=timezone.utc), True) # Assert who still live for normally_exiting_pilot_list in [non_aborted_recent, non_aborted_very_old]: stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] - await get_pilots_by_stamp_bulk(pilot_db, stamps) + await get_pilots_by_stamp(pilot_db, stamps) # Assert who normally does not live for normally_deleted_pilot_list in [ @@ -316,13 +169,13 @@ async def test_insert_and_delete_recent_only_aborted( ]: stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] - with pytest.raises(PilotNotFoundError): - await get_pilots_by_stamp_bulk(pilot_db, stamps) + assert not await get_pilots_by_stamp(pilot_db, stamps) @pytest.mark.asyncio async def test_insert_and_delete_recent( - pilot_db: PilotAgentsDB, create_old_pilots_environment + pilot_db: PilotAgentsDB, + create_old_pilots_environment, # noqa: F811 ): non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( create_old_pilots_environment @@ -331,9 +184,7 @@ async def test_insert_and_delete_recent( async with pilot_db as pilot_db: # Delete all aborted that were born before 2020 # Every aborted that are old may be delete - await pilot_db.clear_pilots_bulk( - datetime(2025, 3, 10, tzinfo=timezone.utc), False - ) + await pilot_db.clear_pilots(datetime(2025, 3, 10, tzinfo=timezone.utc), False) # Assert who normally does not live for normally_deleted_pilot_list in [ @@ -344,21 +195,20 @@ async def test_insert_and_delete_recent( ]: stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] - with pytest.raises(PilotNotFoundError): - await get_pilots_by_stamp_bulk(pilot_db, stamps) + assert not await get_pilots_by_stamp(pilot_db, stamps) @pytest.mark.asyncio async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): async with pilot_db as pilot_db: pilot_stamp = "stamp-test" - await pilot_db.add_pilots_bulk( + await pilot_db.add_pilots( vo=MAIN_VO, pilot_stamps=[pilot_stamp], grid_type="grid-type", ) - res = await get_pilots_by_stamp_bulk(pilot_db, [pilot_stamp]) + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) assert len(res) == 1 pilot = res[0] @@ -374,7 +224,7 @@ async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): # # Modify a pilot, then check if every change is done # - await pilot_db.update_pilot_fields_bulk( + await pilot_db.update_pilot_fields( [ PilotFieldsMapping( PilotStamp=pilot_stamp, @@ -386,7 +236,7 @@ async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): ] ) - res = await get_pilots_by_stamp_bulk(pilot_db, [pilot_stamp]) + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) assert len(res) == 1 pilot = res[0] @@ -414,13 +264,13 @@ async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): async with pilot_db as pilot_db: pilot_stamp = "stamp-test" # Add pilot - await pilot_db.add_pilots_bulk( + await pilot_db.add_pilots( vo=MAIN_VO, pilot_stamps=[pilot_stamp], grid_type="grid-type", ) - res = await get_pilots_by_stamp_bulk(pilot_db, [pilot_stamp]) + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) assert len(res) == 1 pilot = res[0] pilot_id = pilot["PilotID"] @@ -437,7 +287,7 @@ async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in pilot_jobs ] - await pilot_db.add_jobs_to_pilot_bulk(job_to_pilot_mapping) + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) # Verify that he has all jobs db_jobs = await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id) @@ -453,7 +303,7 @@ async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in pilot_jobs ] - await pilot_db.add_jobs_to_pilot_bulk(job_to_pilot_mapping) + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) # Associate pilot with jobs that he has not, but was previously in an error # To test that the rollback worked @@ -463,4 +313,4 @@ async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in pilot_jobs ] - await pilot_db.add_jobs_to_pilot_bulk(job_to_pilot_mapping) + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py index 4b36e6ec5..be80f0179 100644 --- a/diracx-db/tests/pilots/test_query.py +++ b/diracx-db/tests/pilots/test_query.py @@ -48,11 +48,11 @@ async def populated_pilot_db(pilot_db): vo = MAIN_VO - await pilot_db.add_pilots_bulk( + await pilot_db.add_pilots( stamps, vo, grid_type="DIRAC", pilot_references=pilot_references ) - await pilot_db.update_pilot_fields_bulk( + await pilot_db.update_pilot_fields( [ PilotFieldsMapping( PilotStamp=pilot_stamp, diff --git a/diracx-db/tests/pilots/utils.py b/diracx-db/tests/pilots/utils.py new file mode 100644 index 000000000..793310d0d --- /dev/null +++ b/diracx-db/tests/pilots/utils.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +import pytest +from sqlalchemy import update + +from diracx.core.models import ( + ScalarSearchOperator, + ScalarSearchSpec, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +MAIN_VO = "lhcb" +N = 100 + +# ------------ Fetching data ------------ + + +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=1000, + ) + + return pilots + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=10000, + ) + + return [job["JobID"] for job in jobs] + + +# ------------ Creating data ------------ + + +@pytest.fixture +async def add_stamps(pilot_db): + async def _add_stamps(start_n=0): + async with pilot_db as db: + # Add pilots + refs = [f"ref_{i}" for i in range(start_n, start_n + N)] + stamps = [f"stamp_{i}" for i in range(start_n, start_n + N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + return await get_pilots_by_stamp(db, stamps) + + return _add_stamps + + +@pytest.fixture +async def create_timed_pilots(pilot_db, add_stamps): + async def _create_timed_pilots( + old_date: datetime, aborted: bool = False, start_n=0 + ): + # Get pilots + pilots = await add_stamps(start_n) + + async with pilot_db as db: + # Update manually their age + # Collect PilotStamps + pilot_stamps = [pilot["PilotStamp"] for pilot in pilots] + + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(pilot_stamps)) + .values(SubmissionTime=old_date) + ) + + if aborted: + stmt = stmt.values(Status="Aborted") + + res = await db.conn.execute(stmt) + assert res.rowcount == len(pilot_stamps) + + pilots = await get_pilots_by_stamp(db, pilot_stamps) + return pilots + + return _create_timed_pilots + + +@pytest.fixture +async def create_old_pilots_environment(pilot_db, create_timed_pilots): + non_aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), False, N + ) + aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), True, 2 * N + ) + + aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), True, 3 * N + ) + non_aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), False, 4 * N + ) + + pilot_number = 4 * N + + assert pilot_number == ( + len(non_aborted_recent) + + len(aborted_recent) + + len(aborted_very_old) + + len(non_aborted_very_old) + ) + + # Phase 0. Verify that we have the right environment + async with pilot_db as pilot_db: + # Ensure that we can get every pilot (only get first of each group) + await get_pilots_by_stamp(pilot_db, [non_aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_very_old[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [non_aborted_very_old[0]["PilotStamp"]]) + + return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index 1e098d122..82c9daad5 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -2,14 +2,14 @@ from datetime import datetime, timedelta, timezone -from diracx.core.exceptions import PilotAlreadyExistsError, PilotNotFoundError +from diracx.core.exceptions import PilotAlreadyExistsError from diracx.core.models import PilotFieldsMapping from diracx.db.sql import PilotAgentsDB from .query import ( get_pilot_ids_by_stamps, get_pilot_jobs_ids_by_pilot_id, - get_pilots_by_stamp_bulk, + get_pilots_by_stamp, ) @@ -17,57 +17,54 @@ async def register_new_pilots( pilot_db: PilotAgentsDB, pilot_stamps: list[str], vo: str, - grid_type: str = "Dirac", - pilot_job_references: dict[str, str] | None = None, + grid_type: str, + grid_site: str, + destination_site: str, + status_reason: str, + pilot_job_references: dict[str, str] | None, ): # [IMPORTANT] Check unicity of pilot references - # If a pilot already exists, it will undo everything and raise an error - try: - await get_pilots_by_stamp_bulk(pilot_db=pilot_db, pilot_stamps=pilot_stamps) - raise PilotAlreadyExistsError(data={"pilot_stamps": str(pilot_stamps)}) - except PilotNotFoundError as e: - # e.non_existing_pilots is set of the pilot that are not found - # We can compare it with the pilot references that want to add - # If both sets are the same, it means that every pilots is new, and so we can add them to the db - # If not, it means that at least one is already in the db - - non_existing_pilots = e.non_existing_pilots - pilots_that_already_exist = set(pilot_stamps) - non_existing_pilots - - if pilots_that_already_exist: - raise PilotAlreadyExistsError( - data={"pilot_stamps": str(pilots_that_already_exist)} - ) from e - - await pilot_db.add_pilots_bulk( + # If a pilot already exists, we raise an error (transaction will rollback) + existing_pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, pilot_stamps=pilot_stamps + ) + + # If we found pilots from the list, this means some pilots already exists + if len(existing_pilots) > 0: + found_keys = {pilot["PilotStamp"] for pilot in existing_pilots} + + raise PilotAlreadyExistsError(data={"pilot_stamps": str(found_keys)}) + + await pilot_db.add_pilots( pilot_stamps=pilot_stamps, vo=vo, grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, pilot_references=pilot_job_references, + status_reason=status_reason, ) -async def clear_pilots_bulk( +async def clear_pilots( pilot_db: PilotAgentsDB, age_in_days: int, delete_only_aborted: bool ): """Delete pilots that have been submitted before interval_in_days.""" cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) - await pilot_db.clear_pilots_bulk( + await pilot_db.clear_pilots( cutoff_date=cutoff_date, delete_only_aborted=delete_only_aborted ) -async def delete_pilots_by_stamps_bulk( - pilot_db: PilotAgentsDB, pilot_stamps: list[str] -): - await pilot_db.delete_pilots_by_stamps_bulk(pilot_stamps) +async def delete_pilots_by_stamps(pilot_db: PilotAgentsDB, pilot_stamps: list[str]): + await pilot_db.delete_pilots_by_stamps(pilot_stamps) async def update_pilots_fields( pilot_db: PilotAgentsDB, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] ): - await pilot_db.update_pilot_fields_bulk(pilot_stamps_to_fields_mapping) + await pilot_db.update_pilot_fields(pilot_stamps_to_fields_mapping) async def add_jobs_to_pilot( @@ -86,7 +83,7 @@ async def add_jobs_to_pilot( for job_id in pilot_jobs_ids ] - await pilot_db.add_jobs_to_pilot_bulk( + await pilot_db.add_jobs_to_pilot( job_to_pilot_mapping=job_to_pilot_mapping, ) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index 7f54cc1d9..2db667f07 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -48,9 +48,16 @@ async def search( return total, pilots -async def get_pilots_by_stamp_bulk( - pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + parameters: list[str] = [], + allow_missing: bool = True, ) -> list[dict[Any, Any]]: + """Get pilots by their stamp. + + If `allow_missing` is set to False, if a pilot is missing, PilotNotFoundError will be raised. + """ if parameters: parameters.append("PilotStamp") @@ -68,16 +75,18 @@ async def get_pilots_by_stamp_bulk( per_page=MAX_PER_PAGE, ) - # Custom handling, to see which pilot_stamp does not exist (if so, say which one) - found_keys = {row["PilotStamp"] for row in pilots} - missing = set(pilot_stamps) - found_keys - - if missing: - raise PilotNotFoundError( - data={"pilot_stamp": str(missing)}, - detail=str(missing), - non_existing_pilots=missing, - ) + # allow_missing is set as True by default to mark explicitly when we allow or not + if not allow_missing: + # Custom handling, to see which pilot_stamp does not exist (if so, say which one) + found_keys = {row["PilotStamp"] for row in pilots} + missing = set(pilot_stamps) - found_keys + + if missing: + raise PilotNotFoundError( + data={"pilot_stamp": str(missing)}, + detail=str(missing), + non_existing_pilots=missing, + ) return pilots @@ -85,8 +94,11 @@ async def get_pilots_by_stamp_bulk( async def get_pilot_ids_by_stamps( pilot_db: PilotAgentsDB, pilot_stamps: list[str] ) -> list[int]: - pilots = await get_pilots_by_stamp_bulk( - pilot_db=pilot_db, pilot_stamps=pilot_stamps, parameters=["PilotID"] + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["PilotID"], + allow_missing=False, ) return [pilot["PilotID"] for pilot in pilots] diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index 52e00e74c..2a170b342 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -6,27 +6,22 @@ from fastapi import Depends, HTTPException, status -from diracx.core.exceptions import PilotNotFoundError -from diracx.core.properties import NORMAL_USER, TRUSTED_HOST -from diracx.db.sql import PilotAgentsDB -from diracx.logic.pilots.query import get_pilots_by_stamp_bulk +from diracx.core.properties import SERVICE_ADMINISTRATOR, TRUSTED_HOST from diracx.routers.access_policies import BaseAccessPolicy from diracx.routers.utils.users import AuthorizedUserInfo class ActionType(StrEnum): - # Create a pilot - CREATE_PILOT = auto() # Change some pilot fields - CHANGE_PILOT_FIELD = auto() + MANAGE_PILOTS = auto() # Read some pilot info READ_PILOT_FIELDS = auto() class PilotManagementAccessPolicy(BaseAccessPolicy): """Rules: - * You need either NORMAL_USER in your properties - * A NORMAL_USER can create a pilot. + * Every user can access data about his VO + * An administrator, as well as a DIRAC service can modify a pilot. """ @staticmethod @@ -35,91 +30,27 @@ async def policy( user_info: AuthorizedUserInfo, /, *, - pilot_db: PilotAgentsDB | None = None, - pilot_stamps: list[str] | None = None, - vo: str | None = None, action: ActionType | None = None, ): assert action, "action is a mandatory parameter" + # Users can query if action == ActionType.READ_PILOT_FIELDS: - if NORMAL_USER in user_info.properties: - return - - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You have to be logged on to see pilots.", - ) - - if not vo: - assert pilot_stamps and pilot_db, ( - "if vo is not provided, " - "pilot_stamp and pilot_db are mandatory to determine the vo" - ) - - try: - pilots = await get_pilots_by_stamp_bulk( - pilot_db=pilot_db, - pilot_stamps=pilot_stamps, - parameters=["VO"], # For efficiency - ) - except PilotNotFoundError as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The given stamp is not associated with a pilot", - ) from e - - # Semantic assured by get_pilots_by_stamp_bulk - first_vo = pilots[0]["VO"] - - if not all(pilot["VO"] == first_vo for pilot in pilots): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You gave pilots with different VOs.", - ) - - vo = first_vo - - if not vo == user_info.vo: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You don't have the right VO for this resource.", - ) - - if NORMAL_USER not in user_info.properties: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You don't have the rights to create pilots.", - ) - - if action == ActionType.CREATE_PILOT: - return - - if action == ActionType.CHANGE_PILOT_FIELD: return - raise ValueError("Unknown action.") - - -CheckPilotManagementPolicyCallable = Annotated[ - Callable, Depends(PilotManagementAccessPolicy.check) -] - - -class DiracServicesAccessPolicy(BaseAccessPolicy): - """This access policy is used by DIRAC services (ex: Matcher).""" - - @staticmethod - async def policy(policy_name: str, user_info: AuthorizedUserInfo): - if TRUSTED_HOST in user_info.properties: + # If we want to modify pilots, we allow only admins and DIRAC + if ( + TRUSTED_HOST in user_info.properties + or SERVICE_ADMINISTRATOR in user_info.properties + ): return raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="This endpoint is reserved only for DIRAC services.", + detail="You don't have the rights to modify a pilot.", ) -CheckDiracServicesPolicyCallable = Annotated[ - Callable, Depends(DiracServicesAccessPolicy.check) +CheckPilotManagementPolicyCallable = Annotated[ + Callable, Depends(PilotManagementAccessPolicy.check) ] diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 19d78d4d6..a642c9969 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -18,8 +18,11 @@ add_jobs_to_pilot as add_jobs_to_pilot_bl, ) from diracx.logic.pilots.management import ( - clear_pilots_bulk, - delete_pilots_by_stamps_bulk, + clear_pilots as clear_pilots_bl, +) +from diracx.logic.pilots.management import ( + delete_pilots_by_stamps, + get_pilot_jobs_ids_by_stamp, register_new_pilots, update_pilots_fields, ) @@ -29,7 +32,6 @@ from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token from .access_policies import ( ActionType, - CheckDiracServicesPolicyCallable, CheckPilotManagementPolicyCallable, ) @@ -38,7 +40,7 @@ logger = logging.getLogger(__name__) -@router.post("/management/pilot") +@router.post("/") async def add_pilot_stamps( pilot_db: PilotAgentsDB, pilot_stamps: Annotated[ @@ -52,16 +54,24 @@ async def add_pilot_stamps( user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckPilotManagementPolicyCallable, grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", + grid_site: Annotated[str, Body(description="Pilots grid site.")] = "Unknown", + destination_site: Annotated[ + str, Body(description="Pilots destination site.") + ] = "NotAssigned", pilot_references: Annotated[ dict[str, str] | None, Body(description="Association of a pilot reference with a pilot stamp."), ] = None, + status_reason: Annotated[ + str, Body(description="Status reason of the pilots.") + ] = "Unknown", ): """Endpoint where a you can create pilots with their references. If a pilot stamp already exists, it will block the insertion. """ - await check_permissions(action=ActionType.CREATE_PILOT, vo=vo) + # TODO: Verify that grid types, sites, destination sites, etc. are valids + await check_permissions(action=ActionType.MANAGE_PILOTS) try: await register_new_pilots( @@ -69,7 +79,10 @@ async def add_pilot_stamps( pilot_stamps=pilot_stamps, vo=vo, grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, pilot_job_references=pilot_references, + status_reason=status_reason, ) # Logs credentials creation @@ -80,7 +93,7 @@ async def add_pilot_stamps( raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e -@router.delete("/management/pilot", status_code=HTTPStatus.NO_CONTENT) +@router.delete("/", status_code=HTTPStatus.NO_CONTENT) async def delete_pilots( pilot_stamps: Annotated[ list[str], Query(description="Stamps of the pilots we want to delete.") @@ -93,13 +106,11 @@ async def delete_pilots( If at least one pilot is not found, it WILL rollback. """ await check_permissions( - action=ActionType.CHANGE_PILOT_FIELD, - pilot_stamps=pilot_stamps, - pilot_db=pilot_db, + action=ActionType.MANAGE_PILOTS, ) try: - await delete_pilots_by_stamps_bulk(pilot_db=pilot_db, pilot_stamps=pilot_stamps) + await delete_pilots_by_stamps(pilot_db=pilot_db, pilot_stamps=pilot_stamps) except PilotNotFoundError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -119,7 +130,7 @@ async def clear_pilots( ) ), ], - check_permissions: CheckDiracServicesPolicyCallable, + check_permissions: CheckPilotManagementPolicyCallable, delete_only_aborted: Annotated[ bool, Query( @@ -132,7 +143,7 @@ async def clear_pilots( ] = False, ): """Endpoint for DIRAC to delete all pilots that lived more than age_in_days.""" - await check_permissions() + await check_permissions(ActionType.MANAGE_PILOTS) if age_in_days < 0: raise HTTPException( @@ -140,7 +151,7 @@ async def clear_pilots( detail="age_in_days must be positive.", ) - await clear_pilots_bulk( + await clear_pilots_bl( pilot_db=pilot_db, age_in_days=age_in_days, delete_only_aborted=delete_only_aborted, @@ -170,7 +181,7 @@ async def clear_pilots( } -@router.patch("/management/pilot", status_code=HTTPStatus.NO_CONTENT) +@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT) async def update_pilot_fields( pilot_stamps_to_fields_mapping: Annotated[ list[PilotFieldsMapping], @@ -187,15 +198,8 @@ async def update_pilot_fields( Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. """ - # TODO: Add an example for openapi - pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] - # Ensures stamps validity - await check_permissions( - action=ActionType.CHANGE_PILOT_FIELD, - pilot_db=pilot_db, - pilot_stamps=pilot_stamps, - ) + await check_permissions(action=ActionType.MANAGE_PILOTS) await update_pilots_fields( pilot_db=pilot_db, @@ -203,17 +207,44 @@ async def update_pilot_fields( ) -@router.patch("/management/jobs", status_code=HTTPStatus.NO_CONTENT) +@router.get("/jobs") +async def get_pilot_jobs( + pilot_db: PilotAgentsDB, + pilot_stamp: Annotated[str, Body(description="The stamp of the pilot.")], + check_permissions: CheckPilotManagementPolicyCallable, +) -> list[int]: + """Endpoint only for DIRAC services, to get jobs of a pilot.""" + # FIXME: To be tested + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + try: + return await get_pilot_jobs_ids_by_stamp( + pilot_db=pilot_db, + pilot_stamp=pilot_stamp, + ) + except PilotNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="This pilot does not exist." + ) from e + except PilotAlreadyAssociatedWithJobError as e: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="This pilot is already associated with this job.", + ) from e + + +@router.patch("/jobs", status_code=HTTPStatus.NO_CONTENT) async def add_jobs_to_pilot( pilot_db: PilotAgentsDB, pilot_stamp: Annotated[str, Body(description="The stamp of the pilot.")], pilot_jobs_ids: Annotated[ list[int], Body(description="The jobs we want to add to the pilot.") ], - check_permissions: CheckDiracServicesPolicyCallable, + check_permissions: CheckPilotManagementPolicyCallable, ): """Endpoint only for DIRAC services, to associate a pilot with a job.""" - await check_permissions() + # FIXME: To be tested + await check_permissions(ActionType.MANAGE_PILOTS) try: await add_jobs_to_pilot_bl( diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index 7a42be399..0de071e44 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -21,12 +21,6 @@ N = 100 -@pytest.fixture -def test_client(client_factory): - with client_factory.unauthenticated() as client: - yield client - - @pytest.fixture def normal_test_client(client_factory): with client_factory.normal_user() as client: @@ -41,7 +35,7 @@ async def test_create_pilots(normal_test_client): body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} r = normal_test_client.post( - "/api/pilots/management/pilot", + "/api/pilots/", json=body, ) @@ -55,7 +49,7 @@ async def test_create_pilots(normal_test_client): } r = normal_test_client.post( - "/api/pilots/management/pilot", + "/api/pilots/", json=body, headers={ "Content-Type": "application/json", @@ -74,7 +68,7 @@ async def test_create_pilots(normal_test_client): body = {"vo": MAIN_VO, "pilot_stamps": [pilot_stamps[0] + "_new_one"]} r = normal_test_client.post( - "/api/pilots/management/pilot", + "/api/pilots/", json=body, headers={ "Content-Type": "application/json", @@ -92,7 +86,7 @@ async def test_create_pilot_and_delete_it(normal_test_client): # Create a pilot r = normal_test_client.post( - "/api/pilots/management/pilot", + "/api/pilots/", json=body, ) @@ -101,7 +95,7 @@ async def test_create_pilot_and_delete_it(normal_test_client): # -------------- Duplicate -------------- # Duplicate because it exists, should have 409 r = normal_test_client.post( - "/api/pilots/management/pilot", + "/api/pilots/", json=body, ) @@ -112,7 +106,7 @@ async def test_create_pilot_and_delete_it(normal_test_client): # We delete the pilot r = normal_test_client.delete( - "/api/pilots/management/pilot", + "/api/pilots/", params=params, ) @@ -121,7 +115,7 @@ async def test_create_pilot_and_delete_it(normal_test_client): # -------------- Insert -------------- # Create a the same pilot, but works because it does not exist anymore r = normal_test_client.post( - "/api/pilots/management/pilot", + "/api/pilots/", json=body, ) @@ -136,7 +130,7 @@ async def test_create_pilot_and_modify_it(normal_test_client): # Create pilots r = normal_test_client.post( - "/api/pilots/management/pilot", + "/api/pilots/", json=body, ) @@ -156,7 +150,7 @@ async def test_create_pilot_and_modify_it(normal_test_client): ] } - r = normal_test_client.patch("/api/pilots/management/pilot", json=body) + r = normal_test_client.patch("/api/pilots/metadata", json=body) assert r.status_code == 204 @@ -181,3 +175,20 @@ async def test_create_pilot_and_modify_it(normal_test_client): assert pilot2["StatusReason"] != pilot1["StatusReason"] assert pilot2["AccountingSent"] != pilot1["AccountingSent"] assert pilot2["Status"] != pilot1["Status"] + + +async def test_associate_job_with_pilot_and_get_it(normal_test_client): + pilot_stamps = ["stamps_1", "stamp_2"] + + # -------------- Insert -------------- + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + + # Create pilots + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # --------------- As DIRAC, associate a job with a pilot -------- diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py index 981126798..a254ed481 100644 --- a/diracx-routers/tests/pilots/test_query.py +++ b/diracx-routers/tests/pilots/test_query.py @@ -54,7 +54,7 @@ async def populated_pilot_client(normal_test_client): body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} r = normal_test_client.post( - "/api/pilots/management/pilot", + "/api/pilots/", json=body, ) @@ -75,7 +75,7 @@ async def populated_pilot_client(normal_test_client): ] } - r = normal_test_client.patch("/api/pilots/management/pilot", json=body) + r = normal_test_client.patch("/api/pilots/metadata", json=body) assert r.status_code == 204 From 9b75a84fff259d67e98b9f5be3dba596f2592ca3 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Wed, 25 Jun 2025 11:47:37 +0200 Subject: [PATCH 10/33] refactor: Refactored delete pilot endpoints --- .../_generated/aio/operations/_operations.py | 71 +++++++---- .../client/_generated/models/_models.py | 21 ++++ .../_generated/operations/_operations.py | 113 +++++++++++------- .../src/diracx/routers/pilots/management.py | 72 ++++++----- .../_generated/aio/operations/_operations.py | 71 +++++++---- .../client/_generated/models/_models.py | 21 ++++ .../_generated/operations/_operations.py | 113 +++++++++++------- 7 files changed, 319 insertions(+), 163 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index bfa1bf1de..86a54ccab 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -54,8 +54,8 @@ build_jobs_unassign_job_sandboxes_request, build_pilots_add_jobs_to_pilot_request, build_pilots_add_pilot_stamps_request, - build_pilots_clear_pilots_request, build_pilots_delete_pilots_request, + build_pilots_get_pilot_jobs_request, build_pilots_search_request, build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, @@ -2286,15 +2286,36 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I return deserialized # type: ignore @distributed_trace_async - async def delete_pilots(self, *, pilot_stamps: List[str], **kwargs: Any) -> None: + async def delete_pilots( + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: """Delete Pilots. Endpoint to delete a pilot. - If at least one pilot is not found, it WILL rollback. + Two features: + - :keyword pilot_stamps: Stamps of the pilots we want to delete. Required. + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + If deleting by stamps, if at least one pilot is not found, it WILL rollback. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -2314,6 +2335,8 @@ async def delete_pilots(self, *, pilot_stamps: List[str], **kwargs: Any) -> None _request = build_pilots_delete_pilots_request( pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, headers=_headers, params=_params, ) @@ -2435,20 +2458,15 @@ async def update_pilot_fields( return cls(pipeline_response, None, {}) # type: ignore @distributed_trace_async - async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any) -> None: - """Clear Pilots. + async def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: + """Get Pilot Jobs. - Endpoint for DIRAC to delete all pilots that lived more than age_in_days. + Endpoint only for DIRAC services, to get jobs of a pilot. - :keyword age_in_days: The number of days that define the maximum age of pilots to be - deleted.Pilots older than this age will be considered for deletion. Required. - :paramtype age_in_days: int - :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is - 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake. Default value is False. - :paramtype delete_only_aborted: bool - :return: None - :rtype: None + :param body: Required. + :type body: str + :return: list of int + :rtype: list[int] :raises ~azure.core.exceptions.HttpResponseError: """ error_map: MutableMapping = { @@ -2459,14 +2477,17 @@ async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = Fa } error_map.update(kwargs.pop("error_map", {}) or {}) - _headers = kwargs.pop("headers", {}) or {} + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - cls: ClsType[None] = kwargs.pop("cls", None) + content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) + cls: ClsType[List[int]] = kwargs.pop("cls", None) - _request = build_pilots_clear_pilots_request( - age_in_days=age_in_days, - delete_only_aborted=delete_only_aborted, + _content = self._serialize.body(body, "str") + + _request = build_pilots_get_pilot_jobs_request( + content_type=content_type, + content=_content, headers=_headers, params=_params, ) @@ -2479,12 +2500,16 @@ async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = Fa response = pipeline_response.http_response - if response.status_code not in [204]: + if response.status_code not in [200]: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) + deserialized = self._deserialize("[int]", pipeline_response.http_response) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @overload async def add_jobs_to_pilot( diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 81a5760bf..3e73d22fc 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -138,8 +138,14 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :vartype vo: str :ivar grid_type: Grid type of the pilots. :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str :ivar pilot_references: Association of a pilot reference with a pilot stamp. :vartype pilot_references: dict[str, str] + :ivar status_reason: Status reason of the pilots. + :vartype status_reason: str """ _validation = { @@ -151,7 +157,10 @@ class BodyPilotsAddPilotStamps(_serialization.Model): "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, "vo": {"key": "vo", "type": "str"}, "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "status_reason": {"key": "status_reason", "type": "str"}, } def __init__( @@ -160,7 +169,10 @@ def __init__( pilot_stamps: List[str], vo: str, grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", pilot_references: Optional[Dict[str, str]] = None, + status_reason: str = "Unknown", **kwargs: Any ) -> None: """ @@ -170,14 +182,23 @@ def __init__( :paramtype vo: str :keyword grid_type: Grid type of the pilots. :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str :keyword pilot_references: Association of a pilot reference with a pilot stamp. :paramtype pilot_references: dict[str, str] + :keyword status_reason: Status reason of the pilots. + :paramtype status_reason: str """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps self.vo = vo self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site self.pilot_references = pilot_references + self.status_reason = status_reason class BodyPilotsUpdatePilotFields(_serialization.Model): diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 9554ba7fe..cc244e8aa 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -597,7 +597,7 @@ def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: accept = _headers.pop("Accept", "application/json") # Construct URL - _url = "/api/pilots/management/pilot" + _url = "/api/pilots/" # Construct headers if content_type is not None: @@ -607,14 +607,25 @@ def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) -def build_pilots_delete_pilots_request(*, pilot_stamps: List[str], **kwargs: Any) -> HttpRequest: +def build_pilots_delete_pilots_request( + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any +) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) # Construct URL - _url = "/api/pilots/management/pilot" + _url = "/api/pilots/" # Construct parameters - _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if pilot_stamps is not None: + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if age_in_days is not None: + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) @@ -624,7 +635,7 @@ def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) # Construct URL - _url = "/api/pilots/management/pilot" + _url = "/api/pilots/metadata" # Construct headers if content_type is not None: @@ -633,20 +644,21 @@ def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) -def build_pilots_clear_pilots_request( - *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any -) -> HttpRequest: - _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) +def build_pilots_get_pilot_jobs_request(*, content: str, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") # Construct URL - _url = "/api/pilots/management/pilot/interval" + _url = "/api/pilots/jobs" - # Construct parameters - _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") - if delete_only_aborted is not None: - _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + return HttpRequest(method="GET", url=_url, headers=_headers, content=content, **kwargs) def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: @@ -654,7 +666,7 @@ def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) # Construct URL - _url = "/api/pilots/management/jobs" + _url = "/api/pilots/jobs" # Construct headers if content_type is not None: @@ -2901,16 +2913,35 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte @distributed_trace def delete_pilots( # pylint: disable=inconsistent-return-statements - self, *, pilot_stamps: List[str], **kwargs: Any + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any ) -> None: """Delete Pilots. Endpoint to delete a pilot. - If at least one pilot is not found, it WILL rollback. + Two features: + - :keyword pilot_stamps: Stamps of the pilots we want to delete. Required. + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + If deleting by stamps, if at least one pilot is not found, it WILL rollback. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -2930,6 +2961,8 @@ def delete_pilots( # pylint: disable=inconsistent-return-statements _request = build_pilots_delete_pilots_request( pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, headers=_headers, params=_params, ) @@ -3049,22 +3082,15 @@ def update_pilot_fields( # pylint: disable=inconsistent-return-statements return cls(pipeline_response, None, {}) # type: ignore @distributed_trace - def clear_pilots( # pylint: disable=inconsistent-return-statements - self, *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any - ) -> None: - """Clear Pilots. + def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: + """Get Pilot Jobs. - Endpoint for DIRAC to delete all pilots that lived more than age_in_days. + Endpoint only for DIRAC services, to get jobs of a pilot. - :keyword age_in_days: The number of days that define the maximum age of pilots to be - deleted.Pilots older than this age will be considered for deletion. Required. - :paramtype age_in_days: int - :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is - 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake. Default value is False. - :paramtype delete_only_aborted: bool - :return: None - :rtype: None + :param body: Required. + :type body: str + :return: list of int + :rtype: list[int] :raises ~azure.core.exceptions.HttpResponseError: """ error_map: MutableMapping = { @@ -3075,14 +3101,17 @@ def clear_pilots( # pylint: disable=inconsistent-return-statements } error_map.update(kwargs.pop("error_map", {}) or {}) - _headers = kwargs.pop("headers", {}) or {} + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - cls: ClsType[None] = kwargs.pop("cls", None) + content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) + cls: ClsType[List[int]] = kwargs.pop("cls", None) - _request = build_pilots_clear_pilots_request( - age_in_days=age_in_days, - delete_only_aborted=delete_only_aborted, + _content = self._serialize.body(body, "str") + + _request = build_pilots_get_pilot_jobs_request( + content_type=content_type, + content=_content, headers=_headers, params=_params, ) @@ -3095,12 +3124,16 @@ def clear_pilots( # pylint: disable=inconsistent-return-statements response = pipeline_response.http_response - if response.status_code not in [204]: + if response.status_code not in [200]: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) + deserialized = self._deserialize("[int]", pipeline_response.http_response) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @overload def add_jobs_to_pilot( diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index a642c9969..be617ec88 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -95,42 +95,20 @@ async def add_pilot_stamps( @router.delete("/", status_code=HTTPStatus.NO_CONTENT) async def delete_pilots( - pilot_stamps: Annotated[ - list[str], Query(description="Stamps of the pilots we want to delete.") - ], pilot_db: PilotAgentsDB, check_permissions: CheckPilotManagementPolicyCallable, -): - """Endpoint to delete a pilot. - - If at least one pilot is not found, it WILL rollback. - """ - await check_permissions( - action=ActionType.MANAGE_PILOTS, - ) - - try: - await delete_pilots_by_stamps(pilot_db=pilot_db, pilot_stamps=pilot_stamps) - except PilotNotFoundError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="At least one pilot has not been found.", - ) from e - - -@router.delete("/management/pilot/interval", status_code=HTTPStatus.NO_CONTENT) -async def clear_pilots( - pilot_db: PilotAgentsDB, + pilot_stamps: Annotated[ + list[str] | None, Query(description="Stamps of the pilots we want to delete.") + ] = None, age_in_days: Annotated[ - int, + int | None, Query( description=( "The number of days that define the maximum age of pilots to be deleted." "Pilots older than this age will be considered for deletion." ) ), - ], - check_permissions: CheckPilotManagementPolicyCallable, + ] = None, delete_only_aborted: Annotated[ bool, Query( @@ -138,25 +116,45 @@ async def clear_pilots( "Flag indicating whether to only delete pilots whose status is 'Aborted'." "If set to True, only pilots with the 'Aborted' status will be deleted." "It is set by default as True to avoid any mistake." + "This flag is only used for deletion by time." ) ), ] = False, ): - """Endpoint for DIRAC to delete all pilots that lived more than age_in_days.""" - await check_permissions(ActionType.MANAGE_PILOTS) + """Endpoint to delete a pilot. + + Two features: + + 1. Or you provide pilot_stamps, so you can delete pilots by their stamp + 2. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. - if age_in_days < 0: + If deleting by stamps, if at least one pilot is not found, it WILL rollback. + """ + await check_permissions( + action=ActionType.MANAGE_PILOTS, + ) + + if pilot_stamps: + try: + await delete_pilots_by_stamps(pilot_db=pilot_db, pilot_stamps=pilot_stamps) + except PilotNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one pilot has not been found.", + ) from e + + elif age_in_days: + await clear_pilots_bl( + pilot_db=pilot_db, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + ) + else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="age_in_days must be positive.", + detail="You must provide either age_in_days or pilot_stamps.", ) - await clear_pilots_bl( - pilot_db=pilot_db, - age_in_days=age_in_days, - delete_only_aborted=delete_only_aborted, - ) - EXAMPLE_UPDATE_FIELDS = { "Update the BenchMark field": { diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 4c32f530c..d18b54a4d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -57,8 +57,8 @@ build_lollygag_insert_owner_object_request, build_pilots_add_jobs_to_pilot_request, build_pilots_add_pilot_stamps_request, - build_pilots_clear_pilots_request, build_pilots_delete_pilots_request, + build_pilots_get_pilot_jobs_request, build_pilots_search_request, build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, @@ -2453,15 +2453,36 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I return deserialized # type: ignore @distributed_trace_async - async def delete_pilots(self, *, pilot_stamps: List[str], **kwargs: Any) -> None: + async def delete_pilots( + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: """Delete Pilots. Endpoint to delete a pilot. - If at least one pilot is not found, it WILL rollback. + Two features: + - :keyword pilot_stamps: Stamps of the pilots we want to delete. Required. + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + If deleting by stamps, if at least one pilot is not found, it WILL rollback. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -2481,6 +2502,8 @@ async def delete_pilots(self, *, pilot_stamps: List[str], **kwargs: Any) -> None _request = build_pilots_delete_pilots_request( pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, headers=_headers, params=_params, ) @@ -2602,20 +2625,15 @@ async def update_pilot_fields( return cls(pipeline_response, None, {}) # type: ignore @distributed_trace_async - async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any) -> None: - """Clear Pilots. + async def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: + """Get Pilot Jobs. - Endpoint for DIRAC to delete all pilots that lived more than age_in_days. + Endpoint only for DIRAC services, to get jobs of a pilot. - :keyword age_in_days: The number of days that define the maximum age of pilots to be - deleted.Pilots older than this age will be considered for deletion. Required. - :paramtype age_in_days: int - :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is - 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake. Default value is False. - :paramtype delete_only_aborted: bool - :return: None - :rtype: None + :param body: Required. + :type body: str + :return: list of int + :rtype: list[int] :raises ~azure.core.exceptions.HttpResponseError: """ error_map: MutableMapping = { @@ -2626,14 +2644,17 @@ async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = Fa } error_map.update(kwargs.pop("error_map", {}) or {}) - _headers = kwargs.pop("headers", {}) or {} + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - cls: ClsType[None] = kwargs.pop("cls", None) + content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) + cls: ClsType[List[int]] = kwargs.pop("cls", None) - _request = build_pilots_clear_pilots_request( - age_in_days=age_in_days, - delete_only_aborted=delete_only_aborted, + _content = self._serialize.body(body, "str") + + _request = build_pilots_get_pilot_jobs_request( + content_type=content_type, + content=_content, headers=_headers, params=_params, ) @@ -2646,12 +2667,16 @@ async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = Fa response = pipeline_response.http_response - if response.status_code not in [204]: + if response.status_code not in [200]: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) + deserialized = self._deserialize("[int]", pipeline_response.http_response) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @overload async def add_jobs_to_pilot( diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 7724f2823..5746318c9 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -138,8 +138,14 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :vartype vo: str :ivar grid_type: Grid type of the pilots. :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str :ivar pilot_references: Association of a pilot reference with a pilot stamp. :vartype pilot_references: dict[str, str] + :ivar status_reason: Status reason of the pilots. + :vartype status_reason: str """ _validation = { @@ -151,7 +157,10 @@ class BodyPilotsAddPilotStamps(_serialization.Model): "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, "vo": {"key": "vo", "type": "str"}, "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "status_reason": {"key": "status_reason", "type": "str"}, } def __init__( @@ -160,7 +169,10 @@ def __init__( pilot_stamps: List[str], vo: str, grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", pilot_references: Optional[Dict[str, str]] = None, + status_reason: str = "Unknown", **kwargs: Any ) -> None: """ @@ -170,14 +182,23 @@ def __init__( :paramtype vo: str :keyword grid_type: Grid type of the pilots. :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str :keyword pilot_references: Association of a pilot reference with a pilot stamp. :paramtype pilot_references: dict[str, str] + :keyword status_reason: Status reason of the pilots. + :paramtype status_reason: str """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps self.vo = vo self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site self.pilot_references = pilot_references + self.status_reason = status_reason class BodyPilotsUpdatePilotFields(_serialization.Model): diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index b677a98cb..4a6ee360d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -646,7 +646,7 @@ def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: accept = _headers.pop("Accept", "application/json") # Construct URL - _url = "/api/pilots/management/pilot" + _url = "/api/pilots/" # Construct headers if content_type is not None: @@ -656,14 +656,25 @@ def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) -def build_pilots_delete_pilots_request(*, pilot_stamps: List[str], **kwargs: Any) -> HttpRequest: +def build_pilots_delete_pilots_request( + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any +) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) # Construct URL - _url = "/api/pilots/management/pilot" + _url = "/api/pilots/" # Construct parameters - _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if pilot_stamps is not None: + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if age_in_days is not None: + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) @@ -673,7 +684,7 @@ def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) # Construct URL - _url = "/api/pilots/management/pilot" + _url = "/api/pilots/metadata" # Construct headers if content_type is not None: @@ -682,20 +693,21 @@ def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) -def build_pilots_clear_pilots_request( - *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any -) -> HttpRequest: - _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) +def build_pilots_get_pilot_jobs_request(*, content: str, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") # Construct URL - _url = "/api/pilots/management/pilot/interval" + _url = "/api/pilots/jobs" - # Construct parameters - _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") - if delete_only_aborted is not None: - _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + return HttpRequest(method="GET", url=_url, headers=_headers, content=content, **kwargs) def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: @@ -703,7 +715,7 @@ def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) # Construct URL - _url = "/api/pilots/management/jobs" + _url = "/api/pilots/jobs" # Construct headers if content_type is not None: @@ -3114,16 +3126,35 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte @distributed_trace def delete_pilots( # pylint: disable=inconsistent-return-statements - self, *, pilot_stamps: List[str], **kwargs: Any + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any ) -> None: """Delete Pilots. Endpoint to delete a pilot. - If at least one pilot is not found, it WILL rollback. + Two features: + - :keyword pilot_stamps: Stamps of the pilots we want to delete. Required. + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + If deleting by stamps, if at least one pilot is not found, it WILL rollback. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -3143,6 +3174,8 @@ def delete_pilots( # pylint: disable=inconsistent-return-statements _request = build_pilots_delete_pilots_request( pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, headers=_headers, params=_params, ) @@ -3262,22 +3295,15 @@ def update_pilot_fields( # pylint: disable=inconsistent-return-statements return cls(pipeline_response, None, {}) # type: ignore @distributed_trace - def clear_pilots( # pylint: disable=inconsistent-return-statements - self, *, age_in_days: int, delete_only_aborted: bool = False, **kwargs: Any - ) -> None: - """Clear Pilots. + def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: + """Get Pilot Jobs. - Endpoint for DIRAC to delete all pilots that lived more than age_in_days. + Endpoint only for DIRAC services, to get jobs of a pilot. - :keyword age_in_days: The number of days that define the maximum age of pilots to be - deleted.Pilots older than this age will be considered for deletion. Required. - :paramtype age_in_days: int - :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is - 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake. Default value is False. - :paramtype delete_only_aborted: bool - :return: None - :rtype: None + :param body: Required. + :type body: str + :return: list of int + :rtype: list[int] :raises ~azure.core.exceptions.HttpResponseError: """ error_map: MutableMapping = { @@ -3288,14 +3314,17 @@ def clear_pilots( # pylint: disable=inconsistent-return-statements } error_map.update(kwargs.pop("error_map", {}) or {}) - _headers = kwargs.pop("headers", {}) or {} + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - cls: ClsType[None] = kwargs.pop("cls", None) + content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) + cls: ClsType[List[int]] = kwargs.pop("cls", None) - _request = build_pilots_clear_pilots_request( - age_in_days=age_in_days, - delete_only_aborted=delete_only_aborted, + _content = self._serialize.body(body, "str") + + _request = build_pilots_get_pilot_jobs_request( + content_type=content_type, + content=_content, headers=_headers, params=_params, ) @@ -3308,12 +3337,16 @@ def clear_pilots( # pylint: disable=inconsistent-return-statements response = pipeline_response.http_response - if response.status_code not in [204]: + if response.status_code not in [200]: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) + deserialized = self._deserialize("[int]", pipeline_response.http_response) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @overload def add_jobs_to_pilot( From 9df63006774ee5a9c5129cee4ad6c424541a2883 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Thu, 26 Jun 2025 17:18:31 +0200 Subject: [PATCH 11/33] feat: We can search for job associated to a pilot and vice versa --- .../_generated/aio/operations/_operations.py | 19 ++-- .../client/_generated/models/_models.py | 16 ++-- .../_generated/operations/_operations.py | 35 ++++--- .../src/diracx/logic/pilots/management.py | 18 ++-- diracx-logic/src/diracx/logic/pilots/query.py | 18 ++++ .../src/diracx/routers/pilots/management.py | 32 +++---- .../tests/pilots/test_pilot_creation.py | 93 ++++++++++++++++++- .../_generated/aio/operations/_operations.py | 19 ++-- .../client/_generated/models/_models.py | 16 ++-- .../_generated/operations/_operations.py | 35 ++++--- 10 files changed, 213 insertions(+), 88 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 86a54ccab..232bcec5f 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -2458,13 +2458,17 @@ async def update_pilot_fields( return cls(pipeline_response, None, {}) # type: ignore @distributed_trace_async - async def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: + async def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: """Get Pilot Jobs. Endpoint only for DIRAC services, to get jobs of a pilot. - :param body: Required. - :type body: str + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int :return: list of int :rtype: list[int] :raises ~azure.core.exceptions.HttpResponseError: @@ -2477,17 +2481,14 @@ async def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: } error_map.update(kwargs.pop("error_map", {}) or {}) - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _headers = kwargs.pop("headers", {}) or {} _params = kwargs.pop("params", {}) or {} - content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[List[int]] = kwargs.pop("cls", None) - _content = self._serialize.body(body, "str") - _request = build_pilots_get_pilot_jobs_request( - content_type=content_type, - content=_content, + pilot_stamp=pilot_stamp, + job_id=job_id, headers=_headers, params=_params, ) diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 3e73d22fc..73677b300 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -101,30 +101,30 @@ class BodyPilotsAddJobsToPilot(_serialization.Model): :ivar pilot_stamp: The stamp of the pilot. Required. :vartype pilot_stamp: str - :ivar pilot_jobs_ids: The jobs we want to add to the pilot. Required. - :vartype pilot_jobs_ids: list[int] + :ivar job_ids: The jobs we want to add to the pilot. Required. + :vartype job_ids: list[int] """ _validation = { "pilot_stamp": {"required": True}, - "pilot_jobs_ids": {"required": True}, + "job_ids": {"required": True}, } _attribute_map = { "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, - "pilot_jobs_ids": {"key": "pilot_jobs_ids", "type": "[int]"}, + "job_ids": {"key": "job_ids", "type": "[int]"}, } - def __init__(self, *, pilot_stamp: str, pilot_jobs_ids: List[int], **kwargs: Any) -> None: + def __init__(self, *, pilot_stamp: str, job_ids: List[int], **kwargs: Any) -> None: """ :keyword pilot_stamp: The stamp of the pilot. Required. :paramtype pilot_stamp: str - :keyword pilot_jobs_ids: The jobs we want to add to the pilot. Required. - :paramtype pilot_jobs_ids: list[int] + :keyword job_ids: The jobs we want to add to the pilot. Required. + :paramtype job_ids: list[int] """ super().__init__(**kwargs) self.pilot_stamp = pilot_stamp - self.pilot_jobs_ids = pilot_jobs_ids + self.job_ids = job_ids class BodyPilotsAddPilotStamps(_serialization.Model): diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index cc244e8aa..fda6ff2ce 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -644,21 +644,27 @@ def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) -def build_pilots_get_pilot_jobs_request(*, content: str, **kwargs: Any) -> HttpRequest: +def build_pilots_get_pilot_jobs_request( + *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any +) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") # Construct URL _url = "/api/pilots/jobs" + # Construct parameters + if pilot_stamp is not None: + _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") + if job_id is not None: + _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") + # Construct headers - if content_type is not None: - _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest(method="GET", url=_url, headers=_headers, content=content, **kwargs) + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: @@ -3082,13 +3088,17 @@ def update_pilot_fields( # pylint: disable=inconsistent-return-statements return cls(pipeline_response, None, {}) # type: ignore @distributed_trace - def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: + def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: """Get Pilot Jobs. Endpoint only for DIRAC services, to get jobs of a pilot. - :param body: Required. - :type body: str + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int :return: list of int :rtype: list[int] :raises ~azure.core.exceptions.HttpResponseError: @@ -3101,17 +3111,14 @@ def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: } error_map.update(kwargs.pop("error_map", {}) or {}) - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _headers = kwargs.pop("headers", {}) or {} _params = kwargs.pop("params", {}) or {} - content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[List[int]] = kwargs.pop("cls", None) - _content = self._serialize.body(body, "str") - _request = build_pilots_get_pilot_jobs_request( - content_type=content_type, - content=_content, + pilot_stamp=pilot_stamp, + job_id=job_id, headers=_headers, params=_params, ) diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index 82c9daad5..2e96237fe 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone -from diracx.core.exceptions import PilotAlreadyExistsError +from diracx.core.exceptions import PilotAlreadyExistsError, PilotNotFoundError from diracx.core.models import PilotFieldsMapping from diracx.db.sql import PilotAgentsDB @@ -68,7 +68,7 @@ async def update_pilots_fields( async def add_jobs_to_pilot( - pilot_db: PilotAgentsDB, pilot_stamp: str, pilot_jobs_ids: list[int] + pilot_db: PilotAgentsDB, pilot_stamp: str, job_ids: list[int] ): pilot_ids = await get_pilot_ids_by_stamps( pilot_db=pilot_db, pilot_stamps=[pilot_stamp] @@ -79,8 +79,7 @@ async def add_jobs_to_pilot( # Prepare the list of dictionaries for bulk insertion job_to_pilot_mapping = [ - {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} - for job_id in pilot_jobs_ids + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in job_ids ] await pilot_db.add_jobs_to_pilot( @@ -92,9 +91,12 @@ async def get_pilot_jobs_ids_by_stamp( pilot_db: PilotAgentsDB, pilot_stamp: str ) -> list[int]: """Fetch pilot jobs by stamp.""" - pilot_ids = await get_pilot_ids_by_stamps( - pilot_db=pilot_db, pilot_stamps=[pilot_stamp] - ) - pilot_id = pilot_ids[0] + try: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] + except PilotNotFoundError: + return [] return await get_pilot_jobs_ids_by_pilot_id(pilot_db=pilot_db, pilot_id=pilot_id) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index 2db667f07..223352d5c 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -122,3 +122,21 @@ async def get_pilot_jobs_ids_by_pilot_id( ) return [job["JobID"] for job in jobs] + + +async def get_pilot_ids_by_job_id(pilot_db: PilotAgentsDB, job_id: int) -> list[int]: + _, pilots = await pilot_db.search_pilot_to_job_mapping( + parameters=["PilotID"], + search=[ + ScalarSearchSpec( + parameter="JobID", + operator=ScalarSearchOperator.EQUAL, + value=job_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [pilot["PilotID"] for pilot in pilots] diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index be617ec88..25b98db14 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -26,6 +26,7 @@ register_new_pilots, update_pilots_fields, ) +from diracx.logic.pilots.query import get_pilot_ids_by_job_id from ..dependencies import PilotAgentsDB from ..fastapi_classes import DiracxRouter @@ -208,47 +209,46 @@ async def update_pilot_fields( @router.get("/jobs") async def get_pilot_jobs( pilot_db: PilotAgentsDB, - pilot_stamp: Annotated[str, Body(description="The stamp of the pilot.")], check_permissions: CheckPilotManagementPolicyCallable, + pilot_stamp: Annotated[ + str | None, Query(description="The stamp of the pilot.") + ] = None, + job_id: Annotated[int | None, Query(description="The ID of the job.")] = None, ) -> list[int]: """Endpoint only for DIRAC services, to get jobs of a pilot.""" - # FIXME: To be tested await check_permissions(action=ActionType.READ_PILOT_FIELDS) - try: + if pilot_stamp: return await get_pilot_jobs_ids_by_stamp( pilot_db=pilot_db, pilot_stamp=pilot_stamp, ) - except PilotNotFoundError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="This pilot does not exist." - ) from e - except PilotAlreadyAssociatedWithJobError as e: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail="This pilot is already associated with this job.", - ) from e + elif job_id: + return await get_pilot_ids_by_job_id(pilot_db=pilot_db, job_id=job_id) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="You must provide either pilot_stamp or job_id", + ) @router.patch("/jobs", status_code=HTTPStatus.NO_CONTENT) async def add_jobs_to_pilot( pilot_db: PilotAgentsDB, pilot_stamp: Annotated[str, Body(description="The stamp of the pilot.")], - pilot_jobs_ids: Annotated[ + job_ids: Annotated[ list[int], Body(description="The jobs we want to add to the pilot.") ], check_permissions: CheckPilotManagementPolicyCallable, ): """Endpoint only for DIRAC services, to associate a pilot with a job.""" - # FIXME: To be tested - await check_permissions(ActionType.MANAGE_PILOTS) + await check_permissions(action=ActionType.MANAGE_PILOTS) try: await add_jobs_to_pilot_bl( pilot_db=pilot_db, pilot_stamp=pilot_stamp, - pilot_jobs_ids=pilot_jobs_ids, + job_ids=job_ids, ) except PilotNotFoundError as e: raise HTTPException( diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index 0de071e44..6b9d25221 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -1,8 +1,14 @@ from __future__ import annotations import pytest +from fastapi.testclient import TestClient -from diracx.core.models import PilotFieldsMapping, PilotStatus +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, +) pytestmark = pytest.mark.enabled_dependencies( [ @@ -177,7 +183,7 @@ async def test_create_pilot_and_modify_it(normal_test_client): assert pilot2["Status"] != pilot1["Status"] -async def test_associate_job_with_pilot_and_get_it(normal_test_client): +async def test_associate_job_with_pilot_and_get_it(normal_test_client: TestClient): pilot_stamps = ["stamps_1", "stamp_2"] # -------------- Insert -------------- @@ -192,3 +198,86 @@ async def test_associate_job_with_pilot_and_get_it(normal_test_client): assert r.status_code == 200, r.json() # --------------- As DIRAC, associate a job with a pilot -------- + job_ids = [1, 2] + body = {"pilot_stamp": pilot_stamps[0], "job_ids": job_ids} + + # Create pilots + r = normal_test_client.patch( + "/api/pilots/jobs", + json=body, + ) + + assert r.status_code == 204 + + # -------------- Redo it, expect 409 (Conflict) --------------------- + job_ids = [1, 2, 3] # Note for next test : add 3 + body = {"pilot_stamp": pilot_stamps[0], "job_ids": job_ids} + + # Create pilots + r = normal_test_client.patch( + "/api/pilots/jobs", + json=body, + ) + + assert r.status_code == 409 + + # -------------- Add 3 --------------------- + body = {"pilot_stamp": pilot_stamps[0], "job_ids": [3]} + + # Create pilots + r = normal_test_client.patch( + "/api/pilots/jobs", + json=body, + ) + + assert r.status_code == 204 + + # -------------- Add with unknown pilot --------------------- + body = {"pilot_stamp": "stampounet", "job_ids": job_ids} + + # Create pilots + r = normal_test_client.patch( + "/api/pilots/jobs", + json=body, + ) + + assert r.status_code == 400 + + # -------------- Get its jobs --------------------- + r = normal_test_client.get( + "/api/pilots/jobs", params={"pilot_stamp": pilot_stamps[0]} + ) + + assert r.status_code == 200 + assert r.json() == job_ids + + # -------------- Get the other pilot's jobs --------------------- + r = normal_test_client.get( + "/api/pilots/jobs", params={"pilot_stamp": pilot_stamps[1]} + ) + + assert r.status_code == 200 + assert r.json() == [] + + # -------------- Get pilots associated to job 1 --------------------- + r = normal_test_client.get("/api/pilots/jobs", params={"job_id": job_ids[0]}) + + assert r.status_code == 200, r.json() + assert len(r.json()) == 1 + expected_pilot_id = r.json()[0] + + # -------------- Get pilot info to verify that its id is expected_pilot_id --------------------- + condition = ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=expected_pilot_id, + ) + + r = normal_test_client.post( + "/api/pilots/management/search", + json={"parameters": [], "search": [condition], "sorts": []}, + ) + + assert r.status_code == 200, r.json() + assert len(r.json()) == 1 + assert r.json()[0]["PilotStamp"] == pilot_stamps[0] diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index d18b54a4d..29d676656 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -2625,13 +2625,17 @@ async def update_pilot_fields( return cls(pipeline_response, None, {}) # type: ignore @distributed_trace_async - async def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: + async def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: """Get Pilot Jobs. Endpoint only for DIRAC services, to get jobs of a pilot. - :param body: Required. - :type body: str + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int :return: list of int :rtype: list[int] :raises ~azure.core.exceptions.HttpResponseError: @@ -2644,17 +2648,14 @@ async def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: } error_map.update(kwargs.pop("error_map", {}) or {}) - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _headers = kwargs.pop("headers", {}) or {} _params = kwargs.pop("params", {}) or {} - content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[List[int]] = kwargs.pop("cls", None) - _content = self._serialize.body(body, "str") - _request = build_pilots_get_pilot_jobs_request( - content_type=content_type, - content=_content, + pilot_stamp=pilot_stamp, + job_id=job_id, headers=_headers, params=_params, ) diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 5746318c9..cd97b272f 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -101,30 +101,30 @@ class BodyPilotsAddJobsToPilot(_serialization.Model): :ivar pilot_stamp: The stamp of the pilot. Required. :vartype pilot_stamp: str - :ivar pilot_jobs_ids: The jobs we want to add to the pilot. Required. - :vartype pilot_jobs_ids: list[int] + :ivar job_ids: The jobs we want to add to the pilot. Required. + :vartype job_ids: list[int] """ _validation = { "pilot_stamp": {"required": True}, - "pilot_jobs_ids": {"required": True}, + "job_ids": {"required": True}, } _attribute_map = { "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, - "pilot_jobs_ids": {"key": "pilot_jobs_ids", "type": "[int]"}, + "job_ids": {"key": "job_ids", "type": "[int]"}, } - def __init__(self, *, pilot_stamp: str, pilot_jobs_ids: List[int], **kwargs: Any) -> None: + def __init__(self, *, pilot_stamp: str, job_ids: List[int], **kwargs: Any) -> None: """ :keyword pilot_stamp: The stamp of the pilot. Required. :paramtype pilot_stamp: str - :keyword pilot_jobs_ids: The jobs we want to add to the pilot. Required. - :paramtype pilot_jobs_ids: list[int] + :keyword job_ids: The jobs we want to add to the pilot. Required. + :paramtype job_ids: list[int] """ super().__init__(**kwargs) self.pilot_stamp = pilot_stamp - self.pilot_jobs_ids = pilot_jobs_ids + self.job_ids = job_ids class BodyPilotsAddPilotStamps(_serialization.Model): diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 4a6ee360d..9243f0619 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -693,21 +693,27 @@ def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) -def build_pilots_get_pilot_jobs_request(*, content: str, **kwargs: Any) -> HttpRequest: +def build_pilots_get_pilot_jobs_request( + *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any +) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") # Construct URL _url = "/api/pilots/jobs" + # Construct parameters + if pilot_stamp is not None: + _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") + if job_id is not None: + _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") + # Construct headers - if content_type is not None: - _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest(method="GET", url=_url, headers=_headers, content=content, **kwargs) + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: @@ -3295,13 +3301,17 @@ def update_pilot_fields( # pylint: disable=inconsistent-return-statements return cls(pipeline_response, None, {}) # type: ignore @distributed_trace - def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: + def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: """Get Pilot Jobs. Endpoint only for DIRAC services, to get jobs of a pilot. - :param body: Required. - :type body: str + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int :return: list of int :rtype: list[int] :raises ~azure.core.exceptions.HttpResponseError: @@ -3314,17 +3324,14 @@ def get_pilot_jobs(self, body: str, **kwargs: Any) -> List[int]: } error_map.update(kwargs.pop("error_map", {}) or {}) - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _headers = kwargs.pop("headers", {}) or {} _params = kwargs.pop("params", {}) or {} - content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[List[int]] = kwargs.pop("cls", None) - _content = self._serialize.body(body, "str") - _request = build_pilots_get_pilot_jobs_request( - content_type=content_type, - content=_content, + pilot_stamp=pilot_stamp, + job_id=job_id, headers=_headers, params=_params, ) From 65346df26fedd96debe94ee14ccd9cd6f8f59602 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Fri, 27 Jun 2025 09:41:42 +0200 Subject: [PATCH 12/33] feat: Better pilot deletion (following DIRAC behaviour) --- diracx-core/src/diracx/core/models.py | 2 +- diracx-db/src/diracx/db/sql/pilots/db.py | 40 ++--- .../tests/pilots/test_pilot_management.py | 124 +-------------- .../src/diracx/logic/pilots/management.py | 33 ++-- diracx-logic/src/diracx/logic/pilots/query.py | 37 ++++- .../src/diracx/routers/pilots/management.py | 31 ++-- .../tests/pilots/test_pilot_creation.py | 143 ++++++++++++++++++ 7 files changed, 231 insertions(+), 179 deletions(-) diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index 6ed7cd9ff..fb14e2738 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -29,7 +29,7 @@ class VectorSearchOperator(StrEnum): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str | int + value: str | int | datetime class VectorSearchSpec(TypedDict): diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index a175ab866..de795729c 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -13,7 +13,6 @@ ) from diracx.core.models import ( PilotFieldsMapping, - PilotStatus, SearchSpec, SortSpec, ) @@ -25,6 +24,7 @@ JobToPilotMapping, PilotAgents, PilotAgentsDBBase, + PilotOutput, ) @@ -121,35 +121,25 @@ async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): # ----------------------------- Delete Functions ----------------------------- - async def delete_pilots_by_stamps(self, pilot_stamps: list[str]): - """Bulk delete pilots. + async def delete_pilots(self, pilot_ids: list[int]): + """Destructive function. Delete pilots.""" + stmt = delete(PilotAgents).where(PilotAgents.pilot_id.in_(pilot_ids)) - Raises PilotNotFound if one of the pilot was not found. - """ - stmt = delete(PilotAgents).where(PilotAgents.pilot_stamp.in_(pilot_stamps)) - - res = await self.conn.execute(stmt) - - if res.rowcount != len(pilot_stamps): - raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)}) + await self.conn.execute(stmt) - async def clear_pilots( - self, cutoff_date: datetime, delete_only_aborted: bool = False - ) -> int: - """Bulk delete pilots that have SubmissionTime before the 'cutoff_date'. - Returns the number of deletion. - """ - # TODO: Add test (Millisec?) - stmt = delete(PilotAgents).where(PilotAgents.submission_time < cutoff_date) + async def remove_jobs_to_pilots(self, pilot_ids: list[int]): + """Destructive function. De-associate jobs and pilots.""" + stmt = delete(JobToPilotMapping).where( + JobToPilotMapping.pilot_id.in_(pilot_ids) + ) - # If delete_only_aborted is True, add the condition for 'Status' being 'Aborted' - if delete_only_aborted: - stmt = stmt.where(PilotAgents.status == PilotStatus.ABORTED) + await self.conn.execute(stmt) - # Execute the statement - res = await self.conn.execute(stmt) + async def delete_pilot_logs(self, pilot_ids: list[int]): + """Destructive function. Remove logs from pilots.""" + stmt = delete(PilotOutput).where(PilotOutput.pilot_id.in_(pilot_ids)) - return res.rowcount + await self.conn.execute(stmt) # ----------------------------- Update Functions ----------------------------- diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py index ff09fdb74..1e7397b39 100644 --- a/diracx-db/tests/pilots/test_pilot_management.py +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -65,11 +65,11 @@ async def test_insert_and_delete(pilot_db: PilotAgentsDB): ) # Works, the pilots exists - await get_pilots_by_stamp(pilot_db, [stamps[0]]) + res = await get_pilots_by_stamp(pilot_db, [stamps[0]]) await get_pilots_by_stamp(pilot_db, [stamps[0]]) # We delete the first pilot - await pilot_db.delete_pilots_by_stamps([stamps[0]]) + await pilot_db.delete_pilots([res[0]["PilotID"]]) # We get the 2nd pilot that is not delete (no error) await get_pilots_by_stamp(pilot_db, [stamps[1]]) @@ -78,126 +78,6 @@ async def test_insert_and_delete(pilot_db: PilotAgentsDB): assert not await get_pilots_by_stamp(pilot_db, [stamps[0]]) -@pytest.mark.asyncio -async def test_insert_and_delete_only_old_aborted( - pilot_db: PilotAgentsDB, - create_old_pilots_environment, # noqa: F811 -): - non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( - create_old_pilots_environment - ) - - async with pilot_db as pilot_db: - # Delete all aborted that were born before 2020 - # Every aborted that are old may be delete - await pilot_db.clear_pilots(datetime(2020, 1, 1, tzinfo=timezone.utc), True) - - # Assert who still live - for normally_exiting_pilot_list in [ - non_aborted_recent, - aborted_recent, - non_aborted_very_old, - ]: - stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] - - await get_pilots_by_stamp(pilot_db, stamps) - - # Assert who normally does not live - for normally_deleted_pilot_list in [aborted_very_old]: - stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] - - assert not await get_pilots_by_stamp(pilot_db, stamps) - - -@pytest.mark.asyncio -async def test_insert_and_delete_old( - pilot_db: PilotAgentsDB, - create_old_pilots_environment, # noqa: F811 -): - non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( - create_old_pilots_environment - ) - - async with pilot_db as pilot_db: - # Delete all aborted that were born before 2020 - # Every aborted that are old may be delete - await pilot_db.clear_pilots(datetime(2020, 1, 1, tzinfo=timezone.utc), False) - - # Assert who still live - for normally_exiting_pilot_list in [ - non_aborted_recent, - aborted_recent, - ]: - stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] - - await get_pilots_by_stamp(pilot_db, stamps) - - # Assert who normally does not live - for normally_deleted_pilot_list in [ - aborted_very_old, - non_aborted_very_old, - ]: - stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] - - assert not await get_pilots_by_stamp(pilot_db, stamps) - - -@pytest.mark.asyncio -async def test_insert_and_delete_recent_only_aborted( - pilot_db: PilotAgentsDB, - create_old_pilots_environment, # noqa: F811 -): - non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( - create_old_pilots_environment - ) - - async with pilot_db as pilot_db: - # Delete all aborted that were born before 2020 - # Every aborted that are old may be delete - await pilot_db.clear_pilots(datetime(2025, 3, 10, tzinfo=timezone.utc), True) - - # Assert who still live - for normally_exiting_pilot_list in [non_aborted_recent, non_aborted_very_old]: - stamps = [pilot["PilotStamp"] for pilot in normally_exiting_pilot_list] - - await get_pilots_by_stamp(pilot_db, stamps) - - # Assert who normally does not live - for normally_deleted_pilot_list in [ - aborted_very_old, - aborted_recent, - ]: - stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] - - assert not await get_pilots_by_stamp(pilot_db, stamps) - - -@pytest.mark.asyncio -async def test_insert_and_delete_recent( - pilot_db: PilotAgentsDB, - create_old_pilots_environment, # noqa: F811 -): - non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old = ( - create_old_pilots_environment - ) - - async with pilot_db as pilot_db: - # Delete all aborted that were born before 2020 - # Every aborted that are old may be delete - await pilot_db.clear_pilots(datetime(2025, 3, 10, tzinfo=timezone.utc), False) - - # Assert who normally does not live - for normally_deleted_pilot_list in [ - aborted_very_old, - aborted_recent, - non_aborted_recent, - non_aborted_very_old, - ]: - stamps = [pilot["PilotStamp"] for pilot in normally_deleted_pilot_list] - - assert not await get_pilots_by_stamp(pilot_db, stamps) - - @pytest.mark.asyncio async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): async with pilot_db as pilot_db: diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index 2e96237fe..a1b7df0f2 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -7,6 +7,7 @@ from diracx.db.sql import PilotAgentsDB from .query import ( + get_outdated_pilots, get_pilot_ids_by_stamps, get_pilot_jobs_ids_by_pilot_id, get_pilots_by_stamp, @@ -46,19 +47,33 @@ async def register_new_pilots( ) -async def clear_pilots( - pilot_db: PilotAgentsDB, age_in_days: int, delete_only_aborted: bool +async def delete_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str] | None = None, + age_in_days: int | None = None, + delete_only_aborted: bool = True, ): - """Delete pilots that have been submitted before interval_in_days.""" - cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) + if pilot_stamps: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=pilot_stamps, allow_missing=True + ) + else: + assert age_in_days - await pilot_db.clear_pilots( - cutoff_date=cutoff_date, delete_only_aborted=delete_only_aborted - ) + cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) + + pilots = await get_outdated_pilots( + pilot_db=pilot_db, + cutoff_date=cutoff_date, + only_aborted=delete_only_aborted, + parameters=["PilotID"], + ) + pilot_ids = [pilot["PilotID"] for pilot in pilots] -async def delete_pilots_by_stamps(pilot_db: PilotAgentsDB, pilot_stamps: list[str]): - await pilot_db.delete_pilots_by_stamps(pilot_stamps) + await pilot_db.remove_jobs_to_pilots(pilot_ids) + await pilot_db.delete_pilot_logs(pilot_ids) + await pilot_db.delete_pilots(pilot_ids) async def update_pilots_fields( diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index 223352d5c..8cd57428a 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -1,12 +1,15 @@ from __future__ import annotations +from datetime import datetime from typing import Any from diracx.core.exceptions import PilotNotFoundError from diracx.core.models import ( + PilotStatus, ScalarSearchOperator, ScalarSearchSpec, SearchParams, + SearchSpec, VectorSearchOperator, VectorSearchSpec, ) @@ -92,13 +95,13 @@ async def get_pilots_by_stamp( async def get_pilot_ids_by_stamps( - pilot_db: PilotAgentsDB, pilot_stamps: list[str] + pilot_db: PilotAgentsDB, pilot_stamps: list[str], allow_missing=False ) -> list[int]: pilots = await get_pilots_by_stamp( pilot_db=pilot_db, pilot_stamps=pilot_stamps, parameters=["PilotID"], - allow_missing=False, + allow_missing=allow_missing, ) return [pilot["PilotID"] for pilot in pilots] @@ -140,3 +143,33 @@ async def get_pilot_ids_by_job_id(pilot_db: PilotAgentsDB, job_id: int) -> list[ ) return [pilot["PilotID"] for pilot in pilots] + + +async def get_outdated_pilots( + pilot_db: PilotAgentsDB, + cutoff_date: datetime, + only_aborted: bool = True, + parameters: list[str] = [], +): + query: list[SearchSpec] = [ + ScalarSearchSpec( + parameter="SubmissionTime", + operator=ScalarSearchOperator.LESS_THAN, + value=cutoff_date, + ) + ] + + if only_aborted: + query.append( + ScalarSearchSpec( + parameter="Status", + operator=ScalarSearchOperator.EQUAL, + value=PilotStatus.ABORTED, + ) + ) + + _, pilots = await pilot_db.search_pilots( + parameters=parameters, search=query, sorts=[] + ) + + return pilots diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 25b98db14..8fb2a3ea1 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -18,10 +18,9 @@ add_jobs_to_pilot as add_jobs_to_pilot_bl, ) from diracx.logic.pilots.management import ( - clear_pilots as clear_pilots_bl, + delete_pilots as delete_pilots_bl, ) from diracx.logic.pilots.management import ( - delete_pilots_by_stamps, get_pilot_jobs_ids_by_stamp, register_new_pilots, update_pilots_fields, @@ -129,33 +128,25 @@ async def delete_pilots( 1. Or you provide pilot_stamps, so you can delete pilots by their stamp 2. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. - If deleting by stamps, if at least one pilot is not found, it WILL rollback. + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. """ await check_permissions( action=ActionType.MANAGE_PILOTS, ) - if pilot_stamps: - try: - await delete_pilots_by_stamps(pilot_db=pilot_db, pilot_stamps=pilot_stamps) - except PilotNotFoundError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="At least one pilot has not been found.", - ) from e - - elif age_in_days: - await clear_pilots_bl( - pilot_db=pilot_db, - age_in_days=age_in_days, - delete_only_aborted=delete_only_aborted, - ) - else: + if not pilot_stamps and not age_in_days: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="You must provide either age_in_days or pilot_stamps.", + detail="pilot_stamps or age_in_days have to be provided.", ) + await delete_pilots_bl( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + ) + EXAMPLE_UPDATE_FIELDS = { "Update the BenchMark field": { diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index 6b9d25221..44c87e2ea 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -1,7 +1,10 @@ from __future__ import annotations +from datetime import datetime, timezone + import pytest from fastapi.testclient import TestClient +from sqlalchemy import update from diracx.core.models import ( PilotFieldsMapping, @@ -9,6 +12,8 @@ ScalarSearchOperator, ScalarSearchSpec, ) +from diracx.db.sql import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents pytestmark = pytest.mark.enabled_dependencies( [ @@ -281,3 +286,141 @@ async def test_associate_job_with_pilot_and_get_it(normal_test_client: TestClien assert r.status_code == 200, r.json() assert len(r.json()) == 1 assert r.json()[0]["PilotStamp"] == pilot_stamps[0] + + +@pytest.mark.asyncio +async def test_delete_pilots_by_age_and_stamp(normal_test_client): + # Generate 100 pilot stamps + pilot_stamps = [f"stamp_{i}" for i in range(100)] + + # -------------- Insert all pilots -------------- + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + r = normal_test_client.post("/api/pilots/", json=body) + assert r.status_code == 200, r.json() + + # -------------- Modify last 50 pilots' fields -------------- + to_modify = pilot_stamps[50:] + mappings = [] + for idx, stamp in enumerate(to_modify): + # First 25 of modified set to ABORTED, others to WAITING + status = PilotStatus.ABORTED if idx < 25 else PilotStatus.WAITING + mapping = PilotFieldsMapping( + PilotStamp=stamp, + BenchMark=idx + 0.1, + StatusReason=f"Reason_{idx}", + AccountingSent=(idx % 2 == 0), + Status=status, + ).model_dump(exclude_unset=True) + mappings.append(mapping) + + r = normal_test_client.patch( + "/api/pilots/metadata", + json={"pilot_stamps_to_fields_mapping": mappings}, + ) + assert r.status_code == 204 + + # -------------- Directly set SubmissionTime to March 14, 2003 for last 50 -------------- + old_date = datetime(2003, 3, 14, tzinfo=timezone.utc) + # Access DB session from normal_test_client fixtures + db = normal_test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + + async with db: + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(to_modify)) + .values(SubmissionTime=old_date) + ) + await db.conn.execute(stmt) + await db.conn.commit() + + # -------------- Verify all 100 pilots exist -------------- + search_body = {"parameters": [], "search": [], "sort": [], "distinct": True} + r = normal_test_client.post("/api/pilots/management/search", json=search_body) + assert r.status_code == 200, r.json() + assert len(r.json()) == 100 + + # -------------- 1) Delete only old aborted pilots (25 expected) -------------- + # age_in_days large enough to include 2003-03-14 + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15, "delete_only_aborted": True}, + ) + assert r.status_code == 204 + # Expect 75 remaining + r = normal_test_client.post("/api/pilots/management/search", json=search_body) + assert len(r.json()) == 75 + + # -------------- 2) Delete all old pilots (remaining 25 old) -------------- + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15}, + ) + assert r.status_code == 204 + + # Expect 50 remaining + r = normal_test_client.post("/api/pilots/management/search", json=search_body) + assert len(r.json()) == 50 + + # -------------- 3) Delete one recent pilot by stamp -------------- + one_stamp = pilot_stamps[10] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": [one_stamp]}) + assert r.status_code == 204 + # Expect 49 remaining + r = normal_test_client.post("/api/pilots/management/search", json=search_body) + assert len(r.json()) == 49 + + # -------------- 4) Delete all remaining pilots -------------- + # Collect remaining stamps + remaining = [p["PilotStamp"] for p in r.json()] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": remaining}) + assert r.status_code == 204 + # Expect none remaining + r = normal_test_client.post("/api/pilots/management/search", json=search_body) + assert r.status_code == 200 + assert len(r.json()) == 0 + + # -------------- 5) Attempt deleting unknown pilot, expect 400 -------------- + r = normal_test_client.delete( + "/api/pilots/", params={"pilot_stamps": ["unknown_stamp"]} + ) + assert r.status_code == 204 + + +@pytest.mark.asyncio +async def test_associate_two_pilots_share_jobs_and_delete_first(normal_test_client): + # 1) Create two pilots + pilot_stamps = ["stamp_1", "stamp_2"] + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + r = normal_test_client.post("/api/pilots/", json=body) + assert r.status_code == 200, r.json() + + # 2) Associate first pilot with jobs 1-10 + job_ids = list(range(1, 11)) + body = {"pilot_stamp": pilot_stamps[0], "job_ids": job_ids} + r = normal_test_client.patch("/api/pilots/jobs", json=body) + assert r.status_code == 204 + + # 3) Associate second pilot with the same jobs + body = {"pilot_stamp": pilot_stamps[1], "job_ids": job_ids} + r = normal_test_client.patch("/api/pilots/jobs", json=body) + assert r.status_code == 204 + + # 4) Delete first pilot + r = normal_test_client.delete( + "/api/pilots/", params={"pilot_stamps": [pilot_stamps[0]]} + ) + assert r.status_code == 204 + + # 5) Get jobs for pilot_1: expect empty list + r = normal_test_client.get( + "/api/pilots/jobs", params={"pilot_stamp": pilot_stamps[0]} + ) + assert r.status_code == 200 + assert r.json() == [] + + # 6) Get jobs for pilot_2: expect original job_ids + r = normal_test_client.get( + "/api/pilots/jobs", params={"pilot_stamp": pilot_stamps[1]} + ) + assert r.status_code == 200 + assert r.json() == job_ids From ec450b8d4f2d07c41bf553b34578093f1821e6cc Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Fri, 27 Jun 2025 09:45:27 +0200 Subject: [PATCH 13/33] fix: Generate client --- .../src/diracx/client/_generated/aio/operations/_operations.py | 2 +- .../src/diracx/client/_generated/operations/_operations.py | 2 +- .../src/gubbins/client/_generated/aio/operations/_operations.py | 2 +- .../src/gubbins/client/_generated/operations/_operations.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 232bcec5f..f11440961 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -2304,7 +2304,7 @@ async def delete_pilots( #. Or you provide pilot_stamps, so you can delete pilots by their stamp #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. - If deleting by stamps, if at least one pilot is not found, it WILL rollback. + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. :paramtype pilot_stamps: list[str] diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index fda6ff2ce..a7f5257e7 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -2936,7 +2936,7 @@ def delete_pilots( # pylint: disable=inconsistent-return-statements #. Or you provide pilot_stamps, so you can delete pilots by their stamp #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. - If deleting by stamps, if at least one pilot is not found, it WILL rollback. + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. :paramtype pilot_stamps: list[str] diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 29d676656..c22bd73db 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -2471,7 +2471,7 @@ async def delete_pilots( #. Or you provide pilot_stamps, so you can delete pilots by their stamp #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. - If deleting by stamps, if at least one pilot is not found, it WILL rollback. + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. :paramtype pilot_stamps: list[str] diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 9243f0619..1b099ca39 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -3149,7 +3149,7 @@ def delete_pilots( # pylint: disable=inconsistent-return-statements #. Or you provide pilot_stamps, so you can delete pilots by their stamp #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. - If deleting by stamps, if at least one pilot is not found, it WILL rollback. + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. :paramtype pilot_stamps: list[str] From 507a08c5493233f092b0634699db0f155b95ea33 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Fri, 27 Jun 2025 09:49:52 +0200 Subject: [PATCH 14/33] fix: Syntax error --- diracx-db/src/diracx/db/sql/pilots/db.py | 2 +- diracx-logic/src/diracx/logic/pilots/management.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index de795729c..efc9279a3 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -127,7 +127,7 @@ async def delete_pilots(self, pilot_ids: list[int]): await self.conn.execute(stmt) - async def remove_jobs_to_pilots(self, pilot_ids: list[int]): + async def remove_jobs_from_pilots(self, pilot_ids: list[int]): """Destructive function. De-associate jobs and pilots.""" stmt = delete(JobToPilotMapping).where( JobToPilotMapping.pilot_id.in_(pilot_ids) diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index a1b7df0f2..359947baf 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -71,7 +71,7 @@ async def delete_pilots( pilot_ids = [pilot["PilotID"] for pilot in pilots] - await pilot_db.remove_jobs_to_pilots(pilot_ids) + await pilot_db.remove_jobs_from_pilots(pilot_ids) await pilot_db.delete_pilot_logs(pilot_ids) await pilot_db.delete_pilots(pilot_ids) From 7e2231264d5ae59e4df4ad8e72132c190b2abafb Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Mon, 30 Jun 2025 15:09:16 +0200 Subject: [PATCH 15/33] fix: Removed from management --- diracx-routers/src/diracx/routers/pilots/management.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 8fb2a3ea1..582bd7053 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging from http import HTTPStatus from typing import Annotated @@ -37,8 +36,6 @@ router = DiracxRouter() -logger = logging.getLogger(__name__) - @router.post("/") async def add_pilot_stamps( @@ -51,7 +48,6 @@ async def add_pilot_stamps( str, Body(description="Virtual Organisation associated with the inserted pilots."), ], - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckPilotManagementPolicyCallable, grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", grid_site: Annotated[str, Body(description="Pilots grid site.")] = "Unknown", @@ -84,11 +80,6 @@ async def add_pilot_stamps( pilot_job_references=pilot_references, status_reason=status_reason, ) - - # Logs credentials creation - logger.debug( - f"{user_info.preferred_username} added {len(pilot_stamps)} pilots." - ) except PilotAlreadyExistsError as e: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e From 941df03900b6a8ba2eeb92670d67a6a6c8ffec4a Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Mon, 30 Jun 2025 15:12:23 +0200 Subject: [PATCH 16/33] fix: Changed search endpoint --- .../client/_generated/operations/_operations.py | 2 +- .../src/diracx/routers/pilots/management.py | 3 +-- diracx-routers/src/diracx/routers/pilots/query.py | 2 +- diracx-routers/tests/pilots/test_pilot_creation.py | 14 +++++++------- diracx-routers/tests/pilots/test_query.py | 4 +--- .../client/_generated/operations/_operations.py | 2 +- 6 files changed, 12 insertions(+), 15 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index a7f5257e7..a442124a4 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -689,7 +689,7 @@ def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: accept = _headers.pop("Accept", "application/json") # Construct URL - _url = "/api/pilots/management/search" + _url = "/api/pilots/search" # Construct parameters if page is not None: diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 582bd7053..d18ba7719 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -3,7 +3,7 @@ from http import HTTPStatus from typing import Annotated -from fastapi import Body, Depends, HTTPException, Query, status +from fastapi import Body, HTTPException, Query, status from diracx.core.exceptions import ( PilotAlreadyAssociatedWithJobError, @@ -28,7 +28,6 @@ from ..dependencies import PilotAgentsDB from ..fastapi_classes import DiracxRouter -from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token from .access_policies import ( ActionType, CheckPilotManagementPolicyCallable, diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py index f494ce1f6..d6867fcbd 100644 --- a/diracx-routers/src/diracx/routers/pilots/query.py +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -104,7 +104,7 @@ } -@router.post("/management/search", responses=EXAMPLE_RESPONSES) +@router.post("/search", responses=EXAMPLE_RESPONSES) async def search( pilot_db: PilotAgentsDB, check_permissions: CheckPilotManagementPolicyCallable, diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index 44c87e2ea..0280f77ea 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -172,7 +172,7 @@ async def test_create_pilot_and_modify_it(normal_test_client): "distinct": True, } - r = normal_test_client.post("/api/pilots/management/search", json=body) + r = normal_test_client.post("/api/pilots/search", json=body) assert r.status_code == 200, r.json() pilot1 = r.json()[0] pilot2 = r.json()[1] @@ -279,7 +279,7 @@ async def test_associate_job_with_pilot_and_get_it(normal_test_client: TestClien ) r = normal_test_client.post( - "/api/pilots/management/search", + "/api/pilots/search", json={"parameters": [], "search": [condition], "sorts": []}, ) @@ -335,7 +335,7 @@ async def test_delete_pilots_by_age_and_stamp(normal_test_client): # -------------- Verify all 100 pilots exist -------------- search_body = {"parameters": [], "search": [], "sort": [], "distinct": True} - r = normal_test_client.post("/api/pilots/management/search", json=search_body) + r = normal_test_client.post("/api/pilots/search", json=search_body) assert r.status_code == 200, r.json() assert len(r.json()) == 100 @@ -347,7 +347,7 @@ async def test_delete_pilots_by_age_and_stamp(normal_test_client): ) assert r.status_code == 204 # Expect 75 remaining - r = normal_test_client.post("/api/pilots/management/search", json=search_body) + r = normal_test_client.post("/api/pilots/search", json=search_body) assert len(r.json()) == 75 # -------------- 2) Delete all old pilots (remaining 25 old) -------------- @@ -358,7 +358,7 @@ async def test_delete_pilots_by_age_and_stamp(normal_test_client): assert r.status_code == 204 # Expect 50 remaining - r = normal_test_client.post("/api/pilots/management/search", json=search_body) + r = normal_test_client.post("/api/pilots/search", json=search_body) assert len(r.json()) == 50 # -------------- 3) Delete one recent pilot by stamp -------------- @@ -366,7 +366,7 @@ async def test_delete_pilots_by_age_and_stamp(normal_test_client): r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": [one_stamp]}) assert r.status_code == 204 # Expect 49 remaining - r = normal_test_client.post("/api/pilots/management/search", json=search_body) + r = normal_test_client.post("/api/pilots/search", json=search_body) assert len(r.json()) == 49 # -------------- 4) Delete all remaining pilots -------------- @@ -375,7 +375,7 @@ async def test_delete_pilots_by_age_and_stamp(normal_test_client): r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": remaining}) assert r.status_code == 204 # Expect none remaining - r = normal_test_client.post("/api/pilots/management/search", json=search_body) + r = normal_test_client.post("/api/pilots/search", json=search_body) assert r.status_code == 200 assert len(r.json()) == 0 diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py index a254ed481..0672e25af 100644 --- a/diracx-routers/tests/pilots/test_query.py +++ b/diracx-routers/tests/pilots/test_query.py @@ -96,9 +96,7 @@ async def _search( params = {"per_page": per_page, "page": page} - r = populated_pilot_client.post( - "/api/pilots/management/search", json=body, params=params - ) + r = populated_pilot_client.post("/api/pilots/search", json=body, params=params) if r.status_code == 400: # If we have a status_code 400, that means that the query failed diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 1b099ca39..7db833022 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -738,7 +738,7 @@ def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: accept = _headers.pop("Accept", "application/json") # Construct URL - _url = "/api/pilots/management/search" + _url = "/api/pilots/search" # Construct parameters if page is not None: From fd5233218154a348dbbdbabdc24e14400ca29f70 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Tue, 1 Jul 2025 10:51:12 +0200 Subject: [PATCH 17/33] fix: Change from status reason to status (cf dirac-admin-add-pilot) --- .../diracx/client/_generated/models/_models.py | 16 +++++++++------- diracx-db/src/diracx/db/sql/pilots/db.py | 6 +++--- .../src/diracx/logic/pilots/management.py | 4 ++-- .../src/diracx/routers/pilots/management.py | 9 +++++---- .../gubbins/client/_generated/models/_models.py | 16 +++++++++------- 5 files changed, 28 insertions(+), 23 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 73677b300..956ee2b8b 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -144,8 +144,9 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :vartype destination_site: str :ivar pilot_references: Association of a pilot reference with a pilot stamp. :vartype pilot_references: dict[str, str] - :ivar status_reason: Status reason of the pilots. - :vartype status_reason: str + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus """ _validation = { @@ -160,7 +161,7 @@ class BodyPilotsAddPilotStamps(_serialization.Model): "grid_site": {"key": "grid_site", "type": "str"}, "destination_site": {"key": "destination_site", "type": "str"}, "pilot_references": {"key": "pilot_references", "type": "{str}"}, - "status_reason": {"key": "status_reason", "type": "str"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, } def __init__( @@ -172,7 +173,7 @@ def __init__( grid_site: str = "Unknown", destination_site: str = "NotAssigned", pilot_references: Optional[Dict[str, str]] = None, - status_reason: str = "Unknown", + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, **kwargs: Any ) -> None: """ @@ -188,8 +189,9 @@ def __init__( :paramtype destination_site: str :keyword pilot_references: Association of a pilot reference with a pilot stamp. :paramtype pilot_references: dict[str, str] - :keyword status_reason: Status reason of the pilots. - :paramtype status_reason: str + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps @@ -198,7 +200,7 @@ def __init__( self.grid_site = grid_site self.destination_site = destination_site self.pilot_references = pilot_references - self.status_reason = status_reason + self.pilot_status = pilot_status class BodyPilotsUpdatePilotFields(_serialization.Model): diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index efc9279a3..106adbf30 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -13,6 +13,7 @@ ) from diracx.core.models import ( PilotFieldsMapping, + PilotStatus, SearchSpec, SortSpec, ) @@ -43,7 +44,7 @@ async def add_pilots( grid_site: str = "Unknown", destination_site: str = "NotAssigned", pilot_references: dict[str, str] | None = None, - status_reason: str = "Unknown", + status: str = PilotStatus.SUBMITTED, ): """Bulk add pilots in the DB. @@ -64,8 +65,7 @@ async def add_pilots( "DestinationSite": destination_site, "SubmissionTime": now, "LastUpdateTime": now, - "Status": "Submitted", - "StatusReason": status_reason, + "Status": status, "PilotStamp": stamp, } for stamp in pilot_stamps diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index 359947baf..bb8d43fd2 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -21,7 +21,7 @@ async def register_new_pilots( grid_type: str, grid_site: str, destination_site: str, - status_reason: str, + status: str, pilot_job_references: dict[str, str] | None, ): # [IMPORTANT] Check unicity of pilot references @@ -43,7 +43,7 @@ async def register_new_pilots( grid_site=grid_site, destination_site=destination_site, pilot_references=pilot_job_references, - status_reason=status_reason, + status=status, ) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index d18ba7719..798ae4936 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -12,6 +12,7 @@ ) from diracx.core.models import ( PilotFieldsMapping, + PilotStatus, ) from diracx.logic.pilots.management import ( add_jobs_to_pilot as add_jobs_to_pilot_bl, @@ -57,9 +58,9 @@ async def add_pilot_stamps( dict[str, str] | None, Body(description="Association of a pilot reference with a pilot stamp."), ] = None, - status_reason: Annotated[ - str, Body(description="Status reason of the pilots.") - ] = "Unknown", + pilot_status: Annotated[ + PilotStatus, Body(description="Status of the pilots.") + ] = PilotStatus.SUBMITTED, ): """Endpoint where a you can create pilots with their references. @@ -77,7 +78,7 @@ async def add_pilot_stamps( grid_site=grid_site, destination_site=destination_site, pilot_job_references=pilot_references, - status_reason=status_reason, + status=pilot_status, ) except PilotAlreadyExistsError as e: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index cd97b272f..8c5e9a05c 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -144,8 +144,9 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :vartype destination_site: str :ivar pilot_references: Association of a pilot reference with a pilot stamp. :vartype pilot_references: dict[str, str] - :ivar status_reason: Status reason of the pilots. - :vartype status_reason: str + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus """ _validation = { @@ -160,7 +161,7 @@ class BodyPilotsAddPilotStamps(_serialization.Model): "grid_site": {"key": "grid_site", "type": "str"}, "destination_site": {"key": "destination_site", "type": "str"}, "pilot_references": {"key": "pilot_references", "type": "{str}"}, - "status_reason": {"key": "status_reason", "type": "str"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, } def __init__( @@ -172,7 +173,7 @@ def __init__( grid_site: str = "Unknown", destination_site: str = "NotAssigned", pilot_references: Optional[Dict[str, str]] = None, - status_reason: str = "Unknown", + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, **kwargs: Any ) -> None: """ @@ -188,8 +189,9 @@ def __init__( :paramtype destination_site: str :keyword pilot_references: Association of a pilot reference with a pilot stamp. :paramtype pilot_references: dict[str, str] - :keyword status_reason: Status reason of the pilots. - :paramtype status_reason: str + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps @@ -198,7 +200,7 @@ def __init__( self.grid_site = grid_site self.destination_site = destination_site self.pilot_references = pilot_references - self.status_reason = status_reason + self.pilot_status = pilot_status class BodyPilotsUpdatePilotFields(_serialization.Model): From 0d5428544e0d93e32501fdbfcdf8bda0396a27c8 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Fri, 4 Jul 2025 12:06:00 +0200 Subject: [PATCH 18/33] fix: Removed vo from creation, and better policy --- .../_generated/aio/operations/_operations.py | 8 ++-- .../client/_generated/models/_models.py | 8 ---- .../_generated/operations/_operations.py | 8 ++-- .../src/diracx/logic/pilots/management.py | 3 ++ diracx-logic/src/diracx/logic/pilots/query.py | 7 ++- .../diracx/routers/pilots/access_policies.py | 45 ++++++++++++++--- .../src/diracx/routers/pilots/management.py | 48 ++++++++++++------- .../src/diracx/routers/pilots/query.py | 6 +-- .../tests/pilots/test_pilot_creation.py | 15 +++--- .../_generated/aio/operations/_operations.py | 8 ++-- .../client/_generated/models/_models.py | 8 ---- .../_generated/operations/_operations.py | 8 ++-- 12 files changed, 104 insertions(+), 68 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index f11440961..81f6aecab 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -2463,7 +2463,7 @@ async def get_pilot_jobs( ) -> List[int]: """Get Pilot Jobs. - Endpoint only for DIRAC services, to get jobs of a pilot. + Endpoint only for admins, to get jobs of a pilot. :keyword pilot_stamp: The stamp of the pilot. Default value is None. :paramtype pilot_stamp: str @@ -2518,7 +2518,7 @@ async def add_jobs_to_pilot( ) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Required. :type body: ~_generated.models.BodyPilotsAddJobsToPilot @@ -2536,7 +2536,7 @@ async def add_jobs_to_pilot( ) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Required. :type body: IO[bytes] @@ -2552,7 +2552,7 @@ async def add_jobs_to_pilot( async def add_jobs_to_pilot(self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 956ee2b8b..14ef7e9ec 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -134,8 +134,6 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. :vartype pilot_stamps: list[str] - :ivar vo: Virtual Organisation associated with the inserted pilots. Required. - :vartype vo: str :ivar grid_type: Grid type of the pilots. :vartype grid_type: str :ivar grid_site: Pilots grid site. @@ -151,12 +149,10 @@ class BodyPilotsAddPilotStamps(_serialization.Model): _validation = { "pilot_stamps": {"required": True}, - "vo": {"required": True}, } _attribute_map = { "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, - "vo": {"key": "vo", "type": "str"}, "grid_type": {"key": "grid_type", "type": "str"}, "grid_site": {"key": "grid_site", "type": "str"}, "destination_site": {"key": "destination_site", "type": "str"}, @@ -168,7 +164,6 @@ def __init__( self, *, pilot_stamps: List[str], - vo: str, grid_type: str = "Dirac", grid_site: str = "Unknown", destination_site: str = "NotAssigned", @@ -179,8 +174,6 @@ def __init__( """ :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. :paramtype pilot_stamps: list[str] - :keyword vo: Virtual Organisation associated with the inserted pilots. Required. - :paramtype vo: str :keyword grid_type: Grid type of the pilots. :paramtype grid_type: str :keyword grid_site: Pilots grid site. @@ -195,7 +188,6 @@ def __init__( """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps - self.vo = vo self.grid_type = grid_type self.grid_site = grid_site self.destination_site = destination_site diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index a442124a4..2fa761eaf 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -3093,7 +3093,7 @@ def get_pilot_jobs( ) -> List[int]: """Get Pilot Jobs. - Endpoint only for DIRAC services, to get jobs of a pilot. + Endpoint only for admins, to get jobs of a pilot. :keyword pilot_stamp: The stamp of the pilot. Default value is None. :paramtype pilot_stamp: str @@ -3148,7 +3148,7 @@ def add_jobs_to_pilot( ) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Required. :type body: ~_generated.models.BodyPilotsAddJobsToPilot @@ -3164,7 +3164,7 @@ def add_jobs_to_pilot( def add_jobs_to_pilot(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Required. :type body: IO[bytes] @@ -3182,7 +3182,7 @@ def add_jobs_to_pilot( # pylint: disable=inconsistent-return-statements ) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index bb8d43fd2..feaad0725 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -52,6 +52,7 @@ async def delete_pilots( pilot_stamps: list[str] | None = None, age_in_days: int | None = None, delete_only_aborted: bool = True, + vo_constraint: str | None = None, ): if pilot_stamps: pilot_ids = await get_pilot_ids_by_stamps( @@ -59,6 +60,7 @@ async def delete_pilots( ) else: assert age_in_days + assert vo_constraint cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) @@ -67,6 +69,7 @@ async def delete_pilots( cutoff_date=cutoff_date, only_aborted=delete_only_aborted, parameters=["PilotID"], + vo_constraint=vo_constraint, ) pilot_ids = [pilot["PilotID"] for pilot in pilots] diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index 8cd57428a..7c58a95e3 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -148,6 +148,7 @@ async def get_pilot_ids_by_job_id(pilot_db: PilotAgentsDB, job_id: int) -> list[ async def get_outdated_pilots( pilot_db: PilotAgentsDB, cutoff_date: datetime, + vo_constraint: str, only_aborted: bool = True, parameters: list[str] = [], ): @@ -156,7 +157,11 @@ async def get_outdated_pilots( parameter="SubmissionTime", operator=ScalarSearchOperator.LESS_THAN, value=cutoff_date, - ) + ), + # Add VO to avoid deleting other VO's pilots + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=vo_constraint + ), ] if only_aborted: diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index 2a170b342..eadafdb88 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -6,7 +6,9 @@ from fastapi import Depends, HTTPException, status -from diracx.core.properties import SERVICE_ADMINISTRATOR, TRUSTED_HOST +from diracx.core.properties import SERVICE_ADMINISTRATOR +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.logic.pilots.query import get_pilots_by_stamp from diracx.routers.access_policies import BaseAccessPolicy from diracx.routers.utils.users import AuthorizedUserInfo @@ -21,7 +23,7 @@ class ActionType(StrEnum): class PilotManagementAccessPolicy(BaseAccessPolicy): """Rules: * Every user can access data about his VO - * An administrator, as well as a DIRAC service can modify a pilot. + * An administrator can modify a pilot. """ @staticmethod @@ -31,18 +33,47 @@ async def policy( /, *, action: ActionType | None = None, + pilot_db: PilotAgentsDB | None = None, + pilot_stamps: list[str] | None = None, ): assert action, "action is a mandatory parameter" # Users can query + # NOTE: Add into queries a VO constraint if action == ActionType.READ_PILOT_FIELDS: return - # If we want to modify pilots, we allow only admins and DIRAC - if ( - TRUSTED_HOST in user_info.properties - or SERVICE_ADMINISTRATOR in user_info.properties - ): + # If we want to modify pilots, we allow only admins + # TODO: See if we add other types of admins + if SERVICE_ADMINISTRATOR in user_info.properties: + # If we don't provide pilot_db and pilot_stamps, we accept directly + # This is for example when we submit pilots, we use the user VO, so no need to verify + if not (pilot_db and pilot_stamps): + return + + # Else, check its VO + assert pilot_db, "PilotDB is needed to determine pilot VO." + assert pilot_stamps, "PilotStamps are needed to determine pilot VO." + + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["VO"], + allow_missing=True, + ) + + if len(pilots) != len(pilot_stamps): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one pilot does not exist.", + ) + + if not all(pilot["VO"] == user_info.vo for pilot in pilots): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have access to all pilots.", + ) + return raise HTTPException( diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 798ae4936..a5a07b183 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -3,7 +3,7 @@ from http import HTTPStatus from typing import Annotated -from fastapi import Body, HTTPException, Query, status +from fastapi import Body, Depends, HTTPException, Query, status from diracx.core.exceptions import ( PilotAlreadyAssociatedWithJobError, @@ -26,6 +26,7 @@ update_pilots_fields, ) from diracx.logic.pilots.query import get_pilot_ids_by_job_id +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token from ..dependencies import PilotAgentsDB from ..fastapi_classes import DiracxRouter @@ -44,11 +45,8 @@ async def add_pilot_stamps( list[str], Body(description="List of the pilot stamps we want to add to the db."), ], - vo: Annotated[ - str, - Body(description="Virtual Organisation associated with the inserted pilots."), - ], check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", grid_site: Annotated[str, Body(description="Pilots grid site.")] = "Unknown", destination_site: Annotated[ @@ -73,7 +71,7 @@ async def add_pilot_stamps( await register_new_pilots( pilot_db=pilot_db, pilot_stamps=pilot_stamps, - vo=vo, + vo=user_info.vo, grid_type=grid_type, grid_site=grid_site, destination_site=destination_site, @@ -88,6 +86,7 @@ async def add_pilot_stamps( async def delete_pilots( pilot_db: PilotAgentsDB, check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], pilot_stamps: Annotated[ list[str] | None, Query(description="Stamps of the pilots we want to delete.") ] = None, @@ -121,9 +120,18 @@ async def delete_pilots( Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. """ - await check_permissions( - action=ActionType.MANAGE_PILOTS, - ) + vo_constraint: str | None = None + + # If we delete by pilot_stamps, we check that we can access them + # Else, we add a constraint to the request, to avoid deleting pilots from another VO + if pilot_stamps: + await check_permissions( + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + ) + else: + vo_constraint = user_info.vo if not pilot_stamps and not age_in_days: raise HTTPException( @@ -136,6 +144,7 @@ async def delete_pilots( pilot_stamps=pilot_stamps, age_in_days=age_in_days, delete_only_aborted=delete_only_aborted, + vo_constraint=vo_constraint, ) @@ -169,7 +178,7 @@ async def update_pilot_fields( Body( description="(pilot_stamp, pilot_fields) mapping to change.", embed=True, - openapi_examples=EXAMPLE_UPDATE_FIELDS, + openapi_examples=EXAMPLE_UPDATE_FIELDS, # type: ignore ), ], pilot_db: PilotAgentsDB, @@ -180,7 +189,10 @@ async def update_pilot_fields( Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. """ # Ensures stamps validity - await check_permissions(action=ActionType.MANAGE_PILOTS) + pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] + await check_permissions( + action=ActionType.MANAGE_PILOTS, pilot_db=pilot_db, pilot_stamps=pilot_stamps + ) await update_pilots_fields( pilot_db=pilot_db, @@ -197,15 +209,16 @@ async def get_pilot_jobs( ] = None, job_id: Annotated[int | None, Query(description="The ID of the job.")] = None, ) -> list[int]: - """Endpoint only for DIRAC services, to get jobs of a pilot.""" - await check_permissions(action=ActionType.READ_PILOT_FIELDS) - + """Endpoint only for admins, to get jobs of a pilot.""" if pilot_stamp: + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + return await get_pilot_jobs_ids_by_stamp( pilot_db=pilot_db, pilot_stamp=pilot_stamp, ) elif job_id: + # FIXME: Add some policy, verify that it's the user's job? return await get_pilot_ids_by_job_id(pilot_db=pilot_db, job_id=job_id) raise HTTPException( @@ -223,8 +236,11 @@ async def add_jobs_to_pilot( ], check_permissions: CheckPilotManagementPolicyCallable, ): - """Endpoint only for DIRAC services, to associate a pilot with a job.""" - await check_permissions(action=ActionType.MANAGE_PILOTS) + """Endpoint only for admins, to associate a pilot with a job.""" + await check_permissions( + action=ActionType.MANAGE_PILOTS, pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + # FIXME: Also verify job_ids try: await add_jobs_to_pilot_bl( diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py index d6867fcbd..9c7236af3 100644 --- a/diracx-routers/src/diracx/routers/pilots/query.py +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -113,18 +113,16 @@ async def search( page: int = 1, per_page: int = 100, body: Annotated[ - SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) # type: ignore ] = None, ) -> list[dict[str, Any]]: """Retrieve information about pilots.""" # Inspired by /api/jobs/query await check_permissions(action=ActionType.READ_PILOT_FIELDS) - user_vo = user_info.vo - total, pilots = await search_bl( pilot_db=pilot_db, - user_vo=user_vo, + user_vo=user_info.vo, page=page, per_page=per_page, body=body, diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index 0280f77ea..ee78ff65f 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -43,7 +43,7 @@ async def test_create_pilots(normal_test_client): pilot_stamps = [f"stamps_{i}" for i in range(N)] # -------------- Bulk insert -------------- - body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + body = {"pilot_stamps": pilot_stamps} r = normal_test_client.post( "/api/pilots/", @@ -55,7 +55,6 @@ async def test_create_pilots(normal_test_client): # -------------- Register a pilot that already exists, and one that does not -------------- body = { - "vo": MAIN_VO, "pilot_stamps": [pilot_stamps[0], pilot_stamps[0] + "_new_one"], } @@ -76,7 +75,7 @@ async def test_create_pilots(normal_test_client): # -------------- Register a pilot that does not exists **but** was called before in an error -------------- # To prove that, if I tried to register a pilot that does not exist with one that already exists, # i can normally add the one that did not exist before (it should not have added it before) - body = {"vo": MAIN_VO, "pilot_stamps": [pilot_stamps[0] + "_new_one"]} + body = {"pilot_stamps": [pilot_stamps[0] + "_new_one"]} r = normal_test_client.post( "/api/pilots/", @@ -93,7 +92,7 @@ async def test_create_pilot_and_delete_it(normal_test_client): pilot_stamp = "stamps_1" # -------------- Insert -------------- - body = {"vo": MAIN_VO, "pilot_stamps": [pilot_stamp]} + body = {"pilot_stamps": [pilot_stamp]} # Create a pilot r = normal_test_client.post( @@ -137,7 +136,7 @@ async def test_create_pilot_and_modify_it(normal_test_client): pilot_stamps = ["stamps_1", "stamp_2"] # -------------- Insert -------------- - body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + body = {"pilot_stamps": pilot_stamps} # Create pilots r = normal_test_client.post( @@ -192,7 +191,7 @@ async def test_associate_job_with_pilot_and_get_it(normal_test_client: TestClien pilot_stamps = ["stamps_1", "stamp_2"] # -------------- Insert -------------- - body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + body = {"pilot_stamps": pilot_stamps} # Create pilots r = normal_test_client.post( @@ -294,7 +293,7 @@ async def test_delete_pilots_by_age_and_stamp(normal_test_client): pilot_stamps = [f"stamp_{i}" for i in range(100)] # -------------- Insert all pilots -------------- - body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + body = {"pilot_stamps": pilot_stamps} r = normal_test_client.post("/api/pilots/", json=body) assert r.status_code == 200, r.json() @@ -390,7 +389,7 @@ async def test_delete_pilots_by_age_and_stamp(normal_test_client): async def test_associate_two_pilots_share_jobs_and_delete_first(normal_test_client): # 1) Create two pilots pilot_stamps = ["stamp_1", "stamp_2"] - body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + body = {"pilot_stamps": pilot_stamps} r = normal_test_client.post("/api/pilots/", json=body) assert r.status_code == 200, r.json() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index c22bd73db..7c17cde07 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -2630,7 +2630,7 @@ async def get_pilot_jobs( ) -> List[int]: """Get Pilot Jobs. - Endpoint only for DIRAC services, to get jobs of a pilot. + Endpoint only for admins, to get jobs of a pilot. :keyword pilot_stamp: The stamp of the pilot. Default value is None. :paramtype pilot_stamp: str @@ -2685,7 +2685,7 @@ async def add_jobs_to_pilot( ) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Required. :type body: ~_generated.models.BodyPilotsAddJobsToPilot @@ -2703,7 +2703,7 @@ async def add_jobs_to_pilot( ) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Required. :type body: IO[bytes] @@ -2719,7 +2719,7 @@ async def add_jobs_to_pilot( async def add_jobs_to_pilot(self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 8c5e9a05c..4cdd81cfc 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -134,8 +134,6 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. :vartype pilot_stamps: list[str] - :ivar vo: Virtual Organisation associated with the inserted pilots. Required. - :vartype vo: str :ivar grid_type: Grid type of the pilots. :vartype grid_type: str :ivar grid_site: Pilots grid site. @@ -151,12 +149,10 @@ class BodyPilotsAddPilotStamps(_serialization.Model): _validation = { "pilot_stamps": {"required": True}, - "vo": {"required": True}, } _attribute_map = { "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, - "vo": {"key": "vo", "type": "str"}, "grid_type": {"key": "grid_type", "type": "str"}, "grid_site": {"key": "grid_site", "type": "str"}, "destination_site": {"key": "destination_site", "type": "str"}, @@ -168,7 +164,6 @@ def __init__( self, *, pilot_stamps: List[str], - vo: str, grid_type: str = "Dirac", grid_site: str = "Unknown", destination_site: str = "NotAssigned", @@ -179,8 +174,6 @@ def __init__( """ :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. :paramtype pilot_stamps: list[str] - :keyword vo: Virtual Organisation associated with the inserted pilots. Required. - :paramtype vo: str :keyword grid_type: Grid type of the pilots. :paramtype grid_type: str :keyword grid_site: Pilots grid site. @@ -195,7 +188,6 @@ def __init__( """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps - self.vo = vo self.grid_type = grid_type self.grid_site = grid_site self.destination_site = destination_site diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 7db833022..c2af06038 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -3306,7 +3306,7 @@ def get_pilot_jobs( ) -> List[int]: """Get Pilot Jobs. - Endpoint only for DIRAC services, to get jobs of a pilot. + Endpoint only for admins, to get jobs of a pilot. :keyword pilot_stamp: The stamp of the pilot. Default value is None. :paramtype pilot_stamp: str @@ -3361,7 +3361,7 @@ def add_jobs_to_pilot( ) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Required. :type body: ~_generated.models.BodyPilotsAddJobsToPilot @@ -3377,7 +3377,7 @@ def add_jobs_to_pilot( def add_jobs_to_pilot(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Required. :type body: IO[bytes] @@ -3395,7 +3395,7 @@ def add_jobs_to_pilot( # pylint: disable=inconsistent-return-statements ) -> None: """Add Jobs To Pilot. - Endpoint only for DIRAC services, to associate a pilot with a job. + Endpoint only for admins, to associate a pilot with a job. :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] From 85d3c4ade046a68fae37063b9027ca2d411c461a Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Fri, 4 Jul 2025 12:35:31 +0200 Subject: [PATCH 19/33] feat: Better policy, to support also jobs --- .../diracx/routers/pilots/access_policies.py | 56 ++++++++++++------- .../src/diracx/routers/pilots/management.py | 24 ++++++-- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index eadafdb88..f6227db7f 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -7,6 +7,7 @@ from fastapi import Depends, HTTPException, status from diracx.core.properties import SERVICE_ADMINISTRATOR +from diracx.db.sql.job.db import JobDB from diracx.db.sql.pilots.db import PilotAgentsDB from diracx.logic.pilots.query import get_pilots_by_stamp from diracx.routers.access_policies import BaseAccessPolicy @@ -35,26 +36,50 @@ async def policy( action: ActionType | None = None, pilot_db: PilotAgentsDB | None = None, pilot_stamps: list[str] | None = None, + job_db: JobDB | None = None, + job_ids: list[int] | None = None, ): assert action, "action is a mandatory parameter" # Users can query # NOTE: Add into queries a VO constraint - if action == ActionType.READ_PILOT_FIELDS: - return + # To manage pilots, user have to be an admin + if ( + action == ActionType.MANAGE_PILOTS + and SERVICE_ADMINISTRATOR not in user_info.properties + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the permission to manage pilots.", + ) - # If we want to modify pilots, we allow only admins - # TODO: See if we add other types of admins - if SERVICE_ADMINISTRATOR in user_info.properties: - # If we don't provide pilot_db and pilot_stamps, we accept directly - # This is for example when we submit pilots, we use the user VO, so no need to verify - if not (pilot_db and pilot_stamps): - return + # + # Additional checks if job_ids or pilot_stamps are provided + # - # Else, check its VO - assert pilot_db, "PilotDB is needed to determine pilot VO." - assert pilot_stamps, "PilotStamps are needed to determine pilot VO." + # First, if job_ids are provided, we check who is the owner + if job_db and job_ids: + job_owners = await job_db.summary( + ["Owner", "VO"], + [{"parameter": "JobID", "operator": "in", "values": job_ids}], + ) + + expected_owner = { + "Owner": user_info.preferred_username, + "VO": user_info.vo, + "count": len(set(job_ids)), + } + # All the jobs belong to the user doing the query + # and all of them are present + if not job_owners == [expected_owner]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the rights to modify a pilot.", + ) + # This is for example when we submit pilots, we use the user VO, so no need to verify + if pilot_db and pilot_stamps: + # Else, check its VO pilots = await get_pilots_by_stamp( pilot_db=pilot_db, pilot_stamps=pilot_stamps, @@ -74,13 +99,6 @@ async def policy( detail="You don't have access to all pilots.", ) - return - - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You don't have the rights to modify a pilot.", - ) - CheckPilotManagementPolicyCallable = Annotated[ Callable, Depends(PilotManagementAccessPolicy.check) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index a5a07b183..8f783352e 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -28,7 +28,7 @@ from diracx.logic.pilots.query import get_pilot_ids_by_job_id from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token -from ..dependencies import PilotAgentsDB +from ..dependencies import JobDB, PilotAgentsDB from ..fastapi_classes import DiracxRouter from .access_policies import ( ActionType, @@ -203,6 +203,7 @@ async def update_pilot_fields( @router.get("/jobs") async def get_pilot_jobs( pilot_db: PilotAgentsDB, + job_db: JobDB, check_permissions: CheckPilotManagementPolicyCallable, pilot_stamp: Annotated[ str | None, Query(description="The stamp of the pilot.") @@ -211,14 +212,23 @@ async def get_pilot_jobs( ) -> list[int]: """Endpoint only for admins, to get jobs of a pilot.""" if pilot_stamp: - await check_permissions(action=ActionType.READ_PILOT_FIELDS) + # Check VO + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, + pilot_db=pilot_db, + pilot_stamps=[pilot_stamp], + ) return await get_pilot_jobs_ids_by_stamp( pilot_db=pilot_db, pilot_stamp=pilot_stamp, ) elif job_id: - # FIXME: Add some policy, verify that it's the user's job? + # Check job owner + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, job_db=job_db, job_ids=[job_id] + ) + return await get_pilot_ids_by_job_id(pilot_db=pilot_db, job_id=job_id) raise HTTPException( @@ -230,6 +240,7 @@ async def get_pilot_jobs( @router.patch("/jobs", status_code=HTTPStatus.NO_CONTENT) async def add_jobs_to_pilot( pilot_db: PilotAgentsDB, + job_db: JobDB, pilot_stamp: Annotated[str, Body(description="The stamp of the pilot.")], job_ids: Annotated[ list[int], Body(description="The jobs we want to add to the pilot.") @@ -238,9 +249,12 @@ async def add_jobs_to_pilot( ): """Endpoint only for admins, to associate a pilot with a job.""" await check_permissions( - action=ActionType.MANAGE_PILOTS, pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=[pilot_stamp], + job_db=job_db, + job_ids=job_ids, ) - # FIXME: Also verify job_ids try: await add_jobs_to_pilot_bl( From fb7f4d9165c7f40698d32e63510b105dcd5309f4 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Fri, 4 Jul 2025 13:27:31 +0200 Subject: [PATCH 20/33] feat: Add pilot summary, and refactor jobs one --- .../_generated/aio/operations/_operations.py | 108 +++++++++++++++- .../client/_generated/models/__init__.py | 8 +- .../client/_generated/models/_models.py | 76 +++++------ .../_generated/operations/_operations.py | 122 +++++++++++++++++- diracx-core/src/diracx/core/models.py | 2 +- diracx-db/src/diracx/db/sql/dummy/db.py | 23 ++-- diracx-db/src/diracx/db/sql/job/db.py | 21 +-- diracx-db/src/diracx/db/sql/pilots/db.py | 6 + diracx-db/src/diracx/db/sql/utils/base.py | 24 ++++ diracx-db/tests/test_dummy_db.py | 35 ++--- diracx-logic/src/diracx/logic/jobs/query.py | 6 +- diracx-logic/src/diracx/logic/jobs/status.py | 10 +- diracx-logic/src/diracx/logic/pilots/query.py | 13 ++ .../diracx/routers/jobs/access_policies.py | 9 +- .../src/diracx/routers/jobs/query.py | 4 +- .../diracx/routers/pilots/access_policies.py | 11 +- .../src/diracx/routers/pilots/query.py | 21 ++- .../tests/jobs/test_wms_access_policy.py | 10 +- .../tests/pilots/test_pilot_creation.py | 1 + .../_generated/aio/operations/_operations.py | 108 +++++++++++++++- .../client/_generated/models/__init__.py | 8 +- .../client/_generated/models/_models.py | 76 +++++------ .../_generated/operations/_operations.py | 122 +++++++++++++++++- .../src/gubbins/db/sql/lollygag/db.py | 2 +- .../gubbins-db/tests/test_lollygag_db.py | 8 +- 25 files changed, 656 insertions(+), 178 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 81f6aecab..b02f5f037 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -57,6 +57,7 @@ build_pilots_delete_pilots_request, build_pilots_get_pilot_jobs_request, build_pilots_search_request, + build_pilots_summary_request, build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, @@ -1974,14 +1975,14 @@ async def search( @overload async def summary( - self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any ) -> Any: """Summary. Show information suitable for plotting. :param body: Required. - :type body: ~_generated.models.JobSummaryParams + :type body: ~_generated.models.SummaryParams :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2007,13 +2008,13 @@ async def summary(self, body: IO[bytes], *, content_type: str = "application/jso """ @distributed_trace_async - async def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: Any) -> Any: + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. - :param body: Is either a JobSummaryParams type or a IO[bytes] type. Required. - :type body: ~_generated.models.JobSummaryParams or IO[bytes] + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -2038,7 +2039,7 @@ async def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwar if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "JobSummaryParams") + _json = self._serialize.body(body, "SummaryParams") _request = build_jobs_summary_request( content_type=content_type, @@ -2741,3 +2742,98 @@ async def search( return cls(pipeline_response, deserialized, response_headers) # type: ignore return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 5a6c6e047..d9f48e28a 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -24,8 +24,6 @@ InsertedJob, JobCommand, JobStatusUpdate, - JobSummaryParams, - JobSummaryParamsSearchItem, Metadata, OpenIDConfiguration, PilotFieldsMapping, @@ -39,6 +37,8 @@ SetJobStatusReturn, SetJobStatusReturnSuccess, SortSpec, + SummaryParams, + SummaryParamsSearchItem, SupportInfo, TokenResponse, UserInfoResponse, @@ -76,8 +76,6 @@ "InsertedJob", "JobCommand", "JobStatusUpdate", - "JobSummaryParams", - "JobSummaryParamsSearchItem", "Metadata", "OpenIDConfiguration", "PilotFieldsMapping", @@ -91,6 +89,8 @@ "SetJobStatusReturn", "SetJobStatusReturnSuccess", "SortSpec", + "SummaryParams", + "SummaryParamsSearchItem", "SupportInfo", "TokenResponse", "UserInfoResponse", diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 14ef7e9ec..a8457473c 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -536,44 +536,6 @@ def __init__( self.source = source -class JobSummaryParams(_serialization.Model): - """JobSummaryParams. - - All required parameters must be populated in order to send to server. - - :ivar grouping: Grouping. Required. - :vartype grouping: list[str] - :ivar search: Search. - :vartype search: list[~_generated.models.JobSummaryParamsSearchItem] - """ - - _validation = { - "grouping": {"required": True}, - } - - _attribute_map = { - "grouping": {"key": "grouping", "type": "[str]"}, - "search": {"key": "search", "type": "[JobSummaryParamsSearchItem]"}, - } - - def __init__( - self, *, grouping: List[str], search: List["_models.JobSummaryParamsSearchItem"] = [], **kwargs: Any - ) -> None: - """ - :keyword grouping: Grouping. Required. - :paramtype grouping: list[str] - :keyword search: Search. - :paramtype search: list[~_generated.models.JobSummaryParamsSearchItem] - """ - super().__init__(**kwargs) - self.grouping = grouping - self.search = search - - -class JobSummaryParamsSearchItem(_serialization.Model): - """JobSummaryParamsSearchItem.""" - - class Metadata(_serialization.Model): """Metadata. @@ -1203,6 +1165,44 @@ def __init__(self, *, parameter: str, direction: Union[str, "_models.SortDirecti self.direction = direction +class SummaryParams(_serialization.Model): + """SummaryParams. + + All required parameters must be populated in order to send to server. + + :ivar grouping: Grouping. Required. + :vartype grouping: list[str] + :ivar search: Search. + :vartype search: list[~_generated.models.SummaryParamsSearchItem] + """ + + _validation = { + "grouping": {"required": True}, + } + + _attribute_map = { + "grouping": {"key": "grouping", "type": "[str]"}, + "search": {"key": "search", "type": "[SummaryParamsSearchItem]"}, + } + + def __init__( + self, *, grouping: List[str], search: List["_models.SummaryParamsSearchItem"] = [], **kwargs: Any + ) -> None: + """ + :keyword grouping: Grouping. Required. + :paramtype grouping: list[str] + :keyword search: Search. + :paramtype search: list[~_generated.models.SummaryParamsSearchItem] + """ + super().__init__(**kwargs) + self.grouping = grouping + self.search = search + + +class SummaryParamsSearchItem(_serialization.Model): + """SummaryParamsSearchItem.""" + + class SupportInfo(_serialization.Model): """SupportInfo. diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 2fa761eaf..f317d9353 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -705,6 +705,23 @@ def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2607,13 +2624,13 @@ def search( return deserialized # type: ignore @overload - def summary(self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. :param body: Required. - :type body: ~_generated.models.JobSummaryParams + :type body: ~_generated.models.SummaryParams :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2639,13 +2656,13 @@ def summary(self, body: IO[bytes], *, content_type: str = "application/json", ** """ @distributed_trace - def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: Any) -> Any: + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. - :param body: Is either a JobSummaryParams type or a IO[bytes] type. Required. - :type body: ~_generated.models.JobSummaryParams or IO[bytes] + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -2670,7 +2687,7 @@ def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: An if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "JobSummaryParams") + _json = self._serialize.body(body, "SummaryParams") _request = build_jobs_summary_request( content_type=content_type, @@ -3371,3 +3388,96 @@ def search( return cls(pipeline_response, deserialized, response_headers) # type: ignore return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index fb14e2738..773ebac11 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -59,7 +59,7 @@ class InsertedJob(TypedDict): TimeStamp: datetime -class JobSummaryParams(BaseModel): +class SummaryParams(BaseModel): grouping: list[str] search: list[SearchSpec] = [] # TODO: Add more validation diff --git a/diracx-db/src/diracx/db/sql/dummy/db.py b/diracx-db/src/diracx/db/sql/dummy/db.py index 0c25df43c..b68bbe64d 100644 --- a/diracx-db/src/diracx/db/sql/dummy/db.py +++ b/diracx-db/src/diracx/db/sql/dummy/db.py @@ -1,9 +1,10 @@ from __future__ import annotations -from sqlalchemy import func, insert, select +from sqlalchemy import insert from uuid_utils import UUID -from diracx.db.sql.utils import BaseSQLDB, apply_search_filters +from diracx.core.models import SearchSpec +from diracx.db.sql.utils import BaseSQLDB from .schema import Base as DummyDBBase from .schema import Cars, Owners @@ -21,19 +22,11 @@ class DummyDB(BaseSQLDB): # This needs to be here for the BaseSQLDB to create the engine metadata = DummyDBBase.metadata - async def summary(self, group_by, search) -> list[dict[str, str | int]]: - columns = [Cars.__table__.columns[x] for x in group_by] - - stmt = select(*columns, func.count(Cars.license_plate).label("count")) - stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search) - stmt = stmt.group_by(*columns) - - # Execute the query - return [ - dict(row._mapping) - async for row in (await self.conn.stream(stmt)) - if row.count > 0 # type: ignore - ] + async def dummy_summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self.summary(Cars, group_by=group_by, search=search) async def insert_owner(self, name: str) -> int: stmt = insert(Owners).values(name=name) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 682475b9b..8e563fb0e 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Iterable -from sqlalchemy import bindparam, case, delete, func, insert, select, update +from sqlalchemy import bindparam, case, delete, insert, select, update if TYPE_CHECKING: from sqlalchemy.sql.elements import BindParameter @@ -13,7 +13,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobCommand, SearchSpec, SortSpec -from ..utils import BaseSQLDB, _get_columns, apply_search_filters, utcnow +from ..utils import BaseSQLDB, _get_columns, utcnow from .schema import ( HeartBeatLoggingInfo, InputData, @@ -42,20 +42,11 @@ class JobDB(BaseSQLDB): # to find a way to make it dynamic jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"] - async def summary(self, group_by, search) -> list[dict[str, str | int]]: + async def job_summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: """Get a summary of the jobs.""" - columns = _get_columns(Jobs.__table__, group_by) - - stmt = select(*columns, func.count(Jobs.job_id).label("count")) - stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search) - stmt = stmt.group_by(*columns) - - # Execute the query - return [ - dict(row._mapping) - async for row in (await self.conn.stream(stmt)) - if row.count > 0 # type: ignore - ] + return await self.summary(Jobs, group_by=group_by, search=search) async def search_jobs( self, diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index 106adbf30..86611553c 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -240,3 +240,9 @@ async def search_pilot_to_job_mapping( per_page=per_page, page=page, ) + + async def pilot_summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self.summary(PilotAgents, group_by=group_by, search=search) diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index 143c0675c..cdf753208 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -268,6 +268,30 @@ async def search( dict(row._mapping) async for row in (await self.conn.stream(stmt)) ] + async def summary( + self, model: Any, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of a table.""" + columns = _get_columns(model.__table__, group_by) + + pk_columns = list(model.__table__.primary_key.columns) + if not pk_columns: + raise ValueError( + "Model has no primary key and no count_column was provided." + ) + count_col = pk_columns[0] + + stmt = select(*columns, func.count(count_col).label("count")) + stmt = apply_search_filters(model.__table__.columns.__getitem__, stmt, search) + stmt = stmt.group_by(*columns) + + # Execute the query + return [ + dict(row._mapping) + async for row in (await self.conn.stream(stmt)) + if row.count > 0 # type: ignore + ] + def find_time_resolution(value): if isinstance(value, datetime): diff --git a/diracx-db/tests/test_dummy_db.py b/diracx-db/tests/test_dummy_db.py index e0106d833..9c10d9be2 100644 --- a/diracx-db/tests/test_dummy_db.py +++ b/diracx-db/tests/test_dummy_db.py @@ -27,7 +27,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): # So it is important to write test this way async with dummy_db as dummy_db: # First we check that the DB is empty - result = await dummy_db.summary(["Model"], []) + result = await dummy_db.dummy_summary(["Model"], []) assert not result # Now we add some data in the DB @@ -44,13 +44,13 @@ async def test_insert_and_summary(dummy_db: DummyDB): # Check that there are now 10 cars assigned to a single driver async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 # Test the selection async with dummy_db as dummy_db: - result = await dummy_db.summary( + result = await dummy_db.dummy_summary( ["OwnerID"], [{"parameter": "Model", "operator": "eq", "value": "model_1"}] ) @@ -58,7 +58,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): async with dummy_db as dummy_db: with pytest.raises(InvalidQueryError): - result = await dummy_db.summary( + result = await dummy_db.dummy_summary( ["OwnerID"], [ { @@ -93,7 +93,7 @@ async def test_successful_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Add data @@ -104,7 +104,7 @@ async def test_successful_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 # The connection is closed when the context manager is exited @@ -114,7 +114,7 @@ async def test_successful_transaction(dummy_db): # Start a new transaction # The previous data should still be there because the transaction was committed (successful) async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 @@ -129,12 +129,12 @@ async def test_failed_transaction(dummy_db): # The connection is created when the context manager is entered # This is our transaction - with pytest.raises(KeyError): + with pytest.raises(InvalidQueryError): async with dummy_db as dummy_db: assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Add data @@ -149,7 +149,8 @@ async def test_failed_transaction(dummy_db): assert result # This will raise an exception and the transaction will be rolled back - result = await dummy_db.summary(["unexistingfieldraisinganerror"], []) + + result = await dummy_db.dummy_summary(["unexistingfieldraisinganerror"], []) assert result[0]["count"] == 10 # The connection is closed when the context manager is exited @@ -159,7 +160,7 @@ async def test_failed_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result @@ -203,7 +204,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Add data @@ -217,7 +218,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 # This will raise an exception but the transaction will be rolled back @@ -231,7 +232,7 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Start a new transaction, this time we commit it manually @@ -240,7 +241,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Add data @@ -254,7 +255,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 # Manually commit the transaction, and then raise an exception @@ -271,5 +272,5 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should be there because the transaction was committed before the exception async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 diff --git a/diracx-logic/src/diracx/logic/jobs/query.py b/diracx-logic/src/diracx/logic/jobs/query.py index 0ec9738cf..23fb4557e 100644 --- a/diracx-logic/src/diracx/logic/jobs/query.py +++ b/diracx-logic/src/diracx/logic/jobs/query.py @@ -5,9 +5,9 @@ from diracx.core.config.schema import Config from diracx.core.models import ( - JobSummaryParams, ScalarSearchOperator, SearchParams, + SummaryParams, ) from diracx.db.os.job_parameters import JobParametersDB from diracx.db.sql.job.db import JobDB @@ -85,7 +85,7 @@ async def summary( config: Config, job_db: JobDB, preferred_username: str, - body: JobSummaryParams, + body: SummaryParams, ): """Show information suitable for plotting.""" if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: @@ -100,4 +100,4 @@ async def summary( "value": preferred_username, } ) - return await job_db.summary(body.grouping, body.search) + return await job_db.job_summary(body.grouping, body.search) diff --git a/diracx-logic/src/diracx/logic/jobs/status.py b/diracx-logic/src/diracx/logic/jobs/status.py index 90d62f819..2f1138cac 100644 --- a/diracx-logic/src/diracx/logic/jobs/status.py +++ b/diracx-logic/src/diracx/logic/jobs/status.py @@ -623,9 +623,15 @@ async def _insert_parameters( if not updates: return # Get the VOs for the job IDs (required for the index template) - job_vos = await job_db.summary( + job_vos = await job_db.job_summary( ["JobID", "VO"], - [{"parameter": "JobID", "operator": "in", "values": list(updates)}], + [ + VectorSearchSpec( + parameter="JobID", + operator=VectorSearchOperator.IN, + values=list(updates), + ) + ], ) job_id_to_vo = {int(x["JobID"]): str(x["VO"]) for x in job_vos} # Upsert the parameters into the JobParametersDB diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index 7c58a95e3..4dcdc9861 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -10,6 +10,7 @@ ScalarSearchSpec, SearchParams, SearchSpec, + SummaryParams, VectorSearchOperator, VectorSearchSpec, ) @@ -178,3 +179,15 @@ async def get_outdated_pilots( ) return pilots + + +async def summary(pilot_db: PilotAgentsDB, body: SummaryParams, vo: str): + """Show information suitable for plotting.""" + body.search.append( + { + "parameter": "VO", + "operator": ScalarSearchOperator.EQUAL, + "value": vo, + } + ) + return await pilot_db.pilot_summary(body.grouping, body.search) diff --git a/diracx-routers/src/diracx/routers/jobs/access_policies.py b/diracx-routers/src/diracx/routers/jobs/access_policies.py index 1fd5a63ae..2239e4764 100644 --- a/diracx-routers/src/diracx/routers/jobs/access_policies.py +++ b/diracx-routers/src/diracx/routers/jobs/access_policies.py @@ -6,6 +6,7 @@ from fastapi import Depends, HTTPException, status +from diracx.core.models import VectorSearchOperator, VectorSearchSpec from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.db.sql import JobDB, SandboxMetadataDB from diracx.routers.access_policies import BaseAccessPolicy @@ -82,9 +83,13 @@ async def policy( # Now we know we are either in READ/MODIFY for a NORMAL_USER # so just make sure that whatever job_id was given belongs # to the current user - job_owners = await job_db.summary( + job_owners = await job_db.job_summary( ["Owner", "VO"], - [{"parameter": "JobID", "operator": "in", "values": job_ids}], + [ + VectorSearchSpec( + parameter="JobID", operator=VectorSearchOperator.IN, values=job_ids + ) + ], ) expected_owner = { diff --git a/diracx-routers/src/diracx/routers/jobs/query.py b/diracx-routers/src/diracx/routers/jobs/query.py index f2f8dd323..db270ca4d 100644 --- a/diracx-routers/src/diracx/routers/jobs/query.py +++ b/diracx-routers/src/diracx/routers/jobs/query.py @@ -6,8 +6,8 @@ from fastapi import Body, Depends, Response from diracx.core.models import ( - JobSummaryParams, SearchParams, + SummaryParams, ) from diracx.core.properties import JOB_ADMINISTRATOR from diracx.logic.jobs.query import search as search_bl @@ -183,7 +183,7 @@ async def summary( config: Config, job_db: JobDB, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - body: JobSummaryParams, + body: SummaryParams, check_permissions: CheckWMSPolicyCallable, ): """Show information suitable for plotting.""" diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index f6227db7f..c88cd959f 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -6,6 +6,7 @@ from fastapi import Depends, HTTPException, status +from diracx.core.models import VectorSearchOperator, VectorSearchSpec from diracx.core.properties import SERVICE_ADMINISTRATOR from diracx.db.sql.job.db import JobDB from diracx.db.sql.pilots.db import PilotAgentsDB @@ -59,9 +60,15 @@ async def policy( # First, if job_ids are provided, we check who is the owner if job_db and job_ids: - job_owners = await job_db.summary( + job_owners = await job_db.job_summary( ["Owner", "VO"], - [{"parameter": "JobID", "operator": "in", "values": job_ids}], + [ + VectorSearchSpec( + parameter="JobID", + operator=VectorSearchOperator.IN, + values=job_ids, + ) + ], ) expected_owner = { diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py index 9c7236af3..3afbd7cfd 100644 --- a/diracx-routers/src/diracx/routers/pilots/query.py +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -5,8 +5,9 @@ from fastapi import Body, Depends, Response -from diracx.core.models import SearchParams +from diracx.core.models import SearchParams, SummaryParams from diracx.logic.pilots.query import search as search_bl +from diracx.logic.pilots.query import summary as summary_bl from ..dependencies import PilotAgentsDB from ..fastapi_classes import DiracxRouter @@ -145,3 +146,21 @@ async def search( response.headers["Content-Range"] = f"pilots {first_idx}-{last_idx}/{total}" response.status_code = HTTPStatus.PARTIAL_CONTENT return pilots + + +@router.post("/summary") +async def summary( + pilot_db: PilotAgentsDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + body: SummaryParams, + check_permissions: CheckPilotManagementPolicyCallable, +): + """Show information suitable for plotting.""" + # TODO: Test me. + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + return await summary_bl( + pilot_db=pilot_db, + body=body, + vo=user_info.vo, + ) diff --git a/diracx-routers/tests/jobs/test_wms_access_policy.py b/diracx-routers/tests/jobs/test_wms_access_policy.py index 351db139f..1a805899c 100644 --- a/diracx-routers/tests/jobs/test_wms_access_policy.py +++ b/diracx-routers/tests/jobs/test_wms_access_policy.py @@ -23,7 +23,7 @@ class FakeJobDB: - async def summary(self, *args): ... + async def job_summary(self, *args): ... class FakeSBMetadataDB: @@ -159,7 +159,7 @@ async def test_wms_access_policy_read_modify(job_db, monkeypatch): async def summary_matching(*args): return [{"Owner": "preferred_username", "VO": "lhcb", "count": 3}] - monkeypatch.setattr(job_db, "summary", summary_matching) + monkeypatch.setattr(job_db, "job_summary", summary_matching) await WMSAccessPolicy.policy( WMS_POLICY_NAME, @@ -182,7 +182,7 @@ async def summary_matching(*args): async def summary_other_owner(*args): return [{"Owner": "other_owner", "VO": "lhcb", "count": 3}] - monkeypatch.setattr(job_db, "summary", summary_other_owner) + monkeypatch.setattr(job_db, "job_summary", summary_other_owner) with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): await WMSAccessPolicy.policy( WMS_POLICY_NAME, @@ -196,7 +196,7 @@ async def summary_other_owner(*args): async def summary_other_vo(*args): return [{"Owner": "preferred_username", "VO": "gridpp", "count": 3}] - monkeypatch.setattr(job_db, "summary", summary_other_vo) + monkeypatch.setattr(job_db, "job_summary", summary_other_vo) with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): await WMSAccessPolicy.policy( WMS_POLICY_NAME, @@ -210,7 +210,7 @@ async def summary_other_vo(*args): async def summary_other_vo(*args): return [{"Owner": "preferred_username", "VO": "lhcb", "count": 2}] - monkeypatch.setattr(job_db, "summary", summary_other_vo) + monkeypatch.setattr(job_db, "job_summary", summary_other_vo) with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): await WMSAccessPolicy.policy( WMS_POLICY_NAME, diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index ee78ff65f..0233f17b1 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -25,6 +25,7 @@ "BaseAccessPolicy", "PilotAgentsDB", "PilotManagementAccessPolicy", + "JobDB" ] ) diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 7c17cde07..bbc466d1d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -60,6 +60,7 @@ build_pilots_delete_pilots_request, build_pilots_get_pilot_jobs_request, build_pilots_search_request, + build_pilots_summary_request, build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, @@ -1977,14 +1978,14 @@ async def search( @overload async def summary( - self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any ) -> Any: """Summary. Show information suitable for plotting. :param body: Required. - :type body: ~_generated.models.JobSummaryParams + :type body: ~_generated.models.SummaryParams :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2010,13 +2011,13 @@ async def summary(self, body: IO[bytes], *, content_type: str = "application/jso """ @distributed_trace_async - async def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: Any) -> Any: + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. - :param body: Is either a JobSummaryParams type or a IO[bytes] type. Required. - :type body: ~_generated.models.JobSummaryParams or IO[bytes] + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -2041,7 +2042,7 @@ async def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwar if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "JobSummaryParams") + _json = self._serialize.body(body, "SummaryParams") _request = build_jobs_summary_request( content_type=content_type, @@ -2908,3 +2909,98 @@ async def search( return cls(pipeline_response, deserialized, response_headers) # type: ignore return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index 7f6b0f274..7a4362d9e 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -25,8 +25,6 @@ InsertedJob, JobCommand, JobStatusUpdate, - JobSummaryParams, - JobSummaryParamsSearchItem, OpenIDConfiguration, PilotFieldsMapping, SandboxDownloadResponse, @@ -39,6 +37,8 @@ SetJobStatusReturn, SetJobStatusReturnSuccess, SortSpec, + SummaryParams, + SummaryParamsSearchItem, SupportInfo, TokenResponse, UserInfoResponse, @@ -77,8 +77,6 @@ "InsertedJob", "JobCommand", "JobStatusUpdate", - "JobSummaryParams", - "JobSummaryParamsSearchItem", "OpenIDConfiguration", "PilotFieldsMapping", "SandboxDownloadResponse", @@ -91,6 +89,8 @@ "SetJobStatusReturn", "SetJobStatusReturnSuccess", "SortSpec", + "SummaryParams", + "SummaryParamsSearchItem", "SupportInfo", "TokenResponse", "UserInfoResponse", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 4cdd81cfc..b31b3eeba 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -583,44 +583,6 @@ def __init__( self.source = source -class JobSummaryParams(_serialization.Model): - """JobSummaryParams. - - All required parameters must be populated in order to send to server. - - :ivar grouping: Grouping. Required. - :vartype grouping: list[str] - :ivar search: Search. - :vartype search: list[~_generated.models.JobSummaryParamsSearchItem] - """ - - _validation = { - "grouping": {"required": True}, - } - - _attribute_map = { - "grouping": {"key": "grouping", "type": "[str]"}, - "search": {"key": "search", "type": "[JobSummaryParamsSearchItem]"}, - } - - def __init__( - self, *, grouping: List[str], search: List["_models.JobSummaryParamsSearchItem"] = [], **kwargs: Any - ) -> None: - """ - :keyword grouping: Grouping. Required. - :paramtype grouping: list[str] - :keyword search: Search. - :paramtype search: list[~_generated.models.JobSummaryParamsSearchItem] - """ - super().__init__(**kwargs) - self.grouping = grouping - self.search = search - - -class JobSummaryParamsSearchItem(_serialization.Model): - """JobSummaryParamsSearchItem.""" - - class OpenIDConfiguration(_serialization.Model): """OpenIDConfiguration. @@ -1224,6 +1186,44 @@ def __init__(self, *, parameter: str, direction: Union[str, "_models.SortDirecti self.direction = direction +class SummaryParams(_serialization.Model): + """SummaryParams. + + All required parameters must be populated in order to send to server. + + :ivar grouping: Grouping. Required. + :vartype grouping: list[str] + :ivar search: Search. + :vartype search: list[~_generated.models.SummaryParamsSearchItem] + """ + + _validation = { + "grouping": {"required": True}, + } + + _attribute_map = { + "grouping": {"key": "grouping", "type": "[str]"}, + "search": {"key": "search", "type": "[SummaryParamsSearchItem]"}, + } + + def __init__( + self, *, grouping: List[str], search: List["_models.SummaryParamsSearchItem"] = [], **kwargs: Any + ) -> None: + """ + :keyword grouping: Grouping. Required. + :paramtype grouping: list[str] + :keyword search: Search. + :paramtype search: list[~_generated.models.SummaryParamsSearchItem] + """ + super().__init__(**kwargs) + self.grouping = grouping + self.search = search + + +class SummaryParamsSearchItem(_serialization.Model): + """SummaryParamsSearchItem.""" + + class SupportInfo(_serialization.Model): """SupportInfo. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index c2af06038..6824ca0f0 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -754,6 +754,23 @@ def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2656,13 +2673,13 @@ def search( return deserialized # type: ignore @overload - def summary(self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. :param body: Required. - :type body: ~_generated.models.JobSummaryParams + :type body: ~_generated.models.SummaryParams :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2688,13 +2705,13 @@ def summary(self, body: IO[bytes], *, content_type: str = "application/json", ** """ @distributed_trace - def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: Any) -> Any: + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. - :param body: Is either a JobSummaryParams type or a IO[bytes] type. Required. - :type body: ~_generated.models.JobSummaryParams or IO[bytes] + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -2719,7 +2736,7 @@ def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: An if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "JobSummaryParams") + _json = self._serialize.body(body, "SummaryParams") _request = build_jobs_summary_request( content_type=content_type, @@ -3584,3 +3601,96 @@ def search( return cls(pipeline_response, deserialized, response_headers) # type: ignore return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py index 467577394..a32823a3a 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py @@ -21,7 +21,7 @@ class LollygagDB(BaseSQLDB): # This needs to be here for the BaseSQLDB to create the engine metadata = LollygagDBBase.metadata - async def summary(self, group_by, search) -> list[dict[str, str | int]]: + async def lollygag_summary(self, group_by, search) -> list[dict[str, str | int]]: columns = [Cars.__table__.columns[x] for x in group_by] stmt = select(*columns, func.count(Cars.license_plate).label("count")) diff --git a/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py b/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py index b5ff7b84e..43d69e91e 100644 --- a/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py +++ b/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py @@ -31,7 +31,7 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): # So it is important to write test this way async with lollygag_db as lollygag_db: # First we check that the DB is empty - result = await lollygag_db.summary(["Model"], []) + result = await lollygag_db.lollygag_summary(["Model"], []) assert not result # Now we add some data in the DB @@ -51,13 +51,13 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): # Check that there are now 10 cars assigned to a single driver async with lollygag_db as lollygag_db: - result = await lollygag_db.summary(["OwnerID"], []) + result = await lollygag_db.lollygag_summary(["OwnerID"], []) assert result[0]["count"] == 10 # Test the selection async with lollygag_db as lollygag_db: - result = await lollygag_db.summary( + result = await lollygag_db.lollygag_summary( ["OwnerID"], [{"parameter": "Model", "operator": "eq", "value": "model_1"}] ) @@ -65,7 +65,7 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): async with lollygag_db as lollygag_db: with pytest.raises(InvalidQueryError): - result = await lollygag_db.summary( + result = await lollygag_db.lollygag_summary( ["OwnerID"], [ { From 4df97ff1d0941a1145faed0b4d45d94910070960 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Fri, 4 Jul 2025 13:47:51 +0200 Subject: [PATCH 21/33] feat: Add legacy pilot support in management (dirac-admin-add-pilot) --- .../diracx/routers/pilots/access_policies.py | 30 +++++++++++------ .../src/diracx/routers/pilots/management.py | 32 +++++++++++++++++-- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index c88cd959f..1a7b49e7b 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -7,7 +7,7 @@ from fastapi import Depends, HTTPException, status from diracx.core.models import VectorSearchOperator, VectorSearchSpec -from diracx.core.properties import SERVICE_ADMINISTRATOR +from diracx.core.properties import GENERIC_PILOT, SERVICE_ADMINISTRATOR from diracx.db.sql.job.db import JobDB from diracx.db.sql.pilots.db import PilotAgentsDB from diracx.logic.pilots.query import get_pilots_by_stamp @@ -39,20 +39,32 @@ async def policy( pilot_stamps: list[str] | None = None, job_db: JobDB | None = None, job_ids: list[int] | None = None, + allow_legacy_pilots: bool = False ): assert action, "action is a mandatory parameter" # Users can query # NOTE: Add into queries a VO constraint # To manage pilots, user have to be an admin - if ( - action == ActionType.MANAGE_PILOTS - and SERVICE_ADMINISTRATOR not in user_info.properties - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You don't have the permission to manage pilots.", - ) + # In some special cases (described with allow_legacy_pilots), we can allow pilots + if action == ActionType.MANAGE_PILOTS: + + # To make it clear, we separate + is_an_admin = SERVICE_ADMINISTRATOR in user_info.properties + is_a_pilot_if_allowed = allow_legacy_pilots and GENERIC_PILOT in user_info.properties + + if not is_an_admin and not is_a_pilot_if_allowed: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the permission to manage pilots.", + ) + + if action == ActionType.READ_PILOT_FIELDS: + if GENERIC_PILOT in user_info.properties: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Pilots can't read other pilots info." + ) # # Additional checks if job_ids or pilot_stamps are provided diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 8f783352e..66c0dbba0 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -3,6 +3,7 @@ from http import HTTPStatus from typing import Annotated +from diracx.core.properties import GENERIC_PILOT from fastapi import Body, Depends, HTTPException, Query, status from diracx.core.exceptions import ( @@ -65,7 +66,19 @@ async def add_pilot_stamps( If a pilot stamp already exists, it will block the insertion. """ # TODO: Verify that grid types, sites, destination sites, etc. are valids - await check_permissions(action=ActionType.MANAGE_PILOTS) + await check_permissions( + action=ActionType.MANAGE_PILOTS, + allow_legacy_pilots=True # dirac-admin-add-pilot + ) + + # Prevent someone who stole a pilot X509 to create thousands of pilots at a time + # (It would be still able to create thousands of pilots, but slower) + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="As a pilot, you can only create yourself." + ) try: await register_new_pilots( @@ -183,6 +196,7 @@ async def update_pilot_fields( ], pilot_db: PilotAgentsDB, check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], ): """Modify a field of a pilot. @@ -191,9 +205,23 @@ async def update_pilot_fields( # Ensures stamps validity pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] await check_permissions( - action=ActionType.MANAGE_PILOTS, pilot_db=pilot_db, pilot_stamps=pilot_stamps + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + allow_legacy_pilots=True # dirac-admin-add-pilot ) + # Prevent someone who stole a pilot X509 to modify thousands of pilots at a time + # (It would be still able to modify thousands of pilots, but slower) + # We are not able to affirm that this pilots modifies itself + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="As a pilot, you can only modify yourself." + ) + + await update_pilots_fields( pilot_db=pilot_db, pilot_stamps_to_fields_mapping=pilot_stamps_to_fields_mapping, From 3ecd3b30d64fc8d0014b3d587c7d14480f6abe27 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Fri, 4 Jul 2025 16:02:23 +0200 Subject: [PATCH 22/33] chore: Removed association of a job with a pilot because it's internal --- .../src/diracx/routers/pilots/management.py | 36 ------------------- 1 file changed, 36 deletions(-) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 66c0dbba0..50743e78e 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -263,39 +263,3 @@ async def get_pilot_jobs( status_code=status.HTTP_400_BAD_REQUEST, detail="You must provide either pilot_stamp or job_id", ) - - -@router.patch("/jobs", status_code=HTTPStatus.NO_CONTENT) -async def add_jobs_to_pilot( - pilot_db: PilotAgentsDB, - job_db: JobDB, - pilot_stamp: Annotated[str, Body(description="The stamp of the pilot.")], - job_ids: Annotated[ - list[int], Body(description="The jobs we want to add to the pilot.") - ], - check_permissions: CheckPilotManagementPolicyCallable, -): - """Endpoint only for admins, to associate a pilot with a job.""" - await check_permissions( - action=ActionType.MANAGE_PILOTS, - pilot_db=pilot_db, - pilot_stamps=[pilot_stamp], - job_db=job_db, - job_ids=job_ids, - ) - - try: - await add_jobs_to_pilot_bl( - pilot_db=pilot_db, - pilot_stamp=pilot_stamp, - job_ids=job_ids, - ) - except PilotNotFoundError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="This pilot does not exist." - ) from e - except PilotAlreadyAssociatedWithJobError as e: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail="This pilot is already associated with this job.", - ) from e From 2a72da1d259aa4073e11a1feb8ee268bd8029cf9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Jul 2025 14:04:29 +0000 Subject: [PATCH 23/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../diracx/routers/pilots/access_policies.py | 9 +++++---- .../src/diracx/routers/pilots/management.py | 20 +++++++------------ .../tests/pilots/test_pilot_creation.py | 2 +- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index 1a7b49e7b..40e942e60 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -39,7 +39,7 @@ async def policy( pilot_stamps: list[str] | None = None, job_db: JobDB | None = None, job_ids: list[int] | None = None, - allow_legacy_pilots: bool = False + allow_legacy_pilots: bool = False, ): assert action, "action is a mandatory parameter" @@ -48,10 +48,11 @@ async def policy( # To manage pilots, user have to be an admin # In some special cases (described with allow_legacy_pilots), we can allow pilots if action == ActionType.MANAGE_PILOTS: - # To make it clear, we separate is_an_admin = SERVICE_ADMINISTRATOR in user_info.properties - is_a_pilot_if_allowed = allow_legacy_pilots and GENERIC_PILOT in user_info.properties + is_a_pilot_if_allowed = ( + allow_legacy_pilots and GENERIC_PILOT in user_info.properties + ) if not is_an_admin and not is_a_pilot_if_allowed: raise HTTPException( @@ -63,7 +64,7 @@ async def policy( if GENERIC_PILOT in user_info.properties: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Pilots can't read other pilots info." + detail="Pilots can't read other pilots info.", ) # diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 50743e78e..7b632f1cc 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -3,21 +3,16 @@ from http import HTTPStatus from typing import Annotated -from diracx.core.properties import GENERIC_PILOT from fastapi import Body, Depends, HTTPException, Query, status from diracx.core.exceptions import ( - PilotAlreadyAssociatedWithJobError, PilotAlreadyExistsError, - PilotNotFoundError, ) from diracx.core.models import ( PilotFieldsMapping, PilotStatus, ) -from diracx.logic.pilots.management import ( - add_jobs_to_pilot as add_jobs_to_pilot_bl, -) +from diracx.core.properties import GENERIC_PILOT from diracx.logic.pilots.management import ( delete_pilots as delete_pilots_bl, ) @@ -68,7 +63,7 @@ async def add_pilot_stamps( # TODO: Verify that grid types, sites, destination sites, etc. are valids await check_permissions( action=ActionType.MANAGE_PILOTS, - allow_legacy_pilots=True # dirac-admin-add-pilot + allow_legacy_pilots=True, # dirac-admin-add-pilot ) # Prevent someone who stole a pilot X509 to create thousands of pilots at a time @@ -77,7 +72,7 @@ async def add_pilot_stamps( if len(pilot_stamps) != 1: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="As a pilot, you can only create yourself." + detail="As a pilot, you can only create yourself.", ) try: @@ -205,10 +200,10 @@ async def update_pilot_fields( # Ensures stamps validity pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] await check_permissions( - action=ActionType.MANAGE_PILOTS, - pilot_db=pilot_db, + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, pilot_stamps=pilot_stamps, - allow_legacy_pilots=True # dirac-admin-add-pilot + allow_legacy_pilots=True, # dirac-admin-add-pilot ) # Prevent someone who stole a pilot X509 to modify thousands of pilots at a time @@ -218,10 +213,9 @@ async def update_pilot_fields( if len(pilot_stamps) != 1: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="As a pilot, you can only modify yourself." + detail="As a pilot, you can only modify yourself.", ) - await update_pilots_fields( pilot_db=pilot_db, pilot_stamps_to_fields_mapping=pilot_stamps_to_fields_mapping, diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index 0233f17b1..9204ca38b 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -25,7 +25,7 @@ "BaseAccessPolicy", "PilotAgentsDB", "PilotManagementAccessPolicy", - "JobDB" + "JobDB", ] ) From 7ca26a08bdb996faa91e7058ec1bd86256b66efc Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Tue, 8 Jul 2025 08:24:09 +0200 Subject: [PATCH 24/33] fix: Fixes to pass CI --- .../_generated/aio/operations/_operations.py | 94 ------------ .../client/_generated/models/__init__.py | 2 - .../client/_generated/models/_models.py | 33 ---- .../_generated/operations/_operations.py | 107 ------------- .../tests/pilots/test_pilot_creation.py | 143 ------------------ .../_generated/aio/operations/_operations.py | 94 ------------ .../client/_generated/models/__init__.py | 2 - .../client/_generated/models/_models.py | 33 ---- .../_generated/operations/_operations.py | 107 ------------- 9 files changed, 615 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index b02f5f037..e1aad0601 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -52,7 +52,6 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, - build_pilots_add_jobs_to_pilot_request, build_pilots_add_pilot_stamps_request, build_pilots_delete_pilots_request, build_pilots_get_pilot_jobs_request, @@ -2513,99 +2512,6 @@ async def get_pilot_jobs( return deserialized # type: ignore - @overload - async def add_jobs_to_pilot( - self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any - ) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Required. - :type body: ~_generated.models.BodyPilotsAddJobsToPilot - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def add_jobs_to_pilot( - self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any - ) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace_async - async def add_jobs_to_pilot(self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) - - content_type = content_type or "application/json" - _json = None - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _json = self._serialize.body(body, "BodyPilotsAddJobsToPilot") - - _request = build_pilots_add_jobs_to_pilot_request( - content_type=content_type, - json=_json, - content=_content, - headers=_headers, - params=_params, - ) - _request.url = self._client.format_url(_request.url) - - _stream = False - pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [204]: - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if cls: - return cls(pipeline_response, None, {}) # type: ignore - @overload async def search( self, diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index d9f48e28a..33716f267 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -14,7 +14,6 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, - BodyPilotsAddJobsToPilot, BodyPilotsAddPilotStamps, BodyPilotsUpdatePilotFields, GroupInfo, @@ -66,7 +65,6 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", - "BodyPilotsAddJobsToPilot", "BodyPilotsAddPilotStamps", "BodyPilotsUpdatePilotFields", "GroupInfo", diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index a8457473c..b101fdb10 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -94,39 +94,6 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" -class BodyPilotsAddJobsToPilot(_serialization.Model): - """Body_pilots_add_jobs_to_pilot. - - All required parameters must be populated in order to send to server. - - :ivar pilot_stamp: The stamp of the pilot. Required. - :vartype pilot_stamp: str - :ivar job_ids: The jobs we want to add to the pilot. Required. - :vartype job_ids: list[int] - """ - - _validation = { - "pilot_stamp": {"required": True}, - "job_ids": {"required": True}, - } - - _attribute_map = { - "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, - "job_ids": {"key": "job_ids", "type": "[int]"}, - } - - def __init__(self, *, pilot_stamp: str, job_ids: List[int], **kwargs: Any) -> None: - """ - :keyword pilot_stamp: The stamp of the pilot. Required. - :paramtype pilot_stamp: str - :keyword job_ids: The jobs we want to add to the pilot. Required. - :paramtype job_ids: list[int] - """ - super().__init__(**kwargs) - self.pilot_stamp = pilot_stamp - self.job_ids = job_ids - - class BodyPilotsAddPilotStamps(_serialization.Model): """Body_pilots_add_pilot_stamps. diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index f317d9353..3673bdf32 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -667,20 +667,6 @@ def build_pilots_get_pilot_jobs_request( return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) -def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - # Construct URL - _url = "/api/pilots/jobs" - - # Construct headers - if content_type is not None: - _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") - - return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) - - def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -3159,99 +3145,6 @@ def get_pilot_jobs( return deserialized # type: ignore - @overload - def add_jobs_to_pilot( - self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any - ) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Required. - :type body: ~_generated.models.BodyPilotsAddJobsToPilot - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def add_jobs_to_pilot(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace - def add_jobs_to_pilot( # pylint: disable=inconsistent-return-statements - self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any - ) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) - - content_type = content_type or "application/json" - _json = None - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _json = self._serialize.body(body, "BodyPilotsAddJobsToPilot") - - _request = build_pilots_add_jobs_to_pilot_request( - content_type=content_type, - json=_json, - content=_content, - headers=_headers, - params=_params, - ) - _request.url = self._client.format_url(_request.url) - - _stream = False - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [204]: - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if cls: - return cls(pipeline_response, None, {}) # type: ignore - @overload def search( self, diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index 9204ca38b..11293ff12 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -3,14 +3,11 @@ from datetime import datetime, timezone import pytest -from fastapi.testclient import TestClient from sqlalchemy import update from diracx.core.models import ( PilotFieldsMapping, PilotStatus, - ScalarSearchOperator, - ScalarSearchSpec, ) from diracx.db.sql import PilotAgentsDB from diracx.db.sql.pilots.schema import PilotAgents @@ -188,106 +185,6 @@ async def test_create_pilot_and_modify_it(normal_test_client): assert pilot2["Status"] != pilot1["Status"] -async def test_associate_job_with_pilot_and_get_it(normal_test_client: TestClient): - pilot_stamps = ["stamps_1", "stamp_2"] - - # -------------- Insert -------------- - body = {"pilot_stamps": pilot_stamps} - - # Create pilots - r = normal_test_client.post( - "/api/pilots/", - json=body, - ) - - assert r.status_code == 200, r.json() - - # --------------- As DIRAC, associate a job with a pilot -------- - job_ids = [1, 2] - body = {"pilot_stamp": pilot_stamps[0], "job_ids": job_ids} - - # Create pilots - r = normal_test_client.patch( - "/api/pilots/jobs", - json=body, - ) - - assert r.status_code == 204 - - # -------------- Redo it, expect 409 (Conflict) --------------------- - job_ids = [1, 2, 3] # Note for next test : add 3 - body = {"pilot_stamp": pilot_stamps[0], "job_ids": job_ids} - - # Create pilots - r = normal_test_client.patch( - "/api/pilots/jobs", - json=body, - ) - - assert r.status_code == 409 - - # -------------- Add 3 --------------------- - body = {"pilot_stamp": pilot_stamps[0], "job_ids": [3]} - - # Create pilots - r = normal_test_client.patch( - "/api/pilots/jobs", - json=body, - ) - - assert r.status_code == 204 - - # -------------- Add with unknown pilot --------------------- - body = {"pilot_stamp": "stampounet", "job_ids": job_ids} - - # Create pilots - r = normal_test_client.patch( - "/api/pilots/jobs", - json=body, - ) - - assert r.status_code == 400 - - # -------------- Get its jobs --------------------- - r = normal_test_client.get( - "/api/pilots/jobs", params={"pilot_stamp": pilot_stamps[0]} - ) - - assert r.status_code == 200 - assert r.json() == job_ids - - # -------------- Get the other pilot's jobs --------------------- - r = normal_test_client.get( - "/api/pilots/jobs", params={"pilot_stamp": pilot_stamps[1]} - ) - - assert r.status_code == 200 - assert r.json() == [] - - # -------------- Get pilots associated to job 1 --------------------- - r = normal_test_client.get("/api/pilots/jobs", params={"job_id": job_ids[0]}) - - assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - expected_pilot_id = r.json()[0] - - # -------------- Get pilot info to verify that its id is expected_pilot_id --------------------- - condition = ScalarSearchSpec( - parameter="PilotID", - operator=ScalarSearchOperator.EQUAL, - value=expected_pilot_id, - ) - - r = normal_test_client.post( - "/api/pilots/search", - json={"parameters": [], "search": [condition], "sorts": []}, - ) - - assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - assert r.json()[0]["PilotStamp"] == pilot_stamps[0] - - @pytest.mark.asyncio async def test_delete_pilots_by_age_and_stamp(normal_test_client): # Generate 100 pilot stamps @@ -384,43 +281,3 @@ async def test_delete_pilots_by_age_and_stamp(normal_test_client): "/api/pilots/", params={"pilot_stamps": ["unknown_stamp"]} ) assert r.status_code == 204 - - -@pytest.mark.asyncio -async def test_associate_two_pilots_share_jobs_and_delete_first(normal_test_client): - # 1) Create two pilots - pilot_stamps = ["stamp_1", "stamp_2"] - body = {"pilot_stamps": pilot_stamps} - r = normal_test_client.post("/api/pilots/", json=body) - assert r.status_code == 200, r.json() - - # 2) Associate first pilot with jobs 1-10 - job_ids = list(range(1, 11)) - body = {"pilot_stamp": pilot_stamps[0], "job_ids": job_ids} - r = normal_test_client.patch("/api/pilots/jobs", json=body) - assert r.status_code == 204 - - # 3) Associate second pilot with the same jobs - body = {"pilot_stamp": pilot_stamps[1], "job_ids": job_ids} - r = normal_test_client.patch("/api/pilots/jobs", json=body) - assert r.status_code == 204 - - # 4) Delete first pilot - r = normal_test_client.delete( - "/api/pilots/", params={"pilot_stamps": [pilot_stamps[0]]} - ) - assert r.status_code == 204 - - # 5) Get jobs for pilot_1: expect empty list - r = normal_test_client.get( - "/api/pilots/jobs", params={"pilot_stamp": pilot_stamps[0]} - ) - assert r.status_code == 200 - assert r.json() == [] - - # 6) Get jobs for pilot_2: expect original job_ids - r = normal_test_client.get( - "/api/pilots/jobs", params={"pilot_stamp": pilot_stamps[1]} - ) - assert r.status_code == 200 - assert r.json() == job_ids diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index bbc466d1d..3d9951b6d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -55,7 +55,6 @@ build_lollygag_get_gubbins_secrets_request, build_lollygag_get_owner_object_request, build_lollygag_insert_owner_object_request, - build_pilots_add_jobs_to_pilot_request, build_pilots_add_pilot_stamps_request, build_pilots_delete_pilots_request, build_pilots_get_pilot_jobs_request, @@ -2680,99 +2679,6 @@ async def get_pilot_jobs( return deserialized # type: ignore - @overload - async def add_jobs_to_pilot( - self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any - ) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Required. - :type body: ~_generated.models.BodyPilotsAddJobsToPilot - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def add_jobs_to_pilot( - self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any - ) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace_async - async def add_jobs_to_pilot(self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) - - content_type = content_type or "application/json" - _json = None - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _json = self._serialize.body(body, "BodyPilotsAddJobsToPilot") - - _request = build_pilots_add_jobs_to_pilot_request( - content_type=content_type, - json=_json, - content=_content, - headers=_headers, - params=_params, - ) - _request.url = self._client.format_url(_request.url) - - _stream = False - pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [204]: - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if cls: - return cls(pipeline_response, None, {}) # type: ignore - @overload async def search( self, diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index 7a4362d9e..537867ac1 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -14,7 +14,6 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, - BodyPilotsAddJobsToPilot, BodyPilotsAddPilotStamps, BodyPilotsUpdatePilotFields, ExtendedMetadata, @@ -66,7 +65,6 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", - "BodyPilotsAddJobsToPilot", "BodyPilotsAddPilotStamps", "BodyPilotsUpdatePilotFields", "ExtendedMetadata", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index b31b3eeba..5d5e4ad8c 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -94,39 +94,6 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" -class BodyPilotsAddJobsToPilot(_serialization.Model): - """Body_pilots_add_jobs_to_pilot. - - All required parameters must be populated in order to send to server. - - :ivar pilot_stamp: The stamp of the pilot. Required. - :vartype pilot_stamp: str - :ivar job_ids: The jobs we want to add to the pilot. Required. - :vartype job_ids: list[int] - """ - - _validation = { - "pilot_stamp": {"required": True}, - "job_ids": {"required": True}, - } - - _attribute_map = { - "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, - "job_ids": {"key": "job_ids", "type": "[int]"}, - } - - def __init__(self, *, pilot_stamp: str, job_ids: List[int], **kwargs: Any) -> None: - """ - :keyword pilot_stamp: The stamp of the pilot. Required. - :paramtype pilot_stamp: str - :keyword job_ids: The jobs we want to add to the pilot. Required. - :paramtype job_ids: list[int] - """ - super().__init__(**kwargs) - self.pilot_stamp = pilot_stamp - self.job_ids = job_ids - - class BodyPilotsAddPilotStamps(_serialization.Model): """Body_pilots_add_pilot_stamps. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 6824ca0f0..e861b9841 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -716,20 +716,6 @@ def build_pilots_get_pilot_jobs_request( return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) -def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest: - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - # Construct URL - _url = "/api/pilots/jobs" - - # Construct headers - if content_type is not None: - _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") - - return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) - - def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -3372,99 +3358,6 @@ def get_pilot_jobs( return deserialized # type: ignore - @overload - def add_jobs_to_pilot( - self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any - ) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Required. - :type body: ~_generated.models.BodyPilotsAddJobsToPilot - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def add_jobs_to_pilot(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace - def add_jobs_to_pilot( # pylint: disable=inconsistent-return-statements - self, body: Union[_models.BodyPilotsAddJobsToPilot, IO[bytes]], **kwargs: Any - ) -> None: - """Add Jobs To Pilot. - - Endpoint only for admins, to associate a pilot with a job. - - :param body: Is either a BodyPilotsAddJobsToPilot type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAddJobsToPilot or IO[bytes] - :return: None - :rtype: None - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) - - content_type = content_type or "application/json" - _json = None - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _json = self._serialize.body(body, "BodyPilotsAddJobsToPilot") - - _request = build_pilots_add_jobs_to_pilot_request( - content_type=content_type, - json=_json, - content=_content, - headers=_headers, - params=_params, - ) - _request.url = self._client.format_url(_request.url) - - _stream = False - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [204]: - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if cls: - return cls(pipeline_response, None, {}) # type: ignore - @overload def search( self, From c2e4992075198f42af091354f0b86ddc5b1dfb6b Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Mon, 14 Jul 2025 08:41:33 +0200 Subject: [PATCH 25/33] fix: Patch client to use in in DIRAC --- .../_generated/aio/operations/_patch.py | 2 + .../client/_generated/operations/_patch.py | 2 + .../src/diracx/client/patches/pilots/aio.py | 34 ++++++++ .../diracx/client/patches/pilots/common.py | 85 +++++++++++++++++++ .../src/diracx/client/patches/pilots/sync.py | 34 ++++++++ 5 files changed, 157 insertions(+) create mode 100644 diracx-client/src/diracx/client/patches/pilots/aio.py create mode 100644 diracx-client/src/diracx/client/patches/pilots/common.py create mode 100644 diracx-client/src/diracx/client/patches/pilots/sync.py diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py index a408e57d2..aa1cbd79d 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations" ] # Add all objects you want publicly available to users at this package level from ....patches.auth.aio import AuthOperations from ....patches.jobs.aio import JobsOperations +from ....patches.pilots.aio import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/_generated/operations/_patch.py b/diracx-client/src/diracx/client/_generated/operations/_patch.py index b7b8c67fa..9f1e07b4e 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations" ] # Add all objects you want publicly available to users at this package level from ...patches.auth.sync import AuthOperations from ...patches.jobs.sync import JobsOperations +from ...patches.pilots.sync import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/patches/pilots/aio.py b/diracx-client/src/diracx/client/patches/pilots/aio.py new file mode 100644 index 000000000..496aaeeda --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/aio.py @@ -0,0 +1,34 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator_async import distributed_trace_async + +from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations +from .common import make_search_body, make_summary_body, SearchKwargs, SummaryKwargs + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace_async + async def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().search(**make_search_body(**kwargs)) + + @distributed_trace_async + async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().summary(**make_summary_body(**kwargs)) diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py new file mode 100644 index 000000000..b332f8382 --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -0,0 +1,85 @@ +"""Utilities which are common to the sync and async pilots operator patches.""" + +from __future__ import annotations + +__all__ = [ + "make_search_body", + "SearchKwargs", + "make_summary_body", + "SummaryKwargs", +] + +import json +from io import BytesIO +from typing import Any, IO, TypedDict, Unpack, cast, Literal + +from diracx.core.models import SearchSpec + + +class ResponseExtra(TypedDict, total=False): + content_type: str + headers: dict[str, str] + params: dict[str, str] + cls: Any + + +class SearchBody(TypedDict, total=False): + parameters: list[str] | None + search: list[SearchSpec] | None + sort: list[str] | None + + +class SearchExtra(ResponseExtra, total=False): + page: int + per_page: int + + +class SearchKwargs(SearchBody, SearchExtra): ... + + +class UnderlyingSearchArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_search_body(**kwargs: Unpack[SearchKwargs]) -> UnderlyingSearchArgs: + body: SearchBody = {} + for key in SearchBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["parameters", "search", "sort"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSearchArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(SearchExtra, kwargs)) + return result + + +class SummaryBody(TypedDict, total=False): + grouping: list[str] + search: list[str] + + +class SummaryKwargs(SummaryBody, ResponseExtra): ... + + +class UnderlyingSummaryArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_summary_body(**kwargs: Unpack[SummaryKwargs]) -> UnderlyingSummaryArgs: + body: SummaryBody = {} + for key in SummaryBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["grouping", "search"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSummaryArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result diff --git a/diracx-client/src/diracx/client/patches/pilots/sync.py b/diracx-client/src/diracx/client/patches/pilots/sync.py new file mode 100644 index 000000000..6e404aec6 --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/sync.py @@ -0,0 +1,34 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator import distributed_trace + +from ..._generated.operations._operations import PilotsOperations as _PilotsOperations +from .common import make_search_body, make_summary_body, SearchKwargs, SummaryKwargs + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace + def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().search(**make_search_body(**kwargs)) + + @distributed_trace + def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().summary(**make_summary_body(**kwargs)) From b60f4f29514f6a29611d70394c13d60187dcb51a Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Mon, 14 Jul 2025 14:28:48 +0200 Subject: [PATCH 26/33] fix: Minor fixes, and testing pilot summary --- .../src/diracx/logic/pilots/management.py | 2 +- .../src/diracx/routers/pilots/management.py | 2 +- .../src/diracx/routers/pilots/query.py | 1 - diracx-routers/tests/pilots/test_query.py | 43 +++++++++++++++++++ 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index feaad0725..417d3a9a6 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -24,7 +24,7 @@ async def register_new_pilots( status: str, pilot_job_references: dict[str, str] | None, ): - # [IMPORTANT] Check unicity of pilot references + # [IMPORTANT] Check unicity of pilot stamps # If a pilot already exists, we raise an error (transaction will rollback) existing_pilots = await get_pilots_by_stamp( pilot_db=pilot_db, pilot_stamps=pilot_stamps diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 7b632f1cc..8d6bb2b5e 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -208,7 +208,7 @@ async def update_pilot_fields( # Prevent someone who stole a pilot X509 to modify thousands of pilots at a time # (It would be still able to modify thousands of pilots, but slower) - # We are not able to affirm that this pilots modifies itself + # We are not able to affirm that this pilot modifies itself if GENERIC_PILOT in user_info.properties: if len(pilot_stamps) != 1: raise HTTPException( diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py index 3afbd7cfd..56655b46c 100644 --- a/diracx-routers/src/diracx/routers/pilots/query.py +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -156,7 +156,6 @@ async def summary( check_permissions: CheckPilotManagementPolicyCallable, ): """Show information suitable for plotting.""" - # TODO: Test me. await check_permissions(action=ActionType.READ_PILOT_FIELDS) return await summary_bl( diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py index 0672e25af..c6d5cedb4 100644 --- a/diracx-routers/tests/pilots/test_query.py +++ b/diracx-routers/tests/pilots/test_query.py @@ -3,6 +3,7 @@ from __future__ import annotations import pytest +from fastapi.testclient import TestClient from diracx.core.exceptions import InvalidQueryError from diracx.core.models import ( @@ -82,6 +83,48 @@ async def populated_pilot_client(normal_test_client): yield normal_test_client +async def test_pilot_summary(populated_pilot_client: TestClient): + # Group by StatusReason + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["StatusReason"], + }, + ) + + assert r.status_code == 200 + + assert sum([el["count"] for el in r.json()]) == N + assert len(r.json()) == len(PILOT_REASONS) + + # Group by CurrentJobID + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + }, + ) + + assert r.status_code == 200 + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == N + + # Group by CurrentJobID where BenchMark < 10^2 + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + "search": [{"parameter": "BenchMark", "operator": "lt", "value": 10**2}], + }, + ) + + assert r.status_code == 200, r.json() + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == 10 + + @pytest.fixture async def search(populated_pilot_client): async def _search( From 6674d37ffc5900e1f347833c74a801cc8956674f Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Wed, 16 Jul 2025 13:46:31 +0200 Subject: [PATCH 27/33] fix Fixed client (lack a comma) --- .../src/diracx/client/_generated/aio/operations/_patch.py | 2 +- diracx-client/src/diracx/client/_generated/operations/_patch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py index aa1cbd79d..0c70ce3e9 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py @@ -11,7 +11,7 @@ __all__ = [ "AuthOperations", "JobsOperations", - "PilotsOperations" + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ....patches.auth.aio import AuthOperations diff --git a/diracx-client/src/diracx/client/_generated/operations/_patch.py b/diracx-client/src/diracx/client/_generated/operations/_patch.py index 9f1e07b4e..b14e98b84 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/operations/_patch.py @@ -11,7 +11,7 @@ __all__ = [ "AuthOperations", "JobsOperations", - "PilotsOperations" + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ...patches.auth.sync import AuthOperations From 3e358bba576443af870db3cc492a9859ef9bf6bc Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Fri, 1 Aug 2025 10:37:22 +0200 Subject: [PATCH 28/33] fix: Add patch for the pilot client --- .../src/diracx/client/patches/pilots/aio.py | 21 ++++++- .../diracx/client/patches/pilots/common.py | 62 ++++++++++++++++++- .../src/diracx/client/patches/pilots/sync.py | 21 ++++++- 3 files changed, 101 insertions(+), 3 deletions(-) diff --git a/diracx-client/src/diracx/client/patches/pilots/aio.py b/diracx-client/src/diracx/client/patches/pilots/aio.py index 496aaeeda..ac533a67c 100644 --- a/diracx-client/src/diracx/client/patches/pilots/aio.py +++ b/diracx-client/src/diracx/client/patches/pilots/aio.py @@ -16,7 +16,16 @@ from azure.core.tracing.decorator_async import distributed_trace_async from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations -from .common import make_search_body, make_summary_body, SearchKwargs, SummaryKwargs +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) # We're intentionally ignoring overrides here because we want to change the interface. # mypy: disable-error-code=override @@ -32,3 +41,13 @@ async def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: """TODO""" return await super().summary(**make_summary_body(**kwargs)) + + @distributed_trace_async + async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return await super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace_async + async def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return await super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py index b332f8382..f7f22105d 100644 --- a/diracx-client/src/diracx/client/patches/pilots/common.py +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -7,13 +7,17 @@ "SearchKwargs", "make_summary_body", "SummaryKwargs", + "AddPilotStampsKwargs", + "make_add_pilot_stamps_body", + "UpdatePilotFieldsKwargs", + "make_update_pilot_fields_body" ] import json from io import BytesIO from typing import Any, IO, TypedDict, Unpack, cast, Literal -from diracx.core.models import SearchSpec +from diracx.core.models import SearchSpec, PilotStatus, PilotFieldsMapping class ResponseExtra(TypedDict, total=False): @@ -23,6 +27,7 @@ class ResponseExtra(TypedDict, total=False): cls: Any +# ------------------ Search ------------------ class SearchBody(TypedDict, total=False): parameters: list[str] | None search: list[SearchSpec] | None @@ -56,6 +61,7 @@ def make_search_body(**kwargs: Unpack[SearchKwargs]) -> UnderlyingSearchArgs: result.update(cast(SearchExtra, kwargs)) return result +# ------------------ Summary ------------------ class SummaryBody(TypedDict, total=False): grouping: list[str] @@ -83,3 +89,57 @@ def make_summary_body(**kwargs: Unpack[SummaryKwargs]) -> UnderlyingSummaryArgs: result: UnderlyingSummaryArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} result.update(cast(ResponseExtra, kwargs)) return result + +# ------------------ AddPilotStamps ------------------ + +class AddPilotStampsBody(TypedDict, total=False): + pilot_stamps: list[str] + grid_type: str + grid_site: str + pilot_references: dict[str, str] + pilot_status: PilotStatus + +class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ... + +class UnderlyingAddPilotStampsArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> UnderlyingAddPilotStampsArgs: + body: AddPilotStampsBody = {} + for key in AddPilotStampsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingAddPilotStampsArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result + +# ------------------ UpdatePilotFields ------------------ + +class UpdatePilotFieldsBody(TypedDict, total=False): + pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + +class UpdatePilotFieldsKwargs(UpdatePilotFieldsBody, ResponseExtra): ... + +class UnderlyingUpdatePilotFields(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_update_pilot_fields_body(**kwargs: Unpack[UpdatePilotFieldsKwargs]) -> UnderlyingUpdatePilotFields: + body: UpdatePilotFieldsBody = {} + for key in UpdatePilotFieldsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps_to_fields_mapping"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingUpdatePilotFields = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result diff --git a/diracx-client/src/diracx/client/patches/pilots/sync.py b/diracx-client/src/diracx/client/patches/pilots/sync.py index 6e404aec6..744cee161 100644 --- a/diracx-client/src/diracx/client/patches/pilots/sync.py +++ b/diracx-client/src/diracx/client/patches/pilots/sync.py @@ -16,7 +16,16 @@ from azure.core.tracing.decorator import distributed_trace from ..._generated.operations._operations import PilotsOperations as _PilotsOperations -from .common import make_search_body, make_summary_body, SearchKwargs, SummaryKwargs +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) # We're intentionally ignoring overrides here because we want to change the interface. # mypy: disable-error-code=override @@ -32,3 +41,13 @@ def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: """TODO""" return super().summary(**make_summary_body(**kwargs)) + + @distributed_trace + def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace + def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) From cae74110c1eb6d41c2a71fcfc2794b0da2e32ef3 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Fri, 1 Aug 2025 11:29:42 +0200 Subject: [PATCH 29/33] fix: Use also VO as parameter (in case DIRAC has a bad VO) --- .../src/diracx/client/_generated/models/_models.py | 8 ++++++++ .../src/diracx/client/patches/pilots/common.py | 3 ++- .../src/diracx/routers/pilots/management.py | 3 ++- diracx-routers/tests/pilots/test_pilot_creation.py | 11 ++++++----- .../src/gubbins/client/_generated/models/_models.py | 8 ++++++++ 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index b101fdb10..5a1ca98a9 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -101,6 +101,8 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str :ivar grid_type: Grid type of the pilots. :vartype grid_type: str :ivar grid_site: Pilots grid site. @@ -116,10 +118,12 @@ class BodyPilotsAddPilotStamps(_serialization.Model): _validation = { "pilot_stamps": {"required": True}, + "vo": {"required": True}, } _attribute_map = { "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, "grid_type": {"key": "grid_type", "type": "str"}, "grid_site": {"key": "grid_site", "type": "str"}, "destination_site": {"key": "destination_site", "type": "str"}, @@ -131,6 +135,7 @@ def __init__( self, *, pilot_stamps: List[str], + vo: str, grid_type: str = "Dirac", grid_site: str = "Unknown", destination_site: str = "NotAssigned", @@ -141,6 +146,8 @@ def __init__( """ :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str :keyword grid_type: Grid type of the pilots. :paramtype grid_type: str :keyword grid_site: Pilots grid site. @@ -155,6 +162,7 @@ def __init__( """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps + self.vo = vo self.grid_type = grid_type self.grid_site = grid_site self.destination_site = destination_site diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py index f7f22105d..3f5ec8c4b 100644 --- a/diracx-client/src/diracx/client/patches/pilots/common.py +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -98,6 +98,7 @@ class AddPilotStampsBody(TypedDict, total=False): grid_site: str pilot_references: dict[str, str] pilot_status: PilotStatus + vo: str class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ... @@ -111,7 +112,7 @@ def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> Underl for key in AddPilotStampsBody.__optional_keys__: if key not in kwargs: continue - key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status"], key) + key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo"], key) value = kwargs.pop(key) if value is not None: body[key] = value diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 8d6bb2b5e..21ff63796 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -41,6 +41,7 @@ async def add_pilot_stamps( list[str], Body(description="List of the pilot stamps we want to add to the db."), ], + vo: Annotated[str, Body(description="Pilot virtual organization.")], check_permissions: CheckPilotManagementPolicyCallable, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", @@ -79,7 +80,7 @@ async def add_pilot_stamps( await register_new_pilots( pilot_db=pilot_db, pilot_stamps=pilot_stamps, - vo=user_info.vo, + vo=vo, grid_type=grid_type, grid_site=grid_site, destination_site=destination_site, diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index 11293ff12..2171bbf9d 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -41,7 +41,7 @@ async def test_create_pilots(normal_test_client): pilot_stamps = [f"stamps_{i}" for i in range(N)] # -------------- Bulk insert -------------- - body = {"pilot_stamps": pilot_stamps} + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} r = normal_test_client.post( "/api/pilots/", @@ -54,6 +54,7 @@ async def test_create_pilots(normal_test_client): body = { "pilot_stamps": [pilot_stamps[0], pilot_stamps[0] + "_new_one"], + "vo": MAIN_VO, } r = normal_test_client.post( @@ -73,7 +74,7 @@ async def test_create_pilots(normal_test_client): # -------------- Register a pilot that does not exists **but** was called before in an error -------------- # To prove that, if I tried to register a pilot that does not exist with one that already exists, # i can normally add the one that did not exist before (it should not have added it before) - body = {"pilot_stamps": [pilot_stamps[0] + "_new_one"]} + body = {"pilot_stamps": [pilot_stamps[0] + "_new_one"], "vo": MAIN_VO} r = normal_test_client.post( "/api/pilots/", @@ -90,7 +91,7 @@ async def test_create_pilot_and_delete_it(normal_test_client): pilot_stamp = "stamps_1" # -------------- Insert -------------- - body = {"pilot_stamps": [pilot_stamp]} + body = {"pilot_stamps": [pilot_stamp], "vo": MAIN_VO} # Create a pilot r = normal_test_client.post( @@ -134,7 +135,7 @@ async def test_create_pilot_and_modify_it(normal_test_client): pilot_stamps = ["stamps_1", "stamp_2"] # -------------- Insert -------------- - body = {"pilot_stamps": pilot_stamps} + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} # Create pilots r = normal_test_client.post( @@ -191,7 +192,7 @@ async def test_delete_pilots_by_age_and_stamp(normal_test_client): pilot_stamps = [f"stamp_{i}" for i in range(100)] # -------------- Insert all pilots -------------- - body = {"pilot_stamps": pilot_stamps} + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} r = normal_test_client.post("/api/pilots/", json=body) assert r.status_code == 200, r.json() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 5d5e4ad8c..751efcfb9 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -101,6 +101,8 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str :ivar grid_type: Grid type of the pilots. :vartype grid_type: str :ivar grid_site: Pilots grid site. @@ -116,10 +118,12 @@ class BodyPilotsAddPilotStamps(_serialization.Model): _validation = { "pilot_stamps": {"required": True}, + "vo": {"required": True}, } _attribute_map = { "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, "grid_type": {"key": "grid_type", "type": "str"}, "grid_site": {"key": "grid_site", "type": "str"}, "destination_site": {"key": "destination_site", "type": "str"}, @@ -131,6 +135,7 @@ def __init__( self, *, pilot_stamps: List[str], + vo: str, grid_type: str = "Dirac", grid_site: str = "Unknown", destination_site: str = "NotAssigned", @@ -141,6 +146,8 @@ def __init__( """ :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str :keyword grid_type: Grid type of the pilots. :paramtype grid_type: str :keyword grid_site: Pilots grid site. @@ -155,6 +162,7 @@ def __init__( """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps + self.vo = vo self.grid_type = grid_type self.grid_site = grid_site self.destination_site = destination_site From 8a978faf1e78f5e0a5aae6305f7fcb6633235138 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Fri, 1 Aug 2025 10:45:03 +0200 Subject: [PATCH 30/33] fix: Use also VO as parameter (in case DIRAC has a bad VO) --- diracx-routers/src/diracx/routers/pilots/management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 21ff63796..8bb9ea514 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -56,6 +56,7 @@ async def add_pilot_stamps( pilot_status: Annotated[ PilotStatus, Body(description="Status of the pilots.") ] = PilotStatus.SUBMITTED, + vo: str | None = None, ): """Endpoint where a you can create pilots with their references. From 82c552cdc9e9322657875073313b87fa6514f4d6 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Fri, 13 Jun 2025 13:41:30 +0200 Subject: [PATCH 31/33] feat: Add pilot management: create/delete/patch and query --- .../src/diracx/db/sql/utils/functions.py | 90 ++++++++++++++++++- 1 file changed, 87 insertions(+), 3 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/utils/functions.py b/diracx-db/src/diracx/db/sql/utils/functions.py index 34cb2a0da..536412406 100644 --- a/diracx-db/src/diracx/db/sql/utils/functions.py +++ b/diracx-db/src/diracx/db/sql/utils/functions.py @@ -2,16 +2,30 @@ import hashlib from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Sequence, Type -from sqlalchemy import DateTime, func +from sqlalchemy import DateTime, RowMapping, asc, desc, func, select +from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql import expression +from sqlalchemy.sql import ColumnElement, expression + +from diracx.core.exceptions import DiracFormattedError, InvalidQueryError if TYPE_CHECKING: from sqlalchemy.types import TypeEngine +def _get_columns(table, parameters): + columns = [x for x in table.columns] + if parameters: + if unrecognised_parameters := set(parameters) - set(table.columns.keys()): + raise InvalidQueryError( + f"Unrecognised parameters requested {unrecognised_parameters}" + ) + columns = [c for c in columns if c.name in parameters] + return columns + + class utcnow(expression.FunctionElement): # noqa: N801 type: TypeEngine = DateTime() inherit_cache: bool = True @@ -140,3 +154,73 @@ def substract_date(**kwargs: float) -> datetime: def hash(code: str): return hashlib.sha256(code.encode()).hexdigest() + + +def raw_hash(code: str): + return hashlib.sha256(code.encode()).digest() + + +async def fetch_records_bulk_or_raises( + conn: AsyncConnection, + model: Any, # Here, we currently must use `Any` because `declarative_base()` returns any + missing_elements_error_cls: Type[DiracFormattedError], + column_attribute_name: str, + column_name: str, + elements_to_fetch: list, + order_by: tuple[str, str] | None = None, + allow_more_than_one_result_per_input: bool = False, + allow_no_result: bool = False, +) -> Sequence[RowMapping]: + """Fetches a list of elements in a table, returns a list of elements. + All elements from the `element_to_fetch` **must** be present. + Raises the specified error if at least one is missing. + + Example: + fetch_records_bulk_or_raises( + self.conn, + PilotAgents, + PilotNotFound, + "pilot_id", + "PilotID", + [1,2,3] + ) + + """ + assert elements_to_fetch + + # Get the column that needs to be in elements_to_fetch + column = getattr(model, column_attribute_name) + + # Create the request + stmt = select(model).with_for_update().where(column.in_(elements_to_fetch)) + + if order_by: + column_name_to_order_by, direction = order_by + column_to_order_by = getattr(model, column_name_to_order_by) + + operator: ColumnElement = ( + asc(column_to_order_by) if direction == "asc" else desc(column_to_order_by) + ) + + stmt = stmt.order_by(operator) + + # Transform into dictionaries + raw_results = await conn.execute(stmt) + results = raw_results.mappings().all() + + # Detects duplicates + if not allow_more_than_one_result_per_input: + if len(results) > len(elements_to_fetch): + raise RuntimeError("Seems to have duplicates in the database.") + + if not allow_no_result: + # Checks if we have every elements we wanted + found_keys = {row[column_name] for row in results} + missing = set(elements_to_fetch) - found_keys + + if missing: + raise missing_elements_error_cls( + data={column_name: str(missing)}, detail=str(missing) + ) + + return results From 619e3368dc49c3e80afdbe13355be2a151c7b8af Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Wed, 30 Jul 2025 14:50:39 +0200 Subject: [PATCH 32/33] feat: POC to have legacy pilots sending logs --- .../src/diracx/client/_generated/_client.py | 12 +- .../diracx/client/_generated/aio/_client.py | 12 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 250 +++++++++++++++ .../client/_generated/models/__init__.py | 4 + .../client/_generated/models/_models.py | 82 +++++ .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 288 ++++++++++++++++++ diracx-core/src/diracx/core/models.py | 7 + diracx-db/pyproject.toml | 1 + diracx-db/src/diracx/db/os/__init__.py | 3 +- diracx-db/src/diracx/db/os/pilot_logs.py | 22 ++ diracx-db/src/diracx/db/os/utils.py | 41 ++- diracx-db/tests/opensearch/test_search.py | 168 ++++++---- diracx-logic/src/diracx/logic/pilots/query.py | 39 +++ .../src/diracx/logic/pilots/resources.py | 45 +++ diracx-routers/pyproject.toml | 2 + .../src/diracx/routers/dependencies.py | 2 + .../legacy_pilot_resources/__init__.py | 11 + .../legacy_pilot_resources/access_policies.py | 39 +++ .../routers/legacy_pilot_resources/logs.py | 51 ++++ .../diracx/routers/pilots/access_policies.py | 18 +- .../src/diracx/routers/pilots/query.py | 123 +++++++- .../tests/pilots/test_pilot_logging.py | 233 ++++++++++++++ .../src/diracx/testing/mock_osdb.py | 34 ++- .../src/gubbins/client/_generated/_client.py | 6 +- .../gubbins/client/_generated/aio/_client.py | 6 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 250 +++++++++++++++ .../client/_generated/models/__init__.py | 4 + .../client/_generated/models/_models.py | 82 +++++ .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 288 ++++++++++++++++++ 33 files changed, 2055 insertions(+), 76 deletions(-) create mode 100644 diracx-db/src/diracx/db/os/pilot_logs.py create mode 100644 diracx-logic/src/diracx/logic/pilots/resources.py create mode 100644 diracx-routers/src/diracx/routers/legacy_pilot_resources/__init__.py create mode 100644 diracx-routers/src/diracx/routers/legacy_pilot_resources/access_policies.py create mode 100644 diracx-routers/src/diracx/routers/legacy_pilot_resources/logs.py create mode 100644 diracx-routers/tests/pilots/test_pilot_logging.py diff --git a/diracx-client/src/diracx/client/_generated/_client.py b/diracx-client/src/diracx/client/_generated/_client.py index 9e37d5081..80e44eb8b 100644 --- a/diracx-client/src/diracx/client/_generated/_client.py +++ b/diracx-client/src/diracx/client/_generated/_client.py @@ -15,7 +15,14 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + PilotsLegacyOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.operations.JobsOperations :ivar pilots: PilotsOperations operations :vartype pilots: _generated.operations.PilotsOperations + :ivar pilots_legacy: PilotsLegacyOperations operations + :vartype pilots_legacy: _generated.operations.PilotsLegacyOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_legacy = PilotsLegacyOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/diracx-client/src/diracx/client/_generated/aio/_client.py b/diracx-client/src/diracx/client/_generated/aio/_client.py index 397b7f989..5083a584f 100644 --- a/diracx-client/src/diracx/client/_generated/aio/_client.py +++ b/diracx-client/src/diracx/client/_generated/aio/_client.py @@ -15,7 +15,14 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + PilotsLegacyOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.aio.operations.JobsOperations :ivar pilots: PilotsOperations operations :vartype pilots: _generated.aio.operations.PilotsOperations + :ivar pilots_legacy: PilotsLegacyOperations operations + :vartype pilots_legacy: _generated.aio.operations.PilotsLegacyOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_legacy = PilotsLegacyOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py index be02776fc..53c8a8f82 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsLegacyOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "PilotsOperations", + "PilotsLegacyOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index e1aad0601..df72fed83 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -55,6 +55,8 @@ build_pilots_add_pilot_stamps_request, build_pilots_delete_pilots_request, build_pilots_get_pilot_jobs_request, + build_pilots_legacy_send_message_request, + build_pilots_search_logs_request, build_pilots_search_request, build_pilots_summary_request, build_pilots_update_pilot_fields_request, @@ -2649,6 +2651,143 @@ async def search( return deserialized # type: ignore + @overload + async def search_logs( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search_logs( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search_logs( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_logs_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + @overload async def summary( self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any @@ -2743,3 +2882,114 @@ async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsLegacyOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots_legacy` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def send_message( + self, body: _models.BodyPilotsLegacySendMessage, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def send_message(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def send_message(self, body: Union[_models.BodyPilotsLegacySendMessage, IO[bytes]], **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Is either a BodyPilotsLegacySendMessage type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsLegacySendMessage") + + _request = build_pilots_legacy_send_message_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 33716f267..ce33e799c 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -15,6 +15,7 @@ BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, BodyPilotsAddPilotStamps, + BodyPilotsLegacySendMessage, BodyPilotsUpdatePilotFields, GroupInfo, HTTPValidationError, @@ -23,6 +24,7 @@ InsertedJob, JobCommand, JobStatusUpdate, + LogLine, Metadata, OpenIDConfiguration, PilotFieldsMapping, @@ -66,6 +68,7 @@ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", "BodyPilotsAddPilotStamps", + "BodyPilotsLegacySendMessage", "BodyPilotsUpdatePilotFields", "GroupInfo", "HTTPValidationError", @@ -74,6 +77,7 @@ "InsertedJob", "JobCommand", "JobStatusUpdate", + "LogLine", "Metadata", "OpenIDConfiguration", "PilotFieldsMapping", diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 5a1ca98a9..187d8d87b 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -170,6 +170,41 @@ def __init__( self.pilot_status = pilot_status +class BodyPilotsLegacySendMessage(_serialization.Model): + """Body_pilots/legacy_send_message. + + All required parameters must be populated in order to send to server. + + :ivar lines: Message from the pilot to the logging system. Required. + :vartype lines: list[~_generated.models.LogLine] + :ivar pilot_stamp: PilotStamp, required as legacy pilots do not have a token with stamp in it. + Required. + :vartype pilot_stamp: str + """ + + _validation = { + "lines": {"required": True}, + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "lines": {"key": "lines", "type": "[LogLine]"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + } + + def __init__(self, *, lines: List["_models.LogLine"], pilot_stamp: str, **kwargs: Any) -> None: + """ + :keyword lines: Message from the pilot to the logging system. Required. + :paramtype lines: list[~_generated.models.LogLine] + :keyword pilot_stamp: PilotStamp, required as legacy pilots do not have a token with stamp in + it. Required. + :paramtype pilot_stamp: str + """ + super().__init__(**kwargs) + self.lines = lines + self.pilot_stamp = pilot_stamp + + class BodyPilotsUpdatePilotFields(_serialization.Model): """Body_pilots_update_pilot_fields. @@ -511,6 +546,53 @@ def __init__( self.source = source +class LogLine(_serialization.Model): + """LogLine. + + All required parameters must be populated in order to send to server. + + :ivar timestamp: Timestamp. Required. + :vartype timestamp: str + :ivar severity: Severity. Required. + :vartype severity: str + :ivar message: Message. Required. + :vartype message: str + :ivar scope: Scope. Required. + :vartype scope: str + """ + + _validation = { + "timestamp": {"required": True}, + "severity": {"required": True}, + "message": {"required": True}, + "scope": {"required": True}, + } + + _attribute_map = { + "timestamp": {"key": "timestamp", "type": "str"}, + "severity": {"key": "severity", "type": "str"}, + "message": {"key": "message", "type": "str"}, + "scope": {"key": "scope", "type": "str"}, + } + + def __init__(self, *, timestamp: str, severity: str, message: str, scope: str, **kwargs: Any) -> None: + """ + :keyword timestamp: Timestamp. Required. + :paramtype timestamp: str + :keyword severity: Severity. Required. + :paramtype severity: str + :keyword message: Message. Required. + :paramtype message: str + :keyword scope: Scope. Required. + :paramtype scope: str + """ + super().__init__(**kwargs) + self.timestamp = timestamp + self.severity = severity + self.message = message + self.scope = scope + + class Metadata(_serialization.Model): """Metadata. diff --git a/diracx-client/src/diracx/client/_generated/operations/__init__.py b/diracx-client/src/diracx/client/_generated/operations/__init__.py index be02776fc..53c8a8f82 100644 --- a/diracx-client/src/diracx/client/_generated/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsLegacyOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "PilotsOperations", + "PilotsLegacyOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 3673bdf32..b6789fb77 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -691,6 +691,30 @@ def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) +def build_pilots_search_logs_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search/logs" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) @@ -708,6 +732,20 @@ def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_pilots_legacy_send_message_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/legacy/message" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -3282,6 +3320,143 @@ def search( return deserialized # type: ignore + @overload + def search_logs( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search_logs( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search_logs( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_logs_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + @overload def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. @@ -3374,3 +3549,116 @@ def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsLegacyOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots_legacy` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def send_message( + self, body: _models.BodyPilotsLegacySendMessage, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def send_message(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def send_message( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsLegacySendMessage, IO[bytes]], **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Is either a BodyPilotsLegacySendMessage type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsLegacySendMessage") + + _request = build_pilots_legacy_send_message_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index 773ebac11..fc32d30c6 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -306,3 +306,10 @@ class PilotStatus(StrEnum): ABORTED = "Aborted" #: Cannot get information about the pilot status: UNKNOWN = "Unknown" + + +class LogLine(BaseModel): + timestamp: str + severity: str + message: str + scope: str diff --git a/diracx-db/pyproject.toml b/diracx-db/pyproject.toml index 8a5e87d8a..2ebc5cca1 100644 --- a/diracx-db/pyproject.toml +++ b/diracx-db/pyproject.toml @@ -34,6 +34,7 @@ TaskQueueDB = "diracx.db.sql:TaskQueueDB" [project.entry-points."diracx.dbs.os"] JobParametersDB = "diracx.db.os:JobParametersDB" +PilotLogsDB = "diracx.db.os:PilotLogsDB" [build-system] requires = ["hatchling", "hatch-vcs"] diff --git a/diracx-db/src/diracx/db/os/__init__.py b/diracx-db/src/diracx/db/os/__init__.py index 535e2a954..d8a450754 100644 --- a/diracx-db/src/diracx/db/os/__init__.py +++ b/diracx-db/src/diracx/db/os/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations -__all__ = ("JobParametersDB",) +__all__ = ("JobParametersDB", "PilotLogsDB") from .job_parameters import JobParametersDB +from .pilot_logs import PilotLogsDB diff --git a/diracx-db/src/diracx/db/os/pilot_logs.py b/diracx-db/src/diracx/db/os/pilot_logs.py new file mode 100644 index 000000000..614c3cb50 --- /dev/null +++ b/diracx-db/src/diracx/db/os/pilot_logs.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from diracx.db.os.utils import BaseOSDB + + +class PilotLogsDB(BaseOSDB): + fields = { + "PilotStamp": {"type": "keyword"}, + "PilotID": {"type": "long"}, + "Severity": {"type": "keyword"}, + "Message": {"type": "text"}, + "VO": {"type": "keyword"}, + "TimeStamp": {"type": "date_nanos"}, + "Scope": {"type": "keyword"}, + } + index_prefix = "pilot_logs" + + def index_name(self, vo: str, doc_id: int) -> str: + split = int(int(doc_id) // 1e6) + # We split docs into smaller one (grouped by 1 million pilot) + # Ex: pilot_logs_dteam_1030m + return f"{self.index_prefix}_{vo.lower()}_{split}m" diff --git a/diracx-db/src/diracx/db/os/utils.py b/diracx-db/src/diracx/db/os/utils.py index ea5d292e6..7beb0f104 100644 --- a/diracx-db/src/diracx/db/os/utils.py +++ b/diracx-db/src/diracx/db/os/utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +from opensearchpy.helpers import async_bulk + __all__ = ("BaseOSDB",) import contextlib @@ -197,13 +199,35 @@ async def upsert(self, vo: str, doc_id: int, document: Any) -> None: response, ) + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + """Bulk inserting to database.""" + n_inserted, failed = await async_bulk( + self.client, actions=[doc | {"_index": index_name} for doc in docs] + ) + logger.info("Inserted %d documents to %s", n_inserted, index_name) + + if failed: + logger.error("Fail to insert %d documents to %s", failed, index_name) + async def search( - self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None - ) -> list[dict[str, Any]]: + self, + parameters, + search, + sorts, + *, + per_page: int = 10000, + page: int | None = None, + ) -> tuple[int, list[dict[str, Any]]]: """Search the database for matching results. See the DiracX search API documentation for details. """ + if page: + if page < 1: + raise InvalidQueryError("Page must be a positive integer") + if per_page < 1: + raise InvalidQueryError("Per page must be a positive integer") + body = {} if parameters: body["_source"] = parameters @@ -213,7 +237,12 @@ async def search( for sort in sorts: field_name = sort["parameter"] field_type = self.fields.get(field_name, {}).get("type") - require_type("sort", field_name, field_type, {"keyword", "long", "date"}) + require_type( + "sort", + field_name, + field_type, + {"keyword", "long", "date", "date_nanos"}, + ) body["sort"].append({field_name: {"order": sort["direction"]}}) params = {} @@ -226,17 +255,19 @@ async def search( ) hits = [hit["_source"] for hit in response["hits"]["hits"]] + total_hits = response["hits"]["total"]["value"] + # Dates are returned as strings, convert them to Python datetimes for hit in hits: for field_name in hit: if field_name not in self.fields: continue - if self.fields[field_name]["type"] == "date": + if self.fields[field_name]["type"] in ["date", "date_nanos"]: hit[field_name] = datetime.strptime( hit[field_name], "%Y-%m-%dT%H:%M:%S.%f%z" ) - return hits + return total_hits, hits def require_type(operator, field_name, field_type, allowed_types): diff --git a/diracx-db/tests/opensearch/test_search.py b/diracx-db/tests/opensearch/test_search.py index 93998ac3e..8013edd9a 100644 --- a/diracx-db/tests/opensearch/test_search.py +++ b/diracx-db/tests/opensearch/test_search.py @@ -120,15 +120,15 @@ async def prefilled_db(request): async def test_specified_parameters(prefilled_db: DummyOSDB): - results = await prefilled_db.search(None, [], []) - assert len(results) == 3 + total, results = await prefilled_db.search(None, [], []) + assert total == 3 assert DOC1 in results and DOC2 in results and DOC3 in results - results = await prefilled_db.search([], [], []) - assert len(results) == 3 + total, results = await prefilled_db.search([], [], []) + assert total == 3 assert DOC1 in results and DOC2 in results and DOC3 in results - results = await prefilled_db.search(["IntField"], [], []) + total, results = await prefilled_db.search(["IntField"], [], []) expected_results = [] for doc in [DOC1, DOC2, DOC3]: expected_doc = {key: doc[key] for key in {"IntField"}} @@ -136,58 +136,67 @@ async def test_specified_parameters(prefilled_db: DummyOSDB): # If it is the all() check below no longer makes sense assert expected_doc not in expected_results expected_results.append(expected_doc) - assert len(results) == len(expected_results) + assert total == len(expected_results) assert all(result in expected_results for result in results) - results = await prefilled_db.search(["IntField", "UnknownField"], [], []) + total, results = await prefilled_db.search(["IntField", "UnknownField"], [], []) expected_results = [ {"IntField": DOC1["IntField"], "UnknownField": DOC1["UnknownField"]}, {"IntField": DOC2["IntField"], "UnknownField": DOC2["UnknownField"]}, {"IntField": DOC3["IntField"]}, ] - assert len(results) == len(expected_results) + assert total == len(expected_results) assert all(result in expected_results for result in results) async def test_pagination_asc(prefilled_db: DummyOSDB): sort = [{"parameter": "IntField", "direction": "asc"}] - results = await prefilled_db.search(None, [], sort) + total, results = await prefilled_db.search(None, [], sort) assert results == [DOC3, DOC2, DOC1] + assert total == 3 # Pagination has no effect if a specific page isn't requested - results = await prefilled_db.search(None, [], sort, per_page=2) + total, results = await prefilled_db.search(None, [], sort, per_page=2) assert results == [DOC3, DOC2, DOC1] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=2, page=1) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=1) assert results == [DOC3, DOC2] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=2, page=2) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=2) assert results == [DOC1] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=2, page=3) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=3) assert results == [] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=1, page=1) + total, results = await prefilled_db.search(None, [], sort, per_page=1, page=1) assert results == [DOC3] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=1, page=2) + total, results = await prefilled_db.search(None, [], sort, per_page=1, page=2) assert results == [DOC2] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=1, page=3) + total, results = await prefilled_db.search(None, [], sort, per_page=1, page=3) assert results == [DOC1] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=1, page=4) + total, results = await prefilled_db.search(None, [], sort, per_page=1, page=4) assert results == [] + assert total == 3 async def test_pagination_desc(prefilled_db: DummyOSDB): sort = [{"parameter": "IntField", "direction": "desc"}] - results = await prefilled_db.search(None, [], sort, per_page=2, page=1) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=1) assert results == [DOC1, DOC2] - results = await prefilled_db.search(None, [], sort, per_page=2, page=2) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=2) assert results == [DOC3] @@ -195,22 +204,26 @@ async def test_eq_filter_long(prefilled_db: DummyOSDB): part = {"parameter": "IntField", "operator": "eq"} # Search for an ID which doesn't exist - results = await prefilled_db.search(None, [part | {"value": "78"}], []) + total, results = await prefilled_db.search(None, [part | {"value": "78"}], []) assert results == [] + assert total == 0 # Check the DB contains what we expect when not filtering - results = await prefilled_db.search(None, [], []) - assert len(results) == 3 + total, results = await prefilled_db.search(None, [], []) + assert total == 3 assert DOC1 in results assert DOC2 in results assert DOC3 in results # Search separately for the two documents which do exist - results = await prefilled_db.search(None, [part | {"value": "1234"}], []) + total, results = await prefilled_db.search(None, [part | {"value": "1234"}], []) assert results == [DOC1] - results = await prefilled_db.search(None, [part | {"value": "679"}], []) + assert total == 1 + total, results = await prefilled_db.search(None, [part | {"value": "679"}], []) assert results == [DOC2] - results = await prefilled_db.search(None, [part | {"value": "42"}], []) + assert total == 1 + total, results = await prefilled_db.search(None, [part | {"value": "42"}], []) + assert total == 1 assert results == [DOC3] @@ -218,80 +231,97 @@ async def test_operators_long(prefilled_db: DummyOSDB): part = {"parameter": "IntField"} query = part | {"operator": "neq", "value": "1234"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "in", "values": ["1234", "42"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "not in", "values": ["1234", "42"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"]} + assert total == 1 query = part | {"operator": "lt", "value": "1234"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "lt", "value": "679"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC3["IntField"]} + assert total == 1 query = part | {"operator": "gt", "value": "1234"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 query = part | {"operator": "lt", "value": "42"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 async def test_operators_date(prefilled_db: DummyOSDB): part = {"parameter": "DateField"} query = part | {"operator": "eq", "value": DOC3["DateField"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC3["IntField"]} + assert total == 1 query = part | {"operator": "neq", "value": DOC2["DateField"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + assert total == 2 doc1_time = DOC1["DateField"].strftime("%Y-%m-%dT%H:%M") doc2_time = DOC2["DateField"].strftime("%Y-%m-%dT%H:%M") doc3_time = DOC3["DateField"].strftime("%Y-%m-%dT%H:%M") query = part | {"operator": "in", "values": [doc1_time, doc2_time]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC2["IntField"]} + assert total == 2 query = part | {"operator": "not in", "values": [doc1_time, doc2_time]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC3["IntField"]} + assert total == 1 query = part | {"operator": "lt", "value": doc1_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "lt", "value": doc3_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"]} + assert total == 1 query = part | {"operator": "lt", "value": doc2_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 query = part | {"operator": "gt", "value": doc1_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 query = part | {"operator": "gt", "value": doc3_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"]} + assert total == 1 query = part | {"operator": "gt", "value": doc2_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + assert total == 2 @pytest.mark.parametrize( @@ -312,24 +342,28 @@ async def test_operators_date_partial_doc1(prefilled_db: DummyOSDB, date_format: formatted_date = DOC1["DateField"].strftime(date_format) query = {"parameter": "DateField", "operator": "eq", "value": formatted_date} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"]} + assert total == 1 query = {"parameter": "DateField", "operator": "neq", "value": formatted_date} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + assert total == 2 async def test_operators_keyword(prefilled_db: DummyOSDB): part = {"parameter": "KeywordField1"} query = part | {"operator": "eq", "value": DOC1["KeywordField1"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC2["IntField"]} + assert total == 2 query = part | {"operator": "neq", "value": DOC1["KeywordField1"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC3["IntField"]} + assert total == 1 part = {"parameter": "KeywordField0"} @@ -337,27 +371,31 @@ async def test_operators_keyword(prefilled_db: DummyOSDB): "operator": "in", "values": [DOC1["KeywordField0"], DOC3["KeywordField0"]], } - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "in", "values": ["missing"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 query = part | { "operator": "not in", "values": [DOC1["KeywordField0"], DOC3["KeywordField0"]], } - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"]} + assert total == 1 query = part | {"operator": "not in", "values": ["missing"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == { DOC1["IntField"], DOC2["IntField"], DOC3["IntField"], } + assert total == 3 # The MockOSDBMixin doesn't validate if types are indexed correctly if not isinstance(prefilled_db, MockOSDBMixin): @@ -387,36 +425,42 @@ async def test_unindexed_field(prefilled_db: DummyOSDB): async def test_sort_long(prefilled_db: DummyOSDB): - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [{"parameter": "IntField", "direction": "asc"}] ) assert results == [DOC3, DOC2, DOC1] - results = await prefilled_db.search( + assert total == 3 + total, results = await prefilled_db.search( None, [], [{"parameter": "IntField", "direction": "desc"}] ) assert results == [DOC1, DOC2, DOC3] + assert total == 3 async def test_sort_date(prefilled_db: DummyOSDB): - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [{"parameter": "DateField", "direction": "asc"}] ) assert results == [DOC2, DOC3, DOC1] - results = await prefilled_db.search( + assert total == 3 + total, results = await prefilled_db.search( None, [], [{"parameter": "DateField", "direction": "desc"}] ) assert results == [DOC1, DOC3, DOC2] + assert total == 3 async def test_sort_keyword(prefilled_db: DummyOSDB): - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [{"parameter": "KeywordField0", "direction": "asc"}] ) assert results == [DOC1, DOC3, DOC2] - results = await prefilled_db.search( + assert total == 3 + total, results = await prefilled_db.search( None, [], [{"parameter": "KeywordField0", "direction": "desc"}] ) assert results == [DOC2, DOC3, DOC1] + assert total == 3 async def test_sort_text(prefilled_db: DummyOSDB): @@ -436,7 +480,7 @@ async def test_sort_unknown(prefilled_db: DummyOSDB): async def test_sort_multiple(prefilled_db: DummyOSDB): - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [ @@ -445,8 +489,9 @@ async def test_sort_multiple(prefilled_db: DummyOSDB): ], ) assert results == [DOC2, DOC1, DOC3] + assert total == 3 - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [ @@ -455,8 +500,9 @@ async def test_sort_multiple(prefilled_db: DummyOSDB): ], ) assert results == [DOC1, DOC2, DOC3] + assert total == 3 - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [ @@ -465,8 +511,9 @@ async def test_sort_multiple(prefilled_db: DummyOSDB): ], ) assert results == [DOC3, DOC2, DOC1] + assert total == 3 - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [ @@ -475,3 +522,4 @@ async def test_sort_multiple(prefilled_db: DummyOSDB): ], ) assert results == [DOC3, DOC2, DOC1] + assert total == 3 diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index 4dcdc9861..6dd46abaa 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -10,10 +10,12 @@ ScalarSearchSpec, SearchParams, SearchSpec, + SortDirection, SummaryParams, VectorSearchOperator, VectorSearchSpec, ) +from diracx.db.os.pilot_logs import PilotLogsDB from diracx.db.sql import PilotAgentsDB MAX_PER_PAGE = 10000 @@ -191,3 +193,40 @@ async def summary(pilot_db: PilotAgentsDB, body: SummaryParams, vo: str): } ) return await pilot_db.pilot_summary(body.grouping, body.search) + + +async def search_logs( + vo: str, + body: SearchParams | None, + per_page: int, + page: int, + pilot_logs_db: PilotLogsDB, +) -> tuple[int, list[dict]]: + """Retrieve logs from OpenSearch for a given PilotStamp.""" + # Apply a limit to per_page to prevent abuse of the API + if per_page > MAX_PER_PAGE: + per_page = MAX_PER_PAGE + + if body is None: + body = SearchParams() + + search = body.search + parameters = body.parameters + sorts = body.sort + + # Add the vo to make sure that we filter for pilots we can see + # TODO: Test it + search = search + [ + { + "parameter": "VO", + "operator": "eq", + "value": vo, + } + ] + + if not sorts: + sorts = [{"parameter": "TimeStamp", "direction": SortDirection("asc")}] + + return await pilot_logs_db.search( + parameters=parameters, search=search, sorts=sorts, per_page=per_page, page=page + ) diff --git a/diracx-logic/src/diracx/logic/pilots/resources.py b/diracx-logic/src/diracx/logic/pilots/resources.py new file mode 100644 index 000000000..292a74c9a --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/resources.py @@ -0,0 +1,45 @@ +"""File dedicated to logic for pilot only resources (logs, jobs, etc.).""" + +from __future__ import annotations + +from diracx.core.exceptions import PilotNotFoundError +from diracx.core.models import LogLine +from diracx.db.os.pilot_logs import PilotLogsDB +from diracx.db.sql.pilots.db import PilotAgentsDB + +from .query import get_pilot_ids_by_stamps + + +async def send_message( + lines: list[LogLine], + pilot_logs_db: PilotLogsDB, + pilot_db: PilotAgentsDB, + vo: str, + pilot_stamp: str, +): + try: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] # Semantic + except PilotNotFoundError: + # If a pilot is not found, then we still store the data (to not lost it) + # We log it as it's not supposed to happen + # If we arrive here, the pilot as been deleted but is still "alive" + pilot_id = -1 # To detect + + docs = [] + for line in lines: + docs.append( + { + "PilotStamp": pilot_stamp, + "PilotID": pilot_id, + "VO": vo, + "Severity": line.severity, + "Message": line.message, + "TimeStamp": line.timestamp, + "Scope": line.scope, + } + ) + # bulk insert pilot logs to OpenSearch DB: + await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(vo, pilot_id), docs) diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 2038223ce..d97fa1aa2 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -47,11 +47,13 @@ config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" pilots = "diracx.routers.pilots:router" +"pilots/legacy" = "diracx.routers.legacy_pilot_resources:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" PilotManagementAccessPolicy = "diracx.routers.pilots.access_policies:PilotManagementAccessPolicy" +LegacyPilotAccessPolicy = "diracx.routers.legacy_pilot_resources.access_policies:LegacyPilotAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/dependencies.py b/diracx-routers/src/diracx/routers/dependencies.py index 8eb2bd265..88a5be6d0 100644 --- a/diracx-routers/src/diracx/routers/dependencies.py +++ b/diracx-routers/src/diracx/routers/dependencies.py @@ -23,6 +23,7 @@ from diracx.core.settings import DevelopmentSettings as _DevelopmentSettings from diracx.core.settings import SandboxStoreSettings as _SandboxStoreSettings from diracx.db.os import JobParametersDB as _JobParametersDB +from diracx.db.os import PilotLogsDB as _PilotLogsDB from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB @@ -50,6 +51,7 @@ def add_settings_annotation(cls: T) -> T: # Opensearch databases JobParametersDB = Annotated[_JobParametersDB, Depends(_JobParametersDB.session)] +PilotLogsDB = Annotated[_PilotLogsDB, Depends(_PilotLogsDB.session)] # Miscellaneous diff --git a/diracx-routers/src/diracx/routers/legacy_pilot_resources/__init__.py b/diracx-routers/src/diracx/routers/legacy_pilot_resources/__init__.py new file mode 100644 index 000000000..367fc6c93 --- /dev/null +++ b/diracx-routers/src/diracx/routers/legacy_pilot_resources/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import logging + +from ..fastapi_classes import DiracxRouter +from .logs import router as legacy_router + +logger = logging.getLogger(__name__) + +router = DiracxRouter(require_auth=False) +router.include_router(legacy_router) diff --git a/diracx-routers/src/diracx/routers/legacy_pilot_resources/access_policies.py b/diracx-routers/src/diracx/routers/legacy_pilot_resources/access_policies.py new file mode 100644 index 000000000..187323e8e --- /dev/null +++ b/diracx-routers/src/diracx/routers/legacy_pilot_resources/access_policies.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Annotated + +from fastapi import Depends, HTTPException, status + +from diracx.core.properties import GENERIC_PILOT, LIMITED_DELEGATION +from diracx.routers.access_policies import BaseAccessPolicy +from diracx.routers.utils.users import AuthorizedUserInfo + + +class LegacyPilotAccessPolicy(BaseAccessPolicy): + """Rules: + * Every user can access data about his VO + * An administrator can modify a pilot. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + ): + if ( + LIMITED_DELEGATION not in user_info.properties + and GENERIC_PILOT not in user_info.properties + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You must be a pilot to access this resource.", + ) + + return + + +CheckLegacyPilotPolicyCallable = Annotated[ + Callable, Depends(LegacyPilotAccessPolicy.check) +] diff --git a/diracx-routers/src/diracx/routers/legacy_pilot_resources/logs.py b/diracx-routers/src/diracx/routers/legacy_pilot_resources/logs.py new file mode 100644 index 000000000..452957d95 --- /dev/null +++ b/diracx-routers/src/diracx/routers/legacy_pilot_resources/logs.py @@ -0,0 +1,51 @@ +"""File dedicated to legacy pilot resources: pilots with DIRAC auth, without JWT.""" + +from __future__ import annotations + +import logging +from http import HTTPStatus +from typing import Annotated + +from fastapi import Body, Depends + +from diracx.core.models import LogLine +from diracx.logic.pilots.resources import send_message as send_message_bl +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +from ..dependencies import PilotAgentsDB, PilotLogsDB +from ..fastapi_classes import DiracxRouter +from .access_policies import ( + CheckLegacyPilotPolicyCallable, +) + +logger = logging.getLogger(__name__) +router = DiracxRouter() + + +@router.post("/message", status_code=HTTPStatus.NO_CONTENT) +async def send_message( + lines: Annotated[ + list[LogLine], + Body(description="Message from the pilot to the logging system.", embed=True), + ], + pilot_stamp: Annotated[ + str, + Body( + description="PilotStamp, required as legacy pilots do not have a token with stamp in it." + ), + ], + pilot_logs_db: PilotLogsDB, + pilot_db: PilotAgentsDB, + check_permissions: CheckLegacyPilotPolicyCallable, + pilot_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +): + """Send logs with legacy pilot.""" + await check_permissions() + + await send_message_bl( + lines=lines, + pilot_logs_db=pilot_logs_db, + pilot_db=pilot_db, + vo=pilot_info.vo, + pilot_stamp=pilot_stamp, + ) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index 40e942e60..a19ca4537 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -20,6 +20,8 @@ class ActionType(StrEnum): MANAGE_PILOTS = auto() # Read some pilot info READ_PILOT_FIELDS = auto() + # Legacy Pilot + LEGACY_PILOT = auto() class PilotManagementAccessPolicy(BaseAccessPolicy): @@ -43,6 +45,19 @@ async def policy( ): assert action, "action is a mandatory parameter" + is_a_pilot_if_allowed = ( + allow_legacy_pilots and GENERIC_PILOT in user_info.properties + ) + + if action == ActionType.LEGACY_PILOT: + if is_a_pilot_if_allowed: + return + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You must be a pilot to access this resource.", + ) + # Users can query # NOTE: Add into queries a VO constraint # To manage pilots, user have to be an admin @@ -50,9 +65,6 @@ async def policy( if action == ActionType.MANAGE_PILOTS: # To make it clear, we separate is_an_admin = SERVICE_ADMINISTRATOR in user_info.properties - is_a_pilot_if_allowed = ( - allow_legacy_pilots and GENERIC_PILOT in user_info.properties - ) if not is_an_admin and not is_a_pilot_if_allowed: raise HTTPException( diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py index 56655b46c..7044a0326 100644 --- a/diracx-routers/src/diracx/routers/pilots/query.py +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -4,12 +4,14 @@ from typing import Annotated, Any from fastapi import Body, Depends, Response +from opensearchpy import RequestError from diracx.core.models import SearchParams, SummaryParams from diracx.logic.pilots.query import search as search_bl +from diracx.logic.pilots.query import search_logs as search_logs_bl from diracx.logic.pilots.query import summary as summary_bl -from ..dependencies import PilotAgentsDB +from ..dependencies import PilotAgentsDB, PilotLogsDB from ..fastapi_classes import DiracxRouter from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token from .access_policies import ( @@ -148,6 +150,125 @@ async def search( return pilots +EXAMPLE_SEARCHES_LOGS = { + "Show all": { + "summary": "Show all", + "description": "Shows all pilots the current user has access to.", + "value": {}, + }, + "A specific pilot": { + "summary": "A specific pilot", + "description": "Search for a specific pilot by ID", + "value": {"search": [{"parameter": "PilotID", "operator": "eq", "value": "5"}]}, + }, + "Get a specific severity": { + "summary": "Get ordered pilot statuses", + "description": 'Get only pilot logs that have a severity of "ERROR", ordered by PilotID', + "value": { + "parameters": ["PilotID", "Severity"], + "search": [{"parameter": "Severity", "operator": "eq", "value": "ERROR"}], + "sort": [{"parameter": "PilotID", "direction": "asc"}], + }, + }, +} + + +EXAMPLE_RESPONSES_LOGS: dict[int | str, dict[str, Any]] = { + 200: { + "description": "List of matching results", + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "Severity": "ERROR", + "TimeStamp": "2023-05-25T07:03:35.602656", + }, + { + "PilotID": 5, + "Severity": "INFO", + "TimeStamp": "2023-07-25T07:03:35.602652", + }, + ] + } + }, + }, + 206: { + "description": "Partial Content. Only a part of the requested range could be served.", + "headers": { + "Content-Range": { + "description": "The range of logs returned in this response", + "schema": {"type": "string", "example": "logs 0-1/4"}, + } + }, + "model": list[dict[str, Any]], + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "Severity": "ERROR", + "TimeStamp": "2023-05-25T07:03:35.602656", + }, + { + "PilotID": 5, + "Severity": "INFO", + "TimeStamp": "2023-07-25T07:03:35.602652", + }, + ] + } + }, + }, +} + + +@router.post("/search/logs", responses=EXAMPLE_RESPONSES_LOGS) +async def search_logs( + pilot_logs_db: PilotLogsDB, + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckPilotManagementPolicyCallable, + page: int = 1, + per_page: int = 100, + body: Annotated[ + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES_LOGS) + ] = None, +) -> list[dict]: + # users will only see logs from their own VO if enforced by a policy: + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, + ) + + try: + total, logs = await search_logs_bl( + vo=user_info.vo, + body=body, + per_page=per_page, + page=page, + pilot_logs_db=pilot_logs_db, + ) + except RequestError: + total, logs = 0, [] + + # Set the Content-Range header if needed + # https://datatracker.ietf.org/doc/html/rfc7233#section-4 + + # No logs found but there are logs for the requested search + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 + if len(logs) == 0 and total > 0: + response.headers["Content-Range"] = f"logs */{total}" + response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE + + # The total number of logs is greater than the number of pilots returned + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 + elif len(logs) < total: + first_idx = per_page * (page - 1) + last_idx = min(first_idx + len(logs), total) - 1 if total > 0 else 0 + response.headers["Content-Range"] = f"logs {first_idx}-{last_idx}/{total}" + response.status_code = HTTPStatus.PARTIAL_CONTENT + return logs + + @router.post("/summary") async def summary( pilot_db: PilotAgentsDB, diff --git a/diracx-routers/tests/pilots/test_pilot_logging.py b/diracx-routers/tests/pilots/test_pilot_logging.py new file mode 100644 index 000000000..136160a0e --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_logging.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from diracx.core.exceptions import InvalidQueryError + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthDB", + "AuthSettings", + "PilotAgentsDB", + "PilotLogsDB", + "DevelopmentSettings", + "PilotManagementAccessPolicy", + "LegacyPilotAccessPolicy", + ] +) + +N = 100 + + +@pytest.fixture +def test_client(client_factory): + with client_factory.unauthenticated() as client: + yield client + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +@pytest.fixture +def create_pilots(normal_test_client: TestClient): + # Add a pilot stamps + pilot_stamps = [f"stamp_{i}" for i in range(N)] + + body = {"vo": "lhcb", "pilot_stamps": pilot_stamps} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + assert r.status_code == 200, r.json() + + return pilot_stamps + + +@pytest.fixture +async def create_logs(create_pilots, normal_test_client): + for i, stamp in enumerate(create_pilots): + lines = [ + { + "message": stamp, + "timestamp": "2022-02-26 13:48:35.123456", + "scope": "PilotParams" if i % 2 == 1 else "Commands", + "severity": "DEBUG" if i % 2 == 0 else "INFO", + } + ] + msg_dict = {"lines": lines, "pilot_stamp": stamp} + r = normal_test_client.post("/api/pilots/legacy/message", json=msg_dict) + + assert r.status_code == 204, r.json() + # Return only stamps + return create_pilots + + +@pytest.fixture +async def search(normal_test_client): + async def _search( + parameters, conditions, sorts, distinct=False, page=1, per_page=100 + ): + body = { + "parameters": parameters, + "search": conditions, + "sort": sorts, + "distinct": distinct, + } + + params = {"per_page": per_page, "page": page} + + r = normal_test_client.post("/api/pilots/search/logs", json=body, params=params) + + if r.status_code == 400: + # If we have a status_code 400, that means that the query failed + raise InvalidQueryError() + + return r.json(), r.headers + + return _search + + +async def test_single_send_and_retrieve_logs(normal_test_client: TestClient): + # Add a pilot stamps + pilot_stamp = ["stamp_1"] + + # -------------- Bulk insert -------------- + body = {"vo": "lhcb", "pilot_stamps": pilot_stamp} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + msg = "JSON file loaded: pilot.json\nJSON file analysed: pilot.json" + # message dict + lines = [] + for line in msg.split("\n"): + lines.append( + { + "message": line, + "timestamp": "2022-02-26 13:48:35.123456", + "scope": "PilotParams", + "severity": "DEBUG", + } + ) + msg_dict = {"lines": lines, "pilot_stamp": "stamp_1"} + + # send message + r = normal_test_client.post("/api/pilots/legacy/message", json=msg_dict) + + assert r.status_code == 204, r.json() + # get the message back: + data = { + "search": [{"parameter": "PilotStamp", "operator": "eq", "value": "stamp_1"}] + } + r = normal_test_client.post("/api/pilots/search/logs", json=data) + assert r.status_code == 200, r.text + assert [hit["Message"] for hit in r.json()] == msg.split("\n") + + +async def test_query_invalid_stamp(create_logs, normal_test_client): + data = { + "search": [ + {"parameter": "PilotStamp", "operator": "eq", "value": "not_a_stamp"} + ] + } + r = normal_test_client.post("/api/pilots/search/logs", json=data) + assert r.status_code == 200, r.text + assert len(r.json()) == 0 + + +async def test_query_each_length(create_logs, normal_test_client): + for stamp in create_logs: + data = { + "search": [{"parameter": "PilotStamp", "operator": "eq", "value": stamp}] + } + r = normal_test_client.post("/api/pilots/search/logs", json=data) + assert r.status_code == 200, r.text + assert len(r.json()) == 1 + + +async def test_query_each_field(create_logs, normal_test_client): + for i, stamp in enumerate(create_logs): + data = { + "search": [{"parameter": "PilotStamp", "operator": "eq", "value": stamp}], + "sort": [{"parameter": "PilotStamp", "direction": "asc"}], + } + r = normal_test_client.post("/api/pilots/search/logs", json=data) + assert r.status_code == 200, r.text + assert len(r.json()) == 1 + + # Reminder: + + # "message": str(i), + # "timestamp": "2022-02-26 13:48:35.123456", + # "scope": "PilotParams" if i % 2 == 1 else "Commands", + # "severity": "DEBUG" if i % 2 == 0 else "INFO", + log = r.json()[0] + + assert log["Message"] == f"stamp_{i}" + assert log["Scope"] == ("PilotParams" if i % 2 == 1 else "Commands") + assert log["Severity"] == ("DEBUG" if i % 2 == 0 else "INFO") + + +async def test_search_pagination(create_logs, search): + """Test that we can search for logs.""" + # Search for the first 10 logs + result, headers = await search([], [], [], per_page=10, page=1) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 1 + + # Search for the second 10 logs + result, headers = await search([], [], [], per_page=10, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 11 + + # Search for the last 10 logs + result, headers = await search([], [], [], per_page=10, page=10) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 91 + + # Search for the second 50 logs + result, headers = await search([], [], [], per_page=50, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 50 + assert result[0]["PilotID"] == 51 + + # Invalid page number + result, headers = await search([], [], [], per_page=10, page=11) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert not result + + # Invalid page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=10, page=0) + + # Invalid per_page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=0, page=1) diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 5f7fe7f93..6a87b0579 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -10,9 +10,10 @@ from functools import partial from typing import Any, AsyncIterator -from sqlalchemy import select +from sqlalchemy import func, select from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from diracx.core.exceptions import InvalidQueryError from diracx.core.models import SearchSpec, SortSpec from diracx.db.sql import utils as sql_utils @@ -53,7 +54,11 @@ def __init__(self, connection_kwargs: dict[str, Any]) -> None: for field, field_type in self.fields.items(): match field_type["type"]: case "date": + # TODO: Warning, maybe this will crash? See date_nanos + # I needed to set Varchar because it is sent as 2022-06-15T10:12:52.382719622Z, and not datetime column_type = DateNowColumn + case "date_nanos": + column_type = partial(Column, type_=String(32)) case "long": column_type = partial(Column, type_=Integer) case "keyword": @@ -100,6 +105,21 @@ async def upsert(self, vo, doc_id, document) -> None: stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values) await self._sql_db.conn.execute(stmt) + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + async with self._sql_db: + rows = [] + for doc in docs: + # don't use doc_id column explicitly. This ensures that doc_id is unique. + values = {} + for key, value in doc.items(): + if key in self.fields: + values[key] = value + else: + values.setdefault("extra", {})[key] = value + rows.append(values) + stmt = sqlite_insert(self._table).values(rows) + await self._sql_db.conn.execute(stmt) + async def search( self, parameters: list[str] | None, @@ -135,8 +155,17 @@ async def search( self._table.columns.__getitem__, stmt, sorts ) + # Calculate total count before applying pagination + total_count_subquery = stmt.alias() + total_count_stmt = select(func.count()).select_from(total_count_subquery) + total = (await self._sql_db.conn.execute(total_count_stmt)).scalar_one() + # Apply pagination if page is not None: + if page < 1: + raise InvalidQueryError("Page must be a positive integer") + if per_page < 1: + raise InvalidQueryError("Per page must be a positive integer") stmt = stmt.offset((page - 1) * per_page).limit(per_page) results = [] @@ -151,7 +180,8 @@ async def search( if v is None: result.pop(k) results.append(result) - return results + + return total, results async def ping(self): async with self._sql_db: diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py index fdf17b6a3..321d3238c 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py @@ -20,12 +20,13 @@ ConfigOperations, JobsOperations, LollygagOperations, + PilotsLegacyOperations, PilotsOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.operations.LollygagOperations :ivar pilots: PilotsOperations operations :vartype pilots: _generated.operations.PilotsOperations + :ivar pilots_legacy: PilotsLegacyOperations operations + :vartype pilots_legacy: _generated.operations.PilotsLegacyOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_legacy = PilotsLegacyOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py index 76280797e..07253331f 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py @@ -20,12 +20,13 @@ ConfigOperations, JobsOperations, LollygagOperations, + PilotsLegacyOperations, PilotsOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.aio.operations.LollygagOperations :ivar pilots: PilotsOperations operations :vartype pilots: _generated.aio.operations.PilotsOperations + :ivar pilots_legacy: PilotsLegacyOperations operations + :vartype pilots_legacy: _generated.aio.operations.PilotsLegacyOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_legacy = PilotsLegacyOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py index 3408891fc..759b5d4e6 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsLegacyOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "PilotsOperations", + "PilotsLegacyOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 3d9951b6d..a8b08e1ac 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -58,6 +58,8 @@ build_pilots_add_pilot_stamps_request, build_pilots_delete_pilots_request, build_pilots_get_pilot_jobs_request, + build_pilots_legacy_send_message_request, + build_pilots_search_logs_request, build_pilots_search_request, build_pilots_summary_request, build_pilots_update_pilot_fields_request, @@ -2816,6 +2818,143 @@ async def search( return deserialized # type: ignore + @overload + async def search_logs( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search_logs( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search_logs( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_logs_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + @overload async def summary( self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any @@ -2910,3 +3049,114 @@ async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsLegacyOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots_legacy` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def send_message( + self, body: _models.BodyPilotsLegacySendMessage, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def send_message(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def send_message(self, body: Union[_models.BodyPilotsLegacySendMessage, IO[bytes]], **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Is either a BodyPilotsLegacySendMessage type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsLegacySendMessage") + + _request = build_pilots_legacy_send_message_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index 537867ac1..c889fb017 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -15,6 +15,7 @@ BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, BodyPilotsAddPilotStamps, + BodyPilotsLegacySendMessage, BodyPilotsUpdatePilotFields, ExtendedMetadata, GroupInfo, @@ -24,6 +25,7 @@ InsertedJob, JobCommand, JobStatusUpdate, + LogLine, OpenIDConfiguration, PilotFieldsMapping, SandboxDownloadResponse, @@ -66,6 +68,7 @@ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", "BodyPilotsAddPilotStamps", + "BodyPilotsLegacySendMessage", "BodyPilotsUpdatePilotFields", "ExtendedMetadata", "GroupInfo", @@ -75,6 +78,7 @@ "InsertedJob", "JobCommand", "JobStatusUpdate", + "LogLine", "OpenIDConfiguration", "PilotFieldsMapping", "SandboxDownloadResponse", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 751efcfb9..c597bfcc1 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -170,6 +170,41 @@ def __init__( self.pilot_status = pilot_status +class BodyPilotsLegacySendMessage(_serialization.Model): + """Body_pilots/legacy_send_message. + + All required parameters must be populated in order to send to server. + + :ivar lines: Message from the pilot to the logging system. Required. + :vartype lines: list[~_generated.models.LogLine] + :ivar pilot_stamp: PilotStamp, required as legacy pilots do not have a token with stamp in it. + Required. + :vartype pilot_stamp: str + """ + + _validation = { + "lines": {"required": True}, + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "lines": {"key": "lines", "type": "[LogLine]"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + } + + def __init__(self, *, lines: List["_models.LogLine"], pilot_stamp: str, **kwargs: Any) -> None: + """ + :keyword lines: Message from the pilot to the logging system. Required. + :paramtype lines: list[~_generated.models.LogLine] + :keyword pilot_stamp: PilotStamp, required as legacy pilots do not have a token with stamp in + it. Required. + :paramtype pilot_stamp: str + """ + super().__init__(**kwargs) + self.lines = lines + self.pilot_stamp = pilot_stamp + + class BodyPilotsUpdatePilotFields(_serialization.Model): """Body_pilots_update_pilot_fields. @@ -558,6 +593,53 @@ def __init__( self.source = source +class LogLine(_serialization.Model): + """LogLine. + + All required parameters must be populated in order to send to server. + + :ivar timestamp: Timestamp. Required. + :vartype timestamp: str + :ivar severity: Severity. Required. + :vartype severity: str + :ivar message: Message. Required. + :vartype message: str + :ivar scope: Scope. Required. + :vartype scope: str + """ + + _validation = { + "timestamp": {"required": True}, + "severity": {"required": True}, + "message": {"required": True}, + "scope": {"required": True}, + } + + _attribute_map = { + "timestamp": {"key": "timestamp", "type": "str"}, + "severity": {"key": "severity", "type": "str"}, + "message": {"key": "message", "type": "str"}, + "scope": {"key": "scope", "type": "str"}, + } + + def __init__(self, *, timestamp: str, severity: str, message: str, scope: str, **kwargs: Any) -> None: + """ + :keyword timestamp: Timestamp. Required. + :paramtype timestamp: str + :keyword severity: Severity. Required. + :paramtype severity: str + :keyword message: Message. Required. + :paramtype message: str + :keyword scope: Scope. Required. + :paramtype scope: str + """ + super().__init__(**kwargs) + self.timestamp = timestamp + self.severity = severity + self.message = message + self.scope = scope + + class OpenIDConfiguration(_serialization.Model): """OpenIDConfiguration. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py index 3408891fc..759b5d4e6 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsLegacyOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "PilotsOperations", + "PilotsLegacyOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index e861b9841..0db5ffa19 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -740,6 +740,30 @@ def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) +def build_pilots_search_logs_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search/logs" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) @@ -757,6 +781,20 @@ def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_pilots_legacy_send_message_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/legacy/message" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -3495,6 +3533,143 @@ def search( return deserialized # type: ignore + @overload + def search_logs( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search_logs( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search_logs( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_logs_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + @overload def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. @@ -3587,3 +3762,116 @@ def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsLegacyOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots_legacy` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def send_message( + self, body: _models.BodyPilotsLegacySendMessage, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def send_message(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def send_message( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsLegacySendMessage, IO[bytes]], **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Is either a BodyPilotsLegacySendMessage type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsLegacySendMessage") + + _request = build_pilots_legacy_send_message_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore From 98d90039a7e186f9540193b66391b6e1500efbcf Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Fri, 1 Aug 2025 13:36:16 +0200 Subject: [PATCH 33/33] fix: Parameter uplication in pilot management --- diracx-routers/src/diracx/routers/pilots/management.py | 1 - 1 file changed, 1 deletion(-) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 8bb9ea514..21ff63796 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -56,7 +56,6 @@ async def add_pilot_stamps( pilot_status: Annotated[ PilotStatus, Body(description="Status of the pilots.") ] = PilotStatus.SUBMITTED, - vo: str | None = None, ): """Endpoint where a you can create pilots with their references.