web: add redirect to previous page after login with ?next= parameter
parent
71a18f31e4
commit
8fa4c910c1
|
@ -1519,9 +1519,11 @@ name = "mailpot-web"
|
|||
version = "0.0.0+2023-04-07"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"axum-extra",
|
||||
"axum-login",
|
||||
"axum-sessions",
|
||||
"chrono",
|
||||
"dyn-clone",
|
||||
"eyre",
|
||||
"http",
|
||||
"lazy_static",
|
||||
|
@ -1533,6 +1535,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tower-http 0.3.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
@ -16,9 +16,11 @@ path = "src/main.rs"
|
|||
|
||||
[dependencies]
|
||||
axum = { version = "^0.6" }
|
||||
axum-extra = { version = "^0.7" }
|
||||
axum-login = { version = "^0.5" }
|
||||
axum-sessions = { version = "^0.5" }
|
||||
chrono = { version = "^0.4" }
|
||||
dyn-clone = { version = "^1" }
|
||||
eyre = { version = "0.6" }
|
||||
http = "0.2"
|
||||
lazy_static = "^1.4"
|
||||
|
@ -30,3 +32,4 @@ serde = { version = "^1", features = ["derive", ] }
|
|||
serde_json = "^1"
|
||||
tempfile = { version = "^3.5" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tower-http = { version = "^0.3" }
|
||||
|
|
253
web/src/auth.rs
253
web/src/auth.rs
|
@ -77,6 +77,7 @@ pub struct AuthFormPayload {
|
|||
|
||||
pub async fn ssh_signin(
|
||||
mut session: WritableSession,
|
||||
Query(next): Query<Next>,
|
||||
auth: AuthContext,
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
|
@ -87,7 +88,17 @@ pub async fn ssh_signin(
|
|||
}) {
|
||||
return err.into_response();
|
||||
}
|
||||
return Redirect::to(&format!("{}/settings/", state.root_url_prefix)).into_response();
|
||||
return next
|
||||
.or_else(|| format!("{}/settings/", state.root_url_prefix))
|
||||
.into_response();
|
||||
}
|
||||
if next.next.is_some() {
|
||||
if let Err(err) = session.add_message(Message {
|
||||
message: "You need to be logged in to access this page.".into(),
|
||||
level: Level::Info,
|
||||
}) {
|
||||
return err.into_response();
|
||||
};
|
||||
}
|
||||
|
||||
let now: i64 = chrono::offset::Utc::now().timestamp();
|
||||
|
@ -154,6 +165,7 @@ pub async fn ssh_signin(
|
|||
|
||||
pub async fn ssh_signin_post(
|
||||
mut session: WritableSession,
|
||||
Query(next): Query<Next>,
|
||||
mut auth: AuthContext,
|
||||
Form(payload): Form<AuthFormPayload>,
|
||||
state: Arc<AppState>,
|
||||
|
@ -163,10 +175,7 @@ pub async fn ssh_signin_post(
|
|||
message: "You are already logged in.".into(),
|
||||
level: Level::Info,
|
||||
})?;
|
||||
return Ok(Redirect::to(&format!(
|
||||
"{}/settings/",
|
||||
state.root_url_prefix
|
||||
)));
|
||||
return Ok(next.or_else(|| format!("{}/settings/", state.root_url_prefix)));
|
||||
}
|
||||
|
||||
let now: i64 = chrono::offset::Utc::now().timestamp();
|
||||
|
@ -178,7 +187,15 @@ pub async fn ssh_signin_post(
|
|||
message: "The token has expired. Please retry.".into(),
|
||||
level: Level::Error,
|
||||
})?;
|
||||
return Ok(Redirect::to(&format!("{}/login/", state.root_url_prefix)));
|
||||
return Ok(Redirect::to(&format!(
|
||||
"{}/login/{}",
|
||||
state.root_url_prefix,
|
||||
if let Some(ref next) = next.next {
|
||||
next.as_str()
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)));
|
||||
} else {
|
||||
tok
|
||||
}
|
||||
|
@ -187,7 +204,15 @@ pub async fn ssh_signin_post(
|
|||
message: "The token has expired. Please retry.".into(),
|
||||
level: Level::Error,
|
||||
})?;
|
||||
return Ok(Redirect::to(&format!("{}/login/", state.root_url_prefix)));
|
||||
return Ok(Redirect::to(&format!(
|
||||
"{}/login/{}",
|
||||
state.root_url_prefix,
|
||||
if let Some(ref next) = next.next {
|
||||
next.as_str()
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)));
|
||||
};
|
||||
|
||||
drop(session);
|
||||
|
@ -229,10 +254,7 @@ pub async fn ssh_signin_post(
|
|||
auth.login(&user)
|
||||
.await
|
||||
.map_err(|err| ResponseError::new(err.to_string(), StatusCode::BAD_REQUEST))?;
|
||||
Ok(Redirect::to(&format!(
|
||||
"{}/settings/",
|
||||
state.root_url_prefix
|
||||
)))
|
||||
Ok(next.or_else(|| format!("{}/settings/", state.root_url_prefix)))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
|
@ -360,6 +382,215 @@ pub async fn logout_handler(mut auth: AuthContext, State(state): State<Arc<AppSt
|
|||
Redirect::to(&format!("{}/settings/", state.root_url_prefix))
|
||||
}
|
||||
|
||||
pub mod auth_request {
|
||||
use super::*;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::RangeBounds;
|
||||
|
||||
use axum::body::HttpBody;
|
||||
use dyn_clone::DynClone;
|
||||
use tower_http::auth::AuthorizeRequest;
|
||||
|
||||
trait RoleBounds<Role>: DynClone + Send + Sync {
|
||||
fn contains(&self, role: Option<Role>) -> bool;
|
||||
}
|
||||
|
||||
impl<T, Role> RoleBounds<Role> for T
|
||||
where
|
||||
Role: PartialOrd + PartialEq,
|
||||
T: RangeBounds<Role> + Clone + Send + Sync,
|
||||
{
|
||||
fn contains(&self, role: Option<Role>) -> bool {
|
||||
if let Some(role) = role {
|
||||
RangeBounds::contains(self, &role)
|
||||
} else {
|
||||
role.is_none()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Type that performs login authorization.
|
||||
///
|
||||
/// See [`RequireAuthorizationLayer::login`] for more details.
|
||||
pub struct Login<UserId, User, ResBody, Role = ()> {
|
||||
login_url: Option<Arc<Cow<'static, str>>>,
|
||||
redirect_field_name: Option<Arc<Cow<'static, str>>>,
|
||||
role_bounds: Box<dyn RoleBounds<Role>>,
|
||||
_user_id_type: PhantomData<UserId>,
|
||||
_user_type: PhantomData<User>,
|
||||
_body_type: PhantomData<fn() -> ResBody>,
|
||||
}
|
||||
|
||||
impl<UserId, User, ResBody, Role> Clone for Login<UserId, User, ResBody, Role> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
login_url: self.login_url.clone(),
|
||||
redirect_field_name: self.redirect_field_name.clone(),
|
||||
role_bounds: dyn_clone::clone_box(&*self.role_bounds),
|
||||
_user_id_type: PhantomData,
|
||||
_user_type: PhantomData,
|
||||
_body_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<UserId, User, ReqBody, ResBody, Role> AuthorizeRequest<ReqBody>
|
||||
for Login<UserId, User, ResBody, Role>
|
||||
where
|
||||
Role: PartialOrd + PartialEq + Clone + Send + Sync + 'static,
|
||||
User: AuthUser<UserId, Role>,
|
||||
ResBody: HttpBody + Default,
|
||||
{
|
||||
type ResponseBody = ResBody;
|
||||
|
||||
fn authorize(
|
||||
&mut self,
|
||||
request: &mut Request<ReqBody>,
|
||||
) -> Result<(), Response<Self::ResponseBody>> {
|
||||
let user = request
|
||||
.extensions()
|
||||
.get::<Option<User>>()
|
||||
.expect("Auth extension missing. Is the auth layer installed?");
|
||||
|
||||
match user {
|
||||
Some(user) if self.role_bounds.contains(user.get_role()) => {
|
||||
let user = user.clone();
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
_ => {
|
||||
let unauthorized_response = if let Some(ref login_url) = self.login_url {
|
||||
let url: Cow<'static, str> =
|
||||
if let Some(ref next) = self.redirect_field_name {
|
||||
format!(
|
||||
"{login_url}?{next}={}",
|
||||
percent_encoding::utf8_percent_encode(
|
||||
request.uri().path(),
|
||||
percent_encoding::CONTROLS
|
||||
)
|
||||
)
|
||||
.into()
|
||||
} else {
|
||||
login_url.as_ref().clone()
|
||||
};
|
||||
Response::builder()
|
||||
.status(http::StatusCode::TEMPORARY_REDIRECT)
|
||||
.header(http::header::LOCATION, url.as_ref())
|
||||
.body(Default::default())
|
||||
.unwrap()
|
||||
} else {
|
||||
Response::builder()
|
||||
.status(http::StatusCode::UNAUTHORIZED)
|
||||
.body(Default::default())
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
Err(unauthorized_response)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper around [`tower_http::auth::RequireAuthorizationLayer`] which
|
||||
/// provides login authorization.
|
||||
pub struct RequireAuthorizationLayer<UserId, User, Role = ()>(UserId, User, Role);
|
||||
|
||||
impl<UserId, User, Role> RequireAuthorizationLayer<UserId, User, Role>
|
||||
where
|
||||
Role: PartialOrd + PartialEq + Clone + Send + Sync + 'static,
|
||||
User: AuthUser<UserId, Role>,
|
||||
{
|
||||
/// Authorizes requests by requiring a logged in user, otherwise it rejects
|
||||
/// with [`http::StatusCode::UNAUTHORIZED`].
|
||||
pub fn login<ResBody>(
|
||||
) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
|
||||
where
|
||||
ResBody: HttpBody + Default,
|
||||
{
|
||||
tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
|
||||
login_url: None,
|
||||
redirect_field_name: None,
|
||||
role_bounds: Box::new(..),
|
||||
_user_id_type: PhantomData,
|
||||
_user_type: PhantomData,
|
||||
_body_type: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Authorizes requests by requiring a logged in user to have a specific
|
||||
/// range of roles, otherwise it rejects with
|
||||
/// [`http::StatusCode::UNAUTHORIZED`].
|
||||
pub fn login_with_role<ResBody>(
|
||||
role_bounds: impl RangeBounds<Role> + Clone + Send + Sync + 'static,
|
||||
) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
|
||||
where
|
||||
ResBody: HttpBody + Default,
|
||||
{
|
||||
tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
|
||||
login_url: None,
|
||||
redirect_field_name: None,
|
||||
role_bounds: Box::new(role_bounds),
|
||||
_user_id_type: PhantomData,
|
||||
_user_type: PhantomData,
|
||||
_body_type: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Authorizes requests by requiring a logged in user, otherwise it redirects to the
|
||||
/// provided login URL.
|
||||
///
|
||||
/// If `redirect_field_name` is set to a value, the login page will receive the path it was
|
||||
/// redirected from in the URI query part. For example, attempting to visit a protected path
|
||||
/// `/protected` would redirect you to `/login?next=/protected` allowing you to know how to
|
||||
/// return the visitor to their requested page.
|
||||
pub fn login_or_redirect<ResBody>(
|
||||
login_url: Arc<Cow<'static, str>>,
|
||||
redirect_field_name: Option<Arc<Cow<'static, str>>>,
|
||||
) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
|
||||
where
|
||||
ResBody: HttpBody + Default,
|
||||
{
|
||||
tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
|
||||
login_url: Some(login_url),
|
||||
redirect_field_name,
|
||||
role_bounds: Box::new(..),
|
||||
_user_id_type: PhantomData,
|
||||
_user_type: PhantomData,
|
||||
_body_type: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Authorizes requests by requiring a logged in user to have a specific
|
||||
/// range of roles, otherwise it redirects to the
|
||||
/// provided login URL.
|
||||
///
|
||||
/// If `redirect_field_name` is set to a value, the login page will receive the path it was
|
||||
/// redirected from in the URI query part. For example, attempting to visit a protected path
|
||||
/// `/protected` would redirect you to `/login?next=/protected` allowing you to know how to
|
||||
/// return the visitor to their requested page.
|
||||
pub fn login_with_role_or_redirect<ResBody>(
|
||||
role_bounds: impl RangeBounds<Role> + Clone + Send + Sync + 'static,
|
||||
login_url: Arc<Cow<'static, str>>,
|
||||
redirect_field_name: Option<Arc<Cow<'static, str>>>,
|
||||
) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
|
||||
where
|
||||
ResBody: HttpBody + Default,
|
||||
{
|
||||
tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
|
||||
login_url: Some(login_url),
|
||||
redirect_field_name,
|
||||
role_bounds: Box::new(role_bounds),
|
||||
_user_id_type: PhantomData,
|
||||
_user_type: PhantomData,
|
||||
_body_type: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -18,14 +18,14 @@
|
|||
*/
|
||||
|
||||
pub use axum::{
|
||||
extract::{Path, State},
|
||||
extract::{Path, Query, State},
|
||||
handler::Handler,
|
||||
response::{Html, IntoResponse, Redirect},
|
||||
routing::{get, post},
|
||||
Extension, Form, Router,
|
||||
};
|
||||
|
||||
pub use http::{Request, Response, StatusCode};
|
||||
pub use axum_extra::routing::RouterExt;
|
||||
|
||||
pub use axum_login::{
|
||||
memory_store::MemoryStore as AuthMemoryStore, secrecy::SecretVec, AuthLayer, AuthUser,
|
||||
|
@ -40,7 +40,9 @@ pub use axum_sessions::{
|
|||
pub type AuthContext =
|
||||
axum_login::extractors::AuthContext<i64, auth::User, Arc<AppState>, auth::Role>;
|
||||
|
||||
pub type RequireAuth = RequireAuthorizationLayer<i64, auth::User, auth::Role>;
|
||||
pub type RequireAuth = auth::auth_request::RequireAuthorizationLayer<i64, auth::User, auth::Role>;
|
||||
|
||||
pub use http::{Request, Response, StatusCode};
|
||||
|
||||
use chrono::Datelike;
|
||||
use minijinja::value::{Object, Value};
|
||||
|
|
|
@ -49,46 +49,63 @@ async fn main() {
|
|||
|
||||
let auth_layer = AuthLayer::new(shared_state.clone(), &secret);
|
||||
|
||||
let login_url = Arc::new(format!("{}/login/", shared_state.root_url_prefix).into());
|
||||
let app = Router::new()
|
||||
.route("/", get(root))
|
||||
.route("/lists/:pk/", get(list))
|
||||
.route("/lists/:pk/:msgid/", get(list_post))
|
||||
.route("/lists/:pk/edit/", get(list_edit))
|
||||
.route("/help/", get(help))
|
||||
.route(
|
||||
.route_with_tsr("/lists/:pk/", get(list))
|
||||
.route_with_tsr("/lists/:pk/:msgid/", get(list_post))
|
||||
.route_with_tsr("/lists/:pk/edit/", get(list_edit))
|
||||
.route_with_tsr("/help/", get(help))
|
||||
.route_with_tsr(
|
||||
"/login/",
|
||||
get(auth::ssh_signin).post({
|
||||
let shared_state = Arc::clone(&shared_state);
|
||||
move |session, auth, body| auth::ssh_signin_post(session, auth, body, shared_state)
|
||||
move |session, query, auth, body| {
|
||||
auth::ssh_signin_post(session, query, auth, body, shared_state)
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route("/logout/", get(logout_handler))
|
||||
.route(
|
||||
.route_with_tsr("/logout/", get(logout_handler))
|
||||
.route_with_tsr(
|
||||
"/settings/",
|
||||
get({
|
||||
let shared_state = Arc::clone(&shared_state);
|
||||
move |session, user| settings(session, user, shared_state)
|
||||
}
|
||||
.layer(RequireAuth::login()))
|
||||
.layer(RequireAuth::login_or_redirect(
|
||||
Arc::clone(&login_url),
|
||||
Some(Arc::new("next".into())),
|
||||
)))
|
||||
.post(
|
||||
{
|
||||
let shared_state = Arc::clone(&shared_state);
|
||||
move |session, auth, body| settings_post(session, auth, body, shared_state)
|
||||
}
|
||||
.layer(RequireAuth::login()),
|
||||
.layer(RequireAuth::login_or_redirect(
|
||||
Arc::clone(&login_url),
|
||||
Some(Arc::new("next".into())),
|
||||
)),
|
||||
),
|
||||
)
|
||||
.route(
|
||||
.route_with_tsr(
|
||||
"/settings/list/:pk/",
|
||||
get(user_list_subscription)
|
||||
.layer(RequireAuth::login_with_role(Role::User..))
|
||||
.layer(RequireAuth::login_with_role_or_redirect(
|
||||
Role::User..,
|
||||
Arc::clone(&login_url),
|
||||
Some(Arc::new("next".into())),
|
||||
))
|
||||
.post({
|
||||
let shared_state = Arc::clone(&shared_state);
|
||||
move |session, path, user, body| {
|
||||
user_list_subscription_post(session, path, user, body, shared_state)
|
||||
}
|
||||
})
|
||||
.layer(RequireAuth::login_with_role(Role::User..)),
|
||||
.layer(RequireAuth::login_with_role_or_redirect(
|
||||
Role::User..,
|
||||
Arc::clone(&login_url),
|
||||
Some(Arc::new("next".into())),
|
||||
)),
|
||||
)
|
||||
.layer(auth_layer)
|
||||
.layer(session_layer)
|
||||
|
|
|
@ -276,3 +276,37 @@ impl<'de> serde::Deserialize<'de> for IntPOST {
|
|||
deserializer.deserialize_any(IntVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Next {
|
||||
#[serde(default, deserialize_with = "empty_string_as_none")]
|
||||
pub next: Option<String>,
|
||||
}
|
||||
|
||||
impl Next {
|
||||
#[inline]
|
||||
pub fn or_else(self, cl: impl FnOnce() -> String) -> Redirect {
|
||||
if let Some(next) = self.next {
|
||||
Redirect::to(&next)
|
||||
} else {
|
||||
Redirect::to(&cl())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Serde deserialization decorator to map empty Strings to None,
|
||||
fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
T: std::str::FromStr,
|
||||
T::Err: std::fmt::Display,
|
||||
{
|
||||
use serde::Deserialize;
|
||||
let opt = Option::<String>::deserialize(de)?;
|
||||
match opt.as_deref() {
|
||||
None | Some("") => Ok(None),
|
||||
Some(s) => std::str::FromStr::from_str(s)
|
||||
.map_err(serde::de::Error::custom)
|
||||
.map(Some),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue