|
1 | 1 | import sys
|
| 2 | +import uuid |
2 | 3 | from logging import getLogger
|
3 |
| -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, TypeVar |
4 |
| - |
5 |
| -from redis.asyncio import BlockingConnectionPool, Connection, Redis |
| 4 | +from typing import ( |
| 5 | + TYPE_CHECKING, |
| 6 | + Any, |
| 7 | + AsyncGenerator, |
| 8 | + Awaitable, |
| 9 | + Callable, |
| 10 | + Dict, |
| 11 | + Optional, |
| 12 | + TypeVar, |
| 13 | +) |
| 14 | + |
| 15 | +from redis.asyncio import BlockingConnectionPool, Connection, Redis, ResponseError |
| 16 | +from taskiq import AckableMessage |
6 | 17 | from taskiq.abc.broker import AsyncBroker
|
7 | 18 | from taskiq.abc.result_backend import AsyncResultBackend
|
8 | 19 | from taskiq.message import BrokerMessage
|
@@ -132,3 +143,110 @@ async def listen(self) -> AsyncGenerator[bytes, None]:
|
132 | 143 | except ConnectionError as exc:
|
133 | 144 | logger.warning("Redis connection error: %s", exc)
|
134 | 145 | continue
|
| 146 | + |
| 147 | + |
| 148 | +class RedisStreamBroker(BaseRedisBroker): |
| 149 | + """ |
| 150 | + Redis broker that uses streams for task distribution. |
| 151 | +
|
| 152 | + You can read more about streams here: |
| 153 | + https://redis.io/docs/latest/develop/data-types/streams |
| 154 | +
|
| 155 | + This broker supports acknowledgment of messages. |
| 156 | + """ |
| 157 | + |
| 158 | + def __init__( |
| 159 | + self, |
| 160 | + url: str, |
| 161 | + queue_name: str = "taskiq", |
| 162 | + max_connection_pool_size: Optional[int] = None, |
| 163 | + consumer_group_name: str = "taskiq", |
| 164 | + consumer_name: Optional[str] = None, |
| 165 | + consumer_id: str = "$", |
| 166 | + mkstream: bool = True, |
| 167 | + xread_block: int = 10000, |
| 168 | + additional_streams: Optional[Dict[str, str]] = None, |
| 169 | + **connection_kwargs: Any, |
| 170 | + ) -> None: |
| 171 | + super().__init__( |
| 172 | + url, |
| 173 | + task_id_generator=None, |
| 174 | + result_backend=None, |
| 175 | + queue_name=queue_name, |
| 176 | + max_connection_pool_size=max_connection_pool_size, |
| 177 | + **connection_kwargs, |
| 178 | + ) |
| 179 | + self.consumer_group_name = consumer_group_name |
| 180 | + self.consumer_name = consumer_name or str(uuid.uuid4()) |
| 181 | + self.consumer_id = consumer_id |
| 182 | + self.mkstream = mkstream |
| 183 | + self.block = xread_block |
| 184 | + self.additional_streams = additional_streams or {} |
| 185 | + |
| 186 | + async def _declare_consumer_group(self) -> None: |
| 187 | + """ |
| 188 | + Declare consumber group. |
| 189 | +
|
| 190 | + Required for proper work of the broker. |
| 191 | + """ |
| 192 | + streams = {self.queue_name, *self.additional_streams.keys()} |
| 193 | + async with Redis(connection_pool=self.connection_pool) as redis_conn: |
| 194 | + for stream_name in streams: |
| 195 | + try: |
| 196 | + await redis_conn.xgroup_create( |
| 197 | + stream_name, |
| 198 | + self.consumer_group_name, |
| 199 | + id=self.consumer_id, |
| 200 | + mkstream=self.mkstream, |
| 201 | + ) |
| 202 | + except ResponseError as err: |
| 203 | + logger.debug(err) |
| 204 | + |
| 205 | + async def startup(self) -> None: |
| 206 | + """Declare consumer group on startup.""" |
| 207 | + await super().startup() |
| 208 | + await self._declare_consumer_group() |
| 209 | + |
| 210 | + async def kick(self, message: BrokerMessage) -> None: |
| 211 | + """ |
| 212 | + Put a message in a list. |
| 213 | +
|
| 214 | + This method appends a message to the list of all messages. |
| 215 | +
|
| 216 | + :param message: message to append. |
| 217 | + """ |
| 218 | + async with Redis(connection_pool=self.connection_pool) as redis_conn: |
| 219 | + await redis_conn.xadd(self.queue_name, {b"data": message.message}) |
| 220 | + |
| 221 | + def _ack_generator(self, id: str) -> Callable[[], Awaitable[None]]: |
| 222 | + async def _ack() -> None: |
| 223 | + async with Redis(connection_pool=self.connection_pool) as redis_conn: |
| 224 | + await redis_conn.xack( |
| 225 | + self.queue_name, |
| 226 | + self.consumer_group_name, |
| 227 | + id, |
| 228 | + ) |
| 229 | + |
| 230 | + return _ack |
| 231 | + |
| 232 | + async def listen(self) -> AsyncGenerator[AckableMessage, None]: |
| 233 | + """Listen to incoming messages.""" |
| 234 | + async with Redis(connection_pool=self.connection_pool) as redis_conn: |
| 235 | + while True: |
| 236 | + fetched = await redis_conn.xreadgroup( |
| 237 | + self.consumer_group_name, |
| 238 | + self.consumer_name, |
| 239 | + { |
| 240 | + self.queue_name: ">", |
| 241 | + **self.additional_streams, |
| 242 | + }, |
| 243 | + block=self.block, |
| 244 | + noack=False, |
| 245 | + ) |
| 246 | + for _, msg_list in fetched: |
| 247 | + for msg_id, msg in msg_list: |
| 248 | + logger.debug("Received message: %s", msg) |
| 249 | + yield AckableMessage( |
| 250 | + data=msg[b"data"], |
| 251 | + ack=self._ack_generator(msg_id), |
| 252 | + ) |
0 commit comments