|
2 | 2 | import uuid |
3 | 3 |
|
4 | 4 | import pytest |
5 | | -from aio_pika import Channel, Message |
| 5 | +from aio_pika import Channel, ExchangeType, Message |
6 | 6 | from aio_pika.exceptions import QueueEmpty |
7 | 7 | from taskiq import AckableMessage, BrokerMessage |
8 | 8 | from taskiq.utils import maybe_awaitable |
9 | 9 |
|
10 | 10 | from taskiq_aio_pika.broker import AioPikaBroker |
11 | 11 |
|
12 | 12 |
|
13 | | -async def get_first_task(broker: AioPikaBroker) -> AckableMessage: # type: ignore |
| 13 | +async def get_first_task(broker: AioPikaBroker) -> AckableMessage: |
14 | 14 | """ |
15 | 15 | Get first message from the queue. |
16 | 16 |
|
17 | 17 | :param broker: async message broker. |
18 | 18 | :return: first message from listen method |
19 | 19 | """ |
20 | | - async for message in broker.listen(): # noqa: RET503 |
| 20 | + async for message in broker.listen(): |
21 | 21 | return message |
| 22 | + return None # type: ignore |
22 | 23 |
|
23 | 24 |
|
24 | 25 | @pytest.mark.anyio |
@@ -219,3 +220,50 @@ async def test_delayed_message_with_plugin( |
219 | 220 | await asyncio.sleep(2) |
220 | 221 |
|
221 | 222 | assert await main_queue.get() |
| 223 | + |
| 224 | + |
| 225 | +@pytest.mark.anyio |
| 226 | +async def test_direct_kick( |
| 227 | + broker: AioPikaBroker, |
| 228 | + test_channel: Channel, |
| 229 | + queue_name: str, |
| 230 | + exchange_name: str, |
| 231 | +) -> None: |
| 232 | + """ |
| 233 | + Test that messages are published and read correctly. |
| 234 | +
|
| 235 | + We kick the message and then try to listen to the queue, |
| 236 | + and check that message we got is the same as we sent. |
| 237 | + """ |
| 238 | + queue = await test_channel.get_queue(queue_name) |
| 239 | + exchange = await test_channel.get_exchange(exchange_name) |
| 240 | + await queue.delete() |
| 241 | + await exchange.delete() |
| 242 | + |
| 243 | + broker._declare_exchange = True |
| 244 | + broker._exchange_type = ExchangeType.DIRECT |
| 245 | + broker._routing_key = "direct_routing_key" |
| 246 | + |
| 247 | + await broker.startup() |
| 248 | + |
| 249 | + await test_channel.get_queue(queue_name, ensure=True) |
| 250 | + await test_channel.get_exchange(exchange_name, ensure=True) |
| 251 | + |
| 252 | + task_id = uuid.uuid4().hex |
| 253 | + task_name = uuid.uuid4().hex |
| 254 | + |
| 255 | + sent = BrokerMessage( |
| 256 | + task_id=task_id, |
| 257 | + task_name=task_name, |
| 258 | + message=b"my_msg", |
| 259 | + labels={ |
| 260 | + "label1": "val1", |
| 261 | + }, |
| 262 | + ) |
| 263 | + |
| 264 | + await broker.kick(sent) |
| 265 | + |
| 266 | + message = await asyncio.wait_for(get_first_task(broker), timeout=0.4) |
| 267 | + |
| 268 | + assert message.data == sent.message |
| 269 | + await maybe_awaitable(message.ack()) |
0 commit comments