Skip to content

Commit

Permalink
Switch from using Extension to State
Browse files Browse the repository at this point in the history
  • Loading branch information
traxys committed Jan 21, 2025
1 parent 38447d2 commit 060ff8c
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 158 deletions.
69 changes: 29 additions & 40 deletions api/src/account.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,44 @@
use std::{sync::Arc, time::Duration};

use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use axum::{
extract::{self, FromRequestParts},
http::request::Parts,
routing::{get, post},
Extension, Json, Router,
Json, Router,
};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use jwt_simple::prelude::{Claims, MACLike, NoCustomClaims};
use kabalist_types::{
GetAccountNameResponse, LoginRequest, LoginResponse, RecoverPasswordRequest,
RecoverPasswordResponse, RecoveryInfoResponse, RegisterRequest, RegisterResponse,
};
use sqlx::PgPool;
use tokio_stream::StreamExt;
use uuid::Uuid;

use crate::{config::Config, ok_response::*, ErrResponse, Error, OkResponse, Rsp};
use crate::{ok_response::*, ErrResponse, Error, KabalistState, OkResponse, Rsp, State};

#[derive(Debug)]
pub(crate) struct User {
pub id: Uuid,
}

impl<S> FromRequestParts<S> for User
where
S: Send + Sync,
{
impl FromRequestParts<Arc<KabalistState>> for User {
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Extension(config) = Extension::<Arc<Config>>::from_request_parts(parts, state)
.await
.map_err(|e| {
tracing::error!("Could not fetch config extension: {:?}", e);
Error::Internal
})?;

async fn from_request_parts(
parts: &mut Parts,
state: &Arc<KabalistState>,
) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|_| Error::MissingAuthorization)?;

let claims = config
let claims = state
.config
.jwt_secret
.0
.verify_token::<NoCustomClaims>(bearer.token(), None)?;
Expand All @@ -57,7 +50,7 @@ where
}
}

pub(crate) fn router() -> Router {
pub(crate) fn router() -> Router<Arc<KabalistState>> {
Router::new()
.route("/login", post(login))
.route("/register/{id}", post(register))
Expand All @@ -76,29 +69,25 @@ pub(crate) fn router() -> Router {
),
request_body = LoginRequest,
)]
#[tracing::instrument(skip(config, db))]
async fn login(
Extension(config): Extension<Arc<Config>>,
Extension(db): Extension<PgPool>,
Json(request): Json<LoginRequest>,
) -> Rsp<LoginResponse> {
#[tracing::instrument(skip(state))]
async fn login(state: State, Json(request): Json<LoginRequest>) -> Rsp<LoginResponse> {
let mut rsp = sqlx::query!(
"SELECT id FROM accounts WHERE name = $1::text::citext AND password = crypt($2, password)",
request.username,
request.password.0,
)
.fetch(&db);
.fetch(&state.0.pool);

let id = match rsp.next().await {
None => return Err(Error::UnknownAccount),
Some(Err(e)) => return Err(e.into()),
Some(Ok(id)) => id.id,
};

let mut claims = Claims::create(Duration::from_millis(config.exp as _).into());
let mut claims = Claims::create(Duration::from_millis(state.0.config.exp as _).into());
claims.subject = Some(id.to_string());

let token = config.jwt_secret.0.authenticate(claims)?;
let token = state.0.config.jwt_secret.0.authenticate(claims)?;

OkResponse::ok(LoginResponse { token })
}
Expand All @@ -116,13 +105,13 @@ async fn login(
),
request_body = RegisterRequest,
)]
#[tracing::instrument(skip(db))]
#[tracing::instrument(skip(state))]
async fn register(
Extension(db): Extension<PgPool>,
state: State,
extract::Path(id): extract::Path<Uuid>,
Json(req): Json<RegisterRequest>,
) -> Rsp<RegisterResponse> {
let mut tx = db.begin().await?;
let mut tx = state.0.pool.begin().await?;

let mut is_registered =
sqlx::query!("SELECT id FROM registrations WHERE id = $1", id).fetch(&mut *tx);
Expand Down Expand Up @@ -163,9 +152,9 @@ async fn register(
("id" = Uuid, Path, description = "Recovery ID"),
),
)]
#[tracing::instrument(skip(db))]
#[tracing::instrument(skip(state))]
async fn recovery_info(
Extension(db): Extension<PgPool>,
state: State,
extract::Path(id): extract::Path<Uuid>,
) -> Rsp<RecoveryInfoResponse> {
let username = sqlx::query!(
Expand All @@ -175,7 +164,7 @@ async fn recovery_info(
AND password_reset.account = accounts.id"#,
id
)
.fetch_one(&db)
.fetch_one(&state.0.pool)
.await?
.name;

Expand All @@ -199,11 +188,11 @@ async fn recovery_info(
request_body = RecoverPasswordRequest
)]
async fn recover_password(
Extension(db): Extension<PgPool>,
state: State,
extract::Path(id): extract::Path<Uuid>,
Json(request): Json<RecoverPasswordRequest>,
) -> Rsp<RecoverPasswordResponse> {
let mut tx = db.begin().await?;
let mut tx = state.0.pool.begin().await?;

let account = sqlx::query!(
"SELECT password_reset.account FROM password_reset WHERE id = $1",
Expand Down Expand Up @@ -245,12 +234,12 @@ async fn recover_password(
)
)]
async fn get_account_name(
Extension(db): Extension<PgPool>,
state: State,
_user: User,
extract::Path(id): extract::Path<Uuid>,
) -> Rsp<GetAccountNameResponse> {
let name = sqlx::query!("SELECT name::text FROM accounts WHERE id = $1", id)
.fetch_one(&db)
.fetch_one(&state.0.pool)
.await?
.name;

Expand Down
Loading

0 comments on commit 060ff8c

Please sign in to comment.