Skip to content

Commit 5809859

Browse files
committed
Strict typing for aiida.repository module
1 parent 13cb318 commit 5809859

File tree

15 files changed

+92
-91
lines changed

15 files changed

+92
-91
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ repos:
105105
src/aiida/orm/utils/builders/computer.py|
106106
src/aiida/orm/utils/calcjob.py|
107107
src/aiida/orm/utils/node.py|
108-
src/aiida/repository/backend/disk_object_store.py|
109-
src/aiida/repository/backend/sandbox.py|
110108
src/aiida/restapi/common/utils.py|
111109
src/aiida/restapi/resources.py|
112110
src/aiida/restapi/run_api.py|

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ module = [
357357
'aiida.cmdline.params.*',
358358
'aiida.cmdline.groups.*',
359359
'aiida.tools.query.*'
360+
'aiida.repository.*',
360361
]
361362
warn_return_any = true
362363

src/aiida/cmdline/commands/cmd_archive.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from aiida.common.exceptions import CorruptStorage, IncompatibleStorageSchema, UnreachableStorage
2727
from aiida.common.links import GraphTraversalRules
2828
from aiida.common.log import AIIDA_LOGGER
29+
from aiida.common.typing import FilePath
2930
from aiida.common.utils import DEFAULT_BATCH_SIZE, DEFAULT_FILTER_SIZE
3031

3132
EXTRAS_MODE_EXISTING = ['keep_existing', 'update_existing', 'mirror', 'none']
@@ -488,7 +489,7 @@ def _import_archive_and_migrate(
488489
dry_run_success = f'import dry-run of archive {archive} completed. Profile storage unmodified.'
489490

490491
with SandboxFolder(filepath=filepath) as temp_folder:
491-
archive_path = archive
492+
archive_path: FilePath = archive
492493

493494
if web_based:
494495
echo.echo_report(f'downloading archive: {archive}')
@@ -501,14 +502,15 @@ def _import_archive_and_migrate(
501502
archive_path = temp_folder.get_abs_path('downloaded_archive.zip')
502503
echo.echo_success('archive downloaded, proceeding with import')
503504

505+
archive_path = str(archive_path)
504506
echo.echo_report(f'starting import: {archive}')
505507
try:
506508
_import_archive(archive_path, archive_format=archive_format, **import_kwargs)
507509
except IncompatibleStorageSchema as exception:
508510
if try_migration:
509511
echo.echo_report(f'incompatible version detected for {archive}, trying migration')
510512
try:
511-
new_path = temp_folder.get_abs_path('migrated_archive.aiida')
513+
new_path = str(temp_folder.get_abs_path('migrated_archive.aiida'))
512514
archive_format.migrate(archive_path, new_path, archive_format.latest_version, compression=0)
513515
archive_path = new_path
514516
except Exception as sub_exception:

src/aiida/common/folders.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
import pathlib
1818
import shutil
1919
import tempfile
20+
import typing as t
21+
from collections.abc import Iterator
2022

2123
from . import timezone
2224
from .lang import type_check
25+
from .typing import FilePath, Self
2326

2427
# If True, tries to make everything (dirs, files) group-writable.
2528
# Otherwise, tries to make everything only readable and writable by the user.
@@ -45,7 +48,7 @@ class Folder:
4548
to os.path.abspath or normpath are quite slow).
4649
"""
4750

48-
def __init__(self, abspath, folder_limit=None):
51+
def __init__(self, abspath: FilePath, folder_limit: FilePath | None = None):
4952
"""Construct a new instance."""
5053
abspath = os.path.abspath(abspath)
5154
if folder_limit is None:
@@ -64,22 +67,22 @@ def __init__(self, abspath, folder_limit=None):
6467
self._folder_limit = folder_limit
6568

6669
@property
67-
def mode_dir(self):
70+
def mode_dir(self) -> int:
6871
"""Return the mode with which the folders should be created"""
6972
if GROUP_WRITABLE:
7073
return 0o770
7174

7275
return 0o700
7376

7477
@property
75-
def mode_file(self):
78+
def mode_file(self) -> int:
7679
"""Return the mode with which the files should be created"""
7780
if GROUP_WRITABLE:
7881
return 0o660
7982

8083
return 0o600
8184

82-
def get_subfolder(self, subfolder, create=False, reset_limit=False):
85+
def get_subfolder(self, subfolder: FilePath, create=False, reset_limit=False) -> Folder:
8386
"""Return a Folder object pointing to a subfolder.
8487
8588
:param subfolder: a string with the relative path of the subfolder,
@@ -110,7 +113,7 @@ def get_subfolder(self, subfolder, create=False, reset_limit=False):
110113

111114
return new_folder
112115

113-
def get_content_list(self, pattern='*', only_paths=True):
116+
def get_content_list(self, pattern: str = '*', only_paths: bool = True) -> list:
114117
"""Return a list of files (and subfolders) in the folder, matching a given pattern.
115118
116119
Example: If you want to exclude files starting with a dot, you can
@@ -134,7 +137,7 @@ def get_content_list(self, pattern='*', only_paths=True):
134137

135138
return [(fname, not os.path.isdir(os.path.join(self.abspath, fname))) for fname in file_list]
136139

137-
def create_symlink(self, src, name):
140+
def create_symlink(self, src: FilePath, name: FilePath) -> None:
138141
"""Create a symlink inside the folder to the location 'src'.
139142
140143
:param src: the location to which the symlink must point. Can be
@@ -148,7 +151,7 @@ def create_symlink(self, src, name):
148151

149152
# For symlinks, permissions should not be set
150153

151-
def insert_path(self, src, dest_name=None, overwrite=True):
154+
def insert_path(self, src: FilePath, dest_name: FilePath | None = None, overwrite: bool = True) -> FilePath:
152155
"""Copy a file to the folder.
153156
154157
:param src: the source filename to copy
@@ -205,7 +208,9 @@ def insert_path(self, src, dest_name=None, overwrite=True):
205208

206209
return dest_abs_path
207210

208-
def create_file_from_filelike(self, filelike, filename, mode='wb', encoding=None):
211+
def create_file_from_filelike(
212+
self, filelike: t.IO[t.AnyStr], filename: FilePath, mode: str = 'wb', encoding: str | None = None
213+
) -> FilePath:
209214
"""Create a file with the given filename from a filelike object.
210215
211216
:param filelike: a filelike object whose contents to copy
@@ -227,7 +232,7 @@ def create_file_from_filelike(self, filelike, filename, mode='wb', encoding=None
227232

228233
return filepath
229234

230-
def remove_path(self, filename):
235+
def remove_path(self, filename: FilePath) -> None:
231236
"""Remove a file or folder from the folder.
232237
233238
:param filename: the relative path name to remove
@@ -241,7 +246,7 @@ def remove_path(self, filename):
241246
else:
242247
os.remove(dest_abs_path)
243248

244-
def get_abs_path(self, relpath, check_existence=False):
249+
def get_abs_path(self, relpath: FilePath, check_existence: bool = False) -> FilePath:
245250
"""Return an absolute path for a file or folder in this folder.
246251
247252
The advantage of using this method is that it checks that filename
@@ -268,7 +273,9 @@ def get_abs_path(self, relpath, check_existence=False):
268273
return dest_abs_path
269274

270275
@contextlib.contextmanager
271-
def open(self, name, mode='r', encoding='utf8', check_existence=False):
276+
def open(
277+
self, name: FilePath, mode: str = 'r', encoding: str | None = 'utf8', check_existence: bool = False
278+
) -> Iterator[t.Any]:
272279
"""Open a file in the current folder and return the corresponding file object.
273280
274281
:param check_existence: if False, just return the file path.
@@ -282,32 +289,32 @@ def open(self, name, mode='r', encoding='utf8', check_existence=False):
282289
yield handle
283290

284291
@property
285-
def abspath(self):
292+
def abspath(self) -> FilePath:
286293
"""The absolute path of the folder."""
287294
return self._abspath
288295

289296
@property
290-
def folder_limit(self):
297+
def folder_limit(self) -> FilePath:
291298
"""The folder limit that cannot be crossed when creating files and folders."""
292299
return self._folder_limit
293300

294-
def exists(self):
301+
def exists(self) -> bool:
295302
"""Return True if the folder exists, False otherwise."""
296303
return os.path.exists(self.abspath)
297304

298-
def isfile(self, relpath):
305+
def isfile(self, relpath: FilePath) -> bool:
299306
"""Return True if 'relpath' exists inside the folder and is a file,
300307
False otherwise.
301308
"""
302309
return os.path.isfile(os.path.join(self.abspath, relpath))
303310

304-
def isdir(self, relpath):
311+
def isdir(self, relpath: FilePath) -> bool:
305312
"""Return True if 'relpath' exists inside the folder and is a directory,
306313
False otherwise.
307314
"""
308315
return os.path.isdir(os.path.join(self.abspath, relpath))
309316

310-
def erase(self, create_empty_folder=False):
317+
def erase(self, create_empty_folder: bool = False) -> None:
311318
"""Erases the folder. Should be called only in very specific cases,
312319
in general folder should not be erased!
313320
@@ -321,7 +328,7 @@ def erase(self, create_empty_folder=False):
321328
if create_empty_folder:
322329
self.create()
323330

324-
def create(self):
331+
def create(self) -> None:
325332
"""Creates the folder, if it does not exist on the disk yet.
326333
327334
It will also create top directories, if absent.
@@ -331,7 +338,7 @@ def create(self):
331338
"""
332339
os.makedirs(self.abspath, mode=self.mode_dir, exist_ok=True)
333340

334-
def replace_with_folder(self, srcdir, move=False, overwrite=False):
341+
def replace_with_folder(self, srcdir: FilePath, move: bool = False, overwrite: bool = False) -> None:
335342
"""This routine copies or moves the source folder 'srcdir' to the local folder pointed to by this Folder.
336343
337344
:param srcdir: the source folder on the disk; this must be an absolute path
@@ -399,11 +406,11 @@ def __init__(self, filepath: pathlib.Path | None = None):
399406

400407
super().__init__(abspath=tempfile.mkdtemp(dir=filepath))
401408

402-
def __enter__(self):
409+
def __enter__(self) -> Self:
403410
"""Enter a context and return self."""
404411
return self
405412

406-
def __exit__(self, exc_type, exc_value, traceback):
413+
def __exit__(self, exc_type, exc_value, traceback) -> None:
407414
"""Erase the temporary directory created in the constructor."""
408415
self.erase()
409416

@@ -416,9 +423,7 @@ class SubmitTestFolder(Folder):
416423
not overwrite already existing created test folders.
417424
"""
418425

419-
_sub_folder = None
420-
421-
def __init__(self, basepath=CALC_JOB_DRY_RUN_BASE_PATH):
426+
def __init__(self, basepath: FilePath = CALC_JOB_DRY_RUN_BASE_PATH):
422427
"""Construct and create the sandbox folder.
423428
424429
The directory will be created in the current working directory with the name given by `basepath`.
@@ -451,9 +456,9 @@ def __init__(self, basepath=CALC_JOB_DRY_RUN_BASE_PATH):
451456

452457
self._sub_folder = self.get_subfolder(os.path.relpath(subfolder_path, self.abspath), reset_limit=True)
453458

454-
def __enter__(self):
459+
def __enter__(self) -> Folder:
455460
"""Return the sub folder that should be Called when entering in the with statement."""
456461
return self._sub_folder
457462

458-
def __exit__(self, exc_type, exc_value, traceback):
463+
def __exit__(self, exc_type, exc_value, traceback) -> None:
459464
"""When context manager is exited, do not delete the folder."""

src/aiida/engine/daemon/execmanager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from aiida.common.datastructures import CalcInfo, FileCopyOperation
3434
from aiida.common.folders import Folder, SandboxFolder
3535
from aiida.common.links import LinkType
36+
from aiida.common.typing import FilePath
3637
from aiida.engine.processes.exit_code import ExitCode
3738
from aiida.manage.configuration import get_config_option
3839
from aiida.orm import CalcJobNode, Code, FolderData, Node, PortableCode, RemoteData, load_node
@@ -694,7 +695,7 @@ def traverse(node_):
694695

695696

696697
async def retrieve_calculation(
697-
calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str
698+
calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: FilePath
698699
) -> FolderData | None:
699700
"""Retrieve all the files of a completed job calculation using the given transport.
700701

src/aiida/engine/processes/calcjobs/calcjob.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from aiida.common.folders import Folder
2727
from aiida.common.lang import classproperty, override
2828
from aiida.common.links import LinkType
29+
from aiida.common.typing import FilePath
2930

3031
from ..exit_code import ExitCode
3132
from ..ports import PortNamespace
@@ -743,7 +744,7 @@ async def _perform_import(self):
743744
return self.parse(retrieved_temporary_folder.abspath)
744745

745746
def parse(
746-
self, retrieved_temporary_folder: Optional[str] = None, existing_exit_code: ExitCode | None = None
747+
self, retrieved_temporary_folder: FilePath | None = None, existing_exit_code: ExitCode | None = None
747748
) -> ExitCode:
748749
"""Parse a retrieved job calculation.
749750
@@ -771,7 +772,7 @@ def parse(
771772

772773
# Call the retrieved output parser
773774
try:
774-
exit_code_retrieved = self.parse_retrieved_output(retrieved_temporary_folder)
775+
exit_code_retrieved = self.parse_retrieved_output(str(retrieved_temporary_folder))
775776
finally:
776777
if retrieved_temporary_folder is not None:
777778
shutil.rmtree(retrieved_temporary_folder, ignore_errors=True)
@@ -1122,7 +1123,10 @@ def presubmit(self, folder: Folder) -> CalcInfo:
11221123
job_tmpl.max_wallclock_seconds = max_wallclock_seconds
11231124

11241125
submit_script_filename = self.node.get_option('submit_script_filename')
1126+
assert submit_script_filename is not None
11251127
script_content = scheduler.get_submit_script(job_tmpl)
1128+
# TODO: mypy error: Argument 2 to "create_file_from_filelike" of "Folder"
1129+
# has incompatible type "Any | None"; expected "str | PurePath"
11261130
folder.create_file_from_filelike(io.StringIO(script_content), submit_script_filename, 'w', encoding='utf8')
11271131

11281132
def encoder(obj):

src/aiida/repository/backend/abstract.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
import hashlib
1111
import io
1212
import pathlib
13-
from typing import BinaryIO, Iterable, Iterator, List, Optional, Tuple, Union
13+
from collections.abc import Iterable, Iterator
14+
from typing import Any, BinaryIO, List, Optional, Tuple, Union
1415

1516
from aiida.common.hashing import chunked_file_hash
1617

1718
__all__ = ('AbstractRepositoryBackend',)
1819

20+
InfoDictType = dict[str, Union[int, str, dict[str, int], dict[str, float]]]
21+
1922

2023
class AbstractRepositoryBackend(metaclass=abc.ABCMeta):
2124
"""Class that defines the abstract interface for an object repository.
@@ -44,7 +47,7 @@ def key_format(self) -> Optional[str]:
4447
"""
4548

4649
@abc.abstractmethod
47-
def initialise(self, **kwargs) -> None:
50+
def initialise(self, **kwargs: Any) -> None:
4851
"""Initialise the repository if it hasn't already been initialised.
4952
5053
:param kwargs: parameters for the initialisation.
@@ -65,7 +68,7 @@ def erase(self) -> None:
6568
"""
6669

6770
@staticmethod
68-
def is_readable_byte_stream(handle) -> bool:
71+
def is_readable_byte_stream(handle: Any) -> bool:
6972
return hasattr(handle, 'read') and hasattr(handle, 'mode') and 'b' in handle.mode
7073

7174
def put_object_from_filelike(self, handle: BinaryIO) -> str:
@@ -120,7 +123,7 @@ def list_objects(self) -> Iterable[str]:
120123
"""
121124

122125
@abc.abstractmethod
123-
def get_info(self, detailed: bool = False, **kwargs) -> dict:
126+
def get_info(self, detailed: bool = False) -> InfoDictType:
124127
"""Returns relevant information about the content of the repository.
125128
126129
:param detailed:
@@ -129,19 +132,6 @@ def get_info(self, detailed: bool = False, **kwargs) -> dict:
129132
:return: a dictionary with the information.
130133
"""
131134

132-
@abc.abstractmethod
133-
def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None:
134-
"""Performs maintenance operations.
135-
136-
:param dry_run:
137-
flag to only print the actions that would be taken without actually executing them.
138-
139-
:param live:
140-
flag to indicate to the backend whether AiiDA is live or not (i.e. if the profile of the
141-
backend is currently being used/accessed). The backend is expected then to only allow (and
142-
thus set by default) the operations that are safe to perform in this state.
143-
"""
144-
145135
@contextlib.contextmanager
146136
def open(self, key: str) -> Iterator[BinaryIO]: # type: ignore[return]
147137
"""Open a file handle to an object stored under the given key.
@@ -168,7 +158,7 @@ def get_object_content(self, key: str) -> bytes:
168158
return handle.read()
169159

170160
@abc.abstractmethod
171-
def iter_object_streams(self, keys: List[str]) -> Iterator[Tuple[str, BinaryIO]]:
161+
def iter_object_streams(self, keys: Iterable[str]) -> Iterator[Tuple[str, BinaryIO]]:
172162
"""Return an iterator over the (read-only) byte streams of objects identified by key.
173163
174164
.. note:: handles should only be read within the context of this iterator.

0 commit comments

Comments
 (0)