diff --git a/melib/src/backends/imap/connection.rs b/melib/src/backends/imap/connection.rs index e9a67665..6ad569fc 100644 --- a/melib/src/backends/imap/connection.rs +++ b/melib/src/backends/imap/connection.rs @@ -64,6 +64,10 @@ impl MailboxSelection { } } +async fn try_await(cl: impl Future> + Send) -> Result<()> { + cl.await +} + #[derive(Debug)] pub struct ImapConnection { pub stream: Result, @@ -91,7 +95,9 @@ impl ImapStream { if server_conf.danger_accept_invalid_certs { connector.danger_accept_invalid_certs(true); } - let connector = connector.build()?; + let connector = connector + .build() + .chain_err_kind(crate::error::ErrorKind::Network)?; let addr = if let Ok(a) = lookup_ipv4(path, server_conf.server_port) { a @@ -102,21 +108,27 @@ impl ImapStream { ))); }; - let mut socket = AsyncWrapper::new(Connection::Tcp(TcpStream::connect_timeout( - &addr, - std::time::Duration::new(4, 0), - )?))?; + let mut socket = AsyncWrapper::new(Connection::Tcp( + TcpStream::connect_timeout(&addr, std::time::Duration::new(4, 0)) + .chain_err_kind(crate::error::ErrorKind::Network)?, + )) + .chain_err_kind(crate::error::ErrorKind::Network)?; if server_conf.use_starttls { let mut buf = vec![0; 1024]; match server_conf.protocol { - ImapProtocol::IMAP => { - socket - .write_all(format!("M{} STARTTLS\r\n", cmd_id).as_bytes()) - .await? - } + ImapProtocol::IMAP => socket + .write_all(format!("M{} STARTTLS\r\n", cmd_id).as_bytes()) + .await + .chain_err_kind(crate::error::ErrorKind::Network)?, ImapProtocol::ManageSieve => { - socket.read(&mut buf).await?; - socket.write_all(b"STARTTLS\r\n").await?; + socket + .read(&mut buf) + .await + .chain_err_kind(crate::error::ErrorKind::Network)?; + socket + .write_all(b"STARTTLS\r\n") + .await + .chain_err_kind(crate::error::ErrorKind::Network)?; } } let mut response = String::with_capacity(1024); @@ -124,7 +136,10 @@ impl ImapStream { let now = std::time::Instant::now(); while now.elapsed().as_secs() < 3 { - let len = socket.read(&mut buf).await?; + let len = socket + .read(&mut buf) + .await + .chain_err_kind(crate::error::ErrorKind::Network)?; response.push_str(unsafe { std::str::from_utf8_unchecked(&buf[0..len]) }); match server_conf.protocol { ImapProtocol::IMAP => { @@ -157,7 +172,9 @@ impl ImapStream { { // FIXME: This is blocking - let socket = socket.into_inner()?; + let socket = socket + .into_inner() + .chain_err_kind(crate::error::ErrorKind::Network)?; let mut conn_result = connector.connect(path, socket); if let Err(native_tls::HandshakeError::WouldBlock(midhandshake_stream)) = conn_result @@ -173,12 +190,15 @@ impl ImapStream { midhandshake_stream = Some(stream); } p => { - p?; + p.chain_err_kind(crate::error::ErrorKind::Network)?; } } } } - AsyncWrapper::new(Connection::Tls(conn_result?))? + AsyncWrapper::new(Connection::Tls( + conn_result.chain_err_kind(crate::error::ErrorKind::Network)?, + )) + .chain_err_kind(crate::error::ErrorKind::Network)? } } else { let addr = if let Ok(a) = lookup_ipv4(path, server_conf.server_port) { @@ -189,10 +209,11 @@ impl ImapStream { &path ))); }; - AsyncWrapper::new(Connection::Tcp(TcpStream::connect_timeout( - &addr, - std::time::Duration::new(4, 0), - )?))? + AsyncWrapper::new(Connection::Tcp( + TcpStream::connect_timeout(&addr, std::time::Duration::new(4, 0)) + .chain_err_kind(crate::error::ErrorKind::Network)?, + )) + .chain_err_kind(crate::error::ErrorKind::Network)? }; let mut res = String::with_capacity(8 * 1024); let mut ret = ImapStream { @@ -256,7 +277,8 @@ impl ImapStream { return Err(MeliError::new(format!( "Could not connect to {}: server does not accept logins [LOGINDISABLED]", &server_conf.server_hostname - ))); + )) + .set_err_kind(crate::error::ErrorKind::Authentication)); } let mut capabilities = None; @@ -287,7 +309,8 @@ impl ImapStream { return Err(MeliError::new(format!( "Could not connect. Server replied with '{}'", l[tag_start.len()..].trim() - ))); + )) + .set_err_kind(crate::error::ErrorKind::Authentication)); } should_break = true; } @@ -365,7 +388,7 @@ impl ImapStream { continue; } Err(e) => { - return Err(MeliError::from(e)); + return Err(MeliError::from(e).set_err_kind(crate::error::ErrorKind::Network)); } } } @@ -381,42 +404,66 @@ impl ImapStream { } pub async fn send_command(&mut self, command: &[u8]) -> Result<()> { - let command = command.trim(); - match self.protocol { - ImapProtocol::IMAP => { - self.stream.write_all(b"M").await?; - self.stream - .write_all(self.cmd_id.to_string().as_bytes()) - .await?; - self.stream.write_all(b" ").await?; - self.cmd_id += 1; + if let Err(err) = try_await(async move { + let command = command.trim(); + match self.protocol { + ImapProtocol::IMAP => { + self.stream.write_all(b"M").await?; + self.stream + .write_all(self.cmd_id.to_string().as_bytes()) + .await?; + self.stream.write_all(b" ").await?; + self.cmd_id += 1; + } + ImapProtocol::ManageSieve => {} } - ImapProtocol::ManageSieve => {} - } - self.stream.write_all(command).await?; - self.stream.write_all(b"\r\n").await?; - match self.protocol { - ImapProtocol::IMAP => { - debug!("sent: M{} {}", self.cmd_id - 1, unsafe { - std::str::from_utf8_unchecked(command) - }); + self.stream.write_all(command).await?; + self.stream.write_all(b"\r\n").await?; + match self.protocol { + ImapProtocol::IMAP => { + debug!("sent: M{} {}", self.cmd_id - 1, unsafe { + std::str::from_utf8_unchecked(command) + }); + } + ImapProtocol::ManageSieve => {} } - ImapProtocol::ManageSieve => {} + Ok(()) + }) + .await + { + Err(err.set_err_kind(crate::error::ErrorKind::Network)) + } else { + Ok(()) } - Ok(()) } pub async fn send_literal(&mut self, data: &[u8]) -> Result<()> { - self.stream.write_all(data).await?; - self.stream.write_all(b"\r\n").await?; - Ok(()) + if let Err(err) = try_await(async move { + self.stream.write_all(data).await?; + self.stream.write_all(b"\r\n").await?; + Ok(()) + }) + .await + { + Err(err.set_err_kind(crate::error::ErrorKind::Network)) + } else { + Ok(()) + } } pub async fn send_raw(&mut self, raw: &[u8]) -> Result<()> { - self.stream.write_all(raw).await?; - self.stream.write_all(b"\r\n").await?; - Ok(()) + if let Err(err) = try_await(async move { + self.stream.write_all(raw).await?; + self.stream.write_all(b"\r\n").await?; + Ok(()) + }) + .await + { + Err(err.set_err_kind(crate::error::ErrorKind::Network)) + } else { + Ok(()) + } } } @@ -525,18 +572,39 @@ impl ImapConnection { } pub async fn send_command(&mut self, command: &[u8]) -> Result<()> { - self.stream.as_mut()?.send_command(command).await?; - Ok(()) + if let Err(err) = + try_await(async { self.stream.as_mut()?.send_command(command).await }).await + { + if err.kind.is_network() { + self.connect().await?; + } + Err(err) + } else { + Ok(()) + } } pub async fn send_literal(&mut self, data: &[u8]) -> Result<()> { - self.stream.as_mut()?.send_literal(data).await?; - Ok(()) + if let Err(err) = try_await(async { self.stream.as_mut()?.send_literal(data).await }).await + { + if err.kind.is_network() { + self.connect().await?; + } + Err(err) + } else { + Ok(()) + } } pub async fn send_raw(&mut self, raw: &[u8]) -> Result<()> { - self.stream.as_mut()?.send_raw(raw).await?; - Ok(()) + if let Err(err) = try_await(async { self.stream.as_mut()?.send_raw(raw).await }).await { + if err.kind.is_network() { + self.connect().await?; + } + Err(err) + } else { + Ok(()) + } } pub async fn select_mailbox( diff --git a/melib/src/connections.rs b/melib/src/connections.rs index 8d9a21f6..ffbc88a4 100644 --- a/melib/src/connections.rs +++ b/melib/src/connections.rs @@ -150,5 +150,6 @@ pub fn lookup_ipv4(host: &str, port: u16) -> crate::Result } } - Err(crate::error::MeliError::new("Cannot lookup address")) + Err(crate::error::MeliError::new("Cannot lookup address") + .set_kind(crate::error::ErrorKind::Network)) } diff --git a/melib/src/error.rs b/melib/src/error.rs index bdb5032c..f4ebb5eb 100644 --- a/melib/src/error.rs +++ b/melib/src/error.rs @@ -34,17 +34,42 @@ use std::sync::Arc; pub type Result = result::Result; +#[derive(Debug, Copy, PartialEq, Clone)] +pub enum ErrorKind { + None, + Authentication, + Network, +} + +impl ErrorKind { + pub fn is_network(&self) -> bool { + match self { + ErrorKind::Network => true, + _ => false, + } + } + + pub fn is_authentication(&self) -> bool { + match self { + ErrorKind::Authentication => true, + _ => false, + } + } +} + #[derive(Debug, Clone)] pub struct MeliError { pub summary: Option>, pub details: Cow<'static, str>, pub source: Option>, + pub kind: ErrorKind, } pub trait IntoMeliError { fn set_err_summary(self, msg: M) -> MeliError where M: Into>; + fn set_err_kind(self, kind: ErrorKind) -> MeliError; } pub trait ResultIntoMeliError { @@ -52,6 +77,8 @@ pub trait ResultIntoMeliError { where F: Fn() -> M, M: Into>; + + fn chain_err_kind(self, kind: ErrorKind) -> Result; } impl> IntoMeliError for I { @@ -63,6 +90,12 @@ impl> IntoMeliError for I { let err: MeliError = self.into(); err.set_summary(msg) } + + #[inline] + fn set_err_kind(self, kind: ErrorKind) -> MeliError { + let err: MeliError = self.into(); + err.set_kind(kind) + } } impl> ResultIntoMeliError for std::result::Result { @@ -74,6 +107,11 @@ impl> ResultIntoMeliError for std::result::Result { self.map_err(|err| err.set_err_summary(msg_fn())) } + + #[inline] + fn chain_err_kind(self, kind: ErrorKind) -> Result { + self.map_err(|err| err.set_err_kind(kind)) + } } impl MeliError { @@ -85,6 +123,7 @@ impl MeliError { summary: None, details: msg.into(), source: None, + kind: ErrorKind::None, } } @@ -107,6 +146,11 @@ impl MeliError { self.source = new_val; self } + + pub fn set_kind(mut self, new_val: ErrorKind) -> MeliError { + self.kind = new_val; + self + } } impl fmt::Display for MeliError {