diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs index 74eff9dae..be0e1a129 100644 --- a/embassy-net/src/tcp.rs +++ b/embassy-net/src/tcp.rs @@ -660,12 +660,25 @@ pub mod client { pub struct TcpClient<'d, D: Driver, const N: usize, const TX_SZ: usize = 1024, const RX_SZ: usize = 1024> { stack: &'d Stack, state: &'d TcpClientState, + socket_timeout: Option, } impl<'d, D: Driver, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClient<'d, D, N, TX_SZ, RX_SZ> { /// Create a new `TcpClient`. pub fn new(stack: &'d Stack, state: &'d TcpClientState) -> Self { - Self { stack, state } + Self { + stack, + state, + socket_timeout: None, + } + } + + /// Set the timeout for each socket created by this `TcpClient`. + /// + /// If the timeout is set, the socket will be closed if no data is received for the + /// specified duration. + pub fn set_timeout(&mut self, timeout: Option) { + self.socket_timeout = timeout; } } @@ -691,6 +704,7 @@ pub mod client { }; let remote_endpoint = (addr, remote.port()); let mut socket = TcpConnection::new(&self.stack, self.state)?; + socket.socket.set_timeout(self.socket_timeout.clone()); socket .socket .connect(remote_endpoint)