Skip to main content

haste_server/auth_n/middleware/
basic_auth.rs

1use crate::{
2    auth_n::oidc::{
3        error::{OIDCError, OIDCErrorCode},
4        routes::token::{
5            ClientCredentialsMethod, TOKEN_EXPIRATION, client_credentials_to_token_response,
6        },
7        schemas::token_body::{OAuth2TokenBody, OAuth2TokenBodyGrantType},
8    },
9    extract::{
10        basic_credentials::BasicCredentialsHeader,
11        path_tenant::{ProjectIdentifier, TenantIdentifier},
12    },
13    services::AppState,
14};
15use axum::{
16    extract::{Request, State},
17    middleware::Next,
18    response::Response,
19};
20use axum_extra::extract::Cached;
21use haste_fhir_search::SearchEngine;
22use haste_fhir_terminology::FHIRTerminology;
23use haste_jwt::{ProjectId, TenantId};
24use haste_repository::Repository;
25
26use std::{
27    sync::{Arc, LazyLock},
28    time::Duration,
29};
30
31#[derive(Hash, PartialEq, Eq)]
32struct CacheTokenKey(String);
33impl CacheTokenKey {
34    fn new(tenant: &TenantId, project: &ProjectId, client_id: &str, client_secret: &str) -> Self {
35        Self(format!(
36            "{}:{}:{}:{}",
37            tenant, project, client_id, client_secret
38        ))
39    }
40}
41
42// Token creation is expensive so caching for performance.
43static CACHED_BASIC_TOKENS: LazyLock<
44    // Tenant, Project, ClientId, ClientSecret
45    moka::future::Cache<CacheTokenKey, String>,
46> = LazyLock::new(|| {
47    moka::future::Cache::builder()
48        // Set as slightly less than the token expiration to ensure tokens are refreshed before they expire.
49        .time_to_live(Duration::from_secs(TOKEN_EXPIRATION as u64 - 500))
50        .build()
51});
52
53pub async fn basic_auth_middleware<
54    Repo: Repository + Send + Sync + 'static,
55    Search: SearchEngine + Send + Sync + 'static,
56    Terminology: FHIRTerminology + Send + Sync + 'static,
57>(
58    Cached(TenantIdentifier { tenant }): Cached<TenantIdentifier>,
59    Cached(ProjectIdentifier { project }): Cached<ProjectIdentifier>,
60    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
61    // run the `HeaderMap` extractor
62    BasicCredentialsHeader(credentials): BasicCredentialsHeader,
63    // you can also add more extractors here but the last
64    // extractor must implement `FromRequest` which
65    // `Request` does
66    mut request: Request,
67    next: Next,
68) -> Result<Response, OIDCError> {
69    if let Some(credentials) = credentials {
70        if let Some(cached_token) = CACHED_BASIC_TOKENS
71            .get(&CacheTokenKey::new(
72                &tenant,
73                &project,
74                &credentials.0,
75                &credentials.1,
76            ))
77            .await
78        {
79            request.headers_mut().insert(
80                axum::http::header::AUTHORIZATION,
81                format!("Bearer {}", cached_token).parse().unwrap(),
82            );
83        } else {
84            let res = client_credentials_to_token_response(
85                state.as_ref(),
86                &tenant,
87                &project,
88                &None,
89                &OAuth2TokenBody {
90                    client_id: credentials.0.clone(),
91                    client_secret: Some(credentials.1.clone()),
92                    code: None,
93                    code_verifier: None,
94                    grant_type: OAuth2TokenBodyGrantType::ClientCredentials,
95                    redirect_uri: None,
96                    refresh_token: None,
97                    scope: None,
98                },
99                ClientCredentialsMethod::BasicAuth,
100            )
101            .await?;
102
103            let Some(id_token) = res.id_token else {
104                return Err(OIDCError::new(
105                    OIDCErrorCode::AccessDenied,
106                    Some("Failed to authorize client.".to_string()),
107                    None,
108                ));
109            };
110
111            CACHED_BASIC_TOKENS
112                .insert(
113                    CacheTokenKey::new(&tenant, &project, &credentials.0, &credentials.1),
114                    id_token.clone(),
115                )
116                .await;
117
118            request.headers_mut().insert(
119                axum::http::header::AUTHORIZATION,
120                format!("Bearer {}", id_token).parse().unwrap(),
121            );
122        }
123    }
124
125    Ok(next.run(request).await)
126}