diff --git a/.gitignore b/.gitignore index 4115ea0ec..87223fc92 100644 --- a/.gitignore +++ b/.gitignore @@ -34,4 +34,5 @@ venv_examples .coverage **/coverage.xml **/test-report.xml -.ducktape/ +.ducktape +results diff --git a/aio_producer_simple_diagram.md b/aio_producer_simple_diagram.md new file mode 100644 index 000000000..38778bce2 --- /dev/null +++ b/aio_producer_simple_diagram.md @@ -0,0 +1,47 @@ +# AIOProducer Simple Class Diagram + +``` +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ AIOProducer │ │BufferTimeout │ │ MessageBatch │ +│ (Orchestrator) │ │ Manager │ │ (Value Object) │ +├─────────────────┤ ├─────────────────┤ ├─────────────────┤ +│ • produce() │───▶│ • timeout │ │ • immutable │ +│ • flush() │ │ monitoring │ │ • type safe │ +│ • orchestrate │ │ • mark_activity │ │ • clean data │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + │ │ ▲ + ▼ ▼ │ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ProducerBatch │ │ProducerBatch │ │CallbackManager │ +│ Manager │───▶│ Executor │ │ (Unified Mgmt) │ +├─────────────────┤ ├─────────────────┤ ├─────────────────┤ +│ • create_batches│ │ • execute_batch │ │ • sync/async │ +│ • group topics │ │ • thread pool │ │ • object pool │ +│ • manage buffer │ │ • poll results │ │ • event loop │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + │ │ + ▼ ▼ +┌─────────────────┐ ┌─────────────────┐ +│create_message │ │ ReusableMessage │ +│ _batch() │ │ Callback Pool │ +├─────────────────┤ ├─────────────────┤ +│ • factory func │ │ • pooled objs │ +│ • type safety │◄─────────────────────────│ • auto-return │ +│ • validation │ │ • thread safe │ +└─────────────────┘ └─────────────────┘ +``` + +## Architecture Summary + +**7 Components Total:** +- **1 Orchestrator**: AIOProducer (main API) +- **3 Core Services**: ProducerBatchManager, ProducerBatchExecutor, BufferTimeoutManager +- **1 Unified Manager**: CallbackManager (merged handler + pool) +- **2 Data Objects**: MessageBatch + factory function + +**Key Benefits:** +- ✅ Single responsibility per component +- ✅ Clean dependency injection +- ✅ Unified callback management +- ✅ Immutable data structures +- ✅ Performance optimized pooling diff --git a/src/confluent_kafka/aio/_AIOConsumer.py b/src/confluent_kafka/aio/_AIOConsumer.py index c5c2b1084..6b9961707 100644 --- a/src/confluent_kafka/aio/_AIOConsumer.py +++ b/src/confluent_kafka/aio/_AIOConsumer.py @@ -15,7 +15,6 @@ import asyncio import concurrent.futures import confluent_kafka -import functools import confluent_kafka.aio._common as _common @@ -41,14 +40,7 @@ def __init__(self, consumer_conf, max_workers=2, executor=None): self._consumer = confluent_kafka.Consumer(consumer_conf) async def _call(self, blocking_task, *args, **kwargs): - return (await asyncio.gather( - asyncio.get_running_loop().run_in_executor(self.executor, - functools.partial( - blocking_task, - *args, - **kwargs)) - - ))[0] + return await _common.async_call(self.executor, blocking_task, *args, **kwargs) def _wrap_callback(self, loop, callback, edit_args=None, edit_kwargs=None): def ret(*args, **kwargs): diff --git a/src/confluent_kafka/aio/_AIOProducer.py b/src/confluent_kafka/aio/_AIOProducer.py deleted file mode 100644 index eedd95496..000000000 --- a/src/confluent_kafka/aio/_AIOProducer.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2025 Confluent Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import concurrent.futures -import confluent_kafka -from confluent_kafka import KafkaException as _KafkaException -import functools -import confluent_kafka.aio._common as _common - - -class AIOProducer: - def __init__(self, producer_conf, max_workers=1, - executor=None, auto_poll=True): - if executor is not None: - self.executor = executor - else: - self.executor = concurrent.futures.ThreadPoolExecutor( - max_workers=max_workers) - loop = asyncio.get_event_loop() - wrap_common_callbacks = _common.wrap_common_callbacks - wrap_common_callbacks(loop, producer_conf) - - self._producer = confluent_kafka.Producer(producer_conf) - self._running = False - if auto_poll: - self._running = True - self._running_loop = asyncio.create_task(self._loop()) - - async def stop(self): - if self._running: - self._running = False - await self._running_loop - - async def _loop(self): - while self._running: - await self.poll(1.0) - - async def poll(self, *args, **kwargs): - await self._call(self._producer.poll, *args, **kwargs) - - async def _call(self, blocking_task, *args, **kwargs): - return (await asyncio.gather( - asyncio.get_running_loop().run_in_executor(self.executor, - functools.partial( - blocking_task, - *args, - **kwargs)) - - ))[0] - - async def produce(self, *args, **kwargs): - loop = asyncio.get_event_loop() - result = loop.create_future() - - def on_delivery(err, msg): - if err: - loop.call_soon_threadsafe(result.set_exception, - _KafkaException(err)) - else: - loop.call_soon_threadsafe(result.set_result, msg) - - kwargs['on_delivery'] = on_delivery - await self._call(self._producer.produce, *args, **kwargs) - return await result - - async def init_transactions(self, *args, **kwargs): - return await self._call(self._producer.init_transactions, - *args, **kwargs) - - async def begin_transaction(self, *args, **kwargs): - return await self._call(self._producer.begin_transaction, - *args, **kwargs) - - async def send_offsets_to_transaction(self, *args, **kwargs): - return await self._call(self._producer.send_offsets_to_transaction, - *args, **kwargs) - - async def commit_transaction(self, *args, **kwargs): - return await self._call(self._producer.commit_transaction, - *args, **kwargs) - - async def abort_transaction(self, *args, **kwargs): - return await self._call(self._producer.abort_transaction, - *args, **kwargs) - - async def flush(self, *args, **kwargs): - return await self._call(self._producer.flush, *args, **kwargs) - - async def purge(self, *args, **kwargs): - return await self._call(self._producer.purge, *args, **kwargs) - - async def set_sasl_credentials(self, *args, **kwargs): - return await self._call(self._producer.set_sasl_credentials, - *args, **kwargs) - - async def list_topics(self, *args, **kwargs): - return await self._call(self._producer.list_topics, *args, **kwargs) diff --git a/src/confluent_kafka/aio/__init__.py b/src/confluent_kafka/aio/__init__.py index 1edc80b8a..0bcac2836 100644 --- a/src/confluent_kafka/aio/__init__.py +++ b/src/confluent_kafka/aio/__init__.py @@ -13,5 +13,5 @@ # limitations under the License. from ._AIOConsumer import AIOConsumer -from ._AIOProducer import AIOProducer +from .producer import AIOProducer __all__ = ['AIOConsumer', 'AIOProducer'] diff --git a/src/confluent_kafka/aio/_common.py b/src/confluent_kafka/aio/_common.py index 56aa8c8fc..0e3d733c9 100644 --- a/src/confluent_kafka/aio/_common.py +++ b/src/confluent_kafka/aio/_common.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import functools class AsyncLogger: @@ -48,6 +49,26 @@ def wrap_conf_logger(loop, conf): conf['logger'] = AsyncLogger(loop, conf['logger']) +async def async_call(executor, blocking_task, *args, **kwargs): + """Helper function for blocking operations that need ThreadPool execution + + Args: + executor: ThreadPoolExecutor to use for blocking operations + blocking_task: The blocking function to execute + *args, **kwargs: Arguments to pass to the blocking function + + Returns: + Result of the blocking function execution + """ + return (await asyncio.gather( + asyncio.get_running_loop().run_in_executor(executor, + functools.partial( + blocking_task, + *args, + **kwargs)) + ))[0] + + def wrap_common_callbacks(loop, conf): wrap_conf_callback(loop, conf, 'error_cb') wrap_conf_callback(loop, conf, 'throttle_cb') diff --git a/src/confluent_kafka/aio/producer/_AIOProducer.py b/src/confluent_kafka/aio/producer/_AIOProducer.py new file mode 100644 index 000000000..902429dab --- /dev/null +++ b/src/confluent_kafka/aio/producer/_AIOProducer.py @@ -0,0 +1,280 @@ +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import concurrent.futures +import logging + +import confluent_kafka + +import confluent_kafka.aio._common as _common +from confluent_kafka.aio.producer._producer_batch_processor import ProducerBatchManager +from confluent_kafka.aio.producer._kafka_batch_executor import ProducerBatchExecutor +from confluent_kafka.aio.producer._buffer_timeout_manager import BufferTimeoutManager + + +logger = logging.getLogger(__name__) + + +class AIOProducer: + + # ======================================================================== + # INITIALIZATION AND LIFECYCLE MANAGEMENT + # ======================================================================== + + def __init__(self, producer_conf, max_workers=4, executor=None, batch_size=1000, buffer_timeout=5.0): + if executor is not None: + self.executor = executor + else: + self.executor = concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) + # Store the event loop for async operations + self._loop = asyncio.get_running_loop() + + wrap_common_callbacks = _common.wrap_common_callbacks + wrap_common_callbacks(self._loop, producer_conf) + + self._producer = confluent_kafka.Producer(producer_conf) + + # Batching configuration + self._batch_size = batch_size + + # Producer state management + self._is_closed = False # Track if producer is closed + + # Initialize Kafka batch executor for handling Kafka operations + self._kafka_executor = ProducerBatchExecutor(self._producer, self.executor) + + # Initialize batch processor for message batching and processing + self._batch_processor = ProducerBatchManager(self._kafka_executor) + + # Initialize buffer timeout manager for timeout handling + self._buffer_timeout_manager = BufferTimeoutManager( + self._batch_processor, self._kafka_executor, buffer_timeout) + if buffer_timeout > 0: + self._buffer_timeout_manager.start_timeout_monitoring() + + async def close(self): + """Close the producer and cleanup resources + + This method performs a graceful shutdown sequence to ensure all resources + are properly cleaned up and no messages are lost: + + 1. **Signal Shutdown**: Sets the closed flag to signal the timeout task to stop + 2. **Cancel Timeout Task**: Immediately cancels the buffer timeout monitoring task + 3. **Flush Remaining Messages**: Flushes any buffered messages to ensure delivery + 4. **Shutdown ThreadPool**: Waits for all pending ThreadPool operations to complete + 5. **Cleanup**: Ensures the underlying librdkafka producer is properly closed. The shutdown + is designed to be safe and non-blocking for the asyncio event loop + while ensuring all pending operations complete before the producer is closed. + + Raises: + Exception: May raise exceptions from buffer flushing, but these are logged + and don't prevent the cleanup process from completing. + """ + # Set closed flag to signal timeout task to stop + self._is_closed = True + + # Stop the buffer timeout monitoring task + self._buffer_timeout_manager.stop_timeout_monitoring() + + # Flush any remaining messages in the buffer + if not self._batch_processor.is_buffer_empty(): + try: + await self._flush_buffer() + # Update buffer activity since we just flushed + self._buffer_timeout_manager.mark_activity() + except Exception: + logger.error("Error flushing buffer", exc_info=True) + raise + + # Shutdown the ThreadPool executor and wait for any remaining tasks to complete + # This ensures that all pending poll(), flush(), and other blocking operations + # finish before the producer is considered fully closed + if hasattr(self, 'executor'): + # executor.shutdown(wait=True) is a blocking call that: + # - Prevents new tasks from being submitted to the ThreadPool + # - Waits for all currently executing and queued tasks to complete + # - Returns only when all worker threads have finished + # + # We run this in a separate thread (using None as executor) to avoid + # blocking the asyncio event loop during the potentially long shutdown wait + await asyncio.get_running_loop().run_in_executor( + None, self.executor.shutdown, True + ) + + def __del__(self): + """Cleanup method called during garbage collection + + This ensures that the timeout task is properly cancelled even if + close() wasn't explicitly called. + """ + if hasattr(self, '_is_closed'): + self._is_closed = True + if hasattr(self, '_buffer_timeout_manager'): + self._buffer_timeout_manager.stop_timeout_monitoring() + + # ======================================================================== + # CORE PRODUCER OPERATIONS - Main public API + # ======================================================================== + + async def poll(self, timeout=0, *args, **kwargs): + """Processes delivery callbacks from librdkafka - blocking behavior depends on timeout + + This method triggers any pending delivery reports that have been + queued by librdkafka when messages are delivered or fail to deliver. + + Args: + timeout: Timeout in seconds for waiting for callbacks: + - 0 = non-blocking, return immediately after processing available callbacks + - >0 = block up to timeout seconds waiting for new callbacks to arrive + - -1 = block indefinitely until callbacks are available + + Returns: + Number of callbacks processed during this call + """ + return await self._call(self._producer.poll, timeout, *args, **kwargs) + + async def produce(self, topic, value=None, key=None, *args, **kwargs): + """Batched produce: Accumulates messages in buffer and flushes when threshold reached + + Args: + topic: Kafka topic name (required) + value: Message payload (optional) + key: Message key (optional) + *args, **kwargs: Additional parameters like partition, timestamp, headers + + Returns: + asyncio.Future: Future that resolves to the delivered message or raises exception on failure + """ + result = asyncio.get_running_loop().create_future() + + msg_data = { + 'topic': topic, + 'value': value, + 'key': key + } + + # Add optional parameters to message data + if 'partition' in kwargs: + msg_data['partition'] = kwargs['partition'] + if 'timestamp' in kwargs: + msg_data['timestamp'] = kwargs['timestamp'] + if 'headers' in kwargs: + # Headers are not supported in batch mode due to librdkafka API limitations. + # Use individual synchronous produce() calls if headers are required. + raise NotImplementedError( + "Headers are not supported in AIOProducer batch mode. " + "Use the synchronous Producer.produce() method if headers are required." + ) + + self._batch_processor.add_message(msg_data, result) + + self._buffer_timeout_manager.mark_activity() + + # Check if we should flush the buffer + if self._batch_processor.get_buffer_size() >= self._batch_size: + await self._flush_buffer() + + return result + + async def flush(self, *args, **kwargs): + """Waits until all messages are delivered or timeout + + This method performs a complete flush: + 1. Flushes any buffered messages from local buffer to librdkafka + 2. Waits for librdkafka to deliver/acknowledge all messages + """ + # First, flush any remaining messages in the buffer for all topics + if not self._batch_processor.is_buffer_empty(): + await self._flush_buffer() + # Update buffer activity since we just flushed + self._buffer_timeout_manager.mark_activity() + + # Then flush the underlying producer and wait for delivery confirmation + return await self._call(self._producer.flush, *args, **kwargs) + + async def purge(self, *args, **kwargs): + """Purges messages from internal queues - may block during cleanup""" + # Clear local message buffer and futures + self._batch_processor.clear_buffer() + + # Update buffer activity since we cleared the buffer + self._buffer_timeout_manager.mark_activity() + + return await self._call(self._producer.purge, *args, **kwargs) + + async def list_topics(self, *args, **kwargs): + return await self._call(self._producer.list_topics, *args, **kwargs) + + # ======================================================================== + # TRANSACTION OPERATIONS - Kafka transaction support + # ======================================================================== + + async def init_transactions(self, *args, **kwargs): + """Network call to initialize transactions""" + return await self._call(self._producer.init_transactions, + *args, **kwargs) + + async def begin_transaction(self, *args, **kwargs): + """Network call to begin transaction""" + return await self._call(self._producer.begin_transaction, + *args, **kwargs) + + async def send_offsets_to_transaction(self, *args, **kwargs): + """Network call to send offsets to transaction""" + return await self._call(self._producer.send_offsets_to_transaction, + *args, **kwargs) + + async def commit_transaction(self, *args, **kwargs): + """Network call to commit transaction""" + return await self._call(self._producer.commit_transaction, + *args, **kwargs) + + async def abort_transaction(self, *args, **kwargs): + """Network call to abort transaction""" + return await self._call(self._producer.abort_transaction, + *args, **kwargs) + + # ======================================================================== + # AUTHENTICATION AND SECURITY + # ======================================================================== + + async def set_sasl_credentials(self, *args, **kwargs): + """Authentication operation that may involve network calls""" + return await self._call(self._producer.set_sasl_credentials, + *args, **kwargs) + + # ======================================================================== + # BATCH PROCESSING OPERATIONS - Delegated to BatchProcessor + # ======================================================================== + + async def _flush_buffer(self, target_topic=None): + """Flush the current message buffer using clean batch processing workflow + + This method demonstrates the new architecture where AIOProducer simply + orchestrates the workflow between components: + 1. BatchProcessor creates immutable MessageBatch objects + 2. ProducerBatchExecutor executes each batch + 3. BufferTimeoutManager handles activity tracking + """ + await self._batch_processor.flush_buffer(target_topic) + + # ======================================================================== + # UTILITY METHODS - Helper functions and internal utilities + # ======================================================================== + + async def _call(self, blocking_task, *args, **kwargs): + """Helper method for blocking operations that need ThreadPool execution""" + return await _common.async_call(self.executor, blocking_task, *args, **kwargs) diff --git a/src/confluent_kafka/aio/producer/__init__.py b/src/confluent_kafka/aio/producer/__init__.py new file mode 100644 index 000000000..de0218f4f --- /dev/null +++ b/src/confluent_kafka/aio/producer/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Confluent Kafka AIOProducer Module + +This module contains all the components for the async Kafka producer: + +Core Components: +- AIOProducer: Main async producer with clean architecture +- ProducerBatchProcessor: Message batching and organization +- KafkaBatchExecutor: Kafka operations and thread pool management +- BufferTimeoutManager: Timeout monitoring and automatic flushing + +Data Structures: +- MessageBatch: Immutable value object for batch data + +Architecture Benefits: +✅ Single Responsibility: Each component has one clear purpose +✅ Clean Interfaces: Well-defined boundaries between components +✅ Immutable Data: MessageBatch objects prevent accidental mutations +✅ Better Testing: Components can be tested independently +✅ Maintainable: Clear separation makes changes safer +""" + +from ._AIOProducer import AIOProducer + +# Export the main public API +__all__ = ['AIOProducer'] diff --git a/src/confluent_kafka/aio/producer/_buffer_timeout_manager.py b/src/confluent_kafka/aio/producer/_buffer_timeout_manager.py new file mode 100644 index 000000000..399eb4f4d --- /dev/null +++ b/src/confluent_kafka/aio/producer/_buffer_timeout_manager.py @@ -0,0 +1,138 @@ +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import time +import weakref + +logger = logging.getLogger(__name__) + + +class BufferTimeoutManager: + """Manages buffer timeout and activity tracking for message batching + + This class is responsible for: + - Monitoring buffer inactivity and triggering automatic flushes + - Tracking buffer activity timestamps + - Managing background timeout monitoring tasks + - Coordinating between batch processor and executor for timeout flushes + """ + + def __init__(self, batch_processor, kafka_executor, timeout): + """Initialize the buffer timeout manager + + Args: + batch_processor: ProducerBatchManager instance for creating batches + kafka_executor: ProducerBatchExecutor instance for executing batches + timeout: Timeout in seconds for buffer inactivity (0 disables timeout) + """ + self._batch_processor = batch_processor + self._kafka_executor = kafka_executor + self._timeout = timeout + self._last_activity = time.time() + self._timeout_task = None + self._running = False + + def start_timeout_monitoring(self): + """Start the background task that monitors buffer inactivity + + Creates an async task that runs in the background and periodically checks + if messages have been sitting in the buffer for too long without being flushed. + + Key design decisions: + 1. **Weak Reference**: Uses weakref.ref(self) to prevent circular references + 2. **Self-Canceling**: The task stops itself if the manager is garbage collected + 3. **Adaptive Check Interval**: Uses timeout to determine check frequency + """ + if not self._timeout or self._timeout <= 0: + return # Timeout disabled + + self._running = True + self._timeout_task = asyncio.create_task(self._monitor_timeout()) + + def stop_timeout_monitoring(self): + """Stop and cleanup the buffer timeout monitoring task""" + self._running = False + if self._timeout_task and not self._timeout_task.done(): + self._timeout_task.cancel() + self._timeout_task = None + + def mark_activity(self): + """Update the timestamp of the last buffer activity + + This method should be called whenever: + 1. Messages are added to the buffer (in produce()) + 2. Buffer is manually flushed + 3. Buffer is purged/cleared + """ + self._last_activity = time.time() + + async def _monitor_timeout(self): + """Monitor buffer timeout in background task + + This method runs continuously in the background, checking for buffer inactivity + and triggering flushes when the timeout threshold is exceeded. + """ + # Use weak reference to avoid circular reference and allow garbage collection + manager_ref = weakref.ref(self) + + while True: + # Check interval should be proportional to buffer timeout for efficiency + manager = manager_ref() + if manager is None or not manager._running: + break + + # Calculate adaptive check interval: timeout/2 with bounds + # Examples: 0.1s→0.1s, 1s→0.5s, 5s→1.0s, 30s→1.0s + check_interval = max(0.1, min(1.0, manager._timeout / 2)) + await asyncio.sleep(check_interval) + + # Re-check manager after sleep + manager = manager_ref() + if manager is None or not manager._running: + break + + # Check if buffer has been inactive for too long + time_since_activity = time.time() - manager._last_activity + if (time_since_activity >= manager._timeout and + not manager._batch_processor.is_buffer_empty()): + + try: + # Flush the buffer due to timeout + await manager._flush_buffer_due_to_timeout() + # Update activity since we just flushed + manager.mark_activity() + except Exception: + logger.error("Error flushing buffer due to timeout", exc_info=True) + # Re-raise all exceptions - don't swallow any errors + raise + + async def _flush_buffer_due_to_timeout(self): + """Flush buffer due to timeout by coordinating batch processor and executor + + This method handles the complete timeout flush workflow: + 1. Create batches from the batch processor + 2. Execute each batch via the Kafka executor + 3. Clear the processed messages from the buffer + """ + # Create batches from current buffer + batches = self._batch_processor.create_batches() + + # Execute all batches + for batch in batches: + await self._kafka_executor.execute_batch(batch.topic, batch.messages) + + # Clear the buffer since all messages were processed + self._batch_processor.clear_buffer() diff --git a/src/confluent_kafka/aio/producer/_kafka_batch_executor.py b/src/confluent_kafka/aio/producer/_kafka_batch_executor.py new file mode 100644 index 000000000..42bf833fa --- /dev/null +++ b/src/confluent_kafka/aio/producer/_kafka_batch_executor.py @@ -0,0 +1,110 @@ +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging + +logger = logging.getLogger(__name__) + + +class ProducerBatchExecutor: + """Executes Kafka batch operations via thread pool + + This class is responsible for: + - Executing produce_batch operations against confluent_kafka.Producer + - Handling partial batch failures from librdkafka + - Managing thread pool execution to avoid blocking the event loop + - Processing delivery callbacks for successful messages + """ + + def __init__(self, producer, executor): + """Initialize the Kafka batch executor + + Args: + producer: confluent_kafka.Producer instance for Kafka operations + executor: ThreadPoolExecutor for running blocking operations + """ + self._producer = producer + self._executor = executor + + async def execute_batch(self, topic, batch_messages): + """Execute a batch operation via thread pool + + This method handles the complete batch execution workflow: + 1. Execute produce_batch in thread pool to avoid blocking event loop + 2. Handle partial failures that occur during produce_batch + 3. Poll for delivery reports of successful messages + + Args: + topic: Target topic for the batch + batch_messages: List of prepared messages with callbacks assigned + + Returns: + Result from producer.poll() indicating number of delivery reports processed + + Raises: + Exception: Any exception from the batch operation is propagated + """ + def _produce_batch_and_poll(): + """Helper function to run in thread pool + + This function encapsulates all the blocking Kafka operations: + - Call produce_batch with individual message callbacks + - Handle partial batch failures for messages that fail immediately + - Poll for delivery reports to trigger callbacks for successful messages + """ + # Call produce_batch with individual callbacks (no batch callback) + # Convert tuple to list since produce_batch expects a list + messages_list = list(batch_messages) if isinstance(batch_messages, tuple) else batch_messages + self._producer.produce_batch(topic, messages_list) + + # Handle partial batch failures: Check for messages that failed during produce_batch + # These messages have their msgstates destroyed in Producer.c and won't get callbacks + # from librdkafka, so we need to manually invoke their callbacks + self._handle_partial_failures(messages_list) + + # Immediately poll to process delivery callbacks for successful messages + poll_result = self._producer.poll(0) + + return poll_result + + # Execute in thread pool to avoid blocking event loop + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._executor, _produce_batch_and_poll) + + def _handle_partial_failures(self, batch_messages): + """Handle messages that failed during produce_batch + + When produce_batch encounters messages that fail immediately (e.g., message too large, + invalid topic, etc.), librdkafka destroys their msgstates and won't call their callbacks. + We detect these failures by checking for '_error' in the message dict (set by Producer.c) + and manually invoke the simple future-resolving callbacks. + + Args: + batch_messages: List of message dictionaries that were passed to produce_batch + """ + for msg_dict in batch_messages: + if '_error' in msg_dict: + # This message failed during produce_batch - its callback won't be called by librdkafka + callback = msg_dict.get('callback') + if callback: + # Extract the error from the message dict (set by Producer.c) + error = msg_dict['_error'] + # Manually invoke the callback with the error + # Note: msg is None since the message failed before being queued + try: + callback(error, None) + except Exception: + logger.warning("Exception in callback during partial failure handling", exc_info=True) + raise diff --git a/src/confluent_kafka/aio/producer/_message_batch.py b/src/confluent_kafka/aio/producer/_message_batch.py new file mode 100644 index 000000000..0d9432d23 --- /dev/null +++ b/src/confluent_kafka/aio/producer/_message_batch.py @@ -0,0 +1,63 @@ +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import NamedTuple, Sequence, Any, Optional +import asyncio + + +# Create immutable MessageBatch value object using modern typing +class MessageBatch(NamedTuple): + """Immutable batch of messages for Kafka production + + This represents a group of messages destined for the same topic, + along with their associated futures for delivery confirmation. + """ + topic: str # Target topic for this batch + messages: Sequence[dict] # Prepared message dictionaries + futures: Sequence[asyncio.Future] # Futures to resolve on delivery + + @property + def size(self) -> int: + """Get the number of messages in this batch""" + return len(self.messages) + + @property + def info(self) -> str: + """Get a string representation of batch info""" + return f"MessageBatch(topic='{self.topic}', size={len(self.messages)})" + + +def create_message_batch(topic: str, + messages: Sequence[dict], + futures: Sequence[asyncio.Future], + callbacks: Optional[Any] = None) -> MessageBatch: + """Create an immutable MessageBatch from sequences + + This factory function converts mutable sequences into an immutable MessageBatch object. + Uses tuples internally for immutability while accepting any sequence type as input. + + Args: + topic: Target topic name + messages: Sequence of prepared message dictionaries + futures: Sequence of asyncio.Future objects + callbacks: Deprecated parameter, ignored for backwards compatibility + + Returns: + MessageBatch: Immutable batch object + """ + return MessageBatch( + topic=topic, + messages=tuple(messages) if not isinstance(messages, tuple) else messages, + futures=tuple(futures) if not isinstance(futures, tuple) else futures + ) diff --git a/src/confluent_kafka/aio/producer/_producer_batch_processor.py b/src/confluent_kafka/aio/producer/_producer_batch_processor.py new file mode 100644 index 000000000..24f93d0ba --- /dev/null +++ b/src/confluent_kafka/aio/producer/_producer_batch_processor.py @@ -0,0 +1,251 @@ +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging + +from confluent_kafka import KafkaException as _KafkaException +from confluent_kafka.aio.producer._message_batch import create_message_batch + +logger = logging.getLogger(__name__) + + +class ProducerBatchManager: + """Handles batching and processing of Kafka messages for AIOProducer + + This class encapsulates all the logic for: + - Grouping messages by topic + - Managing message buffers and futures + - Creating simple future-resolving callbacks + - Executing batch operations via librdkafka + """ + + def __init__(self, kafka_executor): + """Initialize the batch processor + + Args: + kafka_executor: KafkaBatchExecutor instance for Kafka operations + """ + self._kafka_executor = kafka_executor + self._message_buffer = [] + self._buffer_futures = [] + + def add_message(self, msg_data, future): + """Add a message to the batch buffer + + Args: + msg_data: Dictionary containing message data + future: asyncio.Future to resolve when message is delivered + """ + self._message_buffer.append(msg_data) + self._buffer_futures.append(future) + + def get_buffer_size(self): + """Get the current number of messages in the buffer""" + return len(self._message_buffer) + + def is_buffer_empty(self): + """Check if the buffer is empty""" + return len(self._message_buffer) == 0 + + def clear_buffer(self): + """Clear the entire buffer""" + self._message_buffer.clear() + self._buffer_futures.clear() + + def create_batches(self, target_topic=None): + """Create MessageBatch objects from the current buffer + + Args: + target_topic: Optional topic to create batches for (None for all topics) + + Returns: + List[MessageBatch]: List of immutable MessageBatch objects + """ + if self.is_buffer_empty(): + return [] + + topic_groups = self._group_messages_by_topic() + batches = [] + + for topic, group_data in topic_groups.items(): + if target_topic is None or topic == target_topic: + # Prepare batch messages + batch_messages = self._prepare_batch_messages(group_data['messages']) + + # Assign simple future-resolving callbacks to messages + self._assign_future_callbacks(batch_messages, group_data['futures']) + + # Create immutable MessageBatch object + batch = create_message_batch( + topic=topic, + messages=batch_messages, + futures=group_data['futures'], + callbacks=None # No user callbacks anymore + ) + batches.append(batch) + + return batches + + def _clear_topic_from_buffer(self, target_topic): + """Remove messages for a specific topic from the buffer + + Args: + target_topic: Topic to remove from buffer + """ + messages_to_keep = [] + futures_to_keep = [] + + for i, msg_data in enumerate(self._message_buffer): + if msg_data['topic'] != target_topic: + messages_to_keep.append(msg_data) + futures_to_keep.append(self._buffer_futures[i]) + + self._message_buffer = messages_to_keep + self._buffer_futures = futures_to_keep + + async def flush_buffer(self, target_topic=None): + """Flush the current message buffer using produce_batch + + Args: + target_topic: Optional topic to flush (None for all topics) + + Returns: + None + """ + if self.is_buffer_empty(): + return + + # Create batches for processing + batches = self.create_batches(target_topic) + + # Clear processed messages from buffer + if target_topic is None: + # Clear entire buffer + self.clear_buffer() + else: + # Clear only messages for the target topic + self._clear_topic_from_buffer(target_topic) + + # Execute each batch + for batch in batches: + try: + # Execute batch using the Kafka executor + await self._kafka_executor.execute_batch(batch.topic, batch.messages) + + except Exception as e: + # Handle batch failure by failing all unresolved futures for this batch + self._handle_batch_failure(e, batch.futures) + # Re-raise the exception so caller knows the batch operation failed + raise + + def _group_messages_by_topic(self): + """Group buffered messages by topic for batch processing + + This function efficiently organizes the mixed-topic message buffer into + topic-specific groups, since librdkafka's produce_batch requires separate + calls for each topic. + + Algorithm: + - Single O(n) pass through message buffer + - Groups related data (messages, futures, callbacks) by topic + - Maintains index relationships between buffer arrays + + Returns: + dict: Topic groups with structure: + { + 'topic_name': { + 'messages': [msg_data1, msg_data2, ...], # Message dictionaries + 'futures': [future1, future2, ...], # Corresponding asyncio.Future objects + } + } + """ + topic_groups = {} + + # Iterate through buffer once - O(n) complexity + for i, msg_data in enumerate(self._message_buffer): + topic = msg_data['topic'] + + # Create new topic group if this is first message for this topic + if topic not in topic_groups: + topic_groups[topic] = { + 'messages': [], # Message data for produce_batch + 'futures': [], # Futures to resolve on delivery + } + + # Add message and related data to appropriate topic group + # Note: All arrays stay synchronized by index + topic_groups[topic]['messages'].append(msg_data) + topic_groups[topic]['futures'].append(self._buffer_futures[i]) + + return topic_groups + + def _prepare_batch_messages(self, messages): + """Prepare messages for produce_batch by removing internal fields + + Args: + messages: List of message dictionaries + + Returns: + List of cleaned message dictionaries ready for produce_batch + """ + batch_messages = [] + for msg_data in messages: + # Create a shallow copy and remove fields not needed by produce_batch + batch_msg = copy.copy(msg_data) + batch_msg.pop('topic', None) # Remove topic since it's passed separately + batch_messages.append(batch_msg) + + return batch_messages + + def _assign_future_callbacks(self, batch_messages, futures): + """Assign simple future-resolving callbacks to each message in the batch + + Args: + batch_messages: List of message dictionaries for produce_batch + futures: List of asyncio.Future objects to resolve + """ + for i, batch_msg in enumerate(batch_messages): + future = futures[i] + + def create_simple_callback(fut): + """Create a simple callback that only resolves the future""" + def simple_callback(err, msg): + if err: + if not fut.done(): + fut.set_exception(_KafkaException(err)) + else: + if not fut.done(): + fut.set_result(msg) + return simple_callback + + # Assign the simple callback to this message + batch_msg['callback'] = create_simple_callback(future) + + def _handle_batch_failure(self, exception, batch_futures): + """Handle batch operation failure by failing all unresolved futures + + When a batch operation fails before any individual callbacks are invoked, + we need to fail all futures for this batch since none of the per-message + callbacks will be called by librdkafka. + + Args: + exception: The exception that caused the batch to fail + batch_futures: List of futures for this batch + """ + # Fail all futures since no individual callbacks will be invoked + for future in batch_futures: + # Only set exception if future isn't already done + if not future.done(): + future.set_exception(exception) diff --git a/tests/ducktape/producer_strategy.py b/tests/ducktape/producer_strategy.py new file mode 100644 index 000000000..57a447659 --- /dev/null +++ b/tests/ducktape/producer_strategy.py @@ -0,0 +1,318 @@ +""" +Producer strategies for testing sync and async Kafka producers. + +This module contains strategy classes that encapsulate the different producer +implementations (sync vs async) with consistent interfaces for testing. +""" +import time +import asyncio +from confluent_kafka import Producer + + +class ProducerStrategy: + """Base class for producer strategies""" + + def __init__(self, bootstrap_servers, logger): + self.bootstrap_servers = bootstrap_servers + self.logger = logger + self.metrics = None + + def create_producer(self): + raise NotImplementedError() + + def produce_messages( + self, + topic_name, + test_duration, + start_time, + message_formatter, + delivered_container, + failed_container=None): + raise NotImplementedError() + + def _get_base_config(self): + """Get shared Kafka producer configuration optimized for low-latency, high-throughput""" + return { + 'bootstrap.servers': self.bootstrap_servers, + 'queue.buffering.max.messages': 1000000, # 1M messages (sufficient) + 'queue.buffering.max.kbytes': 1048576, # 1GB (default) + 'batch.size': 65536, # 64KB batches (increased for better efficiency) + 'batch.num.messages': 50000, # 50K messages per batch (up from 10K default) + 'message.max.bytes': 2097152, # 2MB max message size (up from ~1MB default) + 'linger.ms': 1, # Wait 1ms for batching (low latency) + 'compression.type': 'lz4', # Fast compression + 'acks': 1, # Wait for leader only (faster) + 'retries': 3, # Retry failed sends + 'delivery.timeout.ms': 30000, # 30s delivery timeout + 'max.in.flight.requests.per.connection': 5 # Pipeline requests + } + + def _log_configuration(self, config, producer_type, extra_params=None): + """Log producer configuration for validation""" + separator = "=" * (len(f"{producer_type.upper()} PRODUCER CONFIGURATION") + 6) + self.logger.info(f"=== {producer_type.upper()} PRODUCER CONFIGURATION ===") + for key, value in config.items(): + self.logger.info(f"{key}: {value}") + + if extra_params: + for key, value in extra_params.items(): + self.logger.info(f"{key}: {value}") + + self.logger.info(separator) + + def _print_timing_metrics(self, producer_type, produce_times, poll_times, flush_time): + """Print code path timing metrics""" + avg_produce_time = sum(produce_times) / len(produce_times) * 1000 if produce_times else 0 + avg_poll_time = sum(poll_times) / len(poll_times) * 1000 if poll_times else 0 + flush_time_ms = flush_time * 1000 + + separator = "=" * (len(f"{producer_type.upper()} PRODUCER CODE PATH TIMING") + 6) + print(f"\n=== {producer_type.upper()} PRODUCER CODE PATH TIMING ===") + print(f"Time to call {'AIO' if producer_type == 'ASYNC' else ''}Producer.produce(): {avg_produce_time:.4f}ms") + print(f"Time to call {'AIO' if producer_type == 'ASYNC' else ''}Producer.poll(): {avg_poll_time:.4f}ms") + print(f"Time to call {'AIO' if producer_type == 'ASYNC' else ''}Producer.flush(): {flush_time_ms:.4f}ms") + print(f"Total produce() calls: {len(produce_times)}") + print(f"Total poll() calls: {len(poll_times)}") + print(separator) + + +class SyncProducerStrategy(ProducerStrategy): + def create_producer(self, config_overrides=None): + config = self._get_base_config() + + # Apply any test-specific overrides + if config_overrides: + config.update(config_overrides) + + producer = Producer(config) + + # Log the configuration for validation + self._log_configuration(config, "SYNC") + + return producer + + def produce_messages( + self, + topic_name, + test_duration, + start_time, + message_formatter, + delivered_container, + failed_container=None): + config_overrides = getattr(self, 'config_overrides', None) + producer = self.create_producer(config_overrides) + messages_sent = 0 + send_times = {} # Track send times for latency calculation + + # Temporary metrics for timing sections + produce_times = [] + poll_times = [] + flush_time = 0 + + def delivery_callback(err, msg): + if err: + if failed_container is not None: + failed_container.append(err) + if self.metrics: + self.metrics.record_failed(topic=msg.topic() if msg else topic_name, + partition=msg.partition() if msg else 0) + else: + delivered_container.append(msg) + if self.metrics: + # Calculate latency if we have send time + msg_key = msg.key().decode('utf-8', errors='replace') if msg.key() else 'unknown' + latency_ms = 0.0 + if msg_key in send_times: + latency_ms = (time.time() - send_times[msg_key]) * 1000 + del send_times[msg_key] + + self.metrics.record_delivered(latency_ms, topic=msg.topic(), partition=msg.partition()) + + while time.time() - start_time < test_duration: + message_value, message_key = message_formatter(messages_sent) + try: + # Track send time for latency calculation + if self.metrics: + send_times[message_key] = time.time() + message_size = len(message_value.encode('utf-8')) + len(message_key.encode('utf-8')) + self.metrics.record_sent(message_size, topic=topic_name, partition=0) + + # Produce message + produce_start = time.time() + producer.produce( + topic=topic_name, + value=message_value, + key=message_key, + on_delivery=delivery_callback + ) + produce_times.append(time.time() - produce_start) + messages_sent += 1 + + # Use configured polling interval (default to 50 if not set) + poll_interval = getattr(self, 'poll_interval', 50) + + if messages_sent % poll_interval == 0: + poll_start = time.time() + producer.poll(0) + poll_times.append(time.time() - poll_start) + if self.metrics: + self.metrics.record_poll() + + except Exception as e: + if failed_container is not None: + failed_container.append(e) + if self.metrics: + self.metrics.record_failed(topic=topic_name, partition=0) + self.logger.error(f"Failed to produce message {messages_sent}: {e}") + + # Flush producer + flush_start = time.time() + producer.flush(timeout=30) + flush_time = time.time() - flush_start + + # Print timing metrics + self._print_timing_metrics("SYNC", produce_times, poll_times, flush_time) + + return messages_sent + + def get_final_metrics(self): + """Return final metrics summary for the sync producer""" + if self.metrics: + return self.metrics + return None + + +class AsyncProducerStrategy(ProducerStrategy): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._producer_instance = None + + def create_producer(self, config_overrides=None): + from confluent_kafka.aio import AIOProducer + # Enable logging for AIOProducer + import logging + logging.basicConfig(level=logging.INFO) + + config = self._get_base_config() + + # Apply any test-specific overrides + if config_overrides: + config.update(config_overrides) + + # Get producer configuration from strategy attributes + max_workers = getattr(self, 'max_workers', 4) + batch_size = getattr(self, 'batch_size', 1000) # Optimal batch size for low latency + + # Use updated defaults with configurable parameters + self._producer_instance = AIOProducer(config, max_workers=max_workers, batch_size=batch_size) + + # Log the configuration for validation + extra_params = {'max_workers': max_workers, 'batch_size': batch_size} + self._log_configuration(config, "ASYNC", extra_params) + + return self._producer_instance + + def produce_messages( + self, + topic_name, + test_duration, + start_time, + message_formatter, + delivered_container, + failed_container=None): + + async def async_produce(): + config_overrides = getattr(self, 'config_overrides', None) + producer = self.create_producer(config_overrides) + messages_sent = 0 + pending_futures = [] + send_times = {} # Track send times for latency calculation + + # Temporary metrics for timing sections + produce_times = [] + poll_times = [] + flush_time = 0 + + # Pre-create shared metrics callback to avoid closure creation overhead + if self.metrics: + def shared_metrics_callback(err, msg): + if not err: + # Calculate latency if we have send time + msg_key = msg.key().decode('utf-8', errors='replace') if msg.key() else 'unknown' + latency_ms = 0.0 + if msg_key in send_times: + latency_ms = (time.time() - send_times[msg_key]) * 1000 + del send_times[msg_key] + self.metrics.record_delivered(latency_ms, topic=msg.topic(), partition=msg.partition()) + + while time.time() - start_time < test_duration: + message_value, message_key = message_formatter(messages_sent) + try: + # Record sent message for metrics and track send time + if self.metrics: + message_size = len(message_value.encode('utf-8')) + len(message_key.encode('utf-8')) + self.metrics.record_sent(message_size, topic=topic_name, partition=0) + send_times[message_key] = time.time() # Track send time for latency + + # Produce message + produce_start = time.time() + delivery_future = await producer.produce( + topic=topic_name, + value=message_value, + key=message_key + ) + produce_times.append(time.time() - produce_start) + pending_futures.append((delivery_future, message_key)) # Store delivery future + messages_sent += 1 + + except Exception as e: + if failed_container is not None: + failed_container.append(e) + if self.metrics: + self.metrics.record_failed(topic=topic_name, partition=0) + self.logger.error(f"Failed to produce message {messages_sent}: {e}") + + # Flush producer + flush_start = time.time() + await producer.flush(timeout=30) + flush_time = time.time() - flush_start + + # Wait for all pending futures to complete (for delivery confirmation only) + for delivery_future, message_key in pending_futures: + try: + msg = await delivery_future + delivered_container.append(msg) + + # Record delivery metrics (replaces the old callback approach) + if self.metrics: + # Calculate latency if we have send time + latency_ms = 0.0 + if message_key in send_times: + latency_ms = (time.time() - send_times[message_key]) * 1000 + del send_times[message_key] + + self.metrics.record_delivered(latency_ms, topic=msg.topic(), partition=msg.partition()) + + except Exception as e: + if failed_container is not None: + failed_container.append(e) + if self.metrics: + self.metrics.record_failed(topic=topic_name, partition=0) + self.logger.error(f"Failed to deliver message with key {message_key}: {e}") + + # Print timing metrics + self._print_timing_metrics("ASYNC", produce_times, poll_times, flush_time) + + # Close producer to ensure clean shutdown + await producer.close() + + return messages_sent + + loop = asyncio.get_event_loop() + return loop.run_until_complete(async_produce()) + + def get_final_metrics(self): + """Return final metrics summary for the async producer""" + if self.metrics: + return self.metrics + return None diff --git a/tests/ducktape/test_producer.py b/tests/ducktape/test_producer.py index e2a6ab651..230136baa 100644 --- a/tests/ducktape/test_producer.py +++ b/tests/ducktape/test_producer.py @@ -4,11 +4,11 @@ """ import time from ducktape.tests.test import Test -from ducktape.mark import matrix, parametrize +from ducktape.mark import matrix from tests.ducktape.services.kafka import KafkaClient from tests.ducktape.benchmark_metrics import MetricsCollector, MetricsBounds, validate_metrics, print_metrics_report -from confluent_kafka import Producer +from tests.ducktape.producer_strategy import SyncProducerStrategy, AsyncProducerStrategy class SimpleProducerTest(Test): @@ -30,10 +30,22 @@ def setUp(self): self.logger.info("Successfully connected to Kafka") - def test_basic_produce(self): + def create_producer(self, producer_type, config_overrides=None): + """Create appropriate producer strategy based on type""" + if producer_type == "sync": + strategy = SyncProducerStrategy(self.kafka.bootstrap_servers(), self.logger) + else: # async + strategy = AsyncProducerStrategy(self.kafka.bootstrap_servers(), self.logger) + + # Store config overrides for later use in create_producer + strategy.config_overrides = config_overrides + return strategy + + @matrix(producer_type=["sync", "async"]) + def test_basic_produce(self, producer_type): """Test basic message production with comprehensive metrics and bounds validation""" - topic_name = "test-topic" + topic_name = f"test-{producer_type}-topic" test_duration = 5.0 # 5 seconds # Create topic @@ -48,77 +60,31 @@ def test_basic_produce(self): metrics = MetricsCollector() bounds = MetricsBounds() - # Configure producer - producer_config = { - 'bootstrap.servers': self.kafka.bootstrap_servers(), - 'client.id': 'ducktape-test-producer' - } - - self.logger.info("Creating producer with config: %s", producer_config) - producer = Producer(producer_config) - - # Enhanced delivery callback with metrics tracking - send_times = {} # Track send times for latency calculation - - def delivery_callback(err, msg): - """Delivery report callback with metrics tracking""" - if err is not None: - self.logger.error("Message delivery failed: %s", err) - metrics.record_failed(topic=msg.topic() if msg else topic_name, - partition=msg.partition() if msg else 0) - else: - # Calculate actual latency if we have send time - msg_key = msg.key().decode('utf-8', errors='replace') if msg.key() else 'unknown' - if msg_key in send_times: - latency_ms = (time.time() - send_times[msg_key]) * 1000 - del send_times[msg_key] # Clean up - else: - latency_ms = 0.0 # Default latency if timing info not available - - metrics.record_delivered(latency_ms, topic=msg.topic(), partition=msg.partition()) + # Create appropriate producer strategy + strategy = self.create_producer(producer_type) + + # Assign metrics collector to strategy + strategy.metrics = metrics + + self.logger.info(f"Testing {producer_type} producer for {test_duration} seconds") # Start metrics collection metrics.start() - # Time-based message production with metrics - self.logger.info("Producing messages with metrics for %.1f seconds to topic %s", test_duration, topic_name) + # Message formatter + def message_formatter(msg_num): + return f"Test message {msg_num}", f"key-{msg_num}" + + # Containers for results + delivered_messages = [] + failed_messages = [] + + # Run the test start_time = time.time() - messages_sent = 0 - - while time.time() - start_time < test_duration: - message_value = f"Test message {messages_sent}" - message_key = f"key-{messages_sent}" - - try: - # Record message being sent with metrics - message_size = len(message_value.encode('utf-8')) + len(message_key.encode('utf-8')) - metrics.record_sent(message_size, topic=topic_name, partition=0) - - # Track send time for latency calculation - send_times[message_key] = time.time() - - producer.produce( - topic=topic_name, - value=message_value, - key=message_key, - callback=delivery_callback - ) - messages_sent += 1 - - # Poll frequently to trigger delivery callbacks and record poll operations - if messages_sent % 100 == 0: - producer.poll(0) - metrics.record_poll() - - except BufferError: - # Record buffer full events and poll - metrics.record_buffer_full() - producer.poll(0.001) - continue - - # Flush to ensure all messages are sent - self.logger.info("Flushing producer...") - producer.flush(timeout=30) + messages_sent = strategy.produce_messages( + topic_name, test_duration, start_time, message_formatter, + delivered_messages, failed_messages + ) # Finalize metrics collection metrics.finalize() @@ -128,15 +94,15 @@ def delivery_callback(err, msg): is_valid, violations = validate_metrics(metrics_summary, bounds) # Print comprehensive metrics report - self.logger.info("Basic production test with metrics completed:") + self.logger.info(f"=== {producer_type.upper()} PRODUCER METRICS REPORT ===") print_metrics_report(metrics_summary, is_valid, violations) # Enhanced assertions using metrics - assert messages_sent > 0, "No messages were sent during test duration" - assert metrics_summary['messages_delivered'] > 0, "No messages were delivered" + assert messages_sent > 0, "No messages were sent" + assert len(delivered_messages) > 0, "No messages were delivered" + assert metrics_summary['messages_delivered'] > 0, "No messages were delivered (metrics)" assert metrics_summary['send_throughput_msg_per_sec'] > 10, \ - f"Send throughput too low: {metrics_summary['send_throughput_msg_per_sec']:.2f} msg/s " \ - f"(expected > 10 msg/s)" + f"Send throughput too low: {metrics_summary['send_throughput_msg_per_sec']:.2f} msg/s" # Validate against performance bounds if not is_valid: @@ -145,13 +111,11 @@ def delivery_callback(err, msg): self.logger.info("Successfully completed basic production test with comprehensive metrics") - @parametrize(test_duration=2) - @parametrize(test_duration=5) - @parametrize(test_duration=10) - def test_produce_multiple_batches(self, test_duration): + @matrix(producer_type=["sync", "async"], test_duration=[2, 5, 10]) + def test_produce_multiple_batches(self, producer_type, test_duration): """Test batch throughput with comprehensive metrics and bounds validation""" - topic_name = f"batch-test-topic-{test_duration}s" + topic_name = f"{producer_type}-batch-test-topic-{test_duration}s" # Create topic self.kafka.create_topic(topic_name, partitions=2, replication_factor=1) @@ -165,92 +129,62 @@ def test_produce_multiple_batches(self, test_duration): bounds = MetricsBounds() # Adjust bounds for different test durations if test_duration <= 2: - bounds.min_throughput_msg_per_sec = 500.0 # Lower threshold for short tests - - # Configure producer with batch settings - producer_config = { - 'bootstrap.servers': self.kafka.bootstrap_servers(), - 'client.id': f'batch-test-producer-{test_duration}s', - 'batch.size': 1000, # Small batch size for testing - 'linger.ms': 100 # Small linger time for testing - } - - producer = Producer(producer_config) - - # Enhanced delivery callback with metrics - send_times = {} - - def delivery_callback(err, msg): - if err is None: - # Calculate latency - msg_key = msg.key().decode('utf-8', errors='replace') if msg.key() else 'unknown' - if msg_key in send_times: - latency_ms = (time.time() - send_times[msg_key]) * 1000 - del send_times[msg_key] - else: - latency_ms = 0.0 # Default for batch processing - - metrics.record_delivered(latency_ms, topic=msg.topic(), partition=msg.partition()) - else: - self.logger.error("Delivery failed: %s", err) - metrics.record_failed(topic=msg.topic() if msg else topic_name, - partition=msg.partition() if msg else 0) + bounds.min_throughput_msg_per_sec = 50.0 # Lower threshold for short tests + + # Create appropriate producer strategy + strategy = self.create_producer(producer_type) + + # Assign metrics collector to strategy + strategy.metrics = metrics + + self.logger.info(f"Testing {producer_type} producer with batches for {test_duration} seconds") # Start metrics collection metrics.start() - # Time-based batch message production with metrics - self.logger.info("Producing batches with metrics for %d seconds", test_duration) + # Message formatter for batch test + def message_formatter(msg_num): + return f"Batch message {msg_num}", f"batch-key-{msg_num}" + + # Containers for results + delivered_messages = [] + failed_messages = [] + + # Run the test start_time = time.time() - messages_sent = 0 - - while time.time() - start_time < test_duration: - try: - message_value = f"Batch message {messages_sent}" - message_key = f"batch-key-{messages_sent % 10}" # Use modulo for key distribution - partition = messages_sent % 2 # Distribute across partitions - - # Record metrics - message_size = len(message_value.encode('utf-8')) + len(message_key.encode('utf-8')) - metrics.record_sent(message_size, topic=topic_name, partition=partition) - send_times[message_key] = time.time() - - producer.produce( - topic=topic_name, - value=message_value, - key=message_key, - callback=delivery_callback - ) - messages_sent += 1 - - # Poll occasionally to trigger callbacks and record polls - if messages_sent % 100 == 0: - producer.poll(0) - metrics.record_poll() - - except BufferError: - # Record buffer full events - metrics.record_buffer_full() - producer.poll(0.001) - continue - - # Final flush - producer.flush(timeout=30) - - # Finalize metrics + messages_sent = strategy.produce_messages( + topic_name, test_duration, start_time, message_formatter, + delivered_messages, failed_messages + ) + + # Finalize metrics collection metrics.finalize() # Get comprehensive metrics summary metrics_summary = metrics.get_summary() is_valid, violations = validate_metrics(metrics_summary, bounds) + # Get AIOProducer built-in metrics for comparison (async only) + final_metrics = strategy.get_final_metrics() + # Print comprehensive metrics report - self.logger.info("Batch production test (%ds) with metrics completed:", test_duration) + self.logger.info(f"=== {producer_type.upper()} BATCH TEST ({test_duration}s) METRICS REPORT ===") print_metrics_report(metrics_summary, is_valid, violations) + if final_metrics: + # Get the actual metrics dictionary + producer_metrics_summary = final_metrics.get_summary() + if producer_metrics_summary: + self.logger.info("=== Producer Built-in Metrics ===") + self.logger.info(f"Runtime: {producer_metrics_summary['duration_seconds']:.2f}s") + self.logger.info(f"Success Rate: {producer_metrics_summary['success_rate']:.3f}") + self.logger.info(f"Throughput: {producer_metrics_summary['send_throughput_msg_per_sec']:.1f} msg/sec") + self.logger.info(f"Latency: Avg={producer_metrics_summary['avg_latency_ms']:.1f}ms") + # Enhanced assertions using metrics - assert messages_sent > 0, "No messages were sent during test duration" - assert metrics_summary['messages_delivered'] > 0, "No messages were delivered" + assert messages_sent > 0, "No messages were sent" + assert len(delivered_messages) > 0, "No messages were delivered" + assert metrics_summary['messages_delivered'] > 0, "No messages were delivered (metrics)" assert metrics_summary['send_throughput_msg_per_sec'] > 10, \ f"Send throughput too low: {metrics_summary['send_throughput_msg_per_sec']:.2f} msg/s" @@ -262,11 +196,11 @@ def delivery_callback(err, msg): self.logger.info("Successfully completed %ds batch production test with comprehensive metrics", test_duration) - @matrix(compression_type=['none', 'gzip', 'snappy']) - def test_produce_with_compression(self, compression_type): + @matrix(producer_type=["sync", "async"], compression_type=['none', 'gzip', 'snappy']) + def test_produce_with_compression(self, producer_type, compression_type): """Test compression throughput with comprehensive metrics and bounds validation""" - topic_name = f"compression-test-{compression_type}" + topic_name = f"{producer_type}-compression-test-{compression_type}" test_duration = 5.0 # 5 seconds # Create topic @@ -279,92 +213,90 @@ def test_produce_with_compression(self, compression_type): # Initialize metrics collection and bounds metrics = MetricsCollector() bounds = MetricsBounds() - # Adjust bounds for compression tests (may be slower) + # Adjust bounds for compression tests (may be slower with large messages) bounds.min_throughput_msg_per_sec = 5.0 # Lower threshold for large messages bounds.max_p95_latency_ms = 5000.0 # Allow higher latency for compression - # Configure producer with compression - producer_config = { - 'bootstrap.servers': self.kafka.bootstrap_servers(), - 'client.id': f'compression-test-{compression_type}', - 'compression.type': compression_type - } + # Create appropriate producer strategy with compression config + compression_config = {} + if compression_type != 'none': + compression_config['compression.type'] = compression_type + + # Configure polling intervals based on compression type and producer type + if producer_type == 'async': + polling_config = { + 'gzip': 10, # Poll every 10 messages for gzip (frequent) + 'snappy': 50, # Poll every 50 messages for snappy (moderate) + 'none': 100 # Poll every 100 messages for none (standard) + } + else: # sync + # Sync producers need more frequent polling to prevent buffer overflow as throughput is very high + polling_config = { + 'gzip': 5, # Poll every 5 messages for gzip (most frequent) + 'snappy': 25, # Poll every 25 messages for snappy (moderate) + 'none': 50 # Poll every 50 messages for none (standard) + } + poll_interval = polling_config.get(compression_type, 50 if producer_type == 'sync' else 100) + + strategy = self.create_producer(producer_type, compression_config) + strategy.poll_interval = poll_interval + + # Assign metrics collector to strategy + strategy.metrics = metrics + + self.logger.info( + f"Testing {producer_type} producer with {compression_type} compression for {test_duration} seconds") + self.logger.info(f"Using polling interval: {poll_interval} messages per poll") - producer = Producer(producer_config) + # Start metrics collection + metrics.start() # Create larger messages to test compression effectiveness large_message = "x" * 1000 # 1KB message - send_times = {} - - def delivery_callback(err, msg): - if err is None: - # Calculate latency - msg_key = msg.key().decode('utf-8', errors='replace') if msg.key() else 'unknown' - if msg_key in send_times: - latency_ms = (time.time() - send_times[msg_key]) * 1000 - del send_times[msg_key] - else: - latency_ms = 0.0 # Default for compression processing - - metrics.record_delivered(latency_ms, topic=msg.topic(), partition=msg.partition()) - else: - metrics.record_failed(topic=msg.topic() if msg else topic_name, - partition=msg.partition() if msg else 0) - # Start metrics collection - metrics.start() + # Message formatter for compression test + def message_formatter(msg_num): + return f"{large_message}-{msg_num}", f"comp-key-{msg_num}" + + # Containers for results + delivered_messages = [] + failed_messages = [] - # Time-based message production with compression and metrics - self.logger.info("Producing messages with %s compression and metrics for %.1f seconds", - compression_type, test_duration) + # Run the test start_time = time.time() - messages_sent = 0 - - while time.time() - start_time < test_duration: - try: - message_value = f"{large_message}-{messages_sent}" - message_key = f"comp-key-{messages_sent}" - - # Record metrics - message_size = len(message_value.encode('utf-8')) + len(message_key.encode('utf-8')) - metrics.record_sent(message_size, topic=topic_name, partition=0) - send_times[message_key] = time.time() - - producer.produce( - topic=topic_name, - value=message_value, - key=message_key, - callback=delivery_callback - ) - messages_sent += 1 - - # Poll frequently to prevent buffer overflow and record polls - if messages_sent % 10 == 0: - producer.poll(0) - metrics.record_poll() - - except BufferError: - # Record buffer full events - metrics.record_buffer_full() - producer.poll(0.001) - continue - - producer.flush(timeout=30) - - # Finalize metrics + messages_sent = strategy.produce_messages( + topic_name, test_duration, start_time, message_formatter, + delivered_messages, failed_messages + ) + + # Finalize metrics collection metrics.finalize() # Get comprehensive metrics summary metrics_summary = metrics.get_summary() is_valid, violations = validate_metrics(metrics_summary, bounds) + # Get AIOProducer built-in metrics for comparison (async only) + final_metrics = strategy.get_final_metrics() + # Print comprehensive metrics report - self.logger.info("Compression test (%s) with metrics completed:", compression_type) + self.logger.info(f"=== {producer_type.upper()} COMPRESSION TEST ({compression_type}) METRICS REPORT ===") print_metrics_report(metrics_summary, is_valid, violations) + if final_metrics: + # Get the actual metrics dictionary + producer_metrics_summary = final_metrics.get_summary() + if producer_metrics_summary: + self.logger.info("=== Producer Built-in Metrics ===") + self.logger.info(f"Runtime: {producer_metrics_summary['duration_seconds']:.2f}s") + self.logger.info(f"Success Rate: {producer_metrics_summary['success_rate']:.3f}") + self.logger.info(f"Throughput: {producer_metrics_summary['send_throughput_msg_per_sec']:.1f} msg/sec") + self.logger.info(f"Latency: Avg={producer_metrics_summary['avg_latency_ms']:.1f}ms") + # Enhanced assertions using metrics - assert messages_sent > 0, "No messages were sent during test duration" - assert metrics_summary['messages_delivered'] > 0, "No messages were delivered" + assert messages_sent > 0, "No messages were sent" + assert len(delivered_messages) > 0, "No messages were delivered" + assert metrics_summary['messages_delivered'] > 0, "No messages were delivered (metrics)" assert metrics_summary['send_throughput_msg_per_sec'] > 5, \ f"Send throughput too low for {compression_type}: " \ f"{metrics_summary['send_throughput_msg_per_sec']:.2f} msg/s" diff --git a/tests/test_AIOConsumer.py b/tests/test_AIOConsumer.py index 443de0c84..7bcd715a8 100644 --- a/tests/test_AIOConsumer.py +++ b/tests/test_AIOConsumer.py @@ -22,6 +22,9 @@ def mock_consumer(self): def mock_common(self): """Mock the _common module callback wrapping.""" with patch('confluent_kafka.aio._AIOConsumer._common') as mock: + async def mock_async_call(executor, blocking_task, *args, **kwargs): + return blocking_task(*args, **kwargs) + mock.async_call.side_effect = mock_async_call yield mock @pytest.fixture diff --git a/tests/test_AIOProducer.py b/tests/test_AIOProducer.py index f45a3b502..b9cc51396 100644 --- a/tests/test_AIOProducer.py +++ b/tests/test_AIOProducer.py @@ -1,12 +1,16 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import pytest +""" +Unit tests for AIOProducer class. + +""" import asyncio import concurrent.futures +import pytest from unittest.mock import Mock, patch from confluent_kafka import KafkaError, KafkaException -from confluent_kafka.aio._AIOProducer import AIOProducer +from confluent_kafka.aio.producer._AIOProducer import AIOProducer class TestAIOProducer: @@ -14,82 +18,71 @@ class TestAIOProducer: @pytest.fixture def mock_producer(self): - """Mock the underlying confluent_kafka.Producer.""" - with patch('confluent_kafka.aio._AIOProducer.confluent_kafka.Producer') as mock: + with patch('confluent_kafka.aio.producer._AIOProducer.confluent_kafka.Producer') as mock: yield mock @pytest.fixture def mock_common(self): - """Mock the _common module callback wrapping.""" - with patch('confluent_kafka.aio._AIOProducer._common') as mock: + with patch('confluent_kafka.aio.producer._AIOProducer._common') as mock: + async def mock_async_call(executor, blocking_task, *args, **kwargs): + return blocking_task(*args, **kwargs) + mock.async_call.side_effect = mock_async_call yield mock @pytest.fixture def basic_config(self): - """Basic producer configuration.""" return {'bootstrap.servers': 'localhost:9092'} @pytest.mark.asyncio async def test_constructor_behavior(self, mock_producer, mock_common, basic_config): - """Test constructor creates producer with correct configuration and behavior.""" custom_executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) try: - # Test 1: Constructor with custom executor and auto_poll=False producer1 = AIOProducer( basic_config, - max_workers=3, # Should be ignored since executor is provided - executor=custom_executor, - auto_poll=False + max_workers=3, + executor=custom_executor ) - # Test actual object state, not mock calls assert producer1.executor is custom_executor assert producer1.executor._max_workers == 5 - assert producer1._running is False - assert not hasattr(producer1, '_running_loop') - assert hasattr(producer1, '_producer') # Should have underlying producer + assert producer1._is_closed is False + assert hasattr(producer1, '_buffer_timeout_manager') + assert hasattr(producer1, '_producer') - # Test 2: Constructor with default executor and auto_poll=True - producer2 = AIOProducer(basic_config, max_workers=2, auto_poll=True) + producer2 = AIOProducer(basic_config, max_workers=2, batch_size=500) - # Test executor was created with correct max_workers assert producer2.executor is not custom_executor assert isinstance(producer2.executor, concurrent.futures.ThreadPoolExecutor) assert producer2.executor._max_workers == 2 - # Test auto-polling was started - assert producer2._running is True - assert hasattr(producer2, '_running_loop') - assert not producer2._running_loop.done() + assert producer2._batch_size == 500 + assert producer2._is_closed is False + assert hasattr(producer2, '_buffer_timeout_manager') - # Clean up - await producer2.stop() - assert producer2._running is False + await producer2.close() + assert producer2._is_closed is True finally: custom_executor.shutdown(wait=True) @pytest.mark.asyncio - async def test_stop_method(self, mock_producer, mock_common, basic_config): - """Test stop method functionality.""" - # Test stopping running producer + async def test_close_method(self, mock_producer, mock_common, basic_config): producer = AIOProducer(basic_config) - assert producer._running is True + assert producer._is_closed is False - await producer.stop() - assert producer._running is False + await producer.close() + assert producer._is_closed is True - # Test stopping non-running producer - producer2 = AIOProducer(basic_config, auto_poll=False) - assert producer2._running is False + producer2 = AIOProducer(basic_config) + assert producer2._is_closed is False - await producer2.stop() # Should not raise exception - assert producer2._running is False + await producer2.close() + await producer2.close() + assert producer2._is_closed is True @pytest.mark.asyncio async def test_call_method_executor_usage(self, mock_producer, mock_common, basic_config): - """Test that _call method uses ThreadPoolExecutor for async-to-sync bridging.""" - producer = AIOProducer(basic_config, auto_poll=False) + producer = AIOProducer(basic_config) mock_method = Mock(return_value="test_result") result = await producer._call(mock_method, "arg1", kwarg1="value1") @@ -99,159 +92,145 @@ async def test_call_method_executor_usage(self, mock_producer, mock_common, basi @pytest.mark.asyncio async def test_produce_success(self, mock_producer, mock_common, basic_config): - """Test successful message production.""" - producer = AIOProducer(basic_config, auto_poll=False) - + producer = AIOProducer(basic_config, batch_size=1) mock_msg = Mock() - loop = asyncio.get_event_loop() - def sync_produce(*args, **kwargs): - callback = kwargs.get('on_delivery') - if callback: - # Simulate successful delivery - callback would normally come from poll() - loop.call_soon_threadsafe(callback, None, mock_msg) + async def mock_flush_buffer(target_topic=None): + batches = producer._batch_processor.create_batches(target_topic) + for batch in batches: + for future in batch.futures: + if not future.done(): + future.set_result(mock_msg) + producer._batch_processor.clear_buffer() - mock_producer.return_value.produce.side_effect = sync_produce - result = await producer.produce(topic="test_topic", value="test_value") + with patch.object(producer, '_flush_buffer', side_effect=mock_flush_buffer): + result_future = await producer.produce(topic="test_topic", value="test_value") + result = await result_future + assert result is mock_msg - assert result is mock_msg + await producer.close() @pytest.mark.asyncio async def test_produce_error(self, mock_producer, mock_common, basic_config): - """Test message production error handling through real ThreadPoolExecutor.""" - producer = AIOProducer(basic_config, auto_poll=False) + producer = AIOProducer(basic_config, batch_size=1) + mock_error = KafkaError(KafkaError._MSG_TIMED_OUT) - # Capture the main event loop to use from thread - main_loop = asyncio.get_event_loop() + async def mock_flush_buffer(target_topic=None): + batches = producer._batch_processor.create_batches(target_topic) + for batch in batches: + for future in batch.futures: + if not future.done(): + future.set_exception(KafkaException(mock_error)) + producer._batch_processor.clear_buffer() - def sync_produce(*args, **kwargs): - callback = kwargs.get('on_delivery') - if callback: - mock_error = KafkaError(KafkaError._MSG_TIMED_OUT) - # Simulate error delivery - callback would normally come from poll() - main_loop.call_soon_threadsafe(callback, mock_error, None) + with patch.object(producer, '_flush_buffer', side_effect=mock_flush_buffer): + result_future = await producer.produce(topic="test_topic", value="test_value") - mock_producer.return_value.produce.side_effect = sync_produce + with pytest.raises(KafkaException): + await result_future - # This should go through real _call() method and ThreadPoolExecutor - with pytest.raises(KafkaException): - await producer.produce(topic="test_topic", value="test_value") - - # Verify the sync producer.produce was actually called - mock_producer.return_value.produce.assert_called_once() + await producer.close() @pytest.mark.asyncio async def test_produce_with_delayed_callback(self, mock_producer, mock_common, basic_config): - """Test that Future properly waits for delayed delivery callback through real ThreadPoolExecutor.""" - producer = AIOProducer(basic_config, auto_poll=False) + """Test that Future properly waits for delayed delivery callback with batching.""" + producer = AIOProducer(basic_config, batch_size=2) # Need 2 messages to trigger flush + + batch_called = asyncio.Event() + captured_messages = None + + def mock_produce_batch(topic, messages): + nonlocal captured_messages + captured_messages = messages + batch_called.set() + # Don't call callbacks immediately - simulate real async behavior - produce_called = asyncio.Event() - captured_callback = None + mock_producer.return_value.produce_batch.side_effect = mock_produce_batch + mock_producer.return_value.poll.return_value = 1 - def sync_produce(*args, **kwargs): - nonlocal captured_callback - # This runs in real ThreadPoolExecutor - captured_callback = kwargs.get('on_delivery') - produce_called.set() - # Don't call callback immediately - simulate real async behavior + # Start first produce - won't trigger flush yet, but will return a Future + first_future = await producer.produce(topic="test", value="test1") - mock_producer.return_value.produce.side_effect = sync_produce + # The Future should be pending (not resolved yet) + assert not first_future.done() - # Start produce but don't await yet - this tests real Future behavior - produce_task = asyncio.create_task( - producer.produce(topic="test", value="test") - ) + # Add second message to trigger batch flush + await producer.produce(topic="test", value="test2") - # Wait for the ThreadPoolExecutor to execute the produce call - await asyncio.wait_for(produce_called.wait(), timeout=2.0) + # Wait for the batch operation to be called + await asyncio.wait_for(batch_called.wait(), timeout=2.0) - # Produce was called but Future should still be waiting for callback - assert captured_callback is not None - assert not produce_task.done() + # Batch was called and should have captured messages with callbacks + assert captured_messages is not None + assert len(captured_messages) == 2 + assert not first_future.done() # Simulate delayed delivery callback (like from background polling) mock_msg = Mock() mock_msg.topic.return_value = "test" - mock_msg.value.return_value = b"test" - captured_callback(None, mock_msg) + mock_msg.value.return_value = b"test1" + + # Call callback for first message (index 0) + first_callback = captured_messages[0]['callback'] + first_callback(None, mock_msg) # Now the Future should resolve - result = await produce_task + result = await first_future assert result == mock_msg - @pytest.mark.asyncio - async def test_auto_polling_background_loop(self, mock_producer, mock_common, basic_config): - """Test that auto-polling runs continuously in the background.""" - mock_producer.return_value.poll.return_value = 0 - - # Track polling with event - multiple_polls = asyncio.Event() - poll_count = 0 + await producer.close() - def poll_tracker(*args, **kwargs): - nonlocal poll_count - poll_count += 1 - if poll_count >= 5: - multiple_polls.set() - return 0 - - mock_producer.return_value.poll.side_effect = poll_tracker - - # Create producer with auto-polling enabled - producer = AIOProducer(basic_config, auto_poll=True) + @pytest.mark.asyncio + async def test_buffer_timeout_background_task(self, mock_producer, mock_common, basic_config): + """Test that buffer timeout task runs continuously in the background.""" + # Create producer with short timeout for testing + producer = AIOProducer(basic_config, buffer_timeout=0.1) - try: - assert producer._running is True - assert hasattr(producer, '_running_loop') + # Test that timeout task is created and running + assert producer._buffer_timeout_manager._timeout_task is not None + assert not producer._buffer_timeout_manager._timeout_task.done() + assert producer._is_closed is False - # Wait for continuous polling (proves both start and continuity) - await asyncio.wait_for(multiple_polls.wait(), timeout=2.0) - assert poll_count >= 5 + # Wait a bit to ensure task is running + await asyncio.sleep(0.05) + assert not producer._buffer_timeout_manager._timeout_task.done() - finally: - # Stop the producer - await producer.stop() + # Close the producer + await producer.close() - # Verify polling stops - final_count = poll_count + # Verify task stops and producer is closed + assert producer._is_closed is True await asyncio.sleep(0.1) # Grace period for cleanup - assert poll_count - final_count <= 1 # Allow one final poll + assert (producer._buffer_timeout_manager._timeout_task is None or + producer._buffer_timeout_manager._timeout_task.done()) @pytest.mark.asyncio async def test_multiple_concurrent_produce(self, mock_producer, mock_common, basic_config): - """Test multiple concurrent produce operations through real ThreadPoolExecutor.""" - producer = AIOProducer(basic_config, auto_poll=False, max_workers=3) + """Test multiple concurrent produce operations with batching.""" + producer = AIOProducer(basic_config, max_workers=3, batch_size=1) # Force immediate flush completed_produces = [] - produce_call_count = 0 - - # Capture the main event loop to use from thread - main_loop = asyncio.get_event_loop() + batch_call_count = 0 - def sync_produce(*args, **kwargs): - nonlocal produce_call_count - produce_call_count += 1 + def mock_produce_batch(topic, messages): + nonlocal batch_call_count + batch_call_count += 1 - # This runs in ThreadPoolExecutor - simulate real produce call - callback = kwargs.get('on_delivery') - topic = kwargs.get('topic') - value = kwargs.get('value') - - if callback: - # Create mock message for this specific produce call + # Simulate successful delivery for each message in batch + for i, msg_data in enumerate(messages): mock_msg = Mock() mock_msg.topic.return_value = topic - mock_msg.value.return_value = value.encode() if isinstance(value, str) else value - - # Simulate successful delivery from background thread - def deliver_callback(): - completed_produces.append((topic, value)) - callback(None, mock_msg) + mock_msg.value.return_value = ( + msg_data['value'].encode() if isinstance( + msg_data['value'], str) else msg_data['value']) - # Simulate slight delay like real produce would have - main_loop.call_later(0.01, deliver_callback) + completed_produces.append((topic, msg_data['value'])) + # Call the individual message callback + if 'callback' in msg_data: + msg_data['callback'](None, mock_msg) - mock_producer.return_value.produce.side_effect = sync_produce + mock_producer.return_value.produce_batch.side_effect = mock_produce_batch + mock_producer.return_value.poll.return_value = 1 # Start multiple produce operations concurrently tasks = [ @@ -265,10 +244,299 @@ def deliver_callback(): # Verify all operations completed assert len(results) == 3 assert all(result is not None for result in results) - assert produce_call_count == 3 + assert batch_call_count == 3 # Each message triggers its own batch due to batch_size=1 assert len(completed_produces) == 3 # Verify all messages were produced produced_values = [value for topic, value in completed_produces] expected_values = ["msg0", "msg1", "msg2"] assert sorted(produced_values) == sorted(expected_values) + + await producer.close() + + @pytest.mark.asyncio + async def test_constructor_new_implementation(self, mock_producer, mock_common, basic_config): + producer1 = AIOProducer(basic_config) + assert producer1._batch_size == 1000 + assert isinstance(producer1.executor, concurrent.futures.ThreadPoolExecutor) + assert hasattr(producer1, '_loop') + assert hasattr(producer1, '_buffer_timeout_manager') + assert producer1._batch_processor.is_buffer_empty() + assert producer1._is_closed is False + await producer1.close() + + custom_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) + try: + producer2 = AIOProducer( + basic_config, + executor=custom_executor, + batch_size=500, + buffer_timeout=10.0 + ) + assert producer2.executor is custom_executor + assert producer2._batch_size == 500 + assert hasattr(producer2, '_buffer_timeout_manager') + await producer2.close() + finally: + custom_executor.shutdown(wait=True) + + @pytest.mark.asyncio + async def test_lifecycle_management_new_implementation(self, mock_producer, mock_common, basic_config): + """Test lifecycle management for current implementation.""" + + # Test close method with messages in buffer + producer = AIOProducer(basic_config) + + # Add some messages to buffer + with patch.object(producer, '_flush_buffer'): + await producer.produce('test', 'msg1') + assert producer._batch_processor.get_buffer_size() == 1 + + # Test close method + await producer.close() + assert producer._is_closed is True + assert (producer._buffer_timeout_manager._timeout_task is None or + producer._buffer_timeout_manager._timeout_task.done()) + + @pytest.mark.asyncio + async def test_buffer_timeout_task_management(self, mock_producer, mock_common, basic_config): + """Test timeout task lifecycle and weak references.""" + + # Test task creation and configuration + producer = AIOProducer(basic_config, buffer_timeout=1.0) + assert producer._buffer_timeout_manager._timeout_task is not None + assert not producer._buffer_timeout_manager._timeout_task.done() + assert producer._buffer_timeout_manager._timeout == 1.0 + assert producer._is_closed is False + + # Test task stops on close + await producer.close() + assert producer._is_closed is True + assert (producer._buffer_timeout_manager._timeout_task is None or + producer._buffer_timeout_manager._timeout_task.done()) + + @pytest.mark.asyncio + async def test_buffer_timeout_behavior(self, mock_producer, mock_common, basic_config): + """Test buffer activity tracking and timeout triggers.""" + + # Test buffer activity tracking + producer = AIOProducer(basic_config) + initial_time = producer._buffer_timeout_manager._last_activity + assert initial_time > 0 + + # Activity updates on produce + await asyncio.sleep(0.01) # Ensure time difference + with patch.object(producer, '_flush_buffer'): # Prevent auto-flush + await producer.produce('test', 'msg1') + assert producer._buffer_timeout_manager._last_activity > initial_time + + await producer.close() + + @pytest.mark.asyncio + async def test_poll_method_new_implementation(self, mock_producer, mock_common, basic_config): + """Test poll method with different timeout scenarios.""" + producer = AIOProducer(basic_config) + + # Test timeout=0 (non-blocking) + with patch.object(producer, '_call') as mock_call: + await producer.poll(timeout=0) + mock_call.assert_called_with(producer._producer.poll, 0) + + # Test positive timeout (blocking via ThreadPool) + with patch.object(producer, '_call') as mock_call: + await producer.poll(timeout=5) + mock_call.assert_called_with(producer._producer.poll, 5) + + await producer.close() + + @pytest.mark.asyncio + async def test_produce_method_batching(self, mock_producer, mock_common, basic_config): + """Test produce method with batching behavior.""" + producer = AIOProducer(basic_config, batch_size=3) + + # Test basic produce adds to buffer + with patch.object(producer, '_flush_buffer') as mock_flush: + future1 = await producer.produce('topic1', 'value1', key='key1') + assert producer._batch_processor.get_buffer_size() == 1 + assert len(producer._batch_processor._buffer_futures) == 1 + assert isinstance(future1, asyncio.Future) + mock_flush.assert_not_called() # Should not flush yet + + # Test batch size trigger (3rd message should trigger flush) + with patch.object(producer, '_flush_buffer') as mock_flush: + await producer.produce('topic1', 'value2') # 2nd message + await producer.produce('topic1', 'value3') # 3rd message - should trigger flush + mock_flush.assert_called() + + await producer.close() + + @pytest.mark.asyncio + async def test_flush_and_purge_methods_new_implementation(self, mock_producer, mock_common, basic_config): + """Test flush and purge methods for current implementation.""" + producer = AIOProducer(basic_config) + + # Add messages to buffer + with patch.object(producer, '_flush_buffer'): # Prevent auto-flush + await producer.produce('test', 'msg1') + await producer.produce('test', 'msg2') + assert producer._batch_processor.get_buffer_size() == 2 + + # Test purge clears buffers + with patch.object(producer, '_call') as mock_call: + await producer.purge() + mock_call.assert_called_with(producer._producer.purge) + + assert producer._batch_processor.get_buffer_size() == 0 + assert len(producer._batch_processor._buffer_futures) == 0 + + await producer.close() + + @pytest.mark.asyncio + async def test_group_messages_by_topic(self, mock_producer, mock_common, basic_config): + """Test message grouping by topic for batch processing.""" + producer = AIOProducer(basic_config) + + # Test empty buffer + groups = producer._batch_processor._group_messages_by_topic() + assert groups == {} + + # Add mixed topic messages + producer._batch_processor._message_buffer = [ + {'topic': 'topic1', 'value': 'msg1', 'user_callback': None}, + {'topic': 'topic2', 'value': 'msg2', 'user_callback': Mock()}, + {'topic': 'topic1', 'value': 'msg3', 'user_callback': None}, + ] + producer._batch_processor._buffer_futures = [Mock(), Mock(), Mock()] + + groups = producer._batch_processor._group_messages_by_topic() + + # Test grouping correctness + assert len(groups) == 2 + assert 'topic1' in groups and 'topic2' in groups + assert len(groups['topic1']['messages']) == 2 # msg1, msg3 + assert len(groups['topic2']['messages']) == 1 # msg2 + + await producer.close() + + @pytest.mark.asyncio + async def test_error_handling_new_implementation(self, mock_producer, mock_common, basic_config): + """Test error handling in current implementation.""" + producer = AIOProducer(basic_config) + + # Test batch error propagation + producer._batch_processor._message_buffer = [{'topic': 'test', 'value': 'msg', 'user_callback': None}] + producer._batch_processor._buffer_futures = [asyncio.Future()] + + with patch.object(producer._batch_processor, 'flush_buffer', side_effect=Exception("Batch failed")): + with pytest.raises(Exception, match="Batch failed"): + await producer._flush_buffer() + + await producer.close() + + @pytest.mark.asyncio + async def test_future_based_usage_pattern(self, mock_producer, mock_common, basic_config): + """Test the recommended Future-based usage pattern instead of callbacks.""" + producer = AIOProducer(basic_config, batch_size=1) + + # Mock successful delivery + mock_msg = Mock() + mock_msg.topic.return_value = "test-topic" + mock_msg.value.return_value = b"test-value" + + async def mock_flush_buffer(target_topic=None): + batches = producer._batch_processor.create_batches(target_topic) + for batch in batches: + for future in batch.futures: + if not future.done(): + future.set_result(mock_msg) + producer._batch_processor.clear_buffer() + + with patch.object(producer, '_flush_buffer', side_effect=mock_flush_buffer): + # Recommended usage: await the Future returned by produce() + future = await producer.produce(topic="test-topic", value="test-value") + result = await future + + # Verify the result + assert result is mock_msg + assert result.topic() == "test-topic" + assert result.value() == b"test-value" + + await producer.close() + + @pytest.mark.asyncio + async def test_future_based_error_handling(self, mock_producer, mock_common, basic_config): + """Test Future-based error handling pattern.""" + producer = AIOProducer(basic_config, batch_size=1) + + # Mock delivery error + mock_error = KafkaException(KafkaError(KafkaError._MSG_TIMED_OUT)) + + async def mock_flush_buffer(target_topic=None): + batches = producer._batch_processor.create_batches(target_topic) + for batch in batches: + for future in batch.futures: + if not future.done(): + future.set_exception(mock_error) + producer._batch_processor.clear_buffer() + + with patch.object(producer, '_flush_buffer', side_effect=mock_flush_buffer): + # Recommended usage: handle exceptions via Future + future = await producer.produce(topic="test-topic", value="test-value") + + with pytest.raises(KafkaException): + await future + + await producer.close() + + @pytest.mark.asyncio + async def test_future_based_concurrent_usage(self, mock_producer, mock_common, basic_config): + """Test Future-based concurrent usage pattern.""" + producer = AIOProducer(basic_config, batch_size=1) + + # Mock successful delivery + mock_msg = Mock() + mock_msg.topic.return_value = "test-topic" + mock_msg.value.return_value = b"test-value" + + async def mock_flush_buffer(target_topic=None): + batches = producer._batch_processor.create_batches(target_topic) + for batch in batches: + for future in batch.futures: + if not future.done(): + future.set_result(mock_msg) + producer._batch_processor.clear_buffer() + + with patch.object(producer, '_flush_buffer', side_effect=mock_flush_buffer): + # Recommended usage: collect Futures and await them together + futures = [] + for i in range(5): + future = await producer.produce(topic="test-topic", value=f"test-value-{i}") + futures.append(future) + + # Wait for all deliveries to complete + results = await asyncio.gather(*futures) + + # Verify all results + assert len(results) == 5 + for result in results: + assert result is mock_msg + + await producer.close() + + @pytest.mark.asyncio + async def test_edge_cases_batching(self, mock_producer, mock_common, basic_config): + """Test edge cases in batching behavior.""" + producer = AIOProducer(basic_config, batch_size=100) + + # Test large batch handling + with patch.object(producer, '_flush_buffer') as mock_flush: + large_batch_tasks = [ + producer.produce('test', f'msg{i}') + for i in range(150) # Exceeds batch_size + ] + + # Should trigger flush automatically at 100 + await asyncio.gather(*large_batch_tasks) + assert mock_flush.call_count >= 1 # At least one flush + + await producer.close() diff --git a/tests/test_kafka_batch_executor.py b/tests/test_kafka_batch_executor.py new file mode 100644 index 000000000..4ae6c1607 --- /dev/null +++ b/tests/test_kafka_batch_executor.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +""" +Unit tests for the KafkaBatchExecutor class (_kafka_batch_executor.py) + +This module tests the KafkaBatchExecutor class to ensure proper +Kafka batch execution and partial failure handling. +""" + +from confluent_kafka.aio.producer._kafka_batch_executor import ProducerBatchExecutor as KafkaBatchExecutor +import confluent_kafka +import asyncio +import unittest +from unittest.mock import Mock, patch +import sys +import os +import concurrent.futures + +# Add src to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + + +class TestKafkaBatchExecutor(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + self.mock_producer = Mock(spec=confluent_kafka.Producer) + self.kafka_executor = KafkaBatchExecutor(self.mock_producer, self.executor) + + def tearDown(self): + self.executor.shutdown(wait=True) + self.loop.close() + + def test_initialization(self): + self.assertEqual(self.kafka_executor._producer, self.mock_producer) + self.assertEqual(self.kafka_executor._executor, self.executor) + + def test_execute_batch_success(self): + async def async_test(): + batch_messages = [ + {'value': 'test1', 'callback': Mock()}, + {'value': 'test2', 'callback': Mock()}, + ] + + self.mock_producer.produce_batch.return_value = None + self.mock_producer.poll.return_value = 2 + + result = await self.kafka_executor.execute_batch('test-topic', batch_messages) + + self.mock_producer.produce_batch.assert_called_once_with('test-topic', batch_messages) + self.mock_producer.poll.assert_called_once_with(0) + self.assertEqual(result, 2) + + self.loop.run_until_complete(async_test()) + + def test_partial_failure_handling(self): + async def async_test(): + callback1 = Mock() + callback2 = Mock() + batch_messages = [ + {'value': 'test1', 'callback': callback1}, + {'value': 'test2', 'callback': callback2, '_error': 'MSG_SIZE_TOO_LARGE'}, + ] + + self.mock_producer.produce_batch.return_value = None + self.mock_producer.poll.return_value = 1 + + result = await self.kafka_executor.execute_batch('test-topic', batch_messages) + + self.mock_producer.produce_batch.assert_called_once_with('test-topic', batch_messages) + self.mock_producer.poll.assert_called_once_with(0) + + callback1.assert_not_called() + callback2.assert_called_once_with('MSG_SIZE_TOO_LARGE', None) + + self.assertEqual(result, 1) + + self.loop.run_until_complete(async_test()) + + def test_batch_execution_exception(self): + async def async_test(): + batch_messages = [{'value': 'test1', 'callback': Mock()}] + + self.mock_producer.produce_batch.side_effect = Exception("Kafka error") + + with self.assertRaises(Exception) as context: + await self.kafka_executor.execute_batch('test-topic', batch_messages) + + self.assertEqual(str(context.exception), "Kafka error") + self.mock_producer.produce_batch.assert_called_once_with('test-topic', batch_messages) + + self.loop.run_until_complete(async_test()) + + def test_callback_exception_handling(self): + async def async_test(): + failing_callback = Mock(side_effect=Exception("Callback error")) + + batch_messages = [ + {'value': 'test1', 'callback': failing_callback, '_error': 'TEST_ERROR'}, + ] + + self.mock_producer.produce_batch.return_value = None + self.mock_producer.poll.return_value = 0 + + # Expect the callback exception to be raised + with self.assertRaises(Exception) as context: + await self.kafka_executor.execute_batch('test-topic', batch_messages) + + # Verify the callback was called before the exception + failing_callback.assert_called_once_with('TEST_ERROR', None) + self.assertEqual(str(context.exception), "Callback error") + + self.loop.run_until_complete(async_test()) + + def test_thread_pool_execution(self): + async def async_test(): + batch_messages = [{'value': 'test1', 'callback': Mock()}] + + with patch.object(self.loop, 'run_in_executor') as mock_run_in_executor: + future_result = self.loop.create_future() + future_result.set_result(1) + mock_run_in_executor.return_value = future_result + + result = await self.kafka_executor.execute_batch('test-topic', batch_messages) + + mock_run_in_executor.assert_called_once() + args = mock_run_in_executor.call_args + self.assertEqual(args[0][0], self.executor) + self.assertTrue(callable(args[0][1])) + + self.assertEqual(result, 1) + + self.loop.run_until_complete(async_test()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_producer_batch_processor.py b/tests/test_producer_batch_processor.py new file mode 100644 index 000000000..b6f091b19 --- /dev/null +++ b/tests/test_producer_batch_processor.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 +""" +Unit tests for the BatchProcessor class (_batch_processor.py) + +This module tests the BatchProcessor class to ensure proper +message batching, topic grouping, and future management. +""" + +from confluent_kafka.aio.producer._kafka_batch_executor import ProducerBatchExecutor as KafkaBatchExecutor +from confluent_kafka.aio.producer._AIOProducer import AIOProducer +from confluent_kafka.aio.producer._producer_batch_processor import ProducerBatchManager as ProducerBatchProcessor +import asyncio +import unittest +from unittest.mock import Mock, patch +import sys +import os +import concurrent.futures +import confluent_kafka + + +# Add src to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + + +class TestProducerBatchProcessor(unittest.TestCase): + """Test cases for ProducerBatchProcessor class""" + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + + self.producer_config = { + 'bootstrap.servers': 'localhost:9092', + 'client.id': 'test-producer', + 'message.timeout.ms': 100, + 'queue.buffering.max.messages': 1000, + 'api.version.request': False, + } + + self.confluent_kafka_producer = confluent_kafka.Producer(self.producer_config) + + self.kafka_executor = KafkaBatchExecutor(self.confluent_kafka_producer, self.executor) + self.batch_processor = ProducerBatchProcessor(self.kafka_executor) + + async def create_aio_producer(): + return AIOProducer(self.producer_config, executor=self.executor) + + self.aio_producer = self.loop.run_until_complete(create_aio_producer()) + + def tearDown(self): + try: + self.loop.run_until_complete(self.aio_producer.close()) + except Exception: + pass + + try: + self.confluent_kafka_producer.flush(timeout=1) + except Exception: + pass + + try: + self.executor.shutdown(wait=True, timeout=1) + except Exception: + pass + + self.loop.close() + + def test_basic_functionality(self): + self.assertEqual(self.batch_processor.get_buffer_size(), 0) + self.assertTrue(self.batch_processor.is_buffer_empty()) + + future1 = Mock() + future2 = Mock() + msg_data1 = {'topic': 'topic1', 'value': 'test1', 'key': 'key1'} + msg_data2 = {'topic': 'topic2', 'value': 'test2', 'key': 'key2'} + + self.batch_processor.add_message(msg_data1, future1) + self.batch_processor.add_message(msg_data2, future2) + + self.assertEqual(self.batch_processor.get_buffer_size(), 2) + self.assertFalse(self.batch_processor.is_buffer_empty()) + + self.batch_processor.clear_buffer() + + self.assertEqual(self.batch_processor.get_buffer_size(), 0) + self.assertTrue(self.batch_processor.is_buffer_empty()) + + def test_group_messages_by_topic(self): + future1 = Mock() + future2 = Mock() + future3 = Mock() + + msg1 = {'topic': 'topic1', 'value': 'test1', 'user_callback': Mock()} + msg2 = {'topic': 'topic2', 'value': 'test2'} + msg3 = {'topic': 'topic1', 'value': 'test3', 'user_callback': Mock()} + + self.batch_processor.add_message(msg1, future1) + self.batch_processor.add_message(msg2, future2) + self.batch_processor.add_message(msg3, future3) + + topic_groups = self.batch_processor._group_messages_by_topic() + + self.assertEqual(len(topic_groups), 2) + + self.assertIn('topic1', topic_groups) + topic1_group = topic_groups['topic1'] + self.assertEqual(len(topic1_group['messages']), 2) + self.assertEqual(len(topic1_group['futures']), 2) + self.assertEqual(topic1_group['futures'][0], future1) + self.assertEqual(topic1_group['futures'][1], future3) + + self.assertIn('topic2', topic_groups) + topic2_group = topic_groups['topic2'] + self.assertEqual(len(topic2_group['messages']), 1) + self.assertEqual(len(topic2_group['futures']), 1) + self.assertEqual(topic2_group['futures'][0], future2) + + def test_prepare_batch_messages(self): + messages = [ + {'topic': 'test', 'value': 'test1', 'user_callback': Mock(), 'key': 'key1'}, + {'topic': 'test', 'value': 'test2', 'partition': 1}, + ] + + batch_messages = self.batch_processor._prepare_batch_messages(messages) + + self.assertEqual(len(batch_messages), 2) + + self.assertNotIn('topic', batch_messages[0]) + self.assertIn('value', batch_messages[0]) + self.assertIn('key', batch_messages[0]) + + self.assertNotIn('topic', batch_messages[1]) + self.assertIn('value', batch_messages[1]) + self.assertIn('partition', batch_messages[1]) + + def test_assign_future_callbacks(self): + batch_messages = [ + {'value': 'test1'}, + {'value': 'test2'}, + ] + futures = [Mock(), Mock()] + + self.batch_processor._assign_future_callbacks(batch_messages, futures) + + self.assertIn('callback', batch_messages[0]) + self.assertIn('callback', batch_messages[1]) + + def test_handle_batch_failure(self): + """Test handling batch failures""" + futures = [Mock(), Mock()] + futures[0].done.return_value = False + futures[1].done.return_value = True # Already done + + exception = RuntimeError("Batch failed") + + # Handle batch failure + self.batch_processor._handle_batch_failure( + exception, futures + ) + + # Verify first future got exception (not already done) + futures[0].set_exception.assert_called_once_with(exception) + + # Verify second future was not modified (already done) + futures[1].set_exception.assert_not_called() + + # Note: For real AIOProducer, the user callback is invoked directly by _handle_user_callback + + def test_flush_empty_buffer(self): + """Test flushing empty buffer is no-op""" + async def async_test(): + await self.batch_processor.flush_buffer() + self.assertTrue(self.batch_processor.is_buffer_empty()) + + self.loop.run_until_complete(async_test()) + + def test_flush_buffer_with_messages(self): + """Test successful buffer flush with messages""" + async def async_test(): + future1 = self.loop.create_future() + future2 = self.loop.create_future() + msg1 = {'topic': 'topic1', 'value': 'test1'} + msg2 = {'topic': 'topic1', 'value': 'test2'} + + self.batch_processor.add_message(msg1, future1) + self.batch_processor.add_message(msg2, future2) + + success_future = self.loop.create_future() + success_future.set_result("success") + + with patch.object(self.loop, 'run_in_executor', return_value=success_future): + await self.batch_processor.flush_buffer() + + self.assertTrue(self.batch_processor.is_buffer_empty()) + + self.loop.run_until_complete(async_test()) + + def test_flush_buffer_selective_topic(self): + """Test selective topic flushing""" + async def async_test(): + future3 = self.loop.create_future() + future4 = self.loop.create_future() + msg3 = {'topic': 'topic1', 'value': 'test3'} + msg4 = {'topic': 'topic2', 'value': 'test4'} + + self.batch_processor.add_message(msg3, future3) + self.batch_processor.add_message(msg4, future4) + + success_future = self.loop.create_future() + success_future.set_result("success") + + with patch.object(self.loop, 'run_in_executor', return_value=success_future): + await self.batch_processor.flush_buffer(target_topic='topic1') + + self.assertEqual(self.batch_processor.get_buffer_size(), 1) + + self.loop.run_until_complete(async_test()) + + def test_flush_buffer_exception_handling(self): + """Test exception handling during buffer flush""" + async def async_test(): + future = self.loop.create_future() + msg = {'topic': 'topic1', 'value': 'test'} + self.batch_processor.add_message(msg, future) + + exception = RuntimeError("Batch execution failed") + + with patch.object(self.loop, 'run_in_executor', side_effect=exception): + with self.assertRaises(RuntimeError): + await self.batch_processor.flush_buffer() + + self.assertTrue(self.batch_processor.is_buffer_empty()) + + self.loop.run_until_complete(async_test()) + + def test_kafka_executor_integration(self): + """Test executing a batch operation via KafkaBatchExecutor""" + async def async_test(): + batch_messages = [ + {'value': 'test1', 'callback': Mock()}, + {'value': 'test2', 'callback': Mock()}, + ] + + # Mock the executor to return a completed future + future_result = self.loop.create_future() + future_result.set_result("poll_result") + + with patch.object(self.loop, 'run_in_executor', return_value=future_result) as mock_run_in_executor: + result = await self.kafka_executor.execute_batch('test-topic', batch_messages) + + # Verify run_in_executor was called + mock_run_in_executor.assert_called_once() + self.assertEqual(result, "poll_result") + + self.loop.run_until_complete(async_test()) + + def _create_mixed_topic_messages(self): + """Helper to create messages across multiple topics""" + messages_data = [] + futures = [] + user_callbacks = [] + + for i in range(4): + future = Mock() + future.done.return_value = False + user_callback = Mock() + msg_data = { + 'topic': f'topic{i % 2}', + 'value': f'unique_value_{i}', + 'key': f'unique_key_{i}', + 'user_callback': user_callback + } + + self.batch_processor.add_message(msg_data, future) + messages_data.append(msg_data) + futures.append(future) + user_callbacks.append(user_callback) + + return messages_data, futures, user_callbacks + + def _add_alternating_topic_messages(self): + """Helper to add messages alternating between two topics""" + futures = [] + for i in range(5): + future = self.loop.create_future() + msg_data = { + 'topic': f'topic{i % 2}', + 'value': f'test{i}', + 'key': f'key{i}' + } + self.batch_processor.add_message(msg_data, future) + futures.append(future) + return futures + + def test_batch_cycle_buffer_state(self): + """Test buffer state during batch cycle""" + self._add_alternating_topic_messages() + self.assertEqual(self.batch_processor.get_buffer_size(), 5) + self.assertFalse(self.batch_processor.is_buffer_empty()) + + def test_batch_cycle_topic_grouping(self): + """Test topic grouping in batch cycle""" + self._add_alternating_topic_messages() + topic_groups = self.batch_processor._group_messages_by_topic() + + self.assertEqual(len(topic_groups), 2) + self.assertIn('topic0', topic_groups) + self.assertIn('topic1', topic_groups) + self.assertEqual(len(topic_groups['topic0']['messages']), 3) + self.assertEqual(len(topic_groups['topic1']['messages']), 2) + + def test_batch_cycle_message_preparation(self): + """Test message preparation in batch cycle""" + self._add_alternating_topic_messages() + topic_groups = self.batch_processor._group_messages_by_topic() + + batch_messages = self.batch_processor._prepare_batch_messages( + topic_groups['topic0']['messages'] + ) + + self.assertEqual(len(batch_messages), 3) + for batch_msg in batch_messages: + self.assertNotIn('topic', batch_msg) + self.assertIn('value', batch_msg) + self.assertIn('key', batch_msg) + + def test_batch_message_preparation_with_mixed_sizes(self): + """Test batch message preparation with mixed message sizes""" + # Create test messages with different sizes + messages = [ + {'topic': 'test-topic', 'value': 'small message'}, + {'topic': 'test-topic', 'value': 'x' * (5 * 1024 * 1024)}, # Large message + {'topic': 'test-topic', 'value': 'another small'}, + ] + futures = [asyncio.Future(), asyncio.Future(), asyncio.Future()] + + for msg, future in zip(messages, futures): + self.batch_processor.add_message(msg, future) + + topic_groups = self.batch_processor._group_messages_by_topic() + topic_data = topic_groups['test-topic'] + batch_messages = self.batch_processor._prepare_batch_messages(topic_data['messages']) + + self.assertEqual(len(batch_messages), 3) + large_msg = next((msg for msg in batch_messages if len(str(msg.get('value', ''))) > 1000), None) + self.assertIsNotNone(large_msg) + + def test_future_based_usage_pattern(self): + """Test the recommended Future-based usage pattern instead of callbacks.""" + # Create test messages without user callbacks + messages = [ + {'topic': 'test-topic', 'value': 'test1', 'key': 'key1'}, + {'topic': 'test-topic', 'value': 'test2', 'key': 'key2'}, + ] + futures = [asyncio.Future(), asyncio.Future()] + + # Add messages to batch processor + for msg, future in zip(messages, futures): + self.batch_processor.add_message(msg, future) + + # Verify messages are in buffer + self.assertEqual(self.batch_processor.get_buffer_size(), 2) + + # Simulate successful delivery by resolving futures + mock_msg1 = Mock() + mock_msg1.topic.return_value = 'test-topic' + mock_msg1.value.return_value = b'test1' + + mock_msg2 = Mock() + mock_msg2.topic.return_value = 'test-topic' + mock_msg2.value.return_value = b'test2' + + # Applications should await these futures to get delivery results + futures[0].set_result(mock_msg1) + futures[1].set_result(mock_msg2) + + # Verify futures are resolved + self.assertTrue(futures[0].done()) + self.assertTrue(futures[1].done()) + self.assertEqual(futures[0].result(), mock_msg1) + self.assertEqual(futures[1].result(), mock_msg2) + + def test_future_based_error_handling(self): + """Test Future-based error handling pattern.""" + # Create test message + message = {'topic': 'test-topic', 'value': 'test', 'key': 'key'} + future = asyncio.Future() + + # Add message to batch processor + self.batch_processor.add_message(message, future) + + # Simulate delivery error by setting exception on future + mock_error = RuntimeError("Delivery failed") + future.set_exception(mock_error) + + # Verify future is resolved with exception + self.assertTrue(future.done()) + with self.assertRaises(RuntimeError): + future.result() + + +if __name__ == '__main__': + # Run all tests + unittest.main(verbosity=2)