Skip to main content

haste_server/auth_n/middleware/
jwt.rs

1use crate::{
2    auth_n::certificates, extract::bearer_token::AuthBearer, route_path::project_path,
3    services::AppState,
4};
5use axum::{
6    extract::{OriginalUri, Request, State},
7    http::{HeaderMap, StatusCode, Uri},
8    middleware::Next,
9    response::{IntoResponse as _, Response},
10};
11use derivative::Derivative;
12use haste_fhir_model::r4::generated::terminology::IssueType;
13use haste_fhir_operation_error::OperationOutcomeError;
14use haste_fhir_search::SearchEngine;
15use haste_fhir_terminology::FHIRTerminology;
16use haste_jwt::{ProjectId, TenantId, claims::UserTokenClaims};
17use haste_repository::Repository;
18use jsonwebtoken::Validation;
19use std::{
20    path::PathBuf,
21    sync::{Arc, LazyLock},
22};
23use url::Url;
24
25#[derive(Derivative)]
26#[derivative(Debug)]
27pub struct User {
28    #[allow(dead_code)]
29    #[derivative(Debug = "ignore")]
30    pub token: Option<String>,
31    pub claims: haste_jwt::claims::UserTokenClaims,
32}
33
34static VALIDATION_CONFIG: LazyLock<Validation> = LazyLock::new(|| {
35    let mut config = Validation::new(jsonwebtoken::Algorithm::RS256);
36    config.validate_aud = false;
37    config
38});
39
40fn validate_jwt(token: &str) -> Result<UserTokenClaims, StatusCode> {
41    let header = jsonwebtoken::decode_header(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
42
43    let cert_provider = certificates::get_certification_provider();
44
45    let decoding_key = cert_provider
46        .decoding_key(&header.kid.unwrap_or_else(|| "".to_string()).as_str())
47        .map_err(|_| StatusCode::UNAUTHORIZED)?;
48
49    let result = jsonwebtoken::decode::<UserTokenClaims>(
50        token,
51        &decoding_key.decoding_key,
52        &*VALIDATION_CONFIG,
53    )
54    .map_err(|_| StatusCode::UNAUTHORIZED)?;
55
56    Ok(result.claims)
57}
58
59pub fn derive_well_known_openid_configuration_url(
60    api_url: &str,
61    tenant: &TenantId,
62    project: &ProjectId,
63) -> Result<Url, OperationOutcomeError> {
64    let path = PathBuf::from("/.well-known/openid-configuration");
65
66    if let Ok(api_url) = Url::parse(&api_url) {
67        api_url
68            .join(
69                path.join(project_path(tenant, project).strip_prefix("/").unwrap())
70                    .to_str()
71                    .unwrap_or_default(),
72            )
73            .map_err(|e| {
74                tracing::error!("Failed to derive well-known URL: {:?}", e);
75                OperationOutcomeError::error(
76                    IssueType::Invalid(None),
77                    "Invalid API URL configured".to_string(),
78                )
79            })
80    } else {
81        Err(OperationOutcomeError::error(
82            IssueType::Invalid(None),
83            "Invalid API URL configured".to_string(),
84        ))
85    }
86}
87
88pub fn derive_protected_resource_metadata_url(
89    resource_uri: &Uri,
90    api_url: &str,
91) -> Result<Url, OperationOutcomeError> {
92    let path = PathBuf::from("/.well-known/oauth-protected-resource");
93    if let Ok(api_url) = Url::parse(&api_url) {
94        let tenant_url = api_url
95            .join(
96                path.join(resource_uri.path().strip_prefix("/").unwrap_or_default())
97                    .to_str()
98                    .unwrap_or_default(),
99            )
100            .map_err(|e| {
101                tracing::error!("Failed to derive well-known URL: {:?}", e);
102                OperationOutcomeError::error(
103                    IssueType::Invalid(None),
104                    "Invalid API URL configured".to_string(),
105                )
106            })?;
107
108        Ok(tenant_url)
109    } else {
110        Err(OperationOutcomeError::error(
111            IssueType::Invalid(None),
112            "Invalid API URL configured".to_string(),
113        ))
114    }
115}
116
117fn invalid_jwt_response(uri: &Uri, api_url: &str, status_code: StatusCode) -> Response {
118    tracing::warn!(
119        "Invalid JWT token provided in request sending '{}'",
120        status_code
121    );
122
123    let Ok(protected_resource_metadata_url) = derive_protected_resource_metadata_url(uri, api_url)
124    else {
125        return (status_code).into_response();
126    };
127
128    let mut headers = HeaderMap::new();
129    headers.insert(
130        axum::http::header::WWW_AUTHENTICATE,
131        format!(
132            r#"Bearer resource_metadata="{}""#,
133            protected_resource_metadata_url.to_string()
134        )
135        .parse()
136        .unwrap(),
137    );
138    (status_code, headers).into_response()
139}
140
141pub async fn token_verifcation<
142    Repo: Repository + Send + Sync + 'static,
143    Search: SearchEngine + Send + Sync + 'static,
144    Terminology: FHIRTerminology + Send + Sync + 'static,
145>(
146    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
147    // run the `HeaderMap` extractor
148    AuthBearer(token): AuthBearer,
149    // you can also add more extractors here but the last
150    // extractor must implement `FromRequest` which
151    // `Request` does
152    OriginalUri(uri): OriginalUri,
153    mut request: Request,
154    next: Next,
155) -> Result<Response, Response> {
156    let Some(token) = token else {
157        return Err(invalid_jwt_response(
158            &uri,
159            &state
160                .config
161                .get(crate::ServerEnvironmentVariables::APIURI)
162                .unwrap_or_default(),
163            StatusCode::UNAUTHORIZED,
164        ));
165    };
166
167    match validate_jwt(&token) {
168        Ok(claims) => {
169            request.extensions_mut().insert(Arc::new(User {
170                token: Some(token),
171                claims,
172            }));
173            Ok(next.run(request).await)
174        }
175        Err(status_code) => match status_code {
176            StatusCode::UNAUTHORIZED => Err(invalid_jwt_response(
177                &uri,
178                &state
179                    .config
180                    .get(crate::ServerEnvironmentVariables::APIURI)
181                    .unwrap_or_default(),
182                status_code,
183            )),
184            _ => Err((status_code).into_response()),
185        },
186    }
187}