diff --git a/tests/ducktape/services/kafka.py b/tests/ducktape/services/kafka.py index e44725d75..725fd3ceb 100644 --- a/tests/ducktape/services/kafka.py +++ b/tests/ducktape/services/kafka.py @@ -42,7 +42,7 @@ def verify_connection(self): list(metadata.topics.keys())) return True except Exception as e: - self.logger.error("Failed to connect to Kafka at %s: %s", self.bootstrap_servers_str, e) + self.logger.error("Failed to connect to Kafka: %s", e) return False def create_topic(self, topic, partitions=1, replication_factor=1): @@ -124,3 +124,30 @@ def wait_for_topic(self, topic_name, max_wait_time=30, initial_wait=0.1): self.logger.error("Timeout waiting for topic '%s' after %ds", topic_name, max_wait_time) return False + + def add_partitions(self, topic_name, new_partition_count): + """Add partitions to an existing topic""" + try: + from confluent_kafka.admin import AdminClient, NewPartitions + + admin_client = AdminClient({'bootstrap.servers': self.bootstrap_servers_str}) + metadata = admin_client.list_topics(timeout=10) + + if topic_name not in metadata.topics: + raise ValueError(f"Topic {topic_name} does not exist") + + current_partitions = len(metadata.topics[topic_name].partitions) + if new_partition_count <= current_partitions: + return # No change needed + + # Add partitions + new_partitions = NewPartitions(topic=topic_name, new_total_count=new_partition_count) + fs = admin_client.create_partitions([new_partitions]) + + # Wait for completion + for topic, f in fs.items(): + f.result(timeout=30) + + except Exception as e: + self.logger.error("Failed to add partitions to topic %s: %s", topic_name, e) + raise diff --git a/tests/ducktape/test_consumer.py b/tests/ducktape/test_consumer.py index 25930132b..821c07b13 100644 --- a/tests/ducktape/test_consumer.py +++ b/tests/ducktape/test_consumer.py @@ -8,10 +8,13 @@ from ducktape.mark import matrix from tests.ducktape.services.kafka import KafkaClient -from tests.ducktape.consumer_benchmark_metrics import (ConsumerMetricsCollector, ConsumerMetricsBounds, - validate_consumer_metrics, print_consumer_metrics_report) +from tests.ducktape.consumer_benchmark_metrics import ( + ConsumerMetricsCollector, ConsumerMetricsBounds, + validate_consumer_metrics, print_consumer_metrics_report) from tests.ducktape.consumer_strategy import SyncConsumerStrategy, AsyncConsumerStrategy from confluent_kafka import Producer +import asyncio +import pytest class SimpleConsumerTest(Test): @@ -33,9 +36,10 @@ def setup(self): self.logger.info("Successfully connected to Kafka") - def create_consumer(self, consumer_type, batch_size=10): + def create_consumer_strategy(self, consumer_type, group_id=None, batch_size=10): """Create appropriate consumer strategy based on type""" - group_id = f"test-group-{uuid.uuid4()}" # Unique group ID for each test + if not group_id: + group_id = f"test-group-{uuid.uuid4()}" # Unique group ID for each test if consumer_type == "sync": return SyncConsumerStrategy( @@ -52,6 +56,9 @@ def create_consumer(self, consumer_type, batch_size=10): batch_size ) + def create_consumer(self, consumer_type, group_id=None, batch_size=10): + return self.create_consumer_strategy(consumer_type, group_id, batch_size).create_consumer() + def produce_test_messages(self, topic_name, num_messages): """Produce messages to topic for consumer tests""" producer = Producer({'bootstrap.servers': self.kafka.bootstrap_servers()}) @@ -74,13 +81,36 @@ def produce_test_messages(self, topic_name, num_messages): producer.flush(timeout=60) # Final flush with longer timeout self.logger.info(f"Successfully produced {num_messages} messages") + # =========== Performance tests =========== + @matrix(consumer_type=["sync", "async"], batch_size=[1, 5, 20]) def test_basic_consume(self, consumer_type, batch_size): """Test basic message consumption with comprehensive metrics and bounds validation""" + self._run_consumer_performance_test( + consumer_type=consumer_type, + operation_type="consume", + batch_size=batch_size, + ) + + @matrix(consumer_type=["sync", "async"]) + def test_basic_poll(self, consumer_type): + """Test basic message polling (single message) with comprehensive metrics and bounds validation""" + self._run_consumer_performance_test( + consumer_type=consumer_type, + operation_type="poll", + ) - topic_name = f"test-{consumer_type}-consumer-topic" + def _run_consumer_performance_test(self, consumer_type, operation_type, batch_size=None): + """ + Shared helper for consumer performance tests + + Args: + consumer_type: "sync" or "async" + operation_type: "consume" or "poll" + batch_size: Number of messages per batch (default None). Only required for consume operation + """ + topic_name = f"test-{consumer_type}-{operation_type}-topic" test_duration = 5.0 # 5 seconds - # TODO: clean up this magic number num_messages = 1500000 # 1.5M messages for sustained 5-second consumption at ~300K msg/s # Create topic @@ -95,16 +125,20 @@ def test_basic_consume(self, consumer_type, batch_size): self.produce_test_messages(topic_name, num_messages) # Initialize metrics collection and bounds - metrics = ConsumerMetricsCollector(operation_type="consume") + metrics = ConsumerMetricsCollector(operation_type=operation_type) bounds = ConsumerMetricsBounds() # Create appropriate consumer strategy - strategy = self.create_consumer(consumer_type, batch_size) + strategy = self.create_consumer_strategy(consumer_type, batch_size=batch_size) # Assign metrics collector to strategy strategy.metrics = metrics - self.logger.info(f"Testing {consumer_type} consumer for {test_duration} seconds") + if operation_type == "consume": + operation_desc = f"{operation_type} (batch_size={batch_size})" + else: + operation_desc = f"{operation_type} (single messages)" + self.logger.info(f"Testing {consumer_type} consumer {operation_desc} for {test_duration} seconds") # Start metrics collection metrics.start() @@ -112,11 +146,16 @@ def test_basic_consume(self, consumer_type, batch_size): # Container for consumed messages consumed_messages = [] - # Run the test + # Run the test - choose method based on operation type start_time = time.time() - messages_consumed = strategy.consume_messages( - topic_name, test_duration, start_time, consumed_messages, timeout=0.1 - ) + if operation_type == "consume": + messages_consumed = strategy.consume_messages( + topic_name, test_duration, start_time, consumed_messages, timeout=0.1 + ) + else: # poll + messages_consumed = strategy.poll_messages( + topic_name, test_duration, start_time, consumed_messages, timeout=0.1 + ) # Finalize metrics collection metrics.finalize() @@ -128,9 +167,8 @@ def test_basic_consume(self, consumer_type, batch_size): # Print comprehensive metrics report print_consumer_metrics_report(metrics_summary, is_valid, violations, consumer_type, batch_size) - # Get AIOConsumer built-in metrics for comparison (async only) + # Get AIOConsumer built-in metrics for comparison (if requested) final_metrics = strategy.get_final_metrics() - if final_metrics: self.logger.info("=== AIOConsumer Built-in Metrics ===") for key, value in final_metrics.items(): @@ -147,73 +185,241 @@ def test_basic_consume(self, consumer_type, batch_size): if not is_valid: self.logger.warning("Performance bounds validation failed: %s", "; ".join(violations)) - self.logger.info("Successfully completed basic consumption test with comprehensive metrics") - - @matrix(consumer_type=["sync", "async"]) - def test_basic_poll(self, consumer_type): - """Test basic message polling (single message) with comprehensive metrics and bounds validation""" - - topic_name = f"test-{consumer_type}-poll-topic" - test_duration = 5.0 # 5 seconds - num_messages = 1500000 # 1.5M messages for sustained 5-second consumption - - # Create topic - self.kafka.create_topic(topic_name, partitions=1, replication_factor=1) - - # Wait for topic to be available with retry logic - topic_ready = self.kafka.wait_for_topic(topic_name, max_wait_time=30) - assert topic_ready, (f"Topic {topic_name} was not created within timeout. " - f"Available topics: {self.kafka.list_topics()}") - - # Produce test messages - self.produce_test_messages(topic_name, num_messages) - - # Initialize metrics collection and bounds - metrics = ConsumerMetricsCollector(operation_type="poll") - bounds = ConsumerMetricsBounds() + self.logger.info(f"Successfully completed basic {operation_type} test with comprehensive metrics") - # Create appropriate consumer strategy - strategy = self.create_consumer(consumer_type) + # =========== Functional tests =========== - # Assign metrics collector to strategy - strategy.metrics = metrics + def test_async_consumer_joins_and_leaves_rebalance(self): + """Test rebalancing when consumer joins and then leaves the group""" - self.logger.info(f"Testing {consumer_type} consumer polling (single messages) for {test_duration} seconds") + async def async_rebalance_test(): + topic_name = f"test-rebalance-{uuid.uuid4()}" + group_id = f"rebalance-group-{uuid.uuid4()}" # Shared group ID - # Start metrics collection - metrics.start() + # Setup + self._setup_topic_with_messages(topic_name, partitions=2, messages=10) - # Container for consumed messages - consumed_messages = [] + # Create consumers with shared group ID + consumer1 = self.create_consumer("async", group_id) + consumer2 = self.create_consumer("async", group_id) - # Run the test using poll_messages instead of consume_messages - start_time = time.time() - messages_consumed = strategy.poll_messages( - topic_name, test_duration, start_time, consumed_messages, timeout=0.1 - ) + # Track rebalance events + rebalance_events = [] - # Finalize metrics collection - metrics.finalize() + async def track_rebalance(consumer, partitions): + rebalance_events.append(len(partitions)) + await consumer.assign(partitions) - # Get comprehensive metrics summary - metrics_summary = metrics.get_summary() - is_valid, violations = validate_consumer_metrics(metrics_summary, bounds) - - # Print comprehensive metrics report - print_consumer_metrics_report(metrics_summary, is_valid, violations, consumer_type, 1) - - # Enhanced assertions using metrics - assert messages_consumed > 0, "No messages were consumed" - assert len(consumed_messages) > 0, "No messages were collected" - assert metrics_summary['messages_consumed'] > 0, "No messages were consumed (metrics)" - assert metrics_summary['consumption_rate_msg_per_sec'] > 0, \ - f"Consumption rate too low: {metrics_summary['consumption_rate_msg_per_sec']:.2f} msg/s" + try: + # Phase 1: Consumer1 joins (should get all partitions) + await consumer1.subscribe([topic_name], on_assign=track_rebalance) + await self._wait_for_assignment(consumer1, expected_partitions=2) + assert len(rebalance_events) == 1 - # Validate against performance bounds - if not is_valid: - self.logger.warning("Performance bounds validation failed: %s", "; ".join(violations)) + # Phase 2: Consumer2 joins (should split partitions) + await consumer2.subscribe([topic_name], on_assign=track_rebalance) + await self._wait_for_balanced_assignment([consumer1, consumer2], total_partitions=2) + assert len(rebalance_events) >= 2 - self.logger.info("Successfully completed basic poll test with comprehensive metrics") + # Phase 3: Consumer2 leaves (consumer1 should get all partitions back) + await consumer2.close() + await self._wait_for_assignment(consumer1, expected_partitions=2) + assert len(rebalance_events) >= 3 + + # Verify functionality + self.produce_test_messages(topic_name, num_messages=1) + msg = await consumer1.poll(timeout=5.0) + assert msg is not None, "Consumer should receive fresh message" + + finally: + await consumer1.close() + + asyncio.run(async_rebalance_test()) + + def test_async_topic_partition_changes_rebalance(self): + """Test rebalancing when partitions are added to existing topic""" + + async def async_topic_change_test(): + topic_name = f"test-topic-changes-{uuid.uuid4()}" + group_id = f"topic-changes-group-{uuid.uuid4()}" # Shared group ID + + # Setup: Create topic with 2 partitions initially + self.kafka.create_topic(topic_name, partitions=2, replication_factor=1) + topic_ready = self.kafka.wait_for_topic(topic_name, max_wait_time=30) + assert topic_ready, f"Topic {topic_name} was not created" + self.produce_test_messages(topic_name, num_messages=10) + + # Create consumers with shared group ID + consumer1 = self.create_consumer("async", group_id) + consumer2 = self.create_consumer("async", group_id) + + # Track rebalance events + rebalance_events = [] + + async def track_rebalance(consumer, partitions): + rebalance_events.append(len(partitions)) + await consumer.assign(partitions) + + # Both consumers join - should get 1 partition each (2 total) + await consumer1.subscribe([topic_name], on_assign=track_rebalance) + await consumer2.subscribe([topic_name], on_assign=track_rebalance) + + # Wait for initial rebalance + for attempt in range(10): + await consumer1.poll(timeout=1.0) + await consumer2.poll(timeout=1.0) + + assignment1 = await consumer1.assignment() + assignment2 = await consumer2.assignment() + + if len(assignment1) > 0 and len(assignment2) > 0: + break + await asyncio.sleep(1.0) + + # Verify initial state: 2 partitions total, 1 each + assignment1_initial = await consumer1.assignment() + assignment2_initial = await consumer2.assignment() + total_partitions_initial = len(assignment1_initial) + len(assignment2_initial) + + assert total_partitions_initial == 2, \ + f"Should have 2 total partitions initially, got {total_partitions_initial}" + assert len(rebalance_events) >= 2, \ + f"Should have at least 2 rebalance events, got {len(rebalance_events)}" + + # Add partitions to existing topic (2 -> 4 partitions) + self.kafka.add_partitions(topic_name, new_partition_count=4) + + # Produce messages to new partitions to trigger metadata refresh + self.produce_test_messages(topic_name, num_messages=5) + + # Force rebalance by creating a new consumer that joins the group + # This will trigger metadata refresh and rebalancing for all consumers + consumer3 = self.create_consumer("async", group_id) + await consumer3.subscribe([topic_name], on_assign=track_rebalance) + + # Poll all consumers until they detect new partitions and rebalance + consumers = [consumer1, consumer2, consumer3] + for _ in range(30): + # Poll all consumers concurrently + await asyncio.gather(*[c.poll(timeout=1.0) for c in consumers]) + + # Check total partitions across all consumers + assignments = await asyncio.gather(*[c.assignment() for c in consumers]) + total_partitions_current = sum(len(assignment) for assignment in assignments) + + # Rebalance complete when total partitions = 4 (distributed among 3 consumers) + if total_partitions_current == 4: + break + await asyncio.sleep(0.5) + + # Verify final state: 4 partitions total distributed among 3 consumers + assignment1_final = await consumer1.assignment() + assignment2_final = await consumer2.assignment() + assignment3_final = await consumer3.assignment() + total_partitions_final = (len(assignment1_final) + + len(assignment2_final) + + len(assignment3_final)) + + assert total_partitions_final == 4, \ + f"Should have 4 total partitions after adding, got {total_partitions_final}" + # With 3 consumers and 4 partitions, distribution should be roughly 1-2 partitions per consumer + assert len(assignment1_final) >= 1, \ + f"Consumer 1 should have at least 1 partition, got {len(assignment1_final)}" + assert len(assignment2_final) >= 1, \ + f"Consumer 2 should have at least 1 partition, got {len(assignment2_final)}" + assert len(assignment3_final) >= 1, \ + f"Consumer 3 should have at least 1 partition, got {len(assignment3_final)}" + assert len(rebalance_events) >= 5, \ + ("Should have at least 5 rebalance events after partition addition and consumer3 join, " + f"got {len(rebalance_events)}") + + # Verify consumers can still consume from all partitions + msg1 = await consumer1.poll(timeout=5.0) + msg2 = await consumer2.poll(timeout=5.0) + msg3 = await consumer3.poll(timeout=5.0) + messages_received = sum([1 for msg in [msg1, msg2, msg3] if msg is not None]) + assert messages_received > 0, "Consumers should receive messages from new partitions" + + # Clean up + await consumer1.close() + await consumer2.close() + await consumer3.close() + + asyncio.run(async_topic_change_test()) + + def test_async_callback_exception_behavior(self): + """Test current behavior: callback exceptions propagate and fail the consumer""" + + async def async_callback_test(): + topic_name = f"test-callback-exception-{uuid.uuid4()}" + group_id = f"callback-exception-group-{uuid.uuid4()}" + # Setup + self._setup_topic_with_messages(topic_name, partitions=2, messages=10) + consumer = self.create_consumer("async", group_id) + + # Track callback calls and create failing callback + callback_calls = [] + + async def failing_callback(consumer_obj, partitions): + callback_calls.append("called") + raise ValueError("Simulated callback failure") + + try: + # Subscribe with failing callback + await consumer.subscribe([topic_name], on_assign=failing_callback) + + # Current behavior: callback exception should propagate and crash poll() + with pytest.raises(ValueError, match="Simulated callback failure"): + await consumer.poll(timeout=10.0) + + # Verify callback was called before the crash + assert len(callback_calls) == 1, "Callback should have been called before crash" + + finally: + # Consumer may be in an unusable state after the exception + try: + await consumer.close() + except Exception: + pass # Ignore cleanup errors after crash + + asyncio.run(async_callback_test()) + + def _setup_topic_with_messages(self, topic_name, partitions=2, messages=10): + """Helper: Create topic and produce test messages""" + self.kafka.create_topic(topic_name, partitions=partitions, replication_factor=1) + assert self.kafka.wait_for_topic(topic_name, max_wait_time=30) + self.produce_test_messages(topic_name, num_messages=messages) + + async def _wait_for_assignment(self, consumer, expected_partitions, max_wait=15): + """Helper: Wait for consumer to get expected partition count""" + for _ in range(max_wait): + await consumer.poll(timeout=1.0) + assignment = await consumer.assignment() + if len(assignment) == expected_partitions: + return + await asyncio.sleep(1.0) + + assignment = await consumer.assignment() + assert len(assignment) == expected_partitions, \ + f"Expected {expected_partitions} partitions, got {len(assignment)}" + + async def _wait_for_balanced_assignment(self, consumers, total_partitions, max_wait=15): + """Helper: Wait for consumers to split partitions evenly""" + for _ in range(max_wait): + for consumer in consumers: + await consumer.poll(timeout=1.0) + + assignments = [await c.assignment() for c in consumers] + assigned_count = sum(len(a) for a in assignments) + + if assigned_count == total_partitions and all(len(a) > 0 for a in assignments): + return + await asyncio.sleep(1.0) + + assignments = [await c.assignment() for c in consumers] + assigned_count = sum(len(a) for a in assignments) + assert assigned_count == total_partitions, \ + f"Expected {total_partitions} total partitions, got {assigned_count}" def teardown(self): """Clean up test environment""" diff --git a/tests/test_AIOConsumer.py b/tests/test_AIOConsumer.py index fb6337d29..bd4753b49 100644 --- a/tests/test_AIOConsumer.py +++ b/tests/test_AIOConsumer.py @@ -5,12 +5,12 @@ import concurrent.futures from unittest.mock import Mock, patch -from confluent_kafka import TopicPartition +from confluent_kafka import TopicPartition, KafkaError, KafkaException from confluent_kafka.aio._AIOConsumer import AIOConsumer class TestAIOConsumer: - """Unit tests for AIOConsumer class - focusing on async wrapper logic.""" + """Unit tests for AIOConsumer class.""" @pytest.fixture def mock_consumer(self): @@ -124,3 +124,44 @@ async def test_multiple_concurrent_operations(self, mock_consumer, mock_common, results = await asyncio.gather(*tasks) assert len(results) == 3 assert all(result is not None for result in results) + + @pytest.mark.asyncio + async def test_concurrent_operations_error_handling(self, mock_consumer, mock_common, basic_config): + """Test concurrent async operations handle errors gracefully.""" + # Mock: 2 poll calls fail, assignment succeeds + mock_consumer.return_value.poll.side_effect = [ + KafkaException(KafkaError(KafkaError._TRANSPORT)), + KafkaException(KafkaError(KafkaError._TRANSPORT)) + ] + mock_consumer.return_value.assignment.return_value = [] + + consumer = AIOConsumer(basic_config) + + # Run concurrent operations + tasks = [ + consumer.poll(timeout=0.1), + consumer.poll(timeout=0.1), + consumer.assignment() + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify results + assert len(results) == 3 + assert isinstance(results[0], KafkaException) + assert isinstance(results[1], KafkaException) + assert results[2] == [] + + @pytest.mark.asyncio + async def test_network_error_handling(self, mock_consumer, mock_common, basic_config): + """Test AIOConsumer handles network errors gracefully.""" + mock_consumer.return_value.poll.side_effect = KafkaException( + KafkaError(KafkaError._TRANSPORT, "Network timeout") + ) + + consumer = AIOConsumer(basic_config) + + with pytest.raises(KafkaException) as exc_info: + await consumer.poll(timeout=1.0) + + assert exc_info.value.args[0].code() == KafkaError._TRANSPORT