Skip to content

Commit 2ec1d1c

Browse files
authored
refactoring: ⚡️ Inject self
1 parent 0320d25 commit 2ec1d1c

File tree

4 files changed

+146
-36
lines changed

4 files changed

+146
-36
lines changed

injection/common/lazy.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from collections.abc import Callable, Iterator, Mapping
2-
from contextlib import suppress
32
from types import MappingProxyType
4-
from typing import Any, Generic, TypeVar
3+
from typing import Generic, TypeVar
54

65
from injection.common.tools.threading import thread_lock
76

@@ -13,35 +12,32 @@
1312

1413

1514
class Lazy(Generic[_T]):
16-
__slots__ = ("is_set", "__generator")
15+
__slots__ = ("__cache", "__is_set")
1716

1817
def __init__(self, factory: Callable[[], _T]):
19-
def generator() -> Iterator[_T]:
20-
nonlocal factory
21-
22-
with thread_lock:
23-
value = factory()
24-
self.is_set = True
25-
del factory
18+
self.__setup_cache(factory)
2619

27-
while True:
28-
yield value
20+
def __invert__(self) -> _T:
21+
return next(self.__cache)
2922

30-
self.is_set = False
31-
self.__generator = generator()
23+
@property
24+
def is_set(self) -> bool:
25+
return self.__is_set
3226

33-
def __invert__(self) -> _T:
34-
return next(self.__generator)
27+
def __setup_cache(self, factory: Callable[[], _T]):
28+
def new_cache() -> Iterator[_T]:
29+
with thread_lock:
30+
self.__is_set = True
3531

36-
def __call__(self) -> _T:
37-
return ~self
32+
nonlocal factory
33+
cached = factory()
34+
del factory
3835

39-
def __setattr__(self, name: str, value: Any, /):
40-
with suppress(AttributeError):
41-
if self.is_set:
42-
raise TypeError(f"`{self}` is frozen.")
36+
while True:
37+
yield cached
4338

44-
return super().__setattr__(name, value)
39+
self.__cache = new_cache()
40+
self.__is_set = False
4541

4642

4743
class LazyMapping(Mapping[_K, _V]):

injection/common/queue.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from abc import abstractmethod
2+
from collections import deque
3+
from collections.abc import Iterator
4+
from dataclasses import dataclass, field
5+
from typing import NoReturn, Protocol, TypeVar
6+
7+
from injection.common.tools.threading import thread_lock
8+
9+
__all__ = ("LimitedQueue",)
10+
11+
_T = TypeVar("_T")
12+
13+
14+
class Queue(Iterator[_T], Protocol):
15+
__slots__ = ()
16+
17+
@abstractmethod
18+
def add(self, item: _T):
19+
raise NotImplementedError
20+
21+
22+
@dataclass(repr=False, frozen=True, slots=True)
23+
class SimpleQueue(Queue[_T]):
24+
__items: deque[_T] = field(default_factory=deque, init=False)
25+
26+
def __next__(self) -> _T:
27+
try:
28+
return self.__items.popleft()
29+
except IndexError as exc:
30+
raise StopIteration from exc
31+
32+
def add(self, item: _T):
33+
self.__items.append(item)
34+
return self
35+
36+
37+
class NoQueue(Queue[_T]):
38+
__slots__ = ()
39+
40+
def __bool__(self) -> bool:
41+
return False
42+
43+
def __next__(self) -> NoReturn:
44+
raise StopIteration
45+
46+
def add(self, item: _T) -> NoReturn:
47+
raise TypeError("Queue doesn't exist.")
48+
49+
50+
@dataclass(repr=False, slots=True)
51+
class LimitedQueue(Queue[_T]):
52+
__queue: Queue[_T] = field(default_factory=SimpleQueue)
53+
54+
def __next__(self) -> _T:
55+
if not self.__queue:
56+
raise StopIteration
57+
58+
try:
59+
return next(self.__queue)
60+
except StopIteration as exc:
61+
with thread_lock:
62+
self.__queue = NoQueue()
63+
64+
raise exc
65+
66+
def add(self, item: _T):
67+
self.__queue.add(item)
68+
return self

