diff --git a/example/Cargo.toml b/example/Cargo.toml index 50dbaf11..64d537fa 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -23,7 +23,7 @@ ctrlc = { version = "3.0", features = ["termination"] } tokio = { version = "1.0.1", features = ["signal", "time"] } async-trait = "0.1.42" rand = "0.8.5" - +clap = { version = "4.5.40", features = ["derive"] } [[example]] name = "client" diff --git a/example/async-client.rs b/example/async-client.rs index 7e3d93e8..39f0439d 100644 --- a/example/async-client.rs +++ b/example/async-client.rs @@ -11,7 +11,9 @@ use ttrpc::r#async::Client; #[tokio::main(flavor = "current_thread")] async fn main() { - let c = Client::connect(utils::SOCK_ADDR).await.unwrap(); + let sock_addr = utils::get_sock_addr(); + let c = Client::connect(sock_addr).await.unwrap(); + let hc = health_ttrpc::HealthClient::new(c.clone()); let ac = agent_ttrpc::AgentServiceClient::new(c); diff --git a/example/async-server.rs b/example/async-server.rs index e9acdd6e..8043c57f 100644 --- a/example/async-server.rs +++ b/example/async-server.rs @@ -58,10 +58,11 @@ async fn main() { let hservice = health_ttrpc::create_health(Arc::new(HealthService {})); let aservice = agent_ttrpc::create_agent_service(Arc::new(AgentService {})); - utils::remove_if_sock_exist(utils::SOCK_ADDR).unwrap(); + let sock_addr = utils::get_sock_addr(); + utils::remove_if_sock_exist(sock_addr).unwrap(); let mut server = Server::new() - .bind(utils::SOCK_ADDR) + .bind(sock_addr) .unwrap() .register_service(hservice) .register_service(aservice); diff --git a/example/async-stream-client.rs b/example/async-stream-client.rs index 12db16d4..7e6bbdb0 100644 --- a/example/async-stream-client.rs +++ b/example/async-stream-client.rs @@ -13,7 +13,9 @@ use ttrpc::r#async::Client; async fn main() { simple_logging::log_to_stderr(log::LevelFilter::Info); - let c = Client::connect(utils::SOCK_ADDR).await.unwrap(); + let sock_addr = utils::get_sock_addr(); + let c = Client::connect(sock_addr).await.unwrap(); + let sc = streaming_ttrpc::StreamingClient::new(c); let _now = std::time::Instant::now(); diff --git a/example/async-stream-server.rs b/example/async-stream-server.rs index 41e3f47b..005ff312 100644 --- a/example/async-stream-server.rs +++ b/example/async-stream-server.rs @@ -171,10 +171,12 @@ impl streaming_ttrpc::Streaming for StreamingService { async fn main() { simple_logging::log_to_stderr(LevelFilter::Info); let service = streaming_ttrpc::create_streaming(Arc::new(StreamingService {})); - utils::remove_if_sock_exist(utils::SOCK_ADDR).unwrap(); + + let sock_addr = utils::get_sock_addr(); + utils::remove_if_sock_exist(sock_addr).unwrap(); let mut server = Server::new() - .bind(utils::SOCK_ADDR) + .bind(sock_addr) .unwrap() .register_service(service); diff --git a/example/client.rs b/example/client.rs index 87b979fb..095c4fac 100644 --- a/example/client.rs +++ b/example/client.rs @@ -50,7 +50,9 @@ fn main() { fn connect_once() { simple_logging::log_to_stderr(LevelFilter::Trace); - let c = Client::connect(utils::SOCK_ADDR).unwrap(); + let sock_addr = utils::get_sock_addr(); + let c = Client::connect(sock_addr).unwrap(); + let hc = health_ttrpc::HealthClient::new(c.clone()); let ac = agent_ttrpc::AgentServiceClient::new(c); diff --git a/example/server.rs b/example/server.rs index 703c2cbc..37535486 100644 --- a/example/server.rs +++ b/example/server.rs @@ -58,9 +58,11 @@ fn main() { let hservice = health_ttrpc::create_health(Arc::new(HealthService {})); let aservice = agent_ttrpc::create_agent_service(Arc::new(AgentService {})); - utils::remove_if_sock_exist(utils::SOCK_ADDR).unwrap(); + let sock_addr = utils::get_sock_addr(); + utils::remove_if_sock_exist(sock_addr).unwrap(); + let mut server = ttrpc::Server::new() - .bind(utils::SOCK_ADDR) + .bind(sock_addr) .unwrap() .register_service(hservice) .register_service(aservice); diff --git a/example/utils.rs b/example/utils.rs index aa845bf9..10a7f3cd 100644 --- a/example/utils.rs +++ b/example/utils.rs @@ -1,14 +1,42 @@ #![allow(dead_code)] +use clap::Parser; +use log::warn; use std::io::Result; #[cfg(unix)] -pub const SOCK_ADDR: &str = r"unix:///tmp/ttrpc-test"; +pub const SOCK_ADDR_LOCAL: &str = r"unix:///tmp/ttrpc-test"; #[cfg(windows)] -pub const SOCK_ADDR: &str = r"\\.\pipe\ttrpc-test"; +pub const SOCK_ADDR_LOCAL: &str = r"\\.\pipe\ttrpc-test"; + +pub const SOCK_ADDR_TCP: &str = r"tcp://127.0.0.1:65500"; + +#[derive(Debug, Default, Parser)] +pub struct Cli { + #[arg(long = "tcp")] + #[arg(help = "Use a TCP socket instead of a local one")] + pub tcp: bool, +} + +pub fn get_sock_addr() -> &'static str { + let cli = Cli::parse(); + if cli.tcp { + if cfg!(windows) { + warn!("'--tcp' flag ignored; TCP sockets not supported on Windows"); + return SOCK_ADDR_LOCAL; + } + SOCK_ADDR_TCP + } else { + SOCK_ADDR_LOCAL + } +} #[cfg(unix)] pub fn remove_if_sock_exist(sock_addr: &str) -> Result<()> { + if sock_addr.starts_with("tcp://") { + return Ok(()); + } + let path = sock_addr .strip_prefix("unix://") .expect("socket address is not expected"); diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index bbf3329b..dca04e33 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -101,6 +101,15 @@ impl Server { Ok(self.add_listener(listener)) } + #[cfg(unix)] + /// # Safety + /// The file descriptor must represent a unix listener. + pub unsafe fn add_tcp_listener(self, fd: RawFd) -> Result { + let listener = Listener::from_raw_tcp_listener_fd(fd) + .map_err(err_to_others_err!(e, "from_raw_tcp_listener_fd error"))?; + Ok(self.add_listener(listener)) + } + #[cfg(any(target_os = "linux", target_os = "android"))] /// # Safety /// The file descriptor must represent a vsock listener. diff --git a/src/asynchronous/transport/mod.rs b/src/asynchronous/transport/mod.rs index 57621d45..f340c37f 100644 --- a/src/asynchronous/transport/mod.rs +++ b/src/asynchronous/transport/mod.rs @@ -22,6 +22,9 @@ macro_rules! io_other { #[cfg(unix)] mod unix; +#[cfg(unix)] +mod tcp; + #[cfg(any(target_os = "linux", target_os = "android"))] mod vsock; @@ -43,6 +46,11 @@ impl Listener { return Self::bind_unix(addr); } + #[cfg(unix)] + if let Some(addr) = addr.strip_prefix("tcp://") { + return Self::bind_tcp(addr); + } + #[cfg(any(target_os = "linux", target_os = "android"))] if let Some(addr) = addr.strip_prefix("vsock://") { return Self::bind_vsock(addr); @@ -70,6 +78,11 @@ impl Socket { return Self::connect_unix(addr).await; } + #[cfg(unix)] + if let Some(addr) = addr.strip_prefix("tcp://") { + return Self::connect_tcp(addr).await; + } + #[cfg(any(target_os = "linux", target_os = "android"))] if let Some(addr) = addr.strip_prefix("vsock://") { return Self::connect_vsock(addr).await; diff --git a/src/asynchronous/transport/tcp.rs b/src/asynchronous/transport/tcp.rs new file mode 100644 index 00000000..25eafd8d --- /dev/null +++ b/src/asynchronous/transport/tcp.rs @@ -0,0 +1,80 @@ +use std::convert::TryFrom; +use std::io::{Error as IoError, Result as IoResult}; +use std::os::fd::{FromRawFd as _, RawFd}; +use std::net::{ + SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, +}; + +use async_stream::stream; +use tokio::net::{TcpListener, TcpStream}; + +use super::{Listener, Socket}; + +impl Listener { + pub fn bind_tcp(addr: impl AsRef) -> IoResult { + let addr = parse_tcp_addr(addr)?; + let listener = StdTcpListener::bind(addr)?; + Self::try_from(listener) + } + + /// # Safety + /// The file descriptor must represent a tcp listener. + pub unsafe fn from_raw_tcp_listener_fd(fd: std::os::fd::RawFd) -> IoResult { + let listener = unsafe { StdTcpListener::from_raw_fd(fd) }; + Self::try_from(listener) + } +} + +impl Socket { + pub async fn connect_tcp(addr: impl AsRef) -> IoResult { + let addr = parse_tcp_addr(addr)?; + let socket = StdTcpStream::connect(addr)?; + Self::try_from(socket) + } + + /// # Safety + /// The file descriptor must represent a tcp socket. + pub unsafe fn from_raw_tcp_socket_fd(fd: RawFd) -> IoResult { + let socket = unsafe { StdTcpStream::from_raw_fd(fd) }; + Self::try_from(socket) + } +} + +impl From for Listener { + fn from(listener: TcpListener) -> Self { + Self::new(stream! { + loop { + yield listener.accept().await.map(|(socket, _)| socket); + } + }) + } +} + +impl TryFrom for Listener { + type Error = IoError; + fn try_from(listener: StdTcpListener) -> IoResult { + listener.set_nonblocking(true)?; + Ok(Self::from(TcpListener::from_std(listener)?)) + } +} + +impl From for Socket { + fn from(socket: TcpStream) -> Self { + Self::new(socket) + } +} + +impl TryFrom for Socket { + type Error = IoError; + fn try_from(socket: StdTcpStream) -> IoResult { + socket.set_nonblocking(true)?; + Ok(Self::from(TcpStream::from_std(socket)?)) + } +} + +fn parse_tcp_addr(addr: impl AsRef) -> IoResult { + let addr = addr.as_ref(); + + addr.parse::() + .map_err(|e| io_other!("Failed to parse TCP address '{}': {}", addr, e)) +} diff --git a/src/common.rs b/src/common.rs index f7ae9ea2..abba6022 100644 --- a/src/common.rs +++ b/src/common.rs @@ -8,13 +8,15 @@ use nix::fcntl::{fcntl, FcntlArg, OFlag}; use nix::sys::socket::*; -use std::os::unix::io::RawFd; +use std::str::FromStr; +use std::{env, os::unix::io::RawFd}; use crate::error::{Error, Result}; #[derive(Debug, Clone, Copy, PartialEq)] pub(crate) enum Domain { Unix, + Tcp, #[cfg(any(target_os = "linux", target_os = "android"))] Vsock, } @@ -39,6 +41,10 @@ fn parse_sockaddr(addr: &str) -> Result<(Domain, &str)> { return Ok((Domain::Vsock, addr)); } + if let Some(addr) = addr.strip_prefix("tcp://") { + return Ok((Domain::Tcp, addr)); + } + Err(Error::Others(format!("Scheme {addr:?} is not supported"))) } @@ -53,6 +59,10 @@ fn parse_sockaddr(addr: &str) -> Result<(Domain, &str)> { return Ok((Domain::Unix, addr)); } + if let Some(addr) = addr.strip_prefix("tcp://") { + return Ok((Domain::Tcp, addr)); + } + Err(Error::Others(format!("Scheme {addr:?} is not supported"))) } @@ -83,8 +93,8 @@ fn make_addr(domain: Domain, sockaddr: &str) -> Result { UnixAddr::new(sockaddr).map_err(err_to_others_err!(e, "")) } } - Domain::Vsock => Err(Error::Others( - "function make_addr does not support create vsock socket".to_string(), + Domain::Vsock | Domain::Tcp => Err(Error::Others( + "function make_addr does not support create vsock/tcp socket".to_string(), )), } } @@ -130,7 +140,7 @@ fn parse_vscok(addr: &str) -> Result<(u32, u32)> { fn make_socket(sockaddr: &str) -> Result<(RawFd, Domain, Box)> { let (domain, sockaddrv) = parse_sockaddr(sockaddr)?; - let get_sock_addr = |domain, sockaddr| -> Result<(RawFd, Box)> { + let get_unix_addr = |domain, sockaddr| -> Result<(RawFd, Box)> { let fd = socket(AddressFamily::Unix, SockType::Stream, SOCK_CLOEXEC, None) .map_err(|e| Error::Socket(e.to_string()))?; @@ -141,9 +151,20 @@ fn make_socket(sockaddr: &str) -> Result<(RawFd, Domain, Box)> let sockaddr = make_addr(domain, sockaddr)?; Ok((fd, Box::new(sockaddr))) }; + let get_tcp_addr = |sockaddr: &str| -> Result<(RawFd, Box)> { + let fd = socket(AddressFamily::Inet, SockType::Stream, SOCK_CLOEXEC, None) + .map_err(|e| Error::Socket(e.to_string()))?; + + #[cfg(target_os = "macos")] + set_fd_close_exec(fd)?; + let sockaddr = SockaddrIn::from_str(sockaddr).map_err(err_to_others_err!(e, ""))?; + + Ok((fd, Box::new(sockaddr))) + }; let (fd, sockaddr): (i32, Box) = match domain { - Domain::Unix => get_sock_addr(domain, sockaddrv)?, + Domain::Unix => get_unix_addr(domain, sockaddrv)?, + Domain::Tcp => get_tcp_addr(sockaddrv)?, #[cfg(any(target_os = "linux", target_os = "android"))] Domain::Vsock => { let (cid, port) = parse_vscok(sockaddrv)?; @@ -162,9 +183,31 @@ fn make_socket(sockaddr: &str) -> Result<(RawFd, Domain, Box)> Ok((fd, domain, sockaddr)) } +fn set_socket_opts(fd: RawFd, domain: Domain, is_bind: bool) -> Result<()> { + if domain != Domain::Tcp { + return Ok(()); + } + + if is_bind { + setsockopt(fd, sockopt::ReusePort, &true)?; + } + + let tcp_nodelay_enabled = match env::var("TTRPC_TCP_NODELAY_ENABLED") { + Ok(val) if val == "1" || val.eq_ignore_ascii_case("true") => true, + Ok(val) if val == "0" || val.eq_ignore_ascii_case("false") => false, + _ => false, + }; + if tcp_nodelay_enabled { + setsockopt(fd, sockopt::TcpNoDelay, &true)?; + } + + Ok(()) +} + pub(crate) fn do_bind(sockaddr: &str) -> Result<(RawFd, Domain)> { let (fd, domain, sockaddr) = make_socket(sockaddr)?; + set_socket_opts(fd, domain, true)?; bind(fd, sockaddr.as_ref()).map_err(err_to_others_err!(e, ""))?; Ok((fd, domain)) @@ -172,8 +215,9 @@ pub(crate) fn do_bind(sockaddr: &str) -> Result<(RawFd, Domain)> { /// Creates a unix socket for client. pub(crate) unsafe fn client_connect(sockaddr: &str) -> Result { - let (fd, _, sockaddr) = make_socket(sockaddr)?; + let (fd, domain, sockaddr) = make_socket(sockaddr)?; + set_socket_opts(fd, domain, false)?; connect(fd, sockaddr.as_ref())?; Ok(fd) @@ -202,6 +246,12 @@ mod tests { true, ), ("abc:///run/c.sock", None, "", false), + ( + "tcp://127.0.0.1:65500", + Some(Domain::Tcp), + "127.0.0.1:65500", + true, + ), ] { let (input, domain, addr, success) = (i.0, i.1, i.2, i.3); let r = parse_sockaddr(input); @@ -229,6 +279,12 @@ mod tests { ("Vsock:///run/c.sock", None, "", false), ("unix://@/run/b.sock", None, "", false), ("abc:///run/c.sock", None, "", false), + ( + "tcp://127.0.0.1:65500", + Some(Domain::Tcp), + "127.0.0.1:65500", + true, + ), ] { let (input, domain, addr, success) = (i.0, i.1, i.2, i.3); let r = parse_sockaddr(input); diff --git a/tests/run-examples.rs b/tests/run-examples.rs index 459700f0..d3c04b23 100644 --- a/tests/run-examples.rs +++ b/tests/run-examples.rs @@ -4,12 +4,16 @@ use std::{ time::Duration, }; -fn run_example(server: &str, client: &str) -> Result<(), Box> { +fn run_example( + server: &str, + client: &str, + args: &[&str], +) -> Result<(), Box> { // start the server and give it a moment to start. - let mut server = do_run_example(server).spawn().unwrap(); + let mut server = do_run_example(server, args).spawn().unwrap(); std::thread::sleep(Duration::from_secs(2)); - let mut client = do_run_example(client).spawn().unwrap(); + let mut client = do_run_example(client, args).spawn().unwrap(); let mut client_succeeded = false; let start = std::time::Instant::now(); let timeout = Duration::from_secs(600); @@ -55,12 +59,15 @@ fn run_example(server: &str, client: &str) -> Result<(), Box Command { +fn do_run_example(example: &str, args: &[&str]) -> Command { let mut cmd = Command::new("cargo"); - cmd.arg("run") - .arg("--example") - .arg(example) - .stdout(std::process::Stdio::piped()) + cmd.arg("run").arg("--example").arg(example); + + if !args.is_empty() { + cmd.arg("--").args(args); + } + + cmd.stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) .current_dir("example"); cmd @@ -83,9 +90,18 @@ fn wait_with_output(name: &str, cmd: Child) { #[test] fn run_examples() -> Result<(), Box> { - run_example("server", "client")?; - run_example("async-server", "async-client")?; - run_example("async-stream-server", "async-stream-client")?; + // Local + run_example("server", "client", &[])?; + run_example("async-server", "async-client", &[])?; + run_example("async-stream-server", "async-stream-client", &[])?; + + // TCP + #[cfg(not(windows))] + { + run_example("server", "client", &["--tcp"])?; + run_example("async-server", "async-client", &["--tcp"])?; + run_example("async-stream-server", "async-stream-client", &["--tcp"])?; + } Ok(()) }