diff --git a/tests/ducktape/consumer_strategy.py b/tests/ducktape/consumer_strategy.py index 5f43b868e..85d6e66e2 100644 --- a/tests/ducktape/consumer_strategy.py +++ b/tests/ducktape/consumer_strategy.py @@ -38,7 +38,6 @@ def create_consumer(self): 'enable.auto.commit': 'true', 'auto.commit.interval.ms': '5000' } - consumer = Consumer(config) return consumer @@ -172,6 +171,7 @@ def create_consumer(self): } self._consumer_instance = AIOConsumer(config, max_workers=20) + return self._consumer_instance def get_final_metrics(self): diff --git a/tests/ducktape/run_ducktape_test.py b/tests/ducktape/run_ducktape_test.py index 8ba5f06b6..e3ae047cc 100755 --- a/tests/ducktape/run_ducktape_test.py +++ b/tests/ducktape/run_ducktape_test.py @@ -26,6 +26,10 @@ def get_test_info(test_type): 'producer_sr': { 'file': 'test_producer_with_schema_registry.py', 'description': 'Producer with Schema Registry Tests' + }, + 'transactions': { + 'file': 'test_transactions.py', + 'description': 'Transactional Producer and Consumer Tests' } } return test_info.get(test_type) @@ -131,21 +135,21 @@ def run_all_tests(args): """Run all available test types""" test_types = ['producer', 'consumer', 'producer_sr'] overall_success = True - + print("Confluent Kafka Python - All Ducktape Tests") print(f"Timestamp: {datetime.now().isoformat()}") print("=" * 70) - + for test_type in test_types: print(f"\n{'='*20} Running {test_type.upper()} Tests {'='*20}") - + # Create a new args object for this test type test_args = argparse.Namespace( test_type=test_type, test_method=args.test_method, debug=args.debug ) - + # Run the specific test type result = run_single_test_type(test_args) if result != 0: @@ -153,7 +157,7 @@ def run_all_tests(args): print(f"\nāŒ {test_type.upper()} tests failed!") else: print(f"\nāœ… {test_type.upper()} tests passed!") - + print(f"\n{'='*70}") if overall_success: print("šŸŽ‰ All tests completed successfully!") @@ -166,7 +170,7 @@ def run_all_tests(args): def main(): """Run the ducktape test based on specified type""" parser = argparse.ArgumentParser(description="Confluent Kafka Python - Ducktape Test Runner") - parser.add_argument('test_type', nargs='?', choices=['producer', 'consumer', 'producer_sr'], + parser.add_argument('test_type', nargs='?', choices=['producer', 'consumer', 'producer_sr', 'transactions'], help='Type of test to run (default: run all tests)') parser.add_argument('test_method', nargs='?', help='Specific test method to run (optional)') @@ -184,4 +188,4 @@ def main(): if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/tests/ducktape/test_transactions.py b/tests/ducktape/test_transactions.py new file mode 100644 index 000000000..4ebfaaea6 --- /dev/null +++ b/tests/ducktape/test_transactions.py @@ -0,0 +1,263 @@ +""" +Ducktape tests for transactions in async producer and consumer. +""" +import uuid +import asyncio +from ducktape.tests.test import Test + +from confluent_kafka.aio._AIOProducer import AIOProducer +from tests.ducktape.consumer_strategy import AsyncConsumerStrategy, SyncConsumerStrategy +from tests.ducktape.services.kafka import KafkaClient +from confluent_kafka import Producer + + +class TransactionsTest(Test): + def __init__(self, test_context): + super(TransactionsTest, self).__init__(test_context=test_context) + self.kafka = KafkaClient(test_context, bootstrap_servers="localhost:9092") + + def setUp(self): + if not self.kafka.verify_connection(): + raise Exception("Cannot connect to Kafka at localhost:9092. Please ensure Kafka is running.") + + def _new_topic(self, partitions: int = 1) -> str: + topic = f"tx-{uuid.uuid4()}" + self.kafka.create_topic(topic, partitions=partitions, replication_factor=1) + assert self.kafka.wait_for_topic(topic, max_wait_time=30) + return topic + + def _create_transactional_producer(self, producer_type): + """Create transactional producer based on type""" + # TODO: use producer_strategy to simplify code once it's merged + if producer_type == "sync": + config = { + 'bootstrap.servers': self.kafka.bootstrap_servers(), + 'batch.size': 65536, + 'linger.ms': 1, + 'compression.type': 'lz4', + 'transactional.id': f'sync-tx-producer-{uuid.uuid4()}', + 'acks': 'all', + 'enable.idempotence': True + } + return Producer(config) + else: # async + config = { + 'bootstrap.servers': self.kafka.bootstrap_servers(), + 'batch.size': 65536, + 'linger.ms': 1, + 'compression.type': 'lz4', + 'transactional.id': f'async-tx-producer-{uuid.uuid4()}', + 'acks': 'all', + 'enable.idempotence': True, + } + return AIOProducer(config, max_workers=10) + + def _create_transactional_consumer(self, consumer_type, group_id=None): + """Create read_committed consumer based on type using strategy pattern""" + if group_id is None: + group_id = f'tx-consumer-{uuid.uuid4()}' + + # Overrides for transactional consumers + overrides = { + 'isolation.level': 'read_committed', + 'enable.auto.commit': False, + 'auto.offset.reset': 'earliest', + } + + if consumer_type == "sync": + strategy = SyncConsumerStrategy(self.kafka.bootstrap_servers(), group_id, self.logger) + return strategy.create_consumer(config_overrides=overrides) + else: # async + strategy = AsyncConsumerStrategy(self.kafka.bootstrap_servers(), group_id, self.logger) + return strategy.create_consumer(config_overrides=overrides) + + async def _transactional_produce(self, producer, topic: str, values: list[str], partition: int | None = None): + """Produce values within an active transaction, polling to drive delivery.""" + for v in values: + kwargs = {'topic': topic, 'value': v} + if partition is not None: + kwargs['partition'] = partition + await producer.produce(**kwargs) + await producer.poll(0.0) + + async def _seed_topic(self, topic: str, values: list[str]): + """Seed a topic using transactional producer.""" + producer = self._create_transactional_producer("sync") + producer.init_transactions() + producer.begin_transaction() + for v in values: + producer.produce(topic, value=v) + producer.poll(0.0) + producer.commit_transaction() + producer.flush() + + # =========== Functional tests (async) =========== + + def test_commit_transaction(self): + """Committed transactional messages must be visible to read_committed consumer.""" + async def run(): + topic = self._new_topic() + producer = self._create_transactional_producer("async") + consumer = self._create_transactional_consumer("async") + try: + await producer.init_transactions() + await producer.begin_transaction() + await self._transactional_produce(producer, topic, [f'c{i}' for i in range(5)]) + await producer.commit_transaction() + + await consumer.subscribe([topic]) + seen = [] + for _ in range(20): + msg = await consumer.poll(timeout=1.0) + if not msg or msg.error(): + continue + seen.append(msg.value().decode('utf-8')) + if len(seen) >= 5: + break + assert len(seen) == 5, f"expected 5 committed messages, got {len(seen)}" + finally: + await consumer.close() + asyncio.run(run()) + + def test_abort_transaction_then_retry_commit(self): + """Aborted messages must be invisible, and retrying with new transaction must only commit visible results.""" + async def run(): + topic = self._new_topic() + producer = self._create_transactional_producer("async") + consumer = self._create_transactional_consumer("async") + try: + await producer.init_transactions() + + # Abort case + await producer.begin_transaction() + await self._transactional_produce(producer, topic, [f'a{i}' for i in range(3)]) + await producer.abort_transaction() + + await consumer.subscribe([topic]) + aborted_seen = [] + for _ in range(10): + msg = await consumer.poll(timeout=1.0) + if not msg or msg.error(): + continue + val = msg.value().decode('utf-8') + if val.startswith('a'): + aborted_seen.append(val) + assert not aborted_seen, f"aborted messages should be invisible, saw {aborted_seen}" + + # Retry-commit flow + await producer.begin_transaction() + retry_vals = [f'r{i}' for i in range(3)] + await self._transactional_produce(producer, topic, retry_vals) + await producer.commit_transaction() + + # Verify only retry values appear + seen = [] + for _ in range(20): + msg = await consumer.poll(timeout=1.0) + if not msg or msg.error(): + continue + val = msg.value().decode('utf-8') + seen.append(val) + if all(rv in seen for rv in retry_vals): + break + assert all(rv in seen for rv in retry_vals), f"expected retry values {retry_vals}, saw {seen}" + assert all(not s.startswith('a') for s in seen), f"should not see aborted values, saw {seen}" + finally: + await consumer.close() + asyncio.run(run()) + + def test_send_offsets_to_transaction(self): + """Offsets committed atomically with produced results using send_offsets_to_transaction.""" + async def run(): + input_topic = self._new_topic() + output_topic = self._new_topic() + + # Seed input + input_vals = [f'in{i}' for i in range(5)] + await self._seed_topic(input_topic, input_vals) + + producer = self._create_transactional_producer("async") + consumer = self._create_transactional_consumer("async") + try: + await consumer.subscribe([input_topic]) + + # Consume a small batch from input + consumed = [] + for _ in range(20): + msg = await consumer.poll(timeout=1.0) + if not msg or msg.error(): + continue + consumed.append(msg) + if len(consumed) >= 3: + break + assert consumed, "expected to consume at least 1 message from input" + + # Begin transaction: produce results and commit consumer offsets atomically + await producer.init_transactions() + await producer.begin_transaction() + + out_vals = [f'out:{m.value().decode("utf-8")}' for m in consumed] + await self._transactional_produce(producer, output_topic, out_vals) + + assignment = await consumer.assignment() + positions = await consumer.position(assignment) + group_metadata = await consumer.consumer_group_metadata() + + await producer.send_offsets_to_transaction(positions, group_metadata) + await producer.commit_transaction() + + # Verify output has results + out_consumer = self._create_transactional_consumer("async") + try: + await out_consumer.subscribe([output_topic]) + seen = [] + for _ in range(20): + msg = await out_consumer.poll(timeout=1.0) + if not msg or msg.error(): + continue + seen.append(msg.value().decode('utf-8')) + if len(seen) >= len(out_vals): + break + assert set(seen) == set(out_vals), f"expected {out_vals}, saw {seen}" + finally: + await out_consumer.close() + + # Verify committed offsets advanced to positions + committed = await consumer.committed(assignment) + for pos, comm in zip(positions, committed): + assert comm.offset >= pos.offset, f"committed {comm.offset} < position {pos.offset}" + finally: + await consumer.close() + asyncio.run(run()) + + def test_commit_multiple_topics_partitions(self): + """Commit atomically across multiple topics/partitions.""" + async def run(): + topic_a = self._new_topic(partitions=2) + topic_b = self._new_topic(partitions=1) + + producer = self._create_transactional_producer("async") + consumer = self._create_transactional_consumer("async") + try: + await producer.init_transactions() + await producer.begin_transaction() + # Produce across A partitions and B + await self._transactional_produce(producer, topic_a, ["a0-p0", "a1-p0"], partition=0) + await self._transactional_produce(producer, topic_a, ["a0-p1", "a1-p1"], partition=1) + await self._transactional_produce(producer, topic_b, ["b0", "b1"]) # default partition + await producer.commit_transaction() + + await consumer.subscribe([topic_a, topic_b]) + expected = {"a0-p0", "a1-p0", "a0-p1", "a1-p1", "b0", "b1"} + seen = set() + for _ in range(30): + msg = await consumer.poll(timeout=1.0) + if not msg or msg.error(): + continue + seen.add(msg.value().decode('utf-8')) + if seen == expected: + break + assert seen == expected, f"expected {expected}, saw {seen}" + finally: + await consumer.close() + asyncio.run(run()) diff --git a/tests/test_AIOConsumer.py b/tests/test_AIOConsumer.py index 322ae5a50..443de0c84 100644 --- a/tests/test_AIOConsumer.py +++ b/tests/test_AIOConsumer.py @@ -149,7 +149,7 @@ async def test_concurrent_operations_error_handling(self, mock_consumer, mock_co assert len(results) == 2 assert isinstance(results[0], KafkaException) assert isinstance(results[1], KafkaException) - + @pytest.mark.asyncio async def test_network_error_handling(self, mock_consumer, mock_common, basic_config): """Test AIOConsumer handles network errors gracefully."""