Skip to content

Commit 2a10a4a

Browse files
fix: use Protocol type (#23)
1 parent 0c1a2ee commit 2a10a4a

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

snakemake_interface_software_deployment_plugins/__init__.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,17 @@
99
import hashlib
1010
from pathlib import Path
1111
import shutil
12-
from typing import Any, Dict, Iterable, Optional, Self, Tuple, Type, Union
12+
from typing import (
13+
Any,
14+
ClassVar,
15+
Dict,
16+
Iterable,
17+
Optional,
18+
Self,
19+
Tuple,
20+
Type,
21+
Union,
22+
)
1323
import subprocess as sp
1424

1525
from snakemake_interface_software_deployment_plugins.settings import (
@@ -124,7 +134,18 @@ def __str__(self) -> str:
124134

125135

126136
class EnvBase(ABC):
127-
_cache: Dict[Tuple[Type["EnvBase"], Optional["EnvBase"]], Any] = {}
137+
_cache: ClassVar[Dict[Tuple[Type["EnvBase"], Optional["EnvBase"]], Any]] = {}
138+
spec: EnvSpecBase
139+
within: Optional["EnvBase"]
140+
settings: Optional[SoftwareDeploymentSettingsBase]
141+
shell_executable: str
142+
tempdir: Path
143+
_cache_prefix: Path
144+
_deployment_prefix: Path
145+
_pinfile_prefix: Path
146+
_managed_hash_store: Optional[str] = None
147+
_managed_deployment_hash_store: Optional[str] = None
148+
_obj_hash: Optional[int] = None
128149

129150
def __init__(
130151
self,
@@ -142,9 +163,6 @@ def __init__(
142163
self.settings: Optional[SoftwareDeploymentSettingsBase] = settings
143164
self.shell_executable: str = shell_executable
144165
self.tempdir = tempdir
145-
self._managed_hash_store: Optional[str] = None
146-
self._managed_deployment_hash_store: Optional[str] = None
147-
self._obj_hash: Optional[int] = None
148166
self._deployment_prefix: Path = deployment_prefix
149167
self._cache_prefix: Path = cache_prefix
150168
self._pinfile_prefix: Path = pinfile_prefix
@@ -177,6 +195,21 @@ def decorate_shellcmd(self, cmd: str) -> str:
177195
"""
178196
...
179197

198+
def is_deployable(self) -> bool:
199+
"""Overwrite this in case the deployability of the environment depends on
200+
the spec or settings."""
201+
return isinstance(self, DeployableEnvBase)
202+
203+
def is_pinnable(self) -> bool:
204+
"""Overwrite this in case the pinability of the environment depends on
205+
the spec or settings."""
206+
return isinstance(self, PinnableEnvBase)
207+
208+
def is_cacheable(self) -> bool:
209+
"""Overwrite this in case the cacheability of the environment depends on
210+
the spec or settings."""
211+
return isinstance(self, CacheableEnvBase)
212+
180213
@abstractmethod
181214
def record_hash(self, hash_object) -> None:
182215
"""Update given hash object (using hash_object.update()) such that it changes
@@ -241,7 +274,7 @@ def __eq__(self, other) -> bool:
241274
)
242275

243276

244-
class PinnableEnvBase(ABC):
277+
class PinnableEnvBase(EnvBase, ABC):
245278
@classmethod
246279
@abstractmethod
247280
def pinfile_extension(cls) -> str: ...
@@ -256,7 +289,6 @@ async def pin(self) -> None:
256289

257290
@property
258291
def pinfile(self) -> Path:
259-
assert isinstance(self, EnvBase)
260292
ext = self.pinfile_extension()
261293
if not ext.startswith("."):
262294
raise ValueError("pinfile_extension must start with a dot.")
@@ -265,7 +297,7 @@ def pinfile(self) -> Path:
265297
)
266298

267299

268-
class CacheableEnvBase(ABC):
300+
class CacheableEnvBase(EnvBase, ABC):
269301
async def get_cache_assets(self) -> Iterable[str]: ...
270302

271303
@abstractmethod
@@ -277,12 +309,10 @@ async def cache_assets(self) -> None:
277309

278310
@property
279311
def cache_path(self) -> Path:
280-
assert isinstance(self, EnvBase)
281312
return self._cache_prefix
282313

283314
async def remove_cache(self) -> None:
284315
"""Remove the cached environment assets."""
285-
assert isinstance(self, EnvBase)
286316
for asset in await self.get_cache_assets():
287317
asset_path = self.cache_path / asset
288318
if asset_path.exists():
@@ -297,7 +327,7 @@ async def remove_cache(self) -> None:
297327
)
298328

299329

300-
class DeployableEnvBase(ABC):
330+
class DeployableEnvBase(EnvBase, ABC):
301331
@abstractmethod
302332
def is_deployment_path_portable(self) -> bool:
303333
"""Return whether the deployment path matters for the environment, i.e.
@@ -325,7 +355,6 @@ def record_deployment_hash(self, hash_object) -> None:
325355
deployment is senstivive to the path (e.g. in case of conda, which patches
326356
the RPATH in binaries).
327357
"""
328-
assert isinstance(self, EnvBase)
329358
self.record_hash(hash_object)
330359
if not self.is_deployment_path_portable():
331360
hash_object.update(str(self._deployment_prefix).encode())
@@ -337,24 +366,21 @@ def remove(self) -> None:
337366

338367
def managed_remove(self) -> None:
339368
"""Remove the deployed environment, handling exceptions."""
340-
assert isinstance(self, EnvBase)
341369
try:
342370
self.remove()
343371
except Exception as e:
344372
raise WorkflowError(f"Removal of {self.spec} failed: {e}")
345373

346374
async def managed_deploy(self) -> None:
347-
assert isinstance(self, EnvBase)
348375
try:
349376
await self.deploy()
350377
except Exception as e:
351378
raise WorkflowError(f"Deployment of {self.spec} failed: {e}")
352379

353380
def deployment_hash(self) -> str:
354-
assert isinstance(self, EnvBase)
355381
return self._managed_generic_hash("deployment_hash")
356382

357383
@property
358384
def deployment_path(self) -> Path:
359-
assert isinstance(self, EnvBase) and self._deployment_prefix is not None
385+
assert self._deployment_prefix is not None
360386
return self._deployment_prefix / self.deployment_hash()

snakemake_interface_software_deployment_plugins/tests.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,13 @@ def test_deploy(self, tmp_path):
9090
cmd = env.managed_decorate_shellcmd(self.get_test_cmd())
9191
assert sp.run(cmd, shell=True, executable=self.shell_executable).returncode == 0
9292

93-
def test_archive(self, tmp_path):
93+
def test_cache(self, tmp_path):
9494
env = self._get_env(tmp_path)
95-
if not isinstance(env, CacheableEnvBase):
95+
if not env.is_cacheable():
9696
pytest.skip("Environment either not deployable or not cacheable.")
9797

98+
assert isinstance(env, CacheableEnvBase)
99+
98100
asyncio.run(env.cache_assets())
99101

100102
self._deploy(env, tmp_path)
@@ -103,9 +105,11 @@ def test_archive(self, tmp_path):
103105

104106
def test_pin(self, tmp_path):
105107
env = self._get_env(tmp_path)
106-
if not isinstance(env, PinnableEnvBase):
108+
if not env.is_pinnable():
107109
pytest.skip("Environment is not pinnable.")
108110

111+
assert isinstance(env, PinnableEnvBase)
112+
109113
asyncio.run(env.pin())
110114
assert env.pinfile.exists()
111115
print("Pinfile content:", env.pinfile.read_text(), sep="\n")
@@ -164,9 +168,10 @@ def _get_env(self, tmp_path) -> EnvBase:
164168
pinfile_prefix=pinfile_prefix,
165169
)
166170

167-
def _deploy(self, env: DeployableEnvBase, tmp_path):
168-
if not isinstance(env, DeployableEnvBase):
171+
def _deploy(self, env: EnvBase, tmp_path):
172+
if not env.is_deployable():
169173
pytest.skip("Environment is not deployable.")
170174

175+
assert isinstance(env, DeployableEnvBase)
171176
asyncio.run(env.deploy())
172177
assert any((tmp_path / "deployments").iterdir())

0 commit comments

Comments
 (0)