Skip to main content

haste_server/
server.rs

1use crate::{
2    auth_n::{self, certificates::get_certification_provider},
3    fhir_client::ServerCTX,
4    fhir_http::{HTTPBody, HTTPRequest, http_request_to_fhir_request},
5    mcp,
6    middleware::{
7        errors::{log_operationoutcome_errors, operation_outcome_error_handle},
8        security_headers::SecurityHeaderLayer,
9    },
10    openapi,
11    services::{AppState, ConfigError, create_services, get_pool},
12    static_assets::{create_static_server, root_asset_route},
13};
14use axum::{
15    Extension, Router, ServiceExt,
16    body::Body,
17    extract::{DefaultBodyLimit, OriginalUri, Path, State},
18    http::{HeaderName, HeaderValue, Method, Uri},
19    middleware::from_fn,
20    response::{IntoResponse, Response},
21    routing::{any, get, post},
22};
23use haste_config::get_config;
24use haste_fhir_client::FHIRClient;
25use haste_fhir_operation_error::OperationOutcomeError;
26use haste_fhir_search::SearchEngine;
27use haste_fhir_terminology::FHIRTerminology;
28use haste_jwt::{ProjectId, TenantId, claims::UserTokenClaims};
29use haste_repository::{Repository, types::SupportedFHIRVersions};
30use serde::Deserialize;
31use std::{collections::HashMap, sync::Arc};
32use tower::{Layer, ServiceBuilder};
33use tower_http::normalize_path::NormalizePath;
34use tower_http::{
35    compression::CompressionLayer,
36    cors::{Any, CorsLayer},
37    normalize_path::NormalizePathLayer,
38    set_header::SetResponseHeaderLayer,
39    trace::TraceLayer,
40};
41use tower_sessions::{
42    Expiry, SessionManagerLayer,
43    cookie::{SameSite, time::Duration},
44};
45use tower_sessions_sqlx_store::PostgresStore;
46use tracing::{Instrument, Level, span};
47
48const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION");
49
50#[derive(Deserialize)]
51struct FHIRHandlerPath {
52    tenant: TenantId,
53    project: ProjectId,
54    fhir_version: SupportedFHIRVersions,
55    fhir_location: Option<String>,
56}
57
58#[derive(Deserialize)]
59struct FHIRRootHandlerPath {
60    tenant: TenantId,
61    project: ProjectId,
62    fhir_version: SupportedFHIRVersions,
63}
64
65async fn fhir_handler<
66    Repo: Repository + Send + Sync + 'static,
67    Search: SearchEngine + Send + Sync + 'static,
68    Terminology: FHIRTerminology + Send + Sync + 'static,
69>(
70    claims: Arc<UserTokenClaims>,
71    method: Method,
72    uri: Uri,
73    path: FHIRHandlerPath,
74    state: Arc<AppState<Repo, Search, Terminology>>,
75    body: String,
76) -> Result<Response, OperationOutcomeError> {
77    let fhir_location = path.fhir_location.unwrap_or_default();
78    let method_str = method.to_string();
79    let span = span!(Level::ERROR, "FHIR-HTTP", method_str, fhir_location);
80    async {
81        let http_req = HTTPRequest::new(
82            method,
83            fhir_location,
84            HTTPBody::String(body),
85            uri.query()
86                .map(|q| {
87                    url::form_urlencoded::parse(q.as_bytes())
88                        .into_owned()
89                        .collect()
90                })
91                .unwrap_or_else(HashMap::new),
92        );
93
94        let fhir_request = http_request_to_fhir_request(SupportedFHIRVersions::R4, http_req)?;
95
96        let ctx = Arc::new(ServerCTX::new(
97            path.tenant,
98            path.project,
99            path.fhir_version,
100            claims.clone(),
101            state.fhir_client.clone(),
102            state.rate_limit.clone(),
103        ));
104
105        let response = state.fhir_client.request(ctx, fhir_request).await?;
106
107        let http_response = response.into_response();
108        Ok(http_response)
109    }
110    .instrument(span)
111    .await
112}
113
114async fn fhir_root_handler<
115    Repo: Repository + Send + Sync + 'static,
116    Search: SearchEngine + Send + Sync + 'static,
117    Terminology: FHIRTerminology + Send + Sync + 'static,
118>(
119    method: Method,
120    Extension(user): Extension<Arc<UserTokenClaims>>,
121    OriginalUri(uri): OriginalUri,
122    Path(path): Path<FHIRRootHandlerPath>,
123    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
124    body: String,
125) -> Result<Response, OperationOutcomeError> {
126    fhir_handler(
127        user,
128        method,
129        uri,
130        FHIRHandlerPath {
131            tenant: path.tenant,
132            project: path.project,
133            fhir_version: path.fhir_version,
134            fhir_location: None,
135        },
136        state,
137        body,
138    )
139    .await
140}
141
142async fn fhir_type_handler<
143    Repo: Repository + Send + Sync + 'static,
144    Search: SearchEngine + Send + Sync + 'static,
145    Terminology: FHIRTerminology + Send + Sync + 'static,
146>(
147    method: Method,
148    Extension(user): Extension<Arc<UserTokenClaims>>,
149    OriginalUri(uri): OriginalUri,
150    Path(path): Path<FHIRHandlerPath>,
151    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
152    body: String,
153) -> Result<Response, OperationOutcomeError> {
154    fhir_handler(user, method, uri, path, state, body).await
155}
156
157pub async fn server() -> Result<NormalizePath<Router>, OperationOutcomeError> {
158    let config = get_config("environment".into());
159    // Varify instantiates.
160    get_certification_provider();
161
162    let pool = get_pool(config.as_ref()).await;
163    let session_store = PostgresStore::new(pool.clone());
164    session_store.migrate().await.map_err(ConfigError::from)?;
165
166    let max_body_size = config
167        .get(crate::ServerEnvironmentVariables::MaxRequestBodySize)
168        .ok()
169        .and_then(|s| s.parse::<usize>().ok())
170        .unwrap_or(4 * 1024 * 1024);
171    let shared_state = create_services(config).await?;
172
173    let fhir_router = Router::new()
174        .route("/{fhir_version}", any(fhir_root_handler))
175        .route("/{fhir_version}/{*fhir_location}", any(fhir_type_handler));
176
177    let protected_resources_router = Router::new()
178        .nest("/fhir", fhir_router)
179        .route("/mcp", post(mcp::route::mcp_handler))
180        .layer(
181            ServiceBuilder::new()
182                .layer(axum::middleware::from_fn_with_state(
183                    shared_state.clone(),
184                    auth_n::middleware::basic_auth::basic_auth_middleware,
185                ))
186                .layer(axum::middleware::from_fn_with_state(
187                    shared_state.clone(),
188                    auth_n::middleware::jwt::token_verifcation,
189                ))
190                .layer(axum::middleware::from_fn(
191                    auth_n::middleware::project_access::project_access,
192                )),
193        );
194
195    let project_router = Router::new().merge(protected_resources_router).nest(
196        "/oidc",
197        auth_n::oidc::routes::create_router(shared_state.clone()),
198    );
199
200    let tenant_router = Router::new()
201        .nest("/{project}/api/v1", project_router)
202        .layer(
203            // Relies on tenant for html so moving operation outcome error handling to here.
204            ServiceBuilder::new()
205                .layer(from_fn(operation_outcome_error_handle))
206                .layer(from_fn(log_operationoutcome_errors)),
207        );
208
209    let discovery_2_0_document_router = Router::new()
210        .route(
211            "/openid-configuration/w/{tenant}/{project}/{*resource}",
212            get(auth_n::oidc::routes::discovery::openid_configuration),
213        )
214        .route(
215            "/openid-configuration/w/{tenant}/{project}",
216            get(auth_n::oidc::routes::discovery::openid_configuration),
217        )
218        .route(
219            "/oauth-protected-resource/w/{tenant}/{project}/{*resource}",
220            get(auth_n::oidc::routes::discovery::oauth_protected_resource),
221        );
222
223    let app = Router::new()
224        .nest("/.well-known", discovery_2_0_document_router)
225        .nest(
226            "/auth",
227            auth_n::global::routes::create_router(shared_state.clone()),
228        )
229        .route("/openapi.json", get(openapi::openapi_document_handler))
230        .nest("/w/{tenant}", tenant_router)
231        .layer(
232            ServiceBuilder::new()
233                .layer(TraceLayer::new_for_http())
234                // 4mb by default.
235                .layer(DefaultBodyLimit::max(max_body_size))
236                .layer(CompressionLayer::new())
237                .layer(SecurityHeaderLayer::new())
238                .layer(SetResponseHeaderLayer::overriding(
239                    HeaderName::from_static("x-api-version"),
240                    HeaderValue::from_static(SERVER_VERSION),
241                ))
242                .layer(
243                    SessionManagerLayer::new(session_store)
244                        .with_secure(true)
245                        .with_same_site(SameSite::None)
246                        .with_expiry(Expiry::OnInactivity(Duration::days(3))),
247                )
248                .layer(
249                    CorsLayer::new()
250                        // allow `GET` and `POST` when accessing the resource
251                        .allow_methods(Any)
252                        // allow requests from any origin
253                        .allow_origin(Any)
254                        .allow_headers(Any),
255                ),
256        )
257        .with_state(shared_state)
258        .nest(root_asset_route().to_str().unwrap(), create_static_server());
259
260    Ok(NormalizePathLayer::trim_trailing_slash().layer(app))
261}
262
263pub async fn serve(port: u16) -> Result<(), OperationOutcomeError> {
264    let server = server().await?;
265
266    let addr = format!("0.0.0.0:{}", port);
267    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
268
269    tracing::info!("Server started");
270    axum::serve(
271        listener,
272        <tower_http::normalize_path::NormalizePath<Router> as ServiceExt<
273            axum::http::Request<Body>,
274        >>::into_make_service(server),
275    )
276    .await
277    .unwrap();
278
279    Ok(())
280}