haste_server/auth_n/middleware/
basic_auth.rs1use crate::{
2 auth_n::oidc::{
3 error::{OIDCError, OIDCErrorCode},
4 routes::token::{
5 ClientCredentialsMethod, TOKEN_EXPIRATION, client_credentials_to_token_response,
6 },
7 schemas::token_body::{OAuth2TokenBody, OAuth2TokenBodyGrantType},
8 },
9 extract::{
10 basic_credentials::BasicCredentialsHeader,
11 path_tenant::{ProjectIdentifier, TenantIdentifier},
12 },
13 services::AppState,
14};
15use axum::{
16 extract::{Request, State},
17 middleware::Next,
18 response::Response,
19};
20use axum_extra::extract::Cached;
21use haste_fhir_search::SearchEngine;
22use haste_fhir_terminology::FHIRTerminology;
23use haste_jwt::{ProjectId, TenantId};
24use haste_repository::Repository;
25
26use std::{
27 sync::{Arc, LazyLock},
28 time::Duration,
29};
30
31#[derive(Hash, PartialEq, Eq)]
32struct CacheTokenKey(String);
33impl CacheTokenKey {
34 fn new(tenant: &TenantId, project: &ProjectId, client_id: &str, client_secret: &str) -> Self {
35 Self(format!(
36 "{}:{}:{}:{}",
37 tenant, project, client_id, client_secret
38 ))
39 }
40}
41
42static CACHED_BASIC_TOKENS: LazyLock<
44 moka::future::Cache<CacheTokenKey, String>,
46> = LazyLock::new(|| {
47 moka::future::Cache::builder()
48 .time_to_live(Duration::from_secs(TOKEN_EXPIRATION as u64 - 500))
50 .build()
51});
52
53pub async fn basic_auth_middleware<
54 Repo: Repository + Send + Sync + 'static,
55 Search: SearchEngine + Send + Sync + 'static,
56 Terminology: FHIRTerminology + Send + Sync + 'static,
57>(
58 Cached(TenantIdentifier { tenant }): Cached<TenantIdentifier>,
59 Cached(ProjectIdentifier { project }): Cached<ProjectIdentifier>,
60 State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
61 BasicCredentialsHeader(credentials): BasicCredentialsHeader,
63 mut request: Request,
67 next: Next,
68) -> Result<Response, OIDCError> {
69 if let Some(credentials) = credentials {
70 if let Some(cached_token) = CACHED_BASIC_TOKENS
71 .get(&CacheTokenKey::new(
72 &tenant,
73 &project,
74 &credentials.0,
75 &credentials.1,
76 ))
77 .await
78 {
79 request.headers_mut().insert(
80 axum::http::header::AUTHORIZATION,
81 format!("Bearer {}", cached_token).parse().unwrap(),
82 );
83 } else {
84 let res = client_credentials_to_token_response(
85 state.as_ref(),
86 &tenant,
87 &project,
88 &None,
89 &OAuth2TokenBody {
90 client_id: credentials.0.clone(),
91 client_secret: Some(credentials.1.clone()),
92 code: None,
93 code_verifier: None,
94 grant_type: OAuth2TokenBodyGrantType::ClientCredentials,
95 redirect_uri: None,
96 refresh_token: None,
97 scope: None,
98 },
99 ClientCredentialsMethod::BasicAuth,
100 )
101 .await?;
102
103 let Some(id_token) = res.id_token else {
104 return Err(OIDCError::new(
105 OIDCErrorCode::AccessDenied,
106 Some("Failed to authorize client.".to_string()),
107 None,
108 ));
109 };
110
111 CACHED_BASIC_TOKENS
112 .insert(
113 CacheTokenKey::new(&tenant, &project, &credentials.0, &credentials.1),
114 id_token.clone(),
115 )
116 .await;
117
118 request.headers_mut().insert(
119 axum::http::header::AUTHORIZATION,
120 format!("Bearer {}", id_token).parse().unwrap(),
121 );
122 }
123 }
124
125 Ok(next.run(request).await)
126}