1
+ from typing import AsyncIterator , Iterator
2
+
1
3
import pytest
2
4
3
5
from taskiq_faststream .utils import resolve_msg
4
6
5
7
6
8
@pytest .mark .anyio
7
9
async def test_regular () -> None :
8
- assert await resolve_msg ("msg" ) == "msg"
10
+ async for m in resolve_msg ("msg" ):
11
+ assert m == "msg"
9
12
10
13
11
14
@pytest .mark .anyio
12
15
async def test_sync_callable () -> None :
13
- assert await resolve_msg (lambda : "msg" ) == "msg"
16
+ async for m in resolve_msg (lambda : "msg" ):
17
+ assert m == "msg"
14
18
15
19
16
20
@pytest .mark .anyio
17
21
async def test_async_callable () -> None :
18
22
async def gen_msg () -> str :
19
23
return "msg"
20
24
21
- assert await resolve_msg (gen_msg ) == "msg"
25
+ async for m in resolve_msg (gen_msg ):
26
+ assert m == "msg"
22
27
23
28
24
29
@pytest .mark .anyio
@@ -30,7 +35,8 @@ def __init__(self) -> None:
30
35
def __call__ (self ) -> str :
31
36
return "msg"
32
37
33
- assert await resolve_msg (C ()) == "msg"
38
+ async for m in resolve_msg (C ()):
39
+ assert m == "msg"
34
40
35
41
36
42
@pytest .mark .anyio
@@ -42,4 +48,23 @@ def __init__(self) -> None:
42
48
async def __call__ (self ) -> str :
43
49
return "msg"
44
50
45
- assert await resolve_msg (C ()) == "msg"
51
+ async for m in resolve_msg (C ()):
52
+ assert m == "msg"
53
+
54
+
55
+ @pytest .mark .anyio
56
+ async def test_async_generator () -> None :
57
+ async def get_msg () -> AsyncIterator [str ]:
58
+ yield "msg"
59
+
60
+ async for m in resolve_msg (get_msg ):
61
+ assert m == "msg"
62
+
63
+
64
+ @pytest .mark .anyio
65
+ async def test_sync_generator () -> None :
66
+ def get_msg () -> Iterator [str ]:
67
+ yield "msg"
68
+
69
+ async for m in resolve_msg (get_msg ):
70
+ assert m == "msg"
0 commit comments