diff --git a/actix-limitation/CHANGES.md b/actix-limitation/CHANGES.md index 88ef3ed38f..51e9d45613 100644 --- a/actix-limitation/CHANGES.md +++ b/actix-limitation/CHANGES.md @@ -1,6 +1,8 @@ # Changes -## Unreleased +## Unreleased - 2023-xx-xx + +- Added optional scopes to the middleware enabling use of multiple Limiters by passing an `HashMap` to the Http server `app_data` ## 0.5.1 diff --git a/actix-limitation/Cargo.toml b/actix-limitation/Cargo.toml index 90e6e9c554..c92e813b11 100644 --- a/actix-limitation/Cargo.toml +++ b/actix-limitation/Cargo.toml @@ -38,3 +38,4 @@ actix-session = { version = "0.8", optional = true } actix-web = "4" static_assertions = "1" uuid = { version = "1", features = ["v4"] } +pretty_env_logger = "0.5" \ No newline at end of file diff --git a/actix-limitation/examples/README.md b/actix-limitation/examples/README.md new file mode 100644 index 0000000000..5e06ce61c6 --- /dev/null +++ b/actix-limitation/examples/README.md @@ -0,0 +1,39 @@ +# Examples + +We leverage redis to store state of the ratelimiting. +So you will need to have a redis instance available on localhost. + +You can start this redis instance with Docker: +``` +docker run -d -p 6379:6379 --name limiter-redis redis +# Clean up: you can rm the docker this way +# docker rm -f limiter-redis +``` + + +## scoped_limiters + +This example present how to use multiple limiters. +This allow different configurations and the ability to scope them. + +### Starting the example server + +```bash +RUST_LOG=debug cargo run --example scoped_limiters +``` +> RUST_LOG=debug is used to print logs, see crate pretty_env_logger for more details. + +### Testing with curl + +```bash +curl -X PUT localhost:8080/scoped/sms -v +``` +first request should work fine +doing a second request within 60 seconds should yield `HTTP/1.1 429 Too Many Requests` +after 60 seconds you should be able to make 1 request again + + +```bash +curl localhost:8080 +``` +This route should work 30 times, or 29 if you previously requested the /scoped/sms route diff --git a/actix-limitation/examples/scoped_limiters.rs b/actix-limitation/examples/scoped_limiters.rs new file mode 100644 index 0000000000..714422a0f2 --- /dev/null +++ b/actix-limitation/examples/scoped_limiters.rs @@ -0,0 +1,82 @@ +use std::{collections::HashMap, time::Duration}; + +use actix_limitation::{Limiter, RateLimiter}; +use actix_web::{dev::ServiceRequest, get, put, web, App, HttpServer, Responder}; +use redis::Client; + +#[get("/")] +async fn index() -> impl Responder { + "index" +} + +#[put("/sms")] +async fn send_sms() -> impl Responder { + "sending an expensive sms" +} + +#[actix_web::main] +async fn main() -> std::io::Result<()> { + pretty_env_logger::init(); + + // Create an Hashmap to store the multiples [Limiter](Limiter) + let mut limiters = HashMap::new(); + + // Create and connect a redis Client. + let redis_client = Client::open("redis://127.0.0.1/").expect("creation of the redis client"); + + // Create a default limiter + let default_limiter = Limiter::builder_with_redis_client(redis_client.clone()) + // specifying with key_by that we take the user IP address as a identifier. + .key_by(|req: &ServiceRequest| { + req.connection_info() + .realip_remote_addr() + .map(|ip| ip.to_string()) + }) + // Allowing a maximum of 30 requests per minute + .limit(30) + .period(Duration::from_secs(60)) + .build() + .unwrap(); + limiters.insert("default", default_limiter); + + let scope_limiter = Limiter::builder_with_redis_client(redis_client) + .key_by(|req: &ServiceRequest| { + req.connection_info() + .realip_remote_addr() + // ⚠️ we prepend "scoped" to the key in order to isolate this count from the default count + // + // If we were using the same key, a request to this route would always return too many requests + // in this context because the default limiter at the root would be reached first and would count 1 before we check for this. + // To mitigate this issue you could also specify a different namespace with the redis_client passed as parameter: `redis://127.0.0.1/2` + .map(|ip| format!("scoped-{}", ip)) + }) + // Allowing only 1 request per minute + .limit(1) + .period(Duration::from_secs(60)) + .build() + .unwrap(); + limiters.insert("scoped", scope_limiter); + + // Passing this limiters as app_data so it can be accessed by the middleware. + let limiters = web::Data::new(limiters); + HttpServer::new(move || { + App::new() + // Using the default limiter for all the routes + // ⚠️ This limiter will count and apply the limits before the one in "/scoped" + .wrap(RateLimiter::scoped("default")) + .app_data(limiters.clone()) + .service( + web::scope("/scoped") + // Wrapping only for this scope the scoped limiter + .wrap(RateLimiter::scoped("scoped")) + // This route will only be available 1 time every minutes + // Note: the root limiter default will also limit this route + .service(send_sms), + ) + // This route is only limited by the default limiter + .service(index) + }) + .bind(("127.0.0.1", 8080))? + .run() + .await +} diff --git a/actix-limitation/src/builder.rs b/actix-limitation/src/builder.rs index d6053f5ef9..8b265bd148 100644 --- a/actix-limitation/src/builder.rs +++ b/actix-limitation/src/builder.rs @@ -7,10 +7,19 @@ use redis::Client; use crate::{errors::Error, GetArcBoxKeyFn, Limiter}; +/// [RedisConnectionKind] is used to define which connection parameter for the Redis server will be passed +/// It can be an Url or a Client +/// This is done so we can use the same client for multiple Limiters +#[derive(Debug, Clone)] +pub enum RedisConnectionKind { + Url(String), + Client(Client), +} + /// Rate limiter builder. #[derive(Debug)] pub struct Builder { - pub(crate) redis_url: String, + pub(crate) redis_connection: RedisConnectionKind, pub(crate) limit: usize, pub(crate) period: Duration, pub(crate) get_key_fn: Option, @@ -96,8 +105,12 @@ impl Builder { closure }; + let client = match &self.redis_connection { + RedisConnectionKind::Url(url) => Client::open(url.as_str())?, + RedisConnectionKind::Client(client) => client.clone(), + }; Ok(Limiter { - client: Client::open(self.redis_url.as_str())?, + client, limit: self.limit, period: self.period, get_key_fn: get_key, @@ -109,12 +122,25 @@ impl Builder { mod tests { use super::*; + /// Implementing partial Eq to check if builder assigned the redis connection url correctly + /// We can't / shouldn't compare Redis clients, thus the method panic if we try to compare them + impl PartialEq for RedisConnectionKind { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Url(l_url), Self::Url(r_url)) => l_url == r_url, + _ => { + panic!("RedisConnectionKind PartialEq is only implemented for Url") + } + } + } + } + #[test] fn test_create_builder() { - let redis_url = "redis://127.0.0.1"; + let redis_connection = RedisConnectionKind::Url("redis://127.0.0.1".to_string()); let period = Duration::from_secs(10); let builder = Builder { - redis_url: redis_url.to_owned(), + redis_connection: redis_connection.clone(), limit: 100, period, get_key_fn: Some(Arc::new(|_| None)), @@ -123,7 +149,7 @@ mod tests { session_key: Cow::Owned("rate-api".to_string()), }; - assert_eq!(builder.redis_url, redis_url); + assert_eq!(builder.redis_connection, redis_connection); assert_eq!(builder.limit, 100); assert_eq!(builder.period, period); #[cfg(feature = "session")] @@ -133,10 +159,10 @@ mod tests { #[test] fn test_create_limiter() { - let redis_url = "redis://127.0.0.1"; + let redis_connection = RedisConnectionKind::Url("redis://127.0.0.1".to_string()); let period = Duration::from_secs(20); let mut builder = Builder { - redis_url: redis_url.to_owned(), + redis_connection, limit: 100, period: Duration::from_secs(10), get_key_fn: Some(Arc::new(|_| None)), @@ -154,10 +180,10 @@ mod tests { #[test] #[should_panic = "Redis URL did not parse"] fn test_create_limiter_error() { - let redis_url = "127.0.0.1"; + let redis_connection = RedisConnectionKind::Url("127.0.0.1".to_string()); let period = Duration::from_secs(20); let mut builder = Builder { - redis_url: redis_url.to_owned(), + redis_connection, limit: 100, period: Duration::from_secs(10), get_key_fn: Some(Arc::new(|_| None)), diff --git a/actix-limitation/src/lib.rs b/actix-limitation/src/lib.rs index 27d79d7b0e..04f0be9380 100644 --- a/actix-limitation/src/lib.rs +++ b/actix-limitation/src/lib.rs @@ -7,7 +7,7 @@ //! ``` //! //! ```no_run -//! use std::{sync::Arc, time::Duration}; +//! use std::{time::Duration}; //! use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder}; //! use actix_session::SessionExt as _; //! use actix_limitation::{Limiter, RateLimiter}; @@ -23,8 +23,8 @@ //! Limiter::builder("redis://127.0.0.1") //! .key_by(|req: &ServiceRequest| { //! req.get_session() -//! .get(&"session-id") -//! .unwrap_or_else(|_| req.cookie(&"rate-api-id").map(|c| c.to_string())) +//! .get("session-id") +//! .unwrap_or_else(|_| req.cookie("rate-api-id").map(|c| c.to_string())) //! }) //! .limit(5000) //! .period(Duration::from_secs(3600)) // 60 minutes @@ -54,6 +54,7 @@ use std::{borrow::Cow, fmt, sync::Arc, time::Duration}; use actix_web::dev::ServiceRequest; +use builder::RedisConnectionKind; use redis::Client; mod builder; @@ -107,11 +108,27 @@ impl Limiter { /// Construct rate limiter builder with defaults. /// /// See [`redis-rs` docs](https://docs.rs/redis/0.21/redis/#connection-parameters) on connection - /// parameters for how to set the Redis URL. + /// redis_url parameter is used for connecting to Redis. #[must_use] pub fn builder(redis_url: impl Into) -> Builder { Builder { - redis_url: redis_url.into(), + redis_connection: RedisConnectionKind::Url(redis_url.into()), + limit: DEFAULT_REQUEST_LIMIT, + period: Duration::from_secs(DEFAULT_PERIOD_SECS), + get_key_fn: None, + cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME), + #[cfg(feature = "session")] + session_key: Cow::Borrowed(DEFAULT_SESSION_KEY), + } + } + + /// Construct rate limiter builder with defaults. + /// + /// parameters for how to set the Redis URL. + #[must_use] + pub fn builder_with_redis_client(redis_client: Client) -> Builder { + Builder { + redis_connection: RedisConnectionKind::Client(redis_client), limit: DEFAULT_REQUEST_LIMIT, period: Duration::from_secs(DEFAULT_PERIOD_SECS), get_key_fn: None, @@ -177,4 +194,75 @@ mod tests { assert_eq!(limiter.limit, 5000); assert_eq!(limiter.period, Duration::from_secs(3600)); } + + use std::{collections::HashMap, time::Duration}; + + use actix_web::{ + dev::ServiceRequest, get, http::StatusCode, test as actix_test, web, Responder, + }; + + #[actix_web::test] + async fn test_create_scoped_limiter() { + #[get("/")] + async fn index() -> impl Responder { + "index" + } + + let mut limiters = HashMap::new(); + let redis_client = + Client::open("redis://127.0.0.1/2").expect("unable to create redis client"); + limiters.insert( + "default", + Limiter::builder_with_redis_client(redis_client.clone()) + .key_by(|_req: &ServiceRequest| Some("something_default".to_string())) + .limit(5000) + .period(Duration::from_secs(60)) + .build() + .unwrap(), + ); + limiters.insert( + "scoped", + Limiter::builder_with_redis_client(redis_client) + .key_by(|_req: &ServiceRequest| Some("something_scoped".to_string())) + .limit(1) + .period(Duration::from_secs(60)) + .build() + .unwrap(), + ); + let limiters = web::Data::new(limiters); + + let app = actix_web::test::init_service( + actix_web::App::new() + .wrap(RateLimiter::scoped("default")) + .app_data(limiters.clone()) + .service( + web::scope("/scoped") + .wrap(RateLimiter::scoped("scoped")) + .service(index), + ) + .service(index), + ) + .await; + + for _ in 0..3 { + let req = actix_test::TestRequest::get().uri("/").to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK, "{:#?}", resp); + } + for request_count in 0..3 { + let req = actix_test::TestRequest::get().uri("/scoped/").to_request(); + let resp = actix_test::call_service(&app, req).await; + + assert_eq!( + resp.status(), + if request_count > 0 { + StatusCode::TOO_MANY_REQUESTS + } else { + StatusCode::OK + }, + "{:#?}", + resp + ); + } + } } diff --git a/actix-limitation/src/middleware.rs b/actix-limitation/src/middleware.rs index 290f075c1e..c968204d0a 100644 --- a/actix-limitation/src/middleware.rs +++ b/actix-limitation/src/middleware.rs @@ -1,4 +1,4 @@ -use std::{future::Future, pin::Pin, rc::Rc}; +use std::{collections::HashMap, future::Future, pin::Pin, rc::Rc}; use actix_utils::future::{ok, Ready}; use actix_web::{ @@ -11,9 +11,23 @@ use actix_web::{ use crate::{Error as LimitationError, Limiter}; /// Rate limit middleware. +/// +/// Use the `scope` variable to define multiple limiter #[derive(Debug, Default)] #[non_exhaustive] -pub struct RateLimiter; +pub struct RateLimiter { + /// Used to define multiple limiter, with different configurations + /// + /// WARNING: When used (not None) the middleware will expect a `HashMap` in the actix-web `app_data` + pub scope: Option<&'static str>, +} + +impl RateLimiter { + /// Construct the rate limiter with a scope + pub fn scoped(scope: &'static str) -> Self { + RateLimiter { scope: Some(scope) } + } +} impl Transform for RateLimiter where @@ -30,6 +44,7 @@ where fn new_transform(&self, service: S) -> Self::Future { ok(RateLimiterMiddleware { service: Rc::new(service), + scope: self.scope, }) } } @@ -38,6 +53,7 @@ where #[derive(Debug)] pub struct RateLimiterMiddleware { service: Rc, + scope: Option<&'static str>, } impl Service for RateLimiterMiddleware @@ -55,10 +71,21 @@ where fn call(&self, req: ServiceRequest) -> Self::Future { // A mis-configuration of the Actix App will result in a **runtime** failure, so the expect // method description is important context for the developer. - let limiter = req - .app_data::>() - .expect("web::Data should be set in app data for RateLimiter middleware") - .clone(); + let limiter = if let Some(scope) = self.scope { + let limiters = req.app_data::>>().expect( + "web::Data> should be set in app data for RateLimiter middleware", + ); + limiters + .get(scope) + .unwrap_or_else(|| panic!("Unable to find defined limiter with scope: {}", scope)) + .clone() + } else { + let limiter = req + .app_data::>() + .expect("web::Data should be set in app data for RateLimiter middleware"); + // Deref to get the Limiter + (***limiter).clone() + }; let key = (limiter.get_key_fn)(&req); let service = Rc::clone(&self.service);