injection/core/module.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from injection.common.event import Event, EventChannel, EventListener
4040
from injection.common.lazy import Lazy, LazyMapping
41+
from injection.common.queue import LimitedQueue
4142
from injection.common.tools.threading import (
4243
frozen_collection,
4344
synchronized,
@@ -418,9 +419,14 @@ def decorator(wp):
418419
wp.__init__ = self.inject(wp.__init__)
419420
return wp
420421

421-
wrapper = InjectedFunction(wp).update(self)
422-
self.add_listener(wrapper)
423-
return wrapper
422+
function = InjectedFunction(wp)
423+
424+
@function.setup
425+
def listen():
426+
function.update(self)
427+
self.add_listener(function)
428+
429+
return function
424430

425431
return decorator(wrapped) if wrapped else decorator
426432

@@ -612,22 +618,36 @@ class Arguments(NamedTuple):
612618

613619

614620
class InjectedFunction(EventListener):
615-
__slots__ = ("__dict__", "__wrapper", "__dependencies", "__owner")
621+
__slots__ = (
622+
"__dict__",
623+
"__signature__",
624+
"__dependencies",
625+
"__owner",
626+
"__setup_queue",
627+
"__wrapper",
628+
)
616629

617630
def __init__(self, wrapped: Callable[..., Any], /):
618631
update_wrapper(self, wrapped)
619-
self.__signature__ = Lazy[Signature](
620-
lambda: inspect.signature(wrapped, eval_str=True)
621-
)
622632

623633
@wraps(wrapped)
624634
def wrapper(*args, **kwargs):
635+
self.__consume_setup_queue()
625636
args, kwargs = self.bind(args, kwargs)
626637
return wrapped(*args, **kwargs)
627638

628639
self.__wrapper = wrapper
629640
self.__dependencies = Dependencies.empty()
630641
self.__owner = None
642+
self.__setup_queue = LimitedQueue[Callable[[], Any]]()
643+
self.setup(
644+
lambda: self.__set_signature(
645+
inspect.signature(
646+
wrapped,
647+
eval_str=True,
648+
)
649+
)
650+
)
631651

632652
def __repr__(self) -> str:
633653
return repr(self.__wrapper)
@@ -638,7 +658,7 @@ def __str__(self) -> str:
638658
def __call__(self, /, *args, **kwargs) -> Any:
639659
return self.__wrapper(*args, **kwargs)
640660

641-
def __get__(self, instance: object | None, owner: type):
661+
def __get__(self, instance: object = None, owner: type = None):
642662
if instance is None:
643663
return self
644664

@@ -647,7 +667,7 @@ def __get__(self, instance: object | None, owner: type):
647667
def __set_name__(self, owner: type, name: str):
648668
if self.__dependencies.are_resolved:
649669
raise TypeError(
650-
"`__set_name__` is called after dependencies have been resolved."
670+
"Function owner must be assigned before dependencies are resolved."
651671
)
652672

653673
if self.__owner:
@@ -657,7 +677,7 @@ def __set_name__(self, owner: type, name: str):
657677

658678
@property
659679
def signature(self) -> Signature:
660-
return self.__signature__()
680+
return self.__signature__
661681

662682
def bind(
663683
self,
@@ -671,9 +691,9 @@ def bind(
671691
return Arguments(args, kwargs)
672692

673693
bound = self.signature.bind_partial(*args, **kwargs)
674-
dependencies = self.__dependencies.arguments
675-
bound.arguments = dependencies | bound.arguments
676-
694+
bound.arguments = (
695+
bound.arguments | self.__dependencies.arguments | bound.arguments
696+
)
677697
return Arguments(bound.args, bound.kwargs)
678698

679699
def update(self, module: Module):
@@ -686,6 +706,13 @@ def update(self, module: Module):
686706

687707
return self
688708

709+
def setup(self, wrapped: Callable[[], Any] = None, /):
710+
def decorator(wp):
711+
self.__setup_queue.add(wp)
712+
return wp
713+
714+
return decorator(wrapped) if wrapped else decorator
715+
689716
@singledispatchmethod
690717
def on_event(self, event: Event, /):
691718
pass
@@ -695,3 +722,15 @@ def on_event(self, event: Event, /):
695722
def _(self, event: ModuleEvent, /) -> ContextManager:
696723
yield
697724
self.update(event.on_module)
725+
726+
def __consume_setup_queue(self):
727+
for function in self.__setup_queue:
728+
function()
729+
730+
return self
731+
732+
def __set_signature(self, signature: Signature):
733+
with thread_lock:
734+
self.__signature__ = signature
735+
736+
return self

tests/test_inject.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def test_inject_with_annotated_and_union(self):
5353
def test_inject_with_optional(self):
5454
self.assert_inject(Optional[SomeInjectable])
5555

56+
def test_inject_with_no_parameter(self):
57+
@inject
58+
def my_function():
59+
pass
60+
61+
my_function()
62+
5663
def test_inject_with_positional_only_parameter(self):
5764
@inject
5865
def my_function(instance: SomeInjectable, /, **kw):

0 commit comments

Comments
 (0)