use crate::AuthError;
use base64::prelude::*;
use futures_util::{
future,
future::{FutureExt, TryFutureExt}
};
use gotham::{
anyhow,
cookie::CookieJar,
handler::HandlerFuture,
hyper::header::{HeaderMap, HeaderName, AUTHORIZATION},
middleware::{cookie::CookieParser, Middleware, NewMiddleware},
prelude::*,
state::State
};
use jsonwebtoken::DecodingKey;
use serde::de::DeserializeOwned;
use std::{marker::PhantomData, panic::RefUnwindSafe, pin::Pin};
pub type AuthValidation = jsonwebtoken::Validation;
#[derive(Debug, StateData)]
pub enum AuthStatus<T: Send + 'static> {
Unknown,
Unauthenticated,
Invalid(jsonwebtoken::errors::Error),
Authenticated(T)
}
impl<T> Clone for AuthStatus<T>
where
T: Clone + Send + 'static
{
fn clone(&self) -> Self {
match self {
Self::Unknown => Self::Unknown,
Self::Unauthenticated => Self::Unauthenticated,
Self::Invalid(err) => Self::Invalid(err.clone()),
Self::Authenticated(data) => Self::Authenticated(data.clone())
}
}
}
impl<T: Send + 'static> AuthStatus<T> {
pub fn ok(self) -> Result<T, AuthError> {
match self {
Self::Unknown => Err(AuthError::new("The authentication could not be determined")),
Self::Unauthenticated => Err(AuthError::new("Missing token")),
Self::Invalid(err) => Err(AuthError::new(format!("Invalid token: {err}"))),
Self::Authenticated(data) => Ok(data)
}
}
}
#[derive(Clone, Debug, StateData)]
pub enum AuthSource {
Cookie(String),
Header(HeaderName),
AuthorizationHeader
}
pub trait AuthHandler<Data> {
fn jwt_secret<F: FnOnce() -> Option<Data>>(
&self,
state: &mut State,
decode_data: F
) -> Option<Vec<u8>>;
}
#[derive(Clone, Debug)]
pub struct StaticAuthHandler {
secret: Vec<u8>
}
impl StaticAuthHandler {
pub fn from_vec(secret: Vec<u8>) -> Self {
Self { secret }
}
pub fn from_array(secret: &[u8]) -> Self {
Self::from_vec(secret.to_vec())
}
}
impl<T> AuthHandler<T> for StaticAuthHandler {
fn jwt_secret<F: FnOnce() -> Option<T>>(
&self,
_state: &mut State,
_decode_data: F
) -> Option<Vec<u8>> {
Some(self.secret.clone())
}
}
#[derive(Debug)]
pub struct AuthMiddleware<Data, Handler> {
source: AuthSource,
validation: AuthValidation,
handler: Handler,
_data: PhantomData<Data>
}
impl<Data, Handler> Clone for AuthMiddleware<Data, Handler>
where
Handler: Clone
{
fn clone(&self) -> Self {
Self {
source: self.source.clone(),
validation: self.validation.clone(),
handler: self.handler.clone(),
_data: self._data
}
}
}
impl<Data, Handler> AuthMiddleware<Data, Handler>
where
Data: DeserializeOwned + Send,
Handler: AuthHandler<Data> + Default
{
pub fn from_source(source: AuthSource) -> Self {
Self {
source,
validation: Default::default(),
handler: Default::default(),
_data: Default::default()
}
}
}
impl<Data, Handler> AuthMiddleware<Data, Handler>
where
Data: DeserializeOwned + Send,
Handler: AuthHandler<Data>
{
pub fn new(source: AuthSource, validation: AuthValidation, handler: Handler) -> Self {
Self {
source,
validation,
handler,
_data: Default::default()
}
}
fn auth_status(&self, state: &mut State) -> AuthStatus<Data> {
let token = match &self.source {
AuthSource::Cookie(name) => CookieJar::try_borrow_from(&state)
.map(|jar| jar.get(&name).map(|cookie| cookie.value().to_owned()))
.unwrap_or_else(|| {
CookieParser::from_state(&state)
.get(&name)
.map(|cookie| cookie.value().to_owned())
}),
AuthSource::Header(name) => HeaderMap::try_borrow_from(&state)
.and_then(|map| map.get(name))
.and_then(|header| header.to_str().ok())
.map(|value| value.to_owned()),
AuthSource::AuthorizationHeader => HeaderMap::try_borrow_from(&state)
.and_then(|map| map.get(AUTHORIZATION))
.and_then(|header| header.to_str().ok())
.and_then(|value| value.split_whitespace().nth(1))
.map(|value| value.to_owned())
};
let token = match token {
Some(token) => token,
None => return AuthStatus::Unauthenticated
};
let secret = self.handler.jwt_secret(state, || {
let b64 = token.split('.').nth(1)?;
let raw = BASE64_URL_SAFE_NO_PAD.decode(b64).ok()?;
serde_json::from_slice(&raw).ok()?
});
let secret = match secret {
Some(secret) => secret,
None => return AuthStatus::Unknown
};
let data: Data = match jsonwebtoken::decode(
&token,
&DecodingKey::from_secret(&secret),
&self.validation
) {
Ok(data) => data.claims,
Err(e) => return AuthStatus::Invalid(e)
};
AuthStatus::Authenticated(data)
}
}
impl<Data, Handler> Middleware for AuthMiddleware<Data, Handler>
where
Data: DeserializeOwned + Send + 'static,
Handler: AuthHandler<Data>
{
fn call<Chain>(self, mut state: State, chain: Chain) -> Pin<Box<HandlerFuture>>
where
Chain: FnOnce(State) -> Pin<Box<HandlerFuture>>
{
state.put(self.source.clone());
let status = self.auth_status(&mut state);
state.put(status);
chain(state)
.and_then(|(state, res)| future::ok((state, res)))
.boxed()
}
}
impl<Data, Handler> NewMiddleware for AuthMiddleware<Data, Handler>
where
Self: Clone + Middleware + Sync + RefUnwindSafe
{
type Instance = Self;
fn new_middleware(&self) -> anyhow::Result<Self> {
let c: Self = self.clone();
Ok(c)
}
}
#[cfg(test)]
mod test {
use super::*;
use gotham::{cookie::Cookie, hyper::header::COOKIE};
use jsonwebtoken::errors::ErrorKind;
use std::fmt::Debug;
const JWT_SECRET: &'static [u8; 32] = b"Lyzsfnta0cdxyF0T9y6VGxp3jpgoMUuW";
const VALID_TOKEN: &'static str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJtc3JkMCIsInN1YiI6ImdvdGhhbS1yZXN0ZnVsIiwiaWF0IjoxNTc3ODM2ODAwLCJleHAiOjQxMDI0NDQ4MDB9.8h8Ax-nnykqEQ62t7CxmM3ja6NzUQ4L0MLOOzddjLKk";
const EXPIRED_TOKEN: &'static str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJtc3JkMCIsInN1YiI6ImdvdGhhbS1yZXN0ZnVsIiwiaWF0IjoxNTc3ODM2ODAwLCJleHAiOjE1Nzc4MzcxMDB9.eV1snaGLYrJ7qUoMk74OvBY3WUU9M0Je5HTU2xtX1v0";
const INVALID_TOKEN: &'static str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJtc3JkMCIsInN1YiI6ImdvdGhhbS1yZXN0ZnVsIiwiaWF0IjoxNTc3ODM2ODAwLCJleHAiOjQxMDI0NDQ4MDB9";
#[derive(Debug, Deserialize, PartialEq)]
struct TestData {
iss: String,
sub: String,
iat: u64,
exp: u64
}
impl Default for TestData {
fn default() -> Self {
Self {
iss: "msrd0".to_owned(),
sub: "gotham-restful".to_owned(),
iat: 1577836800,
exp: 4102444800
}
}
}
#[derive(Default)]
struct NoneAuthHandler;
impl<T> AuthHandler<T> for NoneAuthHandler {
fn jwt_secret<F: FnOnce() -> Option<T>>(
&self,
_state: &mut State,
_decode_data: F
) -> Option<Vec<u8>> {
None
}
}
#[test]
fn test_auth_middleware_none_secret() {
let middleware = <AuthMiddleware<TestData, NoneAuthHandler>>::from_source(
AuthSource::AuthorizationHeader
);
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
format!("Bearer {VALID_TOKEN}").parse().unwrap()
);
state.put(headers);
middleware.auth_status(&mut state);
});
}
#[derive(Default)]
struct TestAssertingHandler;
impl<T> AuthHandler<T> for TestAssertingHandler
where
T: Debug + Default + PartialEq
{
fn jwt_secret<F: FnOnce() -> Option<T>>(
&self,
_state: &mut State,
decode_data: F
) -> Option<Vec<u8>> {
assert_eq!(decode_data(), Some(T::default()));
Some(JWT_SECRET.to_vec())
}
}
#[test]
fn test_auth_middleware_decode_data() {
let middleware = <AuthMiddleware<TestData, TestAssertingHandler>>::from_source(
AuthSource::AuthorizationHeader
);
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
format!("Bearer {VALID_TOKEN}").parse().unwrap()
);
state.put(headers);
middleware.auth_status(&mut state);
});
}
fn new_middleware<T>(source: AuthSource) -> AuthMiddleware<T, StaticAuthHandler>
where
T: DeserializeOwned + Send
{
AuthMiddleware::new(
source,
Default::default(),
StaticAuthHandler::from_array(JWT_SECRET)
)
}
#[test]
fn test_auth_middleware_no_token() {
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
State::with_new(|mut state| {
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Unauthenticated => {},
_ => panic!("Expected AuthStatus::Unauthenticated, got {status:?}")
};
});
}
#[test]
fn test_auth_middleware_expired_token() {
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
format!("Bearer {EXPIRED_TOKEN}").parse().unwrap()
);
state.put(headers);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Invalid(err) if *err.kind() == ErrorKind::ExpiredSignature => {},
_ => panic!(
"Expected AuthStatus::Invalid(..) with ErrorKind::ExpiredSignature, got {status:?}"
)
};
});
}
#[test]
fn test_auth_middleware_invalid_token() {
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
format!("Bearer {INVALID_TOKEN}").parse().unwrap()
);
state.put(headers);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Invalid(err) if *err.kind() == ErrorKind::InvalidToken => {},
_ => panic!(
"Expected AuthStatus::Invalid(..) with ErrorKind::InvalidToken, got {status:?}"
)
};
});
}
#[test]
fn test_auth_middleware_auth_header_token() {
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
format!("Bearer {VALID_TOKEN}").parse().unwrap()
);
state.put(headers);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
_ => panic!("Expected AuthStatus::Authenticated, got {status:?}")
};
})
}
#[test]
fn test_auth_middleware_header_token() {
let header_name = "x-znoiprwmvfexju";
let middleware =
new_middleware::<TestData>(AuthSource::Header(HeaderName::from_static(header_name)));
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(header_name, VALID_TOKEN.parse().unwrap());
state.put(headers);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
_ => panic!("Expected AuthStatus::Authenticated, got {status:?}")
};
})
}
#[test]
fn test_auth_middleware_cookie_token() {
let cookie_name = "znoiprwmvfexju";
let middleware = new_middleware::<TestData>(AuthSource::Cookie(cookie_name.to_owned()));
State::with_new(|mut state| {
let mut jar = CookieJar::new();
jar.add_original(Cookie::new(cookie_name, VALID_TOKEN));
state.put(jar);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
_ => panic!("Expected AuthStatus::Authenticated, got {status:?}")
};
})
}
#[test]
fn test_auth_middleware_cookie_no_jar() {
let cookie_name = "znoiprwmvfexju";
let middleware = new_middleware::<TestData>(AuthSource::Cookie(cookie_name.to_owned()));
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(
COOKIE,
format!("{cookie_name}={VALID_TOKEN}").parse().unwrap()
);
state.put(headers);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
_ => panic!("Expected AuthStatus::Authenticated, got {status:?}")
};
})
}
}