Skip to main content

haste_server/auth_n/middleware/
jwt.rs

1use crate::{
2    auth_n::certificates, extract::bearer_token::AuthBearer, route_path::project_path,
3    services::AppState,
4};
5use axum::{
6    extract::{OriginalUri, Request, State},
7    http::{HeaderMap, StatusCode, Uri},
8    middleware::Next,
9    response::{IntoResponse as _, Response},
10};
11use haste_fhir_model::r4::generated::terminology::IssueType;
12use haste_fhir_operation_error::OperationOutcomeError;
13use haste_fhir_search::SearchEngine;
14use haste_fhir_terminology::FHIRTerminology;
15use haste_jwt::{ProjectId, TenantId, claims::UserTokenClaims};
16use haste_repository::Repository;
17use jsonwebtoken::Validation;
18use std::{
19    path::PathBuf,
20    sync::{Arc, LazyLock},
21};
22use url::Url;
23
24pub struct User {
25    #[allow(dead_code)]
26    pub token: Option<String>,
27    pub claims: haste_jwt::claims::UserTokenClaims,
28}
29
30static VALIDATION_CONFIG: LazyLock<Validation> = LazyLock::new(|| {
31    let mut config = Validation::new(jsonwebtoken::Algorithm::RS256);
32    config.validate_aud = false;
33    config
34});
35
36fn validate_jwt(token: &str) -> Result<UserTokenClaims, StatusCode> {
37    let header = jsonwebtoken::decode_header(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
38
39    let cert_provider = certificates::get_certification_provider();
40
41    let decoding_key = cert_provider
42        .decoding_key(&header.kid.unwrap_or_else(|| "".to_string()).as_str())
43        .map_err(|_| StatusCode::UNAUTHORIZED)?;
44
45    let result = jsonwebtoken::decode::<UserTokenClaims>(
46        token,
47        &decoding_key.decoding_key,
48        &*VALIDATION_CONFIG,
49    )
50    .map_err(|_| StatusCode::UNAUTHORIZED)?;
51
52    Ok(result.claims)
53}
54
55pub fn derive_well_known_openid_configuration_url(
56    api_url: &str,
57    tenant: &TenantId,
58    project: &ProjectId,
59) -> Result<Url, OperationOutcomeError> {
60    let path = PathBuf::from("/.well-known/openid-configuration");
61
62    if let Ok(api_url) = Url::parse(&api_url) {
63        api_url
64            .join(
65                path.join(project_path(tenant, project).strip_prefix("/").unwrap())
66                    .to_str()
67                    .unwrap_or_default(),
68            )
69            .map_err(|e| {
70                tracing::error!("Failed to derive well-known URL: {:?}", e);
71                OperationOutcomeError::error(
72                    IssueType::Invalid(None),
73                    "Invalid API URL configured".to_string(),
74                )
75            })
76    } else {
77        Err(OperationOutcomeError::error(
78            IssueType::Invalid(None),
79            "Invalid API URL configured".to_string(),
80        ))
81    }
82}
83
84pub fn derive_protected_resource_metadata_url(
85    resource_uri: &Uri,
86    api_url: &str,
87) -> Result<Url, OperationOutcomeError> {
88    let path = PathBuf::from("/.well-known/oauth-protected-resource");
89    if let Ok(api_url) = Url::parse(&api_url) {
90        let tenant_url = api_url
91            .join(
92                path.join(resource_uri.path().strip_prefix("/").unwrap_or_default())
93                    .to_str()
94                    .unwrap_or_default(),
95            )
96            .map_err(|e| {
97                tracing::error!("Failed to derive well-known URL: {:?}", e);
98                OperationOutcomeError::error(
99                    IssueType::Invalid(None),
100                    "Invalid API URL configured".to_string(),
101                )
102            })?;
103
104        Ok(tenant_url)
105    } else {
106        Err(OperationOutcomeError::error(
107            IssueType::Invalid(None),
108            "Invalid API URL configured".to_string(),
109        ))
110    }
111}
112
113fn invalid_jwt_response(uri: &Uri, api_url: &str, status_code: StatusCode) -> Response {
114    tracing::warn!(
115        "Invalid JWT token provided in request sending '{}'",
116        status_code
117    );
118
119    let Ok(protected_resource_metadata_url) = derive_protected_resource_metadata_url(uri, api_url)
120    else {
121        return (status_code).into_response();
122    };
123
124    let mut headers = HeaderMap::new();
125    headers.insert(
126        axum::http::header::WWW_AUTHENTICATE,
127        format!(
128            r#"Bearer resource_metadata="{}""#,
129            protected_resource_metadata_url.to_string()
130        )
131        .parse()
132        .unwrap(),
133    );
134    (status_code, headers).into_response()
135}
136
137pub async fn token_verifcation<
138    Repo: Repository + Send + Sync + 'static,
139    Search: SearchEngine + Send + Sync + 'static,
140    Terminology: FHIRTerminology + Send + Sync + 'static,
141>(
142    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
143    // run the `HeaderMap` extractor
144    AuthBearer(token): AuthBearer,
145    // you can also add more extractors here but the last
146    // extractor must implement `FromRequest` which
147    // `Request` does
148    OriginalUri(uri): OriginalUri,
149    mut request: Request,
150    next: Next,
151) -> Result<Response, Response> {
152    let Some(token) = token else {
153        return Err(invalid_jwt_response(
154            &uri,
155            &state
156                .config
157                .get(crate::ServerEnvironmentVariables::APIURI)
158                .unwrap_or_default(),
159            StatusCode::UNAUTHORIZED,
160        ));
161    };
162
163    match validate_jwt(&token) {
164        Ok(claims) => {
165            request.extensions_mut().insert(Arc::new(User {
166                token: Some(token),
167                claims,
168            }));
169            Ok(next.run(request).await)
170        }
171        Err(status_code) => match status_code {
172            StatusCode::UNAUTHORIZED => Err(invalid_jwt_response(
173                &uri,
174                &state
175                    .config
176                    .get(crate::ServerEnvironmentVariables::APIURI)
177                    .unwrap_or_default(),
178                status_code,
179            )),
180            _ => Err((status_code).into_response()),
181        },
182    }
183}