Skip to content

Commit 56a8651

Browse files
authored
feat: ✨ override parameter
1 parent 3c8841a commit 56a8651

File tree

5 files changed

+72
-22
lines changed

5 files changed

+72
-22
lines changed

documentation/basic-usage.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,20 @@ class C(B):
120120
...
121121
```
122122

123+
If a class is registered in a package and you want to override it, there is the `override` parameter:
124+
125+
```python
126+
@singleton
127+
class A:
128+
...
129+
130+
# ...
131+
132+
@singleton(on=A, override=True)
133+
class B(A):
134+
...
135+
```
136+
123137
## Recipes
124138

125139
A recipe is a function that tells the injector how to construct the instance to be injected. It is important to specify

injection/_pkg.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class Module:
6060
*,
6161
cls: type[Injectable] = ...,
6262
on: type | Iterable[type] | UnionType = ...,
63+
override: bool = ...,
6364
):
6465
"""
6566
Decorator applicable to a class or function. It is used to indicate how the
@@ -73,6 +74,7 @@ class Module:
7374
/,
7475
*,
7576
on: type | Iterable[type] | UnionType = ...,
77+
override: bool = ...,
7678
):
7779
"""
7880
Decorator applicable to a class or function. It is used to indicate how the
@@ -84,6 +86,8 @@ class Module:
8486
self,
8587
instance: _T,
8688
on: type | Iterable[type] | UnionType = ...,
89+
*,
90+
override: bool = ...,
8791
) -> _T:
8892
"""
8993
Function for registering a specific instance to be injected. This is useful for
@@ -149,7 +153,7 @@ class ModulePriorities(Enum):
149153

150154
@runtime_checkable
151155
class Injectable(Protocol[_T]):
152-
def __init__(self, factory: Callable[[], _T] = ..., *args, **kwargs): ...
156+
def __init__(self, factory: Callable[[], _T] = ..., /): ...
153157
@property
154158
def is_locked(self) -> bool: ...
155159
def unlock(self): ...

injection/core/module.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class ContainerEvent(Event, ABC):
6161
@dataclass(frozen=True, slots=True)
6262
class ContainerDependenciesUpdated(ContainerEvent):
6363
classes: Collection[type]
64+
override: bool
6465

6566
def __str__(self) -> str:
6667
length = len(self.classes)
@@ -212,26 +213,19 @@ def is_locked(self) -> bool:
212213
def __injectables(self) -> frozenset[Injectable]:
213214
return frozenset(self.__data.values())
214215

215-
def update(self, classes: Types, injectable: Injectable):
216-
classes = frozenset(get_origins(*classes))
216+
def update(self, classes: Types, injectable: Injectable, override: bool):
217+
values = {origin: injectable for origin in get_origins(*classes)}
217218

218-
if classes:
219-
event = ContainerDependenciesUpdated(self, classes)
219+
if values:
220+
event = ContainerDependenciesUpdated(self, values, override)
220221

221222
with self.notify(event):
222-
self.__data.update(
223-
(self.check_if_exists(cls), injectable) for cls in classes
224-
)
225-
226-
return self
223+
if not override:
224+
self.__check_if_exists(*values)
227225

228-
def check_if_exists(self, cls: type) -> type:
229-
if cls in self.__data:
230-
raise RuntimeError(
231-
f"An injectable already exists for the class `{format_type(cls)}`."
232-
)
226+
self.__data.update(values)
233227

234-
return cls
228+
return self
235229

236230
def unlock(self):
237231
for injectable in self.__injectables:
@@ -244,6 +238,13 @@ def add_listener(self, listener: EventListener):
244238
def notify(self, event: Event) -> ContextManager | ContextDecorator:
245239
return self.__channel.dispatch(event)
246240

241+
def __check_if_exists(self, *classes: type):
242+
for cls in classes:
243+
if cls in self.__data:
244+
raise RuntimeError(
245+
f"An injectable already exists for the class `{format_type(cls)}`."
246+
)
247+
247248

248249
"""
249250
Module
@@ -280,7 +281,7 @@ def __getitem__(self, cls: type[_T] | UnionType, /) -> Injectable[_T]:
280281
raise NoInjectable(cls)
281282

282283
def __setitem__(self, cls: type | UnionType, injectable: Injectable, /):
283-
self.update((cls,), injectable)
284+
self.update((cls,), injectable, override=True)
284285

285286
def __contains__(self, cls: type | UnionType, /) -> bool:
286287
return any(cls in broker for broker in self.__brokers)
@@ -304,22 +305,29 @@ def injectable(
304305
*,
305306
cls: type[Injectable] = NewInjectable,
306307
on: type | Types = None,
308+
override: bool = False,
307309
):
308310
def decorator(wp):
309311
factory = self.inject(wp, return_factory=True)
310312
injectable = cls(factory)
311313
classes = find_types(wp, on)
312-
self.update(classes, injectable)
314+
self.update(classes, injectable, override)
313315
return wp
314316

315317
return decorator(wrapped) if wrapped else decorator
316318

317319
singleton = partialmethod(injectable, cls=SingletonInjectable)
318320

319-
def set_constant(self, instance: _T, on: type | Types = None) -> _T:
321+
def set_constant(
322+
self,
323+
instance: _T,
324+
on: type | Types = None,
325+
*,
326+
override: bool = False,
327+
) -> _T:
320328
cls = type(instance)
321329

322-
@self.injectable(on=(cls, on))
330+
@self.injectable(on=(cls, on), override=override)
323331
def get_constant():
324332
return instance
325333

@@ -364,8 +372,8 @@ def get_instance(self, cls: type[_T], none: bool = True) -> _T | None:
364372
def get_lazy_instance(self, cls: type[_T]) -> Lazy[_T | None]:
365373
return Lazy(lambda: self.get_instance(cls))
366374

367-
def update(self, classes: Types, injectable: Injectable):
368-
self.__container.update(classes, injectable)
375+
def update(self, classes: Types, injectable: Injectable, override: bool = False):
376+
self.__container.update(classes, injectable, override)
369377
return self
370378

371379
def use(

tests/test_injectable.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,15 @@ class B(A):
166166
@injectable(on=A)
167167
class C(A):
168168
pass
169+
170+
def test_injectable_with_override(self):
171+
@injectable
172+
class A:
173+
pass
174+
175+
@injectable(on=A, override=True)
176+
class B(A):
177+
pass
178+
179+
a = get_instance(A)
180+
assert isinstance(a, B)

tests/test_singleton.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,15 @@ class B(A):
164164
@singleton(on=A)
165165
class C(A):
166166
pass
167+
168+
def test_injectable_with_override(self):
169+
@singleton
170+
class A:
171+
pass
172+
173+
@singleton(on=A, override=True)
174+
class B(A):
175+
pass
176+
177+
a = get_instance(A)
178+
assert isinstance(a, B)

0 commit comments

Comments
 (0)