web: add redirect to previous page after login with ?next= parameter

grcov
Manos Pitsidianakis 2023-04-15 13:35:12 +03:00
parent 71a18f31e4
commit 8fa4c910c1
Signed by: Manos Pitsidianakis
GPG Key ID: 7729C7707F7E09D0
6 changed files with 317 additions and 27 deletions

3
Cargo.lock generated
View File

@ -1519,9 +1519,11 @@ name = "mailpot-web"
version = "0.0.0+2023-04-07"
dependencies = [
"axum",
"axum-extra",
"axum-login",
"axum-sessions",
"chrono",
"dyn-clone",
"eyre",
"http",
"lazy_static",
@ -1533,6 +1535,7 @@ dependencies = [
"serde_json",
"tempfile",
"tokio",
"tower-http 0.3.5",
]
[[package]]

View File

@ -16,9 +16,11 @@ path = "src/main.rs"
[dependencies]
axum = { version = "^0.6" }
axum-extra = { version = "^0.7" }
axum-login = { version = "^0.5" }
axum-sessions = { version = "^0.5" }
chrono = { version = "^0.4" }
dyn-clone = { version = "^1" }
eyre = { version = "0.6" }
http = "0.2"
lazy_static = "^1.4"
@ -30,3 +32,4 @@ serde = { version = "^1", features = ["derive", ] }
serde_json = "^1"
tempfile = { version = "^3.5" }
tokio = { version = "1", features = ["full"] }
tower-http = { version = "^0.3" }

View File

@ -77,6 +77,7 @@ pub struct AuthFormPayload {
pub async fn ssh_signin(
mut session: WritableSession,
Query(next): Query<Next>,
auth: AuthContext,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
@ -87,7 +88,17 @@ pub async fn ssh_signin(
}) {
return err.into_response();
}
return Redirect::to(&format!("{}/settings/", state.root_url_prefix)).into_response();
return next
.or_else(|| format!("{}/settings/", state.root_url_prefix))
.into_response();
}
if next.next.is_some() {
if let Err(err) = session.add_message(Message {
message: "You need to be logged in to access this page.".into(),
level: Level::Info,
}) {
return err.into_response();
};
}
let now: i64 = chrono::offset::Utc::now().timestamp();
@ -154,6 +165,7 @@ pub async fn ssh_signin(
pub async fn ssh_signin_post(
mut session: WritableSession,
Query(next): Query<Next>,
mut auth: AuthContext,
Form(payload): Form<AuthFormPayload>,
state: Arc<AppState>,
@ -163,10 +175,7 @@ pub async fn ssh_signin_post(
message: "You are already logged in.".into(),
level: Level::Info,
})?;
return Ok(Redirect::to(&format!(
"{}/settings/",
state.root_url_prefix
)));
return Ok(next.or_else(|| format!("{}/settings/", state.root_url_prefix)));
}
let now: i64 = chrono::offset::Utc::now().timestamp();
@ -178,7 +187,15 @@ pub async fn ssh_signin_post(
message: "The token has expired. Please retry.".into(),
level: Level::Error,
})?;
return Ok(Redirect::to(&format!("{}/login/", state.root_url_prefix)));
return Ok(Redirect::to(&format!(
"{}/login/{}",
state.root_url_prefix,
if let Some(ref next) = next.next {
next.as_str()
} else {
""
}
)));
} else {
tok
}
@ -187,7 +204,15 @@ pub async fn ssh_signin_post(
message: "The token has expired. Please retry.".into(),
level: Level::Error,
})?;
return Ok(Redirect::to(&format!("{}/login/", state.root_url_prefix)));
return Ok(Redirect::to(&format!(
"{}/login/{}",
state.root_url_prefix,
if let Some(ref next) = next.next {
next.as_str()
} else {
""
}
)));
};
drop(session);
@ -229,10 +254,7 @@ pub async fn ssh_signin_post(
auth.login(&user)
.await
.map_err(|err| ResponseError::new(err.to_string(), StatusCode::BAD_REQUEST))?;
Ok(Redirect::to(&format!(
"{}/settings/",
state.root_url_prefix
)))
Ok(next.or_else(|| format!("{}/settings/", state.root_url_prefix)))
}
#[derive(Debug, Clone, Default)]
@ -360,6 +382,215 @@ pub async fn logout_handler(mut auth: AuthContext, State(state): State<Arc<AppSt
Redirect::to(&format!("{}/settings/", state.root_url_prefix))
}
pub mod auth_request {
use super::*;
use std::marker::PhantomData;
use std::ops::RangeBounds;
use axum::body::HttpBody;
use dyn_clone::DynClone;
use tower_http::auth::AuthorizeRequest;
trait RoleBounds<Role>: DynClone + Send + Sync {
fn contains(&self, role: Option<Role>) -> bool;
}
impl<T, Role> RoleBounds<Role> for T
where
Role: PartialOrd + PartialEq,
T: RangeBounds<Role> + Clone + Send + Sync,
{
fn contains(&self, role: Option<Role>) -> bool {
if let Some(role) = role {
RangeBounds::contains(self, &role)
} else {
role.is_none()
}
}
}
/// Type that performs login authorization.
///
/// See [`RequireAuthorizationLayer::login`] for more details.
pub struct Login<UserId, User, ResBody, Role = ()> {
login_url: Option<Arc<Cow<'static, str>>>,
redirect_field_name: Option<Arc<Cow<'static, str>>>,
role_bounds: Box<dyn RoleBounds<Role>>,
_user_id_type: PhantomData<UserId>,
_user_type: PhantomData<User>,
_body_type: PhantomData<fn() -> ResBody>,
}
impl<UserId, User, ResBody, Role> Clone for Login<UserId, User, ResBody, Role> {
fn clone(&self) -> Self {
Self {
login_url: self.login_url.clone(),
redirect_field_name: self.redirect_field_name.clone(),
role_bounds: dyn_clone::clone_box(&*self.role_bounds),
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
}
}
}
impl<UserId, User, ReqBody, ResBody, Role> AuthorizeRequest<ReqBody>
for Login<UserId, User, ResBody, Role>
where
Role: PartialOrd + PartialEq + Clone + Send + Sync + 'static,
User: AuthUser<UserId, Role>,
ResBody: HttpBody + Default,
{
type ResponseBody = ResBody;
fn authorize(
&mut self,
request: &mut Request<ReqBody>,
) -> Result<(), Response<Self::ResponseBody>> {
let user = request
.extensions()
.get::<Option<User>>()
.expect("Auth extension missing. Is the auth layer installed?");
match user {
Some(user) if self.role_bounds.contains(user.get_role()) => {
let user = user.clone();
request.extensions_mut().insert(user);
Ok(())
}
_ => {
let unauthorized_response = if let Some(ref login_url) = self.login_url {
let url: Cow<'static, str> =
if let Some(ref next) = self.redirect_field_name {
format!(
"{login_url}?{next}={}",
percent_encoding::utf8_percent_encode(
request.uri().path(),
percent_encoding::CONTROLS
)
)
.into()
} else {
login_url.as_ref().clone()
};
Response::builder()
.status(http::StatusCode::TEMPORARY_REDIRECT)
.header(http::header::LOCATION, url.as_ref())
.body(Default::default())
.unwrap()
} else {
Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(Default::default())
.unwrap()
};
Err(unauthorized_response)
}
}
}
}
/// A wrapper around [`tower_http::auth::RequireAuthorizationLayer`] which
/// provides login authorization.
pub struct RequireAuthorizationLayer<UserId, User, Role = ()>(UserId, User, Role);
impl<UserId, User, Role> RequireAuthorizationLayer<UserId, User, Role>
where
Role: PartialOrd + PartialEq + Clone + Send + Sync + 'static,
User: AuthUser<UserId, Role>,
{
/// Authorizes requests by requiring a logged in user, otherwise it rejects
/// with [`http::StatusCode::UNAUTHORIZED`].
pub fn login<ResBody>(
) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
where
ResBody: HttpBody + Default,
{
tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
login_url: None,
redirect_field_name: None,
role_bounds: Box::new(..),
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
})
}
/// Authorizes requests by requiring a logged in user to have a specific
/// range of roles, otherwise it rejects with
/// [`http::StatusCode::UNAUTHORIZED`].
pub fn login_with_role<ResBody>(
role_bounds: impl RangeBounds<Role> + Clone + Send + Sync + 'static,
) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
where
ResBody: HttpBody + Default,
{
tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
login_url: None,
redirect_field_name: None,
role_bounds: Box::new(role_bounds),
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
})
}
/// Authorizes requests by requiring a logged in user, otherwise it redirects to the
/// provided login URL.
///
/// If `redirect_field_name` is set to a value, the login page will receive the path it was
/// redirected from in the URI query part. For example, attempting to visit a protected path
/// `/protected` would redirect you to `/login?next=/protected` allowing you to know how to
/// return the visitor to their requested page.
pub fn login_or_redirect<ResBody>(
login_url: Arc<Cow<'static, str>>,
redirect_field_name: Option<Arc<Cow<'static, str>>>,
) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
where
ResBody: HttpBody + Default,
{
tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
login_url: Some(login_url),
redirect_field_name,
role_bounds: Box::new(..),
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
})
}
/// Authorizes requests by requiring a logged in user to have a specific
/// range of roles, otherwise it redirects to the
/// provided login URL.
///
/// If `redirect_field_name` is set to a value, the login page will receive the path it was
/// redirected from in the URI query part. For example, attempting to visit a protected path
/// `/protected` would redirect you to `/login?next=/protected` allowing you to know how to
/// return the visitor to their requested page.
pub fn login_with_role_or_redirect<ResBody>(
role_bounds: impl RangeBounds<Role> + Clone + Send + Sync + 'static,
login_url: Arc<Cow<'static, str>>,
redirect_field_name: Option<Arc<Cow<'static, str>>>,
) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
where
ResBody: HttpBody + Default,
{
tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
login_url: Some(login_url),
redirect_field_name,
role_bounds: Box::new(role_bounds),
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -18,14 +18,14 @@
*/
pub use axum::{
extract::{Path, State},
extract::{Path, Query, State},
handler::Handler,
response::{Html, IntoResponse, Redirect},
routing::{get, post},
Extension, Form, Router,
};
pub use http::{Request, Response, StatusCode};
pub use axum_extra::routing::RouterExt;
pub use axum_login::{
memory_store::MemoryStore as AuthMemoryStore, secrecy::SecretVec, AuthLayer, AuthUser,
@ -40,7 +40,9 @@ pub use axum_sessions::{
pub type AuthContext =
axum_login::extractors::AuthContext<i64, auth::User, Arc<AppState>, auth::Role>;
pub type RequireAuth = RequireAuthorizationLayer<i64, auth::User, auth::Role>;
pub type RequireAuth = auth::auth_request::RequireAuthorizationLayer<i64, auth::User, auth::Role>;
pub use http::{Request, Response, StatusCode};
use chrono::Datelike;
use minijinja::value::{Object, Value};

View File

@ -49,46 +49,63 @@ async fn main() {
let auth_layer = AuthLayer::new(shared_state.clone(), &secret);
let login_url = Arc::new(format!("{}/login/", shared_state.root_url_prefix).into());
let app = Router::new()
.route("/", get(root))
.route("/lists/:pk/", get(list))
.route("/lists/:pk/:msgid/", get(list_post))
.route("/lists/:pk/edit/", get(list_edit))
.route("/help/", get(help))
.route(
.route_with_tsr("/lists/:pk/", get(list))
.route_with_tsr("/lists/:pk/:msgid/", get(list_post))
.route_with_tsr("/lists/:pk/edit/", get(list_edit))
.route_with_tsr("/help/", get(help))
.route_with_tsr(
"/login/",
get(auth::ssh_signin).post({
let shared_state = Arc::clone(&shared_state);
move |session, auth, body| auth::ssh_signin_post(session, auth, body, shared_state)
move |session, query, auth, body| {
auth::ssh_signin_post(session, query, auth, body, shared_state)
}
}),
)
.route("/logout/", get(logout_handler))
.route(
.route_with_tsr("/logout/", get(logout_handler))
.route_with_tsr(
"/settings/",
get({
let shared_state = Arc::clone(&shared_state);
move |session, user| settings(session, user, shared_state)
}
.layer(RequireAuth::login()))
.layer(RequireAuth::login_or_redirect(
Arc::clone(&login_url),
Some(Arc::new("next".into())),
)))
.post(
{
let shared_state = Arc::clone(&shared_state);
move |session, auth, body| settings_post(session, auth, body, shared_state)
}
.layer(RequireAuth::login()),
.layer(RequireAuth::login_or_redirect(
Arc::clone(&login_url),
Some(Arc::new("next".into())),
)),
),
)
.route(
.route_with_tsr(
"/settings/list/:pk/",
get(user_list_subscription)
.layer(RequireAuth::login_with_role(Role::User..))
.layer(RequireAuth::login_with_role_or_redirect(
Role::User..,
Arc::clone(&login_url),
Some(Arc::new("next".into())),
))
.post({
let shared_state = Arc::clone(&shared_state);
move |session, path, user, body| {
user_list_subscription_post(session, path, user, body, shared_state)
}
})
.layer(RequireAuth::login_with_role(Role::User..)),
.layer(RequireAuth::login_with_role_or_redirect(
Role::User..,
Arc::clone(&login_url),
Some(Arc::new("next".into())),
)),
)
.layer(auth_layer)
.layer(session_layer)

View File

@ -276,3 +276,37 @@ impl<'de> serde::Deserialize<'de> for IntPOST {
deserializer.deserialize_any(IntVisitor)
}
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Next {
#[serde(default, deserialize_with = "empty_string_as_none")]
pub next: Option<String>,
}
impl Next {
#[inline]
pub fn or_else(self, cl: impl FnOnce() -> String) -> Redirect {
if let Some(next) = self.next {
Redirect::to(&next)
} else {
Redirect::to(&cl())
}
}
}
/// Serde deserialization decorator to map empty Strings to None,
fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error>
where
D: serde::Deserializer<'de>,
T: std::str::FromStr,
T::Err: std::fmt::Display,
{
use serde::Deserialize;
let opt = Option::<String>::deserialize(de)?;
match opt.as_deref() {
None | Some("") => Ok(None),
Some(s) => std::str::FromStr::from_str(s)
.map_err(serde::de::Error::custom)
.map(Some),
}
}