Skip to content

Commit a005723

Browse files
authored
feat: ✨ Mark AsyncInjectedFunction as coroutine function
1 parent c7e81c8 commit a005723

File tree

5 files changed

+11
-64
lines changed

5 files changed

+11
-64
lines changed

injection/_core/common/asynchronous.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from abc import abstractmethod
22
from collections.abc import Awaitable, Callable, Generator
33
from dataclasses import dataclass
4-
from typing import Any, NoReturn, Protocol, override, runtime_checkable
4+
from typing import Any, NoReturn, Protocol, runtime_checkable
55

66

77
@dataclass(repr=False, eq=False, frozen=True, slots=True)
88
class SimpleAwaitable[T](Awaitable[T]):
99
callable: Callable[..., Awaitable[T]]
1010

11-
@override
1211
def __await__(self) -> Generator[Any, Any, T]:
1312
return self.callable().__await__()
1413

@@ -30,11 +29,9 @@ def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
3029
class AsyncCaller[**P, T](Caller[P, T]):
3130
callable: Callable[P, Awaitable[T]]
3231

33-
@override
3432
async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
3533
return await self.callable(*args, **kwargs)
3634

37-
@override
3835
def call(self, /, *args: P.args, **kwargs: P.kwargs) -> NoReturn:
3936
raise RuntimeError(
4037
"Synchronous call isn't supported for an asynchronous Callable."
@@ -45,10 +42,8 @@ def call(self, /, *args: P.args, **kwargs: P.kwargs) -> NoReturn:
4542
class SyncCaller[**P, T](Caller[P, T]):
4643
callable: Callable[P, T]
4744

48-
@override
4945
async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
5046
return self.callable(*args, **kwargs)
5147

52-
@override
5348
def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
5449
return self.callable(*args, **kwargs)

injection/_core/common/invertible.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import abstractmethod
22
from collections.abc import Callable
33
from dataclasses import dataclass
4-
from typing import Protocol, override, runtime_checkable
4+
from typing import Protocol, runtime_checkable
55

66

77
@runtime_checkable
@@ -15,6 +15,5 @@ def __invert__(self) -> T:
1515
class SimpleInvertible[T](Invertible[T]):
1616
callable: Callable[..., T]
1717

18-
@override
1918
def __invert__(self) -> T:
2019
return self.callable()

injection/_core/common/lazy.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections.abc import Callable, Iterator, Mapping
22
from types import MappingProxyType
3-
from typing import override
43

54
from injection._core.common.invertible import Invertible
65

@@ -14,7 +13,6 @@ class Lazy[T](Invertible[T]):
1413
def __init__(self, factory: Callable[..., T]) -> None:
1514
self.__setup_cache(factory)
1615

17-
@override
1816
def __invert__(self) -> T:
1917
return next(self.__iterator)
2018

@@ -44,15 +42,12 @@ class LazyMapping[K, V](Mapping[K, V]):
4442
def __init__(self, iterator: Iterator[tuple[K, V]]) -> None:
4543
self.__lazy = Lazy(lambda: MappingProxyType(dict(iterator)))
4644

47-
@override
4845
def __getitem__(self, key: K, /) -> V:
4946
return (~self.__lazy)[key]
5047

51-
@override
5248
def __iter__(self) -> Iterator[K]:
5349
yield from ~self.__lazy
5450

55-
@override
5651
def __len__(self) -> int:
5752
return len(~self.__lazy)
5853

injection/_core/injectables.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import MutableMapping
33
from contextlib import suppress
44
from dataclasses import dataclass
5-
from typing import Any, ClassVar, NoReturn, Protocol, override, runtime_checkable
5+
from typing import Any, ClassVar, NoReturn, Protocol, runtime_checkable
66

77
from injection._core.common.asynchronous import Caller
88
from injection._core.common.threading import synchronized
@@ -37,11 +37,9 @@ class BaseInjectable[T](Injectable[T], ABC):
3737
class SimpleInjectable[T](BaseInjectable[T]):
3838
__slots__ = ()
3939

40-
@override
4140
async def aget_instance(self) -> T:
4241
return await self.factory.acall()
4342

44-
@override
4543
def get_instance(self) -> T:
4644
return self.factory.call()
4745

@@ -56,15 +54,12 @@ def cache(self) -> MutableMapping[str, Any]:
5654
return self.__dict__
5755

5856
@property
59-
@override
6057
def is_locked(self) -> bool:
6158
return self.__key in self.cache
6259

63-
@override
6460
def unlock(self) -> None:
6561
self.cache.clear()
6662

67-
@override
6863
async def aget_instance(self) -> T:
6964
with suppress(KeyError):
7065
return self.__check_instance()
@@ -75,7 +70,6 @@ async def aget_instance(self) -> T:
7570

7671
return instance
7772

78-
@override
7973
def get_instance(self) -> T:
8074
with suppress(KeyError):
8175
return self.__check_instance()
@@ -97,10 +91,8 @@ def __set_instance(self, value: T) -> None:
9791
class ShouldBeInjectable[T](Injectable[T]):
9892
cls: type[T]
9993

100-
@override
10194
async def aget_instance(self) -> T:
10295
return self.get_instance()
10396

104-
@override
10597
def get_instance(self) -> NoReturn:
10698
raise InjectionError(f"`{self.cls}` should be an injectable.")

injection/_core/module.py

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import inspect
54
from abc import ABC, abstractmethod
65
from collections import OrderedDict
76
from collections.abc import (
@@ -17,7 +16,8 @@
1716
from dataclasses import dataclass, field
1817
from enum import StrEnum
1918
from functools import partialmethod, singledispatchmethod, update_wrapper
20-
from inspect import Signature, isclass, iscoroutinefunction
19+
from inspect import Signature, isclass, iscoroutinefunction, markcoroutinefunction
20+
from inspect import signature as inspect_signature
2121
from logging import Logger, getLogger
2222
from queue import Empty, Queue
2323
from types import MethodType
@@ -29,9 +29,7 @@
2929
NamedTuple,
3030
Protocol,
3131
Self,
32-
TypeGuard,
3332
overload,
34-
override,
3533
runtime_checkable,
3634
)
3735
from uuid import uuid4
@@ -76,7 +74,6 @@ class LocatorDependenciesUpdated[T](LocatorEvent):
7674
classes: Collection[InputType[T]]
7775
mode: Mode
7876

79-
@override
8077
def __str__(self) -> str:
8178
length = len(self.classes)
8279
formatted_types = ", ".join(f"`{cls}`" for cls in self.classes)
@@ -95,7 +92,6 @@ class ModuleEvent(Event, ABC):
9592
class ModuleEventProxy(ModuleEvent):
9693
event: Event
9794

98-
@override
9995
def __str__(self) -> str:
10096
return f"`{self.module}` has propagated an event: {self.origin}"
10197

@@ -116,7 +112,6 @@ class ModuleAdded(ModuleEvent):
116112
module_added: Module
117113
priority: Priority
118114

119-
@override
120115
def __str__(self) -> str:
121116
return f"`{self.module}` now uses `{self.module_added}`."
122117

@@ -125,7 +120,6 @@ def __str__(self) -> str:
125120
class ModuleRemoved(ModuleEvent):
126121
module_removed: Module
127122

128-
@override
129123
def __str__(self) -> str:
130124
return f"`{self.module}` no longer uses `{self.module_removed}`."
131125

@@ -135,7 +129,6 @@ class ModulePriorityUpdated(ModuleEvent):
135129
module_updated: Module
136130
priority: Priority
137131

138-
@override
139132
def __str__(self) -> str:
140133
return (
141134
f"In `{self.module}`, the priority `{self.priority}` "
@@ -242,7 +235,6 @@ class Locator(Broker):
242235

243236
static_hooks: ClassVar[LocatorHooks[Any]] = LocatorHooks.default()
244237

245-
@override
246238
def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
247239
for input_class in self.__standardize_inputs((cls,)):
248240
try:
@@ -254,15 +246,13 @@ def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
254246

255247
raise NoInjectable(cls)
256248

257-
@override
258249
def __contains__(self, cls: InputType[Any], /) -> bool:
259250
return any(
260251
input_class in self.__records
261252
for input_class in self.__standardize_inputs((cls,))
262253
)
263254

264255
@property
265-
@override
266256
def is_locked(self) -> bool:
267257
return any(injectable.is_locked for injectable in self.__injectables)
268258

@@ -284,15 +274,13 @@ def update[T](self, updater: Updater[T]) -> Self:
284274

285275
return self
286276

287-
@override
288277
@synchronized()
289278
def unlock(self) -> Self:
290279
for injectable in self.__injectables:
291280
injectable.unlock()
292281

293282
return self
294283

295-
@override
296284
async def all_ready(self) -> None:
297285
for injectable in self.__injectables:
298286
await injectable.aget_instance()
@@ -387,20 +375,17 @@ class Module(Broker, EventListener):
387375
def __post_init__(self) -> None:
388376
self.__locator.add_listener(self)
389377

390-
@override
391378
def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
392379
for broker in self.__brokers:
393380
with suppress(KeyError):
394381
return broker[cls]
395382

396383
raise NoInjectable(cls)
397384

398-
@override
399385
def __contains__(self, cls: InputType[Any], /) -> bool:
400386
return any(cls in broker for broker in self.__brokers)
401387

402388
@property
403-
@override
404389
def is_locked(self) -> bool:
405390
return any(broker.is_locked for broker in self.__brokers)
406391

@@ -695,15 +680,13 @@ def change_priority(self, module: Module, priority: Priority | PriorityStr) -> S
695680

696681
return self
697682

698-
@override
699683
@synchronized()
700684
def unlock(self) -> Self:
701685
for broker in self.__brokers:
702686
broker.unlock()
703687

704688
return self
705689

706-
@override
707690
async def all_ready(self) -> None:
708691
for broker in self.__brokers:
709692
await broker.all_ready()
@@ -720,7 +703,6 @@ def remove_listener(self, listener: EventListener) -> Self:
720703
self.__channel.remove_listener(listener)
721704
return self
722705

723-
@override
724706
def on_event(self, event: Event, /) -> ContextManager[None] | None:
725707
self_event = ModuleEventProxy(self, event)
726708
return self.dispatch(self_event)
@@ -890,7 +872,7 @@ def signature(self) -> Signature:
890872
return self.__signature
891873

892874
with synchronized():
893-
signature = inspect.signature(self.wrapped, eval_str=True)
875+
signature = inspect_signature(self.wrapped, eval_str=True)
894876
self.__signature = signature
895877

896878
return signature
@@ -915,13 +897,11 @@ def bind(
915897
additional_arguments = self.__dependencies.get_arguments()
916898
return self.__bind(args, kwargs, additional_arguments)
917899

918-
@override
919900
async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
920901
self.__setup()
921902
arguments = await self.abind(args, kwargs)
922903
return self.wrapped(*arguments.args, **arguments.kwargs)
923904

924-
@override
925905
def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
926906
self.__setup()
927907
arguments = self.bind(args, kwargs)
@@ -957,7 +937,6 @@ def decorator(wp: Callable[_P, _T]) -> Callable[_P, _T]:
957937
return decorator(wrapped) if wrapped else decorator
958938

959939
@singledispatchmethod
960-
@override
961940
def on_event(self, event: Event, /) -> ContextManager[None] | None: # type: ignore[override]
962941
return None
963942

@@ -1014,11 +993,9 @@ def __init__(self, metadata: InjectMetadata[P, T]) -> None:
1014993
update_wrapper(self, metadata.wrapped)
1015994
self.__inject_metadata__ = metadata
1016995

1017-
@override
1018996
def __repr__(self) -> str: # pragma: no cover
1019997
return repr(self.__inject_metadata__.wrapped)
1020998

1021-
@override
1022999
def __str__(self) -> str: # pragma: no cover
10231000
return str(self.__inject_metadata__.wrapped)
10241001

@@ -1043,34 +1020,23 @@ def __set_name__(self, owner: type, name: str) -> None:
10431020
class AsyncInjectedFunction[**P, T](InjectedFunction[P, Awaitable[T]]):
10441021
__slots__ = ()
10451022

1046-
@override
1023+
def __init__(self, metadata: InjectMetadata[P, Awaitable[T]]) -> None:
1024+
super().__init__(metadata)
1025+
markcoroutinefunction(self)
1026+
10471027
async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
10481028
return await (await self.__inject_metadata__.acall(*args, **kwargs))
10491029

10501030

10511031
class SyncInjectedFunction[**P, T](InjectedFunction[P, T]):
10521032
__slots__ = ()
10531033

1054-
@override
10551034
def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
10561035
return self.__inject_metadata__.call(*args, **kwargs)
10571036

10581037

1059-
def _is_coroutine_function[**P, T](
1060-
function: Callable[P, T] | Callable[P, Awaitable[T]],
1061-
) -> TypeGuard[Callable[P, Awaitable[T]]]:
1062-
if iscoroutinefunction(function):
1063-
return True
1064-
1065-
elif isclass(function):
1066-
return False
1067-
1068-
call = getattr(function, "__call__", None)
1069-
return iscoroutinefunction(call)
1070-
1071-
10721038
def _get_caller[**P, T](function: Callable[P, T]) -> Caller[P, T]:
1073-
if _is_coroutine_function(function):
1039+
if iscoroutinefunction(function):
10741040
return AsyncCaller(function)
10751041

10761042
elif isinstance(function, InjectedFunction):

0 commit comments

Comments
 (0)