diff --git a/examples/rtu-client.rs b/examples/rtu-client.rs index 9896dadc..94c60ffe 100644 --- a/examples/rtu-client.rs +++ b/examples/rtu-client.rs @@ -3,25 +3,45 @@ //! Asynchronous RTU client example -#[tokio::main(flavor = "current_thread")] -async fn main() -> Result<(), Box> { - use tokio_serial::SerialStream; +use tokio_modbus::{prelude::*, Address, Quantity, Slave}; +use tokio_serial::SerialStream; + +const SERIAL_PATH: &str = "/dev/ttyUSB0"; - use tokio_modbus::prelude::*; +const BAUD_RATE: u32 = 19_200; - let tty_path = "/dev/ttyUSB0"; - let slave = Slave(0x17); +const SERVER: Slave = Slave(0x17); - let builder = tokio_serial::new(tty_path, 19200); - let port = SerialStream::open(&builder).unwrap(); +const SENSOR_ADDRESS: Address = 0x082B; - let mut ctx = rtu::attach_slave(port, slave); - println!("Reading a sensor value"); - let rsp = ctx.read_holding_registers(0x082B, 2).await??; - println!("Sensor value is: {rsp:?}"); +const SENSOR_QUANTITY: Quantity = 2; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let builder = tokio_serial::new(SERIAL_PATH, BAUD_RATE); + let transport = SerialStream::open(&builder).unwrap(); + + let mut client = rtu::Client::new(transport); + + println!("Reading sensor values (request/response using the low-level API"); + let request = Request::ReadHoldingRegisters(SENSOR_ADDRESS, SENSOR_QUANTITY); + let request_context = client.send_request(SERVER, request).await?; + let Response::ReadHoldingRegisters(values) = client.recv_response(request_context).await?? + else { + // The response variant will always match its corresponding request variant if successful. + unreachable!(); + }; + println!("Sensor responded with: {values:?}"); + + println!("Reading sensor values (call) using the high-level API"); + let mut client_context = client::Context::from(rtu::ClientContext::new(client, SERVER).boxed()); + let values = client_context + .read_holding_registers(SENSOR_ADDRESS, SENSOR_QUANTITY) + .await??; + println!("Sensor responded with: {values:?}"); println!("Disconnecting"); - ctx.disconnect().await?; + client_context.disconnect().await?; Ok(()) } diff --git a/src/client/mod.rs b/src/client/mod.rs index 79e6bf01..f27a685b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -339,6 +339,28 @@ impl Writer for Context { } } +#[cfg(any(feature = "rtu", feature = "tcp"))] +pub(crate) async fn disconnect_framed( + framed: tokio_util::codec::Framed, +) -> std::io::Result<()> +where + T: tokio::io::AsyncWrite + Unpin, +{ + use tokio::io::AsyncWriteExt as _; + + framed + .into_inner() + .shutdown() + .await + .or_else(|err| match err.kind() { + std::io::ErrorKind::NotConnected | std::io::ErrorKind::BrokenPipe => { + // Already disconnected. + Ok(()) + } + _ => Err(err), + }) +} + #[cfg(test)] mod tests { use crate::{Error, Result}; diff --git a/src/client/rtu.rs b/src/client/rtu.rs index 330482ee..a2c10f93 100644 --- a/src/client/rtu.rs +++ b/src/client/rtu.rs @@ -3,11 +3,25 @@ //! RTU client connections +use std::io; + +use futures_util::{SinkExt as _, StreamExt as _}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; + +use crate::{ + codec::rtu::ClientCodec, + frame::{ + rtu::{Header, RequestAdu}, + RequestPdu, + }, + slave::SlaveContext, + FunctionCode, Request, Response, Result, Slave, +}; -use super::*; +use super::{disconnect_framed, Context}; -/// Connect to no particular Modbus slave device for sending +/// Connect to no particular _Modbus_ slave device for sending /// broadcast messages. pub fn attach(transport: T) -> Context where @@ -16,13 +30,240 @@ where attach_slave(transport, Slave::broadcast()) } -/// Connect to any kind of Modbus slave device. +/// Connect to any kind of _Modbus_ slave device. pub fn attach_slave(transport: T, slave: Slave) -> Context where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let client = crate::service::rtu::Client::new(transport, slave); + let client = Client::new(transport); + let context = ClientContext::new(client, slave); Context { - client: Box::new(client), + client: Box::new(context), + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct RequestContext { + pub(crate) function_code: FunctionCode, + pub(crate) header: Header, +} + +impl RequestContext { + #[must_use] + pub const fn function_code(&self) -> FunctionCode { + self.function_code + } +} + +/// _Modbus_ RTU client. +#[derive(Debug)] +pub struct Client { + framed: Framed, +} + +impl Client +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(transport: T) -> Self { + let framed = Framed::new(transport, ClientCodec::default()); + Self { framed } + } + + pub async fn disconnect(self) -> io::Result<()> { + let Self { framed } = self; + disconnect_framed(framed).await + } + + pub async fn call(&mut self, server: Slave, request: Request<'_>) -> Result { + let request_context = self.send_request(server, request).await?; + self.recv_response(request_context).await + } + + pub async fn send_request( + &mut self, + server: Slave, + request: Request<'_>, + ) -> io::Result { + let request_adu = request_adu(server, request); + self.send_request_adu(request_adu).await + } + + async fn send_request_adu( + &mut self, + request_adu: RequestAdu<'_>, + ) -> io::Result { + let request_context = request_adu.context(); + + self.framed.read_buffer_mut().clear(); + self.framed.send(request_adu).await?; + + Ok(request_context) + } + + pub async fn recv_response(&mut self, request_context: RequestContext) -> Result { + let response_adu = self + .framed + .next() + .await + .unwrap_or_else(|| Err(io::Error::from(io::ErrorKind::BrokenPipe)))?; + + response_adu.try_into_response(request_context) + } +} + +/// _Modbus_ RTU client with (server) context and connection state. +/// +/// Client that invokes methods (request/response) on a single or many (broadcast) server(s). +/// +/// The server can be switched between method calls. +#[derive(Debug)] +pub struct ClientContext { + client: Option>, + server: Slave, +} + +impl ClientContext { + pub fn new(client: Client, server: Slave) -> Self { + Self { + client: Some(client), + server, + } + } + + #[must_use] + pub const fn is_connected(&self) -> bool { + self.client.is_some() + } + + #[must_use] + pub const fn server(&self) -> Slave { + self.server + } + + pub fn set_server(&mut self, server: Slave) { + self.server = server; + } +} + +impl ClientContext +where + T: AsyncWrite + Unpin, +{ + pub async fn disconnect(&mut self) -> io::Result<()> { + let Some(client) = self.client.take() else { + // Already disconnected. + return Ok(()); + }; + disconnect_framed(client.framed).await + } +} + +impl ClientContext +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub async fn call(&mut self, request: Request<'_>) -> Result { + log::debug!("Call {request:?}"); + + let Some(client) = &mut self.client else { + return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected").into()); + }; + + client.call(self.server, request).await + } +} + +impl ClientContext +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + #[must_use] + pub fn boxed(self) -> Box { + Box::new(self) + } +} + +impl SlaveContext for ClientContext { + fn set_slave(&mut self, slave: Slave) { + self.set_server(slave); + } +} + +#[async_trait::async_trait] +impl crate::client::Client for ClientContext +where + T: AsyncRead + AsyncWrite + Send + Unpin, +{ + async fn call(&mut self, req: Request<'_>) -> Result { + self.call(req).await + } + + async fn disconnect(&mut self) -> io::Result<()> { + self.disconnect().await + } +} + +fn request_adu<'a, R>(server: Slave, request_pdu: R) -> RequestAdu<'a> +where + R: Into>, +{ + let hdr = Header { slave: server }; + let pdu = request_pdu.into(); + RequestAdu { hdr, pdu } +} + +#[cfg(test)] +mod tests { + use core::{ + pin::Pin, + task::{Context, Poll}, + }; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result}; + + use crate::Error; + + use super::*; + + #[derive(Debug)] + struct MockTransport; + + impl Unpin for MockTransport {} + + impl AsyncRead for MockTransport { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl AsyncWrite for MockTransport { + fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll> { + Poll::Ready(Ok(2)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + unimplemented!() + } + } + + #[tokio::test] + async fn handle_broken_pipe() { + let transport = MockTransport; + let client = Client::new(transport); + let mut context = ClientContext::new(client, Slave::broadcast()); + let res = context.call(Request::ReadCoils(0x00, 5)).await; + assert!(res.is_err()); + let err = res.err().unwrap(); + assert!( + matches!(err, Error::Transport(err) if err.kind() == std::io::ErrorKind::BrokenPipe) + ); } } diff --git a/src/client/tcp.rs b/src/client/tcp.rs index 1e613a5d..4d61e9f4 100644 --- a/src/client/tcp.rs +++ b/src/client/tcp.rs @@ -3,14 +3,26 @@ //! TCP client connections -use std::{fmt, io, net::SocketAddr}; +use std::{io, net::SocketAddr}; +use futures_util::{SinkExt as _, StreamExt as _}; use tokio::{ io::{AsyncRead, AsyncWrite}, net::TcpStream, }; +use tokio_util::codec::Framed; -use super::*; +use crate::{ + codec::tcp::ClientCodec, + frame::{ + tcp::{Header, RequestAdu, ResponseAdu, TransactionId, UnitId}, + verify_response_header, RequestPdu, ResponsePdu, + }, + slave::SlaveContext, + ExceptionResponse, ProtocolError, Request, Response, Result, Slave, +}; + +use super::{disconnect_framed, Context}; /// Establish a direct connection to a Modbus TCP coupler. pub async fn connect(socket_addr: SocketAddr) -> io::Result { @@ -31,7 +43,7 @@ pub async fn connect_slave(socket_addr: SocketAddr, slave: Slave) -> io::Result< /// The connection could either be an ordinary [`TcpStream`] or a TLS connection. pub fn attach(transport: T) -> Context where - T: AsyncRead + AsyncWrite + Send + Unpin + fmt::Debug + 'static, + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { attach_slave(transport, Slave::tcp_device()) } @@ -41,10 +53,215 @@ where /// The connection could either be an ordinary [`TcpStream`] or a TLS connection. pub fn attach_slave(transport: T, slave: Slave) -> Context where - T: AsyncRead + AsyncWrite + Send + Unpin + fmt::Debug + 'static, + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - let client = crate::service::tcp::Client::new(transport, slave); + let client = Client::new(transport, slave); Context { client: Box::new(client), } } + +const INITIAL_TRANSACTION_ID: TransactionId = 0; + +#[derive(Debug)] +struct TransactionIdGenerator { + next_transaction_id: TransactionId, +} + +impl TransactionIdGenerator { + const fn new() -> Self { + Self { + next_transaction_id: INITIAL_TRANSACTION_ID, + } + } + + fn next(&mut self) -> TransactionId { + let next_transaction_id = self.next_transaction_id; + self.next_transaction_id = next_transaction_id.wrapping_add(1); + next_transaction_id + } +} + +/// Modbus TCP client +#[derive(Debug)] +pub(crate) struct Client { + framed: Option>, + transaction_id_generator: TransactionIdGenerator, + unit_id: UnitId, +} + +impl Client +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) fn new(transport: T, slave: Slave) -> Self { + let framed = Framed::new(transport, ClientCodec::new()); + let transaction_id_generator = TransactionIdGenerator::new(); + let unit_id: UnitId = slave.into(); + Self { + framed: Some(framed), + transaction_id_generator, + unit_id, + } + } + + fn next_request_hdr(&mut self, unit_id: UnitId) -> Header { + let transaction_id = self.transaction_id_generator.next(); + Header { + transaction_id, + unit_id, + } + } + + fn next_request_adu<'a, R>(&mut self, req: R) -> RequestAdu<'a> + where + R: Into>, + { + RequestAdu { + hdr: self.next_request_hdr(self.unit_id), + pdu: req.into(), + } + } + + fn framed(&mut self) -> io::Result<&mut Framed> { + let Some(framed) = &mut self.framed else { + return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); + }; + Ok(framed) + } + + pub(crate) async fn call(&mut self, req: Request<'_>) -> Result { + log::debug!("Call {req:?}"); + + let req_function_code = req.function_code(); + let req_adu = self.next_request_adu(req); + let req_hdr = req_adu.hdr; + + let framed = self.framed()?; + + framed.read_buffer_mut().clear(); + framed.send(req_adu).await?; + + let res_adu = framed.next().await.ok_or_else(io::Error::last_os_error)??; + let ResponseAdu { + hdr: res_hdr, + pdu: res_pdu, + } = res_adu; + let ResponsePdu(result) = res_pdu; + + // Match headers of request and response. + if let Err(message) = verify_response_header(&req_hdr, &res_hdr) { + return Err(ProtocolError::HeaderMismatch { message, result }.into()); + } + + // Match function codes of request and response. + let rsp_function_code = match &result { + Ok(response) => response.function_code(), + Err(ExceptionResponse { function, .. }) => *function, + }; + if req_function_code != rsp_function_code { + return Err(ProtocolError::FunctionCodeMismatch { + request: req_function_code, + result, + } + .into()); + } + + Ok(result.map_err( + |ExceptionResponse { + function: _, + exception, + }| exception, + )) + } + + async fn disconnect(&mut self) -> io::Result<()> { + let Some(framed) = self.framed.take() else { + // Already disconnected. + return Ok(()); + }; + disconnect_framed(framed).await + } +} + +impl SlaveContext for Client { + fn set_slave(&mut self, slave: Slave) { + self.unit_id = slave.into(); + } +} + +#[async_trait::async_trait] +impl crate::client::Client for Client +where + T: AsyncRead + AsyncWrite + Send + Unpin, +{ + async fn call(&mut self, req: Request<'_>) -> Result { + self.call(req).await + } + + async fn disconnect(&mut self) -> io::Result<()> { + self.disconnect().await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_same_headers() { + // Given + let req_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + let rsp_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_ok()); + } + + #[test] + fn invalid_validate_not_same_unit_id() { + // Given + let req_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + let rsp_hdr = Header { + unit_id: 5, + transaction_id: 42, + }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_err()); + } + + #[test] + fn invalid_validate_not_same_transaction_id() { + // Given + let req_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + let rsp_hdr = Header { + unit_id: 0, + transaction_id: 86, + }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_err()); + } +} diff --git a/src/codec/rtu.rs b/src/codec/rtu.rs index 3da768ab..0e53a59e 100644 --- a/src/codec/rtu.rs +++ b/src/codec/rtu.rs @@ -10,7 +10,7 @@ use tokio_util::codec::{Decoder, Encoder}; use crate::{ bytes::{Buf, BufMut, Bytes, BytesMut}, frame::{rtu::*, MEI_TYPE_READ_DEVICE_IDENTIFICATION}, - slave::SlaveId, + Slave, SlaveId, }; use super::{encode_request_pdu, request_pdu_size, RequestPdu}; @@ -324,7 +324,9 @@ impl Decoder for ClientCodec { return Ok(None); }; - let hdr = Header { slave_id }; + let hdr = Header { + slave: Slave(slave_id), + }; // Decoding of the PDU is unlikely to fail due // to transmission errors, because the frame's bytes @@ -349,7 +351,9 @@ impl Decoder for ServerCodec { return Ok(None); }; - let hdr = Header { slave_id }; + let hdr = Header { + slave: Slave(slave_id), + }; // Decoding of the PDU is unlikely to fail due // to transmission errors, because the frame's bytes @@ -375,7 +379,7 @@ impl<'a> Encoder> for ClientCodec { let buf_offset = buf.len(); let request_pdu_size = request_pdu_size(&request)?; buf.reserve(request_pdu_size + 3); - buf.put_u8(hdr.slave_id); + buf.put_u8(hdr.slave.into()); encode_request_pdu(buf, &request); let crc = calc_crc(&buf[buf_offset..]); write_crc(buf, crc); @@ -395,7 +399,7 @@ impl Encoder for ServerCodec { let buf_offset = buf.len(); let response_result_pdu_size = super::response_result_pdu_size(&pdu_res)?; buf.reserve(response_result_pdu_size + 3); - buf.put_u8(hdr.slave_id); + buf.put_u8(hdr.slave.into()); super::encode_response_result_pdu(buf, &pdu_res); let crc = calc_crc(&buf[buf_offset..]); write_crc(buf, crc); @@ -738,7 +742,7 @@ mod tests { ); let ResponseAdu { hdr, pdu } = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(buf.len(), 1); - assert_eq!(hdr.slave_id, 0x01); + assert_eq!(hdr.slave, Slave(0x01)); if let Ok(Response::ReadHoldingRegisters(data)) = pdu.into() { assert_eq!(data.len(), 2); assert_eq!(data, vec![0x8902, 0x42C7]); @@ -769,7 +773,7 @@ mod tests { ); let ResponseAdu { hdr, pdu } = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(buf.len(), 1); - assert_eq!(hdr.slave_id, 0x01); + assert_eq!(hdr.slave, Slave(0x01)); if let Ok(Response::ReadHoldingRegisters(data)) = pdu.into() { assert_eq!(data.len(), 2); assert_eq!(data, vec![0x8902, 0x42C7]); @@ -807,7 +811,9 @@ mod tests { let req = Request::ReadHoldingRegisters(0x082b, 2); let pdu = req.into(); let slave_id = 0x01; - let hdr = Header { slave_id }; + let hdr = Header { + slave: Slave(slave_id), + }; let adu = RequestAdu { hdr, pdu }; codec.encode(adu, &mut buf).unwrap(); @@ -823,7 +829,9 @@ mod tests { let req = Request::ReadHoldingRegisters(0x082b, 2); let pdu = req.into(); let slave_id = 0x01; - let hdr = Header { slave_id }; + let hdr = Header { + slave: Slave(slave_id), + }; let adu = RequestAdu { hdr, pdu }; let mut buf = BytesMut::with_capacity(40); #[allow(unsafe_code)] diff --git a/src/frame/mod.rs b/src/frame/mod.rs index b04f1040..77f9c623 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -21,7 +21,7 @@ pub(crate) const MEI_TYPE_READ_DEVICE_IDENTIFICATION: u8 = 0x0E; /// A Modbus function code. /// /// All function codes as defined by the protocol specification V1.1b3. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum FunctionCode { /// 01 (0x01) Read Coils. ReadCoils, @@ -309,7 +309,7 @@ impl Request<'_> { #[derive(Debug, Clone, PartialEq, Eq)] pub struct SlaveRequest<'a> { /// Slave id from the request - pub slave: crate::slave::SlaveId, + pub slave: crate::SlaveId, /// A `Request` enum pub request: Request<'a>, } @@ -804,6 +804,24 @@ impl error::Error for ExceptionResponse { } } +/// Check that `req_hdr` is the same `Header` as `rsp_hdr`. +/// +/// # Errors +/// +/// If the 2 headers are different, an error message with the details will be returned. +#[cfg(any(feature = "rtu", feature = "tcp"))] +pub(crate) fn verify_response_header( + req_hdr: &H, + rsp_hdr: &H, +) -> Result<(), String> { + if req_hdr != rsp_hdr { + return Err(format!( + "expected/request = {req_hdr:?}, actual/response = {rsp_hdr:?}" + )); + } + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/frame/rtu.rs b/src/frame/rtu.rs index 19100bd7..6135024d 100644 --- a/src/frame/rtu.rs +++ b/src/frame/rtu.rs @@ -3,11 +3,11 @@ use super::*; -use crate::slave::SlaveId; +use crate::{client::rtu::RequestContext, ProtocolError, Result, Slave}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub(crate) struct Header { - pub(crate) slave_id: SlaveId, + pub(crate) slave: Slave, } #[derive(Debug, Clone)] @@ -16,12 +16,60 @@ pub struct RequestAdu<'a> { pub(crate) pdu: RequestPdu<'a>, } +impl RequestAdu<'_> { + pub(crate) fn context(&self) -> RequestContext { + RequestContext { + function_code: self.pdu.0.function_code(), + header: self.hdr, + } + } +} + #[derive(Debug, Clone)] pub(crate) struct ResponseAdu { pub(crate) hdr: Header, pub(crate) pdu: ResponsePdu, } +impl ResponseAdu { + pub(crate) fn try_into_response(self, request_context: RequestContext) -> Result { + let RequestContext { + function_code: req_function_code, + header: req_hdr, + } = request_context; + + let ResponseAdu { + hdr: rsp_hdr, + pdu: rsp_pdu, + } = self; + let ResponsePdu(result) = rsp_pdu; + + if let Err(message) = verify_response_header(&req_hdr, &rsp_hdr) { + return Err(ProtocolError::HeaderMismatch { message, result }.into()); + } + + // Match function codes of request and response. + let rsp_function_code = match &result { + Ok(response) => response.function_code(), + Err(ExceptionResponse { function, .. }) => *function, + }; + if req_function_code != rsp_function_code { + return Err(ProtocolError::FunctionCodeMismatch { + request: req_function_code, + result, + } + .into()); + } + + Ok(result.map_err( + |ExceptionResponse { + function: _, + exception, + }| exception, + )) + } +} + impl<'a> From> for Request<'a> { fn from(from: RequestAdu<'a>) -> Self { from.pdu.into() @@ -31,9 +79,41 @@ impl<'a> From> for Request<'a> { #[cfg(feature = "server")] impl<'a> From> for SlaveRequest<'a> { fn from(from: RequestAdu<'a>) -> Self { + let RequestAdu { hdr, pdu } = from; Self { - slave: from.hdr.slave_id, - request: from.pdu.into(), + slave: hdr.slave.into(), + request: pdu.into(), } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_same_headers() { + // Given + let req_hdr = Header { slave: Slave(0) }; + let rsp_hdr = Header { slave: Slave(0) }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_ok()); + } + + #[test] + fn invalid_validate_not_same_slave_id() { + // Given + let req_hdr = Header { slave: Slave(0) }; + let rsp_hdr = Header { slave: Slave(5) }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index a397923b..2f8a1291 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,7 +61,5 @@ pub use self::frame::{ /// 2. [`ExceptionCode`]: An error occurred on the _Modbus_ server. pub type Result = std::result::Result, Error>; -mod service; - #[cfg(feature = "server")] pub mod server; diff --git a/src/service/mod.rs b/src/service/mod.rs deleted file mode 100644 index 3bb5facc..00000000 --- a/src/service/mod.rs +++ /dev/null @@ -1,43 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2017-2025 slowtec GmbH -// SPDX-License-Identifier: MIT OR Apache-2.0 - -#[cfg(feature = "rtu")] -pub(crate) mod rtu; - -#[cfg(feature = "tcp")] -pub(crate) mod tcp; - -#[cfg(any(feature = "rtu", feature = "tcp"))] -async fn disconnect(framed: tokio_util::codec::Framed) -> std::io::Result<()> -where - T: tokio::io::AsyncWrite + Unpin, -{ - use tokio::io::AsyncWriteExt as _; - - framed - .into_inner() - .shutdown() - .await - .or_else(|err| match err.kind() { - std::io::ErrorKind::NotConnected | std::io::ErrorKind::BrokenPipe => { - // Already disconnected. - Ok(()) - } - _ => Err(err), - }) -} - -/// Check that `req_hdr` is the same `Header` as `rsp_hdr`. -/// -/// # Errors -/// -/// If the 2 headers are different, an error message with the details will be returned. -#[cfg(any(feature = "rtu", feature = "tcp"))] -fn verify_response_header(req_hdr: &H, rsp_hdr: &H) -> Result<(), String> { - if req_hdr != rsp_hdr { - return Err(format!( - "expected/request = {req_hdr:?}, actual/response = {rsp_hdr:?}" - )); - } - Ok(()) -} diff --git a/src/service/rtu.rs b/src/service/rtu.rs deleted file mode 100644 index 82bc63dc..00000000 --- a/src/service/rtu.rs +++ /dev/null @@ -1,216 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2017-2025 slowtec GmbH -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use std::io; - -use futures_util::{SinkExt as _, StreamExt as _}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; - -use crate::{ - codec, - frame::{rtu::*, *}, - slave::*, - ProtocolError, Result, -}; - -use super::{disconnect, verify_response_header}; - -/// Modbus RTU client -#[derive(Debug)] -pub(crate) struct Client { - framed: Option>, - slave_id: SlaveId, -} - -impl Client -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub(crate) fn new(transport: T, slave: Slave) -> Self { - let framed = Framed::new(transport, codec::rtu::ClientCodec::default()); - let slave_id = slave.into(); - Self { - slave_id, - framed: Some(framed), - } - } - - fn framed(&mut self) -> io::Result<&mut Framed> { - let Some(framed) = &mut self.framed else { - return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); - }; - Ok(framed) - } - - fn next_request_adu<'a, R>(&self, req: R) -> RequestAdu<'a> - where - R: Into>, - { - let slave_id = self.slave_id; - let hdr = Header { slave_id }; - let pdu = req.into(); - RequestAdu { hdr, pdu } - } - - async fn call(&mut self, req: Request<'_>) -> Result { - log::debug!("Call {req:?}"); - - let req_function_code = req.function_code(); - let req_adu = self.next_request_adu(req); - let req_hdr = req_adu.hdr; - - let framed = self.framed()?; - - framed.read_buffer_mut().clear(); - framed.send(req_adu).await?; - - let res_adu = framed - .next() - .await - .unwrap_or_else(|| Err(io::Error::from(io::ErrorKind::BrokenPipe)))?; - let ResponseAdu { - hdr: res_hdr, - pdu: res_pdu, - } = res_adu; - let ResponsePdu(result) = res_pdu; - - // Match headers of request and response. - if let Err(message) = verify_response_header(&req_hdr, &res_hdr) { - return Err(ProtocolError::HeaderMismatch { message, result }.into()); - } - - // Match function codes of request and response. - let rsp_function_code = match &result { - Ok(response) => response.function_code(), - Err(ExceptionResponse { function, .. }) => *function, - }; - if req_function_code != rsp_function_code { - return Err(ProtocolError::FunctionCodeMismatch { - request: req_function_code, - result, - } - .into()); - } - - Ok(result.map_err( - |ExceptionResponse { - function: _, - exception, - }| exception, - )) - } - - async fn disconnect(&mut self) -> io::Result<()> { - let Some(framed) = self.framed.take() else { - // Already disconnected. - return Ok(()); - }; - disconnect(framed).await - } -} - -impl SlaveContext for Client { - fn set_slave(&mut self, slave: Slave) { - self.slave_id = slave.into(); - } -} - -#[async_trait::async_trait] -impl crate::client::Client for Client -where - T: AsyncRead + AsyncWrite + Send + Unpin, -{ - async fn call(&mut self, req: Request<'_>) -> Result { - self.call(req).await - } - - async fn disconnect(&mut self) -> io::Result<()> { - self.disconnect().await - } -} - -#[cfg(test)] -mod tests { - - use core::{ - pin::Pin, - task::{Context, Poll}, - }; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result}; - - use crate::{ - service::{rtu::Header, verify_response_header}, - Error, - }; - - #[test] - fn validate_same_headers() { - // Given - let req_hdr = Header { slave_id: 0 }; - let rsp_hdr = Header { slave_id: 0 }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_ok()); - } - - #[test] - fn invalid_validate_not_same_slave_id() { - // Given - let req_hdr = Header { slave_id: 0 }; - let rsp_hdr = Header { slave_id: 5 }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_err()); - } - - #[derive(Debug)] - struct MockTransport; - - impl Unpin for MockTransport {} - - impl AsyncRead for MockTransport { - fn poll_read( - self: Pin<&mut Self>, - _: &mut Context<'_>, - _: &mut ReadBuf<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - } - - impl AsyncWrite for MockTransport { - fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll> { - Poll::Ready(Ok(2)) - } - - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - unimplemented!() - } - } - - #[tokio::test] - async fn handle_broken_pipe() { - let transport = MockTransport; - let mut client = - crate::service::rtu::Client::new(transport, crate::service::rtu::Slave::broadcast()); - let res = client - .call(crate::service::rtu::Request::ReadCoils(0x00, 5)) - .await; - assert!(res.is_err()); - let err = res.err().unwrap(); - assert!( - matches!(err, Error::Transport(err) if err.kind() == std::io::ErrorKind::BrokenPipe) - ); - } -} diff --git a/src/service/tcp.rs b/src/service/tcp.rs deleted file mode 100644 index ff4647a2..00000000 --- a/src/service/tcp.rs +++ /dev/null @@ -1,226 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2017-2025 slowtec GmbH -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use std::io; - -use futures_util::{SinkExt as _, StreamExt as _}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; - -use crate::{ - codec, - frame::{ - tcp::{Header, RequestAdu, ResponseAdu, TransactionId, UnitId}, - RequestPdu, ResponsePdu, - }, - service::verify_response_header, - slave::*, - ExceptionResponse, ProtocolError, Request, Response, Result, -}; - -use super::disconnect; - -const INITIAL_TRANSACTION_ID: TransactionId = 0; - -#[derive(Debug)] -struct TransactionIdGenerator { - next_transaction_id: TransactionId, -} - -impl TransactionIdGenerator { - const fn new() -> Self { - Self { - next_transaction_id: INITIAL_TRANSACTION_ID, - } - } - - fn next(&mut self) -> TransactionId { - let next_transaction_id = self.next_transaction_id; - self.next_transaction_id = next_transaction_id.wrapping_add(1); - next_transaction_id - } -} - -/// Modbus TCP client -#[derive(Debug)] -pub(crate) struct Client { - framed: Option>, - transaction_id_generator: TransactionIdGenerator, - unit_id: UnitId, -} - -impl Client -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub(crate) fn new(transport: T, slave: Slave) -> Self { - let framed = Framed::new(transport, codec::tcp::ClientCodec::new()); - let transaction_id_generator = TransactionIdGenerator::new(); - let unit_id: UnitId = slave.into(); - Self { - framed: Some(framed), - transaction_id_generator, - unit_id, - } - } - - fn next_request_hdr(&mut self, unit_id: UnitId) -> Header { - let transaction_id = self.transaction_id_generator.next(); - Header { - transaction_id, - unit_id, - } - } - - fn next_request_adu<'a, R>(&mut self, req: R) -> RequestAdu<'a> - where - R: Into>, - { - RequestAdu { - hdr: self.next_request_hdr(self.unit_id), - pdu: req.into(), - } - } - - fn framed(&mut self) -> io::Result<&mut Framed> { - let Some(framed) = &mut self.framed else { - return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); - }; - Ok(framed) - } - - pub(crate) async fn call(&mut self, req: Request<'_>) -> Result { - log::debug!("Call {req:?}"); - - let req_function_code = req.function_code(); - let req_adu = self.next_request_adu(req); - let req_hdr = req_adu.hdr; - - let framed = self.framed()?; - - framed.read_buffer_mut().clear(); - framed.send(req_adu).await?; - - let res_adu = framed.next().await.ok_or_else(io::Error::last_os_error)??; - let ResponseAdu { - hdr: res_hdr, - pdu: res_pdu, - } = res_adu; - let ResponsePdu(result) = res_pdu; - - // Match headers of request and response. - if let Err(message) = verify_response_header(&req_hdr, &res_hdr) { - return Err(ProtocolError::HeaderMismatch { message, result }.into()); - } - - // Match function codes of request and response. - let rsp_function_code = match &result { - Ok(response) => response.function_code(), - Err(ExceptionResponse { function, .. }) => *function, - }; - if req_function_code != rsp_function_code { - return Err(ProtocolError::FunctionCodeMismatch { - request: req_function_code, - result, - } - .into()); - } - - Ok(result.map_err( - |ExceptionResponse { - function: _, - exception, - }| exception, - )) - } - - async fn disconnect(&mut self) -> io::Result<()> { - let Some(framed) = self.framed.take() else { - // Already disconnected. - return Ok(()); - }; - disconnect(framed).await - } -} - -impl SlaveContext for Client { - fn set_slave(&mut self, slave: Slave) { - self.unit_id = slave.into(); - } -} - -#[async_trait::async_trait] -impl crate::client::Client for Client -where - T: AsyncRead + AsyncWrite + Send + Unpin, -{ - async fn call(&mut self, req: Request<'_>) -> Result { - self.call(req).await - } - - async fn disconnect(&mut self) -> io::Result<()> { - self.disconnect().await - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn validate_same_headers() { - // Given - let req_hdr = Header { - unit_id: 0, - transaction_id: 42, - }; - let rsp_hdr = Header { - unit_id: 0, - transaction_id: 42, - }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_ok()); - } - - #[test] - fn invalid_validate_not_same_unit_id() { - // Given - let req_hdr = Header { - unit_id: 0, - transaction_id: 42, - }; - let rsp_hdr = Header { - unit_id: 5, - transaction_id: 42, - }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_err()); - } - - #[test] - fn invalid_validate_not_same_transaction_id() { - // Given - let req_hdr = Header { - unit_id: 0, - transaction_id: 42, - }; - let rsp_hdr = Header { - unit_id: 0, - transaction_id: 86, - }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_err()); - } -} diff --git a/src/slave.rs b/src/slave.rs index 8963edb1..87699756 100644 --- a/src/slave.rs +++ b/src/slave.rs @@ -9,7 +9,8 @@ use std::{fmt, num::ParseIntError, str::FromStr}; pub type SlaveId = u8; /// A single byte for addressing Modbus slave devices. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] pub struct Slave(pub SlaveId); impl Slave { @@ -100,7 +101,7 @@ impl FromStr for Slave { impl fmt::Display for Slave { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} (0x{:0>2X})", self.0, self.0) + write!(f, "{} (0x{:02X})", self.0, self.0) } } @@ -147,9 +148,9 @@ mod tests { } #[test] - fn format() { - assert!(format!("{}", Slave(123)).contains("123")); - assert!(format!("{}", Slave(0x7B)).contains("0x7B")); - assert!(!format!("{}", Slave(0x7B)).contains("0x7b")); + fn display() { + assert_eq!("0 (0x00)", Slave(0).to_string()); + assert_eq!("123 (0x7B)", Slave(123).to_string()); + assert_eq!("123 (0x7B)", Slave(0x7B).to_string()); } }