diff --git a/omniqueue/src/backends/sqs.rs b/omniqueue/src/backends/sqs.rs index f78dc85..79905b8 100644 --- a/omniqueue/src/backends/sqs.rs +++ b/omniqueue/src/backends/sqs.rs @@ -32,6 +32,25 @@ pub struct SqsConfig { /// Whether to override the AWS endpoint URL with the queue DSN. pub override_endpoint: bool, + + /// Message system attributes to request when receiving messages. + /// If not specified, no attributes will be requested. + pub message_attribute_names: Vec, + + /// Optional dead-letter queue configuration for filter failures. + /// When a message fails the filter this many times, it will be sent to the DLQ. + /// This is separate from SQS's native redrive policy, which handles processing failures. + pub dlq_config: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct DeadLetterQueueConfig { + /// The URL of the dead-letter queue. + pub queue_url: String, + + /// The maximum number of times a message can fail the filter + /// before being moved to the dead-letter queue. + pub max_filter_failures: usize, } #[derive(Clone, Debug)] @@ -39,6 +58,8 @@ pub struct SqsConfigFull { queue_dsn: String, override_endpoint: bool, sqs_config: Option, + message_attribute_names: Vec, + dlq_config: Option, } impl SqsConfigFull { @@ -71,11 +92,15 @@ impl From for SqsConfigFull { let SqsConfig { queue_dsn, override_endpoint, + message_attribute_names, + dlq_config, } = cfg; Self { queue_dsn, override_endpoint, sqs_config: None, + message_attribute_names, + dlq_config, } } } @@ -92,6 +117,8 @@ impl From for SqsConfigFull { queue_dsn: dsn, override_endpoint: false, sqs_config: None, + message_attribute_names: Vec::new(), + dlq_config: None, } } } @@ -145,6 +172,8 @@ impl QueueBackend for SqsBackend { let consumer = SqsConsumer { client, queue_dsn: cfg.queue_dsn, + message_attribute_names: cfg.message_attribute_names, + dlq_config: cfg.dlq_config, }; Ok((producer, consumer)) @@ -169,6 +198,8 @@ impl QueueBackend for SqsBackend { let consumer = SqsConsumer { client, queue_dsn: cfg.queue_dsn, + message_attribute_names: cfg.message_attribute_names, + dlq_config: cfg.dlq_config, }; Ok(consumer) @@ -408,10 +439,12 @@ impl crate::ScheduledQueueProducer for SqsProducer { pub struct SqsConsumer { client: Client, queue_dsn: String, + message_attribute_names: Vec, + dlq_config: Option, } impl SqsConsumer { - fn wrap_message(&self, message: &Message) -> Delivery { + pub fn wrap_message(&self, message: &Message) -> Delivery { Delivery::new( message.body().unwrap_or_default().as_bytes().to_owned(), SqsAcker { @@ -423,12 +456,159 @@ impl SqsConsumer { ) } + /// Receives a single message and checks if it matches the given filter predicate. + /// + /// This will receive one message from SQS. If the message matches the filter, + /// it is returned. If it doesn't match, it is nacked (with DLQ handling if configured) + /// and NoData error is returned. + /// + /// # Arguments + /// * `filter` - A predicate function that receives the SQS Message and returns true if it should be consumed + /// * `timeout` - Maximum duration to wait for a message to arrive + /// + /// # Example + /// ```no_run + /// # use omniqueue::backends::sqs::SqsConsumer; + /// # use std::time::Duration; + /// # async fn example(consumer: SqsConsumer) { + /// let delivery = consumer.receive_with_filter( + /// |msg| { + /// msg.message_attributes() + /// .and_then(|attrs| attrs.get("version")) + /// .and_then(|attr| attr.string_value()) + /// .map(|v| v == "2.0") + /// .unwrap_or(false) + /// }, + /// Duration::from_secs(30) + /// ).await; + /// # } + /// ``` + pub async fn receive_with_filter(&self, filter: F, timeout: Duration) -> Result + where + F: Fn(&Message) -> bool, + { + let wait_time = timeout.min(Duration::from_secs(20)); // SQS max wait time + + let mut req = self + .client + .receive_message() + .set_max_number_of_messages(Some(1)) + .set_wait_time_seconds(Some( + wait_time + .as_secs() + .try_into() + .map_err(|_| QueueError::NoData)?, + )) + .queue_url(&self.queue_dsn); + + for attr in &self.message_attribute_names { + req = req.message_system_attribute_names(attr.clone()); + } + + req = req.message_attribute_names("All".to_string()); + + let out = req.send().boxed().await.map_err(aws_to_queue_error)?; + + if let Some(message) = out.messages().first() { + if filter(message) { + return Ok(self.wrap_message(message)); + } else { + // Message doesn't match filter - track filter failures + if let (Some(dlq_config), Some(receipt_handle), Some(body)) = + (&self.dlq_config, message.receipt_handle(), message.body()) + { + // Get current filter failure count from message attributes + let filter_fail_count = message + .message_attributes() + .and_then(|attrs| attrs.get("FilterFailCount")) + .and_then(|attr| attr.string_value()) + .and_then(|v| v.parse::().ok()) + .unwrap_or(0); + + let new_count = filter_fail_count + 1; + + if new_count >= dlq_config.max_filter_failures { + // Send to DLQ (without FilterFailCount attribute) + let _ = self + .client + .send_message() + .queue_url(&dlq_config.queue_url) + .message_body(body) + .send() + .boxed() + .await; + + // Delete from main queue + let _ = self + .client + .delete_message() + .queue_url(&self.queue_dsn) + .receipt_handle(receipt_handle) + .send() + .boxed() + .await; + } else { + // Delete old message + let _ = self + .client + .delete_message() + .queue_url(&self.queue_dsn) + .receipt_handle(receipt_handle) + .send() + .boxed() + .await; + + // Re-send with incremented filter fail count + let _ = self + .client + .send_message() + .queue_url(&self.queue_dsn) + .message_body(body) + .message_attributes( + "FilterFailCount", + aws_sdk_sqs::types::MessageAttributeValue::builder() + .data_type("Number") + .string_value(new_count.to_string()) + .build() + .unwrap(), + ) + .send() + .boxed() + .await; + } + } else { + // No DLQ config - just return to queue + if let Some(receipt_handle) = message.receipt_handle() { + let _ = self + .client + .change_message_visibility() + .queue_url(&self.queue_dsn) + .receipt_handle(receipt_handle) + .visibility_timeout(0) + .send() + .boxed() + .await; + } + } + return Err(QueueError::NoData); + } + } + + Err(QueueError::NoData) + } + pub async fn receive(&self) -> Result { - let out = self + let mut req = self .client .receive_message() .set_max_number_of_messages(Some(1)) - .queue_url(&self.queue_dsn) + .queue_url(&self.queue_dsn); + + for attr in &self.message_attribute_names { + req = req.message_system_attribute_names(attr.clone()); + } + + let out = req .send() // Segment the async state machine. send future is >5kb at the time of writing. .boxed() @@ -447,14 +627,20 @@ impl SqsConsumer { max_messages: usize, deadline: Duration, ) -> Result> { - let out = self + let mut req = self .client .receive_message() .set_wait_time_seconds(Some( deadline.as_secs().try_into().map_err(QueueError::generic)?, )) .set_max_number_of_messages(Some(max_messages.try_into().map_err(QueueError::generic)?)) - .queue_url(&self.queue_dsn) + .queue_url(&self.queue_dsn); + + for attr in &self.message_attribute_names { + req = req.message_system_attribute_names(attr.clone()); + } + + let out = req .send() // Segment the async state machine. send future is >5kb at the time of writing. .boxed() @@ -466,6 +652,179 @@ impl SqsConsumer { .map(|message| -> Result { Ok(self.wrap_message(message)) }) .collect::, _>>() } + + /// Receives multiple messages that match the given filter predicate. + /// + /// This will poll SQS repeatedly until the requested number of matching messages + /// are found, or until the timeout is reached. Messages that don't match the filter + /// are automatically nacked and returned to the queue. + /// + /// # Arguments + /// * `filter` - A predicate function that receives the SQS Message and returns true if it should be consumed + /// * `max_messages` - Maximum number of matching messages to return + /// * `timeout` - Maximum duration to keep polling for matching messages + /// + /// # Example + /// ```no_run + /// # use omniqueue::backends::sqs::SqsConsumer; + /// # use std::time::Duration; + /// # async fn example(consumer: SqsConsumer) { + /// let deliveries = consumer.receive_all_with_filter( + /// |msg| { + /// msg.message_attributes() + /// .and_then(|attrs| attrs.get("priority")) + /// .and_then(|attr| attr.string_value()) + /// .map(|v| v == "high") + /// .unwrap_or(false) + /// }, + /// 10, + /// Duration::from_secs(30) + /// ).await; + /// # } + /// ``` + pub async fn receive_all_with_filter( + &self, + filter: F, + max_messages: usize, + timeout: Duration, + ) -> Result> + where + F: Fn(&Message) -> bool, + { + let deadline = std::time::Instant::now() + timeout; + let mut results = Vec::new(); + + while results.len() < max_messages { + if std::time::Instant::now() >= deadline { + break; + } + + let remaining = deadline.saturating_duration_since(std::time::Instant::now()); + let wait_time = remaining.min(Duration::from_secs(20)); // SQS max wait time + + let batch_size = (max_messages - results.len()).min(10); + + let mut req = self + .client + .receive_message() + .set_max_number_of_messages(Some( + batch_size.try_into().map_err(QueueError::generic)?, + )) + .set_wait_time_seconds(Some( + wait_time + .as_secs() + .try_into() + .map_err(QueueError::generic)?, + )) + .queue_url(&self.queue_dsn); + + for attr in &self.message_attribute_names { + req = req.message_system_attribute_names(attr.clone()); + } + + // Request all message attributes for filtering + req = req.message_attribute_names("All".to_string()); + + let out = req.send().boxed().await.map_err(aws_to_queue_error)?; + + let messages = out.messages(); + if messages.is_empty() { + // No more messages available + break; + } + + for message in messages { + if filter(message) { + // Message matches filter, add to results + results.push(self.wrap_message(message)); + + if results.len() >= max_messages { + break; + } + } else { + // Message doesn't match filter - track filter failures + if let (Some(dlq_config), Some(receipt_handle), Some(body)) = + (&self.dlq_config, message.receipt_handle(), message.body()) + { + // Get current filter failure count from message attributes + let filter_fail_count = message + .message_attributes() + .and_then(|attrs| attrs.get("FilterFailCount")) + .and_then(|attr| attr.string_value()) + .and_then(|v| v.parse::().ok()) + .unwrap_or(0); + + let new_count = filter_fail_count + 1; + + if new_count >= dlq_config.max_filter_failures { + // Send to DLQ (without FilterFailCount attribute) + let _ = self + .client + .send_message() + .queue_url(&dlq_config.queue_url) + .message_body(body) + .send() + .boxed() + .await; + + // Delete from main queue + let _ = self + .client + .delete_message() + .queue_url(&self.queue_dsn) + .receipt_handle(receipt_handle) + .send() + .boxed() + .await; + } else { + // Delete old message + let _ = self + .client + .delete_message() + .queue_url(&self.queue_dsn) + .receipt_handle(receipt_handle) + .send() + .boxed() + .await; + + // Re-send with incremented filter fail count + let _ = self + .client + .send_message() + .queue_url(&self.queue_dsn) + .message_body(body) + .message_attributes( + "FilterFailCount", + aws_sdk_sqs::types::MessageAttributeValue::builder() + .data_type("Number") + .string_value(new_count.to_string()) + .build() + .unwrap(), + ) + .send() + .boxed() + .await; + } + } else { + // No DLQ config - just return to queue + if let Some(receipt_handle) = message.receipt_handle() { + let _ = self + .client + .change_message_visibility() + .queue_url(&self.queue_dsn) + .receipt_handle(receipt_handle) + .visibility_timeout(0) + .send() + .boxed() + .await; + } + } + } + } + } + + Ok(results) + } } impl crate::QueueConsumer for SqsConsumer { diff --git a/omniqueue/tests/it/sqs.rs b/omniqueue/tests/it/sqs.rs index 642f450..be97bcf 100644 --- a/omniqueue/tests/it/sqs.rs +++ b/omniqueue/tests/it/sqs.rs @@ -42,6 +42,8 @@ async fn make_test_queue() -> QueueBuilder { let config = SqsConfig { queue_dsn: format!("{ROOT_URL}/queue/{queue_name}"), override_endpoint: true, + message_attribute_names: Vec::new(), + dlq_config: None, }; SqsBackend::builder(config) @@ -221,3 +223,303 @@ async fn test_scheduled() { assert!(now.elapsed() < delay * 2); assert_eq!(Some(payload1), delivery.payload_serde_json().unwrap()); } + +#[tokio::test] +async fn test_receive_with_filter() { + use aws_sdk_sqs::types::{MessageAttributeValue, MessageSystemAttributeName}; + + let config = aws_config::from_env().endpoint_url(ROOT_URL).load().await; + let client = Client::new(&config); + + let queue_name: String = std::iter::repeat_with(fastrand::alphanumeric) + .take(8) + .collect(); + client + .create_queue() + .queue_name(&queue_name) + .send() + .await + .unwrap(); + + let queue_url = format!("{ROOT_URL}/queue/{queue_name}"); + + // Send messages with different attributes + client + .send_message() + .queue_url(&queue_url) + .message_body("message1") + .message_attributes( + "version", + MessageAttributeValue::builder() + .data_type("String") + .string_value("1.0") + .build() + .unwrap(), + ) + .send() + .await + .unwrap(); + + client + .send_message() + .queue_url(&queue_url) + .message_body("message2") + .message_attributes( + "version", + MessageAttributeValue::builder() + .data_type("String") + .string_value("2.0") + .build() + .unwrap(), + ) + .send() + .await + .unwrap(); + + client + .send_message() + .queue_url(&queue_url) + .message_body("message3") + .message_attributes( + "version", + MessageAttributeValue::builder() + .data_type("String") + .string_value("2.0") + .build() + .unwrap(), + ) + .send() + .await + .unwrap(); + + let config = omniqueue::backends::sqs::SqsConfig { + queue_dsn: queue_url, + override_endpoint: true, + message_attribute_names: vec![MessageSystemAttributeName::ApproximateReceiveCount], + dlq_config: None, + }; + + let (_p, c) = omniqueue::backends::sqs::SqsBackend::builder(config) + .build_pair() + .await + .unwrap(); + + // Filter for version = "2.0" + let delivery = c + .receive_with_filter( + |msg| { + msg.message_attributes() + .and_then(|attrs| attrs.get("version")) + .and_then(|attr| attr.string_value()) + .map(|v| v == "2.0") + .unwrap_or(false) + }, + Duration::from_secs(10), + ) + .await + .unwrap(); + + let body = String::from_utf8(delivery.borrow_payload().unwrap().to_vec()).unwrap(); + assert!(body == "message2" || body == "message3"); + delivery.ack().await.unwrap(); +} + +#[tokio::test] +async fn test_receive_all_with_filter() { + use aws_sdk_sqs::types::{MessageAttributeValue, MessageSystemAttributeName}; + + let config = aws_config::from_env().endpoint_url(ROOT_URL).load().await; + let client = Client::new(&config); + + let queue_name: String = std::iter::repeat_with(fastrand::alphanumeric) + .take(8) + .collect(); + client + .create_queue() + .queue_name(&queue_name) + .send() + .await + .unwrap(); + + let queue_url = format!("{ROOT_URL}/queue/{queue_name}"); + + // Send 5 messages: 2 with priority=high, 3 with priority=low + for i in 0..5 { + let priority = if i < 2 { "high" } else { "low" }; + client + .send_message() + .queue_url(&queue_url) + .message_body(format!("message{}", i)) + .message_attributes( + "priority", + MessageAttributeValue::builder() + .data_type("String") + .string_value(priority) + .build() + .unwrap(), + ) + .send() + .await + .unwrap(); + } + + let config = omniqueue::backends::sqs::SqsConfig { + queue_dsn: queue_url, + override_endpoint: true, + message_attribute_names: vec![MessageSystemAttributeName::ApproximateReceiveCount], + dlq_config: None, + }; + + let (_p, c) = omniqueue::backends::sqs::SqsBackend::builder(config) + .build_pair() + .await + .unwrap(); + + // Filter for priority = "high" + let deliveries = c + .receive_all_with_filter( + |msg| { + msg.message_attributes() + .and_then(|attrs| attrs.get("priority")) + .and_then(|attr| attr.string_value()) + .map(|v| v == "high") + .unwrap_or(false) + }, + 10, + Duration::from_secs(10), + ) + .await + .unwrap(); + + assert_eq!(deliveries.len(), 2); + for delivery in deliveries { + delivery.ack().await.unwrap(); + } +} + +#[tokio::test] +async fn test_filter_with_dlq() { + use aws_sdk_sqs::types::{MessageAttributeValue, MessageSystemAttributeName}; + use omniqueue::backends::sqs::DeadLetterQueueConfig; + + let config = aws_config::from_env().endpoint_url(ROOT_URL).load().await; + let client = Client::new(&config); + + // Create main queue + let main_queue_name: String = std::iter::repeat_with(fastrand::alphanumeric) + .take(8) + .collect(); + client + .create_queue() + .queue_name(&main_queue_name) + .send() + .await + .unwrap(); + + // Create DLQ + let dlq_name = format!("{}_dlq", main_queue_name); + client + .create_queue() + .queue_name(&dlq_name) + .send() + .await + .unwrap(); + + let main_queue_url = format!("{ROOT_URL}/queue/{main_queue_name}"); + let dlq_url = format!("{ROOT_URL}/queue/{dlq_name}"); + + // Send a message that won't match the filter + client + .send_message() + .queue_url(&main_queue_url) + .message_body("wrong_version_message") + .message_attributes( + "version", + MessageAttributeValue::builder() + .data_type("String") + .string_value("1.0") + .build() + .unwrap(), + ) + .send() + .await + .unwrap(); + + let max_filter_failures = 3; + + let config = omniqueue::backends::sqs::SqsConfig { + queue_dsn: main_queue_url.clone(), + override_endpoint: true, + message_attribute_names: vec![MessageSystemAttributeName::ApproximateReceiveCount], + dlq_config: Some(DeadLetterQueueConfig { + queue_url: dlq_url.clone(), + max_filter_failures, + }), + }; + + let (_p, c) = omniqueue::backends::sqs::SqsBackend::builder(config) + .build_pair() + .await + .unwrap(); + + // Helper to check message count in a queue + async fn check_queue_count(client: &Client, queue_url: &str, expected_count: i32) { + tokio::time::sleep(Duration::from_millis(500)).await; + let attrs = client + .get_queue_attributes() + .queue_url(queue_url) + .attribute_names(aws_sdk_sqs::types::QueueAttributeName::ApproximateNumberOfMessages) + .send() + .await + .unwrap(); + + let count = attrs + .attributes() + .and_then(|a| { + a.get(&aws_sdk_sqs::types::QueueAttributeName::ApproximateNumberOfMessages) + }) + .and_then(|v| v.parse::().ok()) + .unwrap_or(0); + + assert_eq!( + count, expected_count, + "Queue {} has {} messages, expected {}", + queue_url, count, expected_count + ); + } + + // Try to receive messages with filter for version="2.0" + // The message with version="1.0" should be rejected and eventually go to DLQ + for i in 0..max_filter_failures { + println!("Filter attempt {}/{}", i + 1, max_filter_failures); + + let result = c + .receive_with_filter( + |msg| { + msg.message_attributes() + .and_then(|attrs| attrs.get("version")) + .and_then(|attr| attr.string_value()) + .map(|v| v == "2.0") + .unwrap_or(false) + }, + Duration::from_secs(3), + ) + .await; + + // Should timeout since no matching message + assert!( + result.is_err(), + "Expected no matching message on attempt {}", + i + 1 + ); + + // Wait for message to be re-queued + tokio::time::sleep(Duration::from_millis(500)).await; + } + + // After max_filter_failures attempts, message should be in DLQ + check_queue_count(&client, &dlq_url, 1).await; + + // Verify message is no longer in main queue + check_queue_count(&client, &main_queue_url, 0).await; +}