From 42c4c615181be89dfbf0c80e0272e3b6f18ba3a7 Mon Sep 17 00:00:00 2001 From: Manos Pitsidianakis Date: Tue, 15 Sep 2020 01:17:32 +0300 Subject: [PATCH] melib/connections: impl tcp keepalive --- melib/src/connections.rs | 97 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 91 insertions(+), 6 deletions(-) diff --git a/melib/src/connections.rs b/melib/src/connections.rs index 6ce1872ad..0c28a13eb 100644 --- a/melib/src/connections.rs +++ b/melib/src/connections.rs @@ -22,6 +22,21 @@ //! Connections layers (TCP/fd/TLS/Deflate) to use with remote backends. #[cfg(feature = "deflate_compression")] use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression}; +#[cfg(any(target_os = "openbsd", target_os = "netbsd", target_os = "haiku"))] +use libc::SO_KEEPALIVE as KEEPALIVE_OPTION; +#[cfg(any(target_os = "macos", target_os = "ios"))] +use libc::TCP_KEEPALIVE as KEEPALIVE_OPTION; +#[cfg(not(any( + target_os = "openbsd", + target_os = "netbsd", + target_os = "haiku", + target_os = "macos", + target_os = "ios" +)))] +use libc::TCP_KEEPIDLE as KEEPALIVE_OPTION; +use libc::{self, c_int, c_void}; +use std::os::unix::io::AsRawFd; +use std::time::Duration; #[derive(Debug)] pub enum Connection { @@ -37,6 +52,18 @@ pub enum Connection { use Connection::*; +macro_rules! syscall { + ($fn: ident ( $($arg: expr),* $(,)* ) ) => {{ + #[allow(unused_unsafe)] + let res = unsafe { libc::$fn($($arg, )*) }; + if res == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } + }}; +} + impl Connection { pub const IO_BUF_SIZE: usize = 64 * 1024; #[cfg(feature = "deflate_compression")] @@ -74,7 +101,7 @@ impl Connection { } } - pub fn set_read_timeout(&self, dur: Option) -> std::io::Result<()> { + pub fn set_read_timeout(&self, dur: Option) -> std::io::Result<()> { match self { Tcp(ref t) => t.set_read_timeout(dur), #[cfg(feature = "tls")] @@ -85,7 +112,7 @@ impl Connection { } } - pub fn set_write_timeout(&self, dur: Option) -> std::io::Result<()> { + pub fn set_write_timeout(&self, dur: Option) -> std::io::Result<()> { match self { Tcp(ref t) => t.set_write_timeout(dur), #[cfg(feature = "tls")] @@ -95,6 +122,67 @@ impl Connection { Deflate { ref inner, .. } => inner.get_ref().get_ref().set_write_timeout(dur), } } + + pub fn keepalive(&self) -> std::io::Result> { + if let Fd(_) = self { + return Ok(None); + } + unsafe { + let raw: c_int = self.getsockopt(libc::SOL_SOCKET, libc::SO_KEEPALIVE)?; + if raw == 0 { + return Ok(None); + } + let secs: c_int = self.getsockopt(libc::IPPROTO_TCP, KEEPALIVE_OPTION)?; + Ok(Some(Duration::new(secs as u64, 0))) + } + } + + pub fn set_keepalive(&self, keepalive: Option) -> std::io::Result<()> { + if let Fd(_) = self { + return Ok(()); + } + unsafe { + self.setsockopt( + libc::SOL_SOCKET, + libc::SO_KEEPALIVE, + keepalive.is_some() as c_int, + )?; + if let Some(dur) = keepalive { + // TODO: checked cast here + self.setsockopt(libc::IPPROTO_TCP, KEEPALIVE_OPTION, dur.as_secs() as c_int)?; + } + Ok(()) + } + } + + unsafe fn setsockopt(&self, opt: c_int, val: c_int, payload: T) -> std::io::Result<()> + where + T: Copy, + { + let payload = &payload as *const T as *const c_void; + syscall!(setsockopt( + self.as_raw_fd(), + opt, + val, + payload, + std::mem::size_of::() as libc::socklen_t, + ))?; + Ok(()) + } + + unsafe fn getsockopt(&self, opt: c_int, val: c_int) -> std::io::Result { + let mut slot: T = std::mem::zeroed(); + let mut len = std::mem::size_of::() as libc::socklen_t; + syscall!(getsockopt( + self.as_raw_fd(), + opt, + val, + &mut slot as *mut _ as *mut _, + &mut len, + ))?; + assert_eq!(len as usize, std::mem::size_of::()); + Ok(slot) + } } impl Drop for Connection { @@ -191,10 +279,7 @@ pub fn lookup_ipv4(host: &str, port: u16) -> crate::Result use futures::future::{self, Either, Future}; -pub async fn timeout( - dur: Option, - f: impl Future, -) -> crate::Result { +pub async fn timeout(dur: Option, f: impl Future) -> crate::Result { futures::pin_mut!(f); if let Some(dur) = dur { match future::select(f, smol::Timer::after(dur)).await {