haste_server/
server.rs

1use crate::{
2    auth_n,
3    fhir_client::ServerCTX,
4    fhir_http::{HTTPBody, HTTPRequest, http_request_to_fhir_request},
5    mcp,
6    middleware::errors::{log_operationoutcome_errors, operation_outcome_error_handle},
7    services::{AppState, ConfigError, create_services, get_pool},
8    static_assets::{create_static_server, root_asset_route},
9};
10use axum::{
11    Extension, Router, ServiceExt,
12    body::Body,
13    extract::{DefaultBodyLimit, OriginalUri, Path, State},
14    http::{Method, Uri},
15    middleware::from_fn,
16    response::{IntoResponse, Response},
17    routing::{any, get, post},
18};
19use haste_config::get_config;
20use haste_fhir_client::FHIRClient;
21use haste_fhir_operation_error::OperationOutcomeError;
22use haste_fhir_search::SearchEngine;
23use haste_fhir_terminology::FHIRTerminology;
24use haste_jwt::{ProjectId, TenantId, claims::UserTokenClaims};
25use haste_repository::{Repository, types::SupportedFHIRVersions};
26use serde::Deserialize;
27use std::{collections::HashMap, sync::Arc, time::Instant};
28use tower::{Layer, ServiceBuilder};
29use tower_http::normalize_path::NormalizePath;
30use tower_http::{
31    compression::CompressionLayer,
32    cors::{Any, CorsLayer},
33    normalize_path::NormalizePathLayer,
34};
35use tower_sessions::{
36    Expiry, SessionManagerLayer,
37    cookie::{SameSite, time::Duration},
38};
39use tower_sessions_sqlx_store::PostgresStore;
40use tracing::{Instrument, Level, info, span};
41
42#[derive(Deserialize)]
43struct FHIRHandlerPath {
44    tenant: TenantId,
45    project: ProjectId,
46    fhir_version: SupportedFHIRVersions,
47    fhir_location: Option<String>,
48}
49
50#[derive(Deserialize)]
51struct FHIRRootHandlerPath {
52    tenant: TenantId,
53    project: ProjectId,
54    fhir_version: SupportedFHIRVersions,
55}
56
57async fn fhir_handler<
58    Repo: Repository + Send + Sync + 'static,
59    Search: SearchEngine + Send + Sync + 'static,
60    Terminology: FHIRTerminology + Send + Sync + 'static,
61>(
62    claims: Arc<UserTokenClaims>,
63    method: Method,
64    uri: Uri,
65    path: FHIRHandlerPath,
66    state: Arc<AppState<Repo, Search, Terminology>>,
67    body: String,
68) -> Result<Response, OperationOutcomeError> {
69    let start = Instant::now();
70    let fhir_location = path.fhir_location.unwrap_or_default();
71    let method_str = method.to_string();
72    let span = span!(Level::ERROR, "FHIR-HTTP", method_str, fhir_location);
73    async {
74        let http_req = HTTPRequest::new(
75            method,
76            fhir_location,
77            HTTPBody::String(body),
78            uri.query()
79                .map(|q| {
80                    url::form_urlencoded::parse(q.as_bytes())
81                        .into_owned()
82                        .collect()
83                })
84                .unwrap_or_else(HashMap::new),
85        );
86
87        let fhir_request = http_request_to_fhir_request(SupportedFHIRVersions::R4, http_req)?;
88
89        let ctx = Arc::new(ServerCTX::new(
90            path.tenant,
91            path.project,
92            path.fhir_version,
93            claims.clone(),
94            state.fhir_client.clone(),
95        ));
96
97        let response = state.fhir_client.request(ctx, fhir_request).await?;
98
99        info!("Request processed in {:?}", start.elapsed());
100
101        let http_response = response.into_response();
102        Ok(http_response)
103    }
104    .instrument(span)
105    .await
106}
107
108async fn fhir_root_handler<
109    Repo: Repository + Send + Sync + 'static,
110    Search: SearchEngine + Send + Sync + 'static,
111    Terminology: FHIRTerminology + Send + Sync + 'static,
112>(
113    method: Method,
114    Extension(user): Extension<Arc<UserTokenClaims>>,
115    OriginalUri(uri): OriginalUri,
116    Path(path): Path<FHIRRootHandlerPath>,
117    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
118    body: String,
119) -> Result<Response, OperationOutcomeError> {
120    fhir_handler(
121        user,
122        method,
123        uri,
124        FHIRHandlerPath {
125            tenant: path.tenant,
126            project: path.project,
127            fhir_version: path.fhir_version,
128            fhir_location: None,
129        },
130        state,
131        body,
132    )
133    .await
134}
135
136async fn fhir_type_handler<
137    Repo: Repository + Send + Sync + 'static,
138    Search: SearchEngine + Send + Sync + 'static,
139    Terminology: FHIRTerminology + Send + Sync + 'static,
140>(
141    method: Method,
142    Extension(user): Extension<Arc<UserTokenClaims>>,
143    OriginalUri(uri): OriginalUri,
144    Path(path): Path<FHIRHandlerPath>,
145    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
146    body: String,
147) -> Result<Response, OperationOutcomeError> {
148    fhir_handler(user, method, uri, path, state, body).await
149}
150
151pub async fn server() -> Result<NormalizePath<Router>, OperationOutcomeError> {
152    let config = get_config("environment".into());
153    auth_n::certificates::create_certifications(&*config).unwrap();
154    let subscriber = tracing_subscriber::FmtSubscriber::new();
155    tracing::subscriber::set_global_default(subscriber).unwrap();
156
157    let pool = get_pool(config.as_ref()).await;
158    let session_store = PostgresStore::new(pool.clone());
159    session_store.migrate().await.map_err(ConfigError::from)?;
160
161    let max_body_size = config
162        .get(crate::ServerEnvironmentVariables::MaxRequestBodySize)
163        .ok()
164        .and_then(|s| s.parse::<usize>().ok())
165        .unwrap_or(4 * 1024 * 1024);
166    let shared_state = create_services(config).await?;
167
168    let fhir_router = Router::new()
169        .route("/{fhir_version}", any(fhir_root_handler))
170        .route("/{fhir_version}/{*fhir_location}", any(fhir_type_handler));
171
172    let protected_resources_router = Router::new()
173        .nest("/fhir", fhir_router)
174        .route("/mcp", post(mcp::route::mcp_handler))
175        .layer(
176            ServiceBuilder::new()
177                .layer(axum::middleware::from_fn_with_state(
178                    shared_state.clone(),
179                    auth_n::middleware::basic_auth::basic_auth_middleware,
180                ))
181                .layer(axum::middleware::from_fn_with_state(
182                    shared_state.clone(),
183                    auth_n::middleware::jwt::token_verifcation,
184                ))
185                .layer(axum::middleware::from_fn(
186                    auth_n::middleware::project_access::project_access,
187                )),
188        );
189
190    let project_router = Router::new().merge(protected_resources_router).nest(
191        "/oidc",
192        auth_n::oidc::routes::create_router(shared_state.clone()),
193    );
194
195    let tenant_router = Router::new()
196        .nest("/{project}/api/v1", project_router)
197        .layer(
198            // Relies on tenant for html so moving operation outcome error handling to here.
199            ServiceBuilder::new()
200                .layer(from_fn(operation_outcome_error_handle))
201                .layer(from_fn(log_operationoutcome_errors)),
202        );
203
204    let discovery_2_0_document_router = Router::new()
205        .route(
206            "/openid-configuration/w/{tenant}/{project}/{*resource}",
207            get(auth_n::oidc::routes::discovery::openid_configuration),
208        )
209        .route(
210            "/openid-configuration/w/{tenant}/{project}",
211            get(auth_n::oidc::routes::discovery::openid_configuration),
212        )
213        .route(
214            "/oauth-protected-resource/w/{tenant}/{project}/{*resource}",
215            get(auth_n::oidc::routes::discovery::oauth_protected_resource),
216        );
217
218    let app = Router::new()
219        .nest("/.well-known", discovery_2_0_document_router)
220        .nest(
221            "/global",
222            auth_n::global::routes::create_router(shared_state.clone()),
223        )
224        .nest("/w/{tenant}", tenant_router)
225        .layer(
226            ServiceBuilder::new()
227                // 4mb by default.
228                .layer(DefaultBodyLimit::max(max_body_size))
229                .layer(CompressionLayer::new())
230                .layer(
231                    SessionManagerLayer::new(session_store)
232                        .with_secure(true)
233                        .with_same_site(SameSite::None)
234                        .with_expiry(Expiry::OnInactivity(Duration::days(3))),
235                )
236                .layer(
237                    CorsLayer::new()
238                        // allow `GET` and `POST` when accessing the resource
239                        .allow_methods(Any)
240                        // allow requests from any origin
241                        .allow_origin(Any)
242                        .allow_headers(Any),
243                ),
244        )
245        .with_state(shared_state)
246        .nest(root_asset_route().to_str().unwrap(), create_static_server());
247
248    Ok(NormalizePathLayer::trim_trailing_slash().layer(app))
249}
250
251pub async fn serve(port: u16) -> Result<(), OperationOutcomeError> {
252    let server = server().await?;
253
254    let addr = format!("0.0.0.0:{}", port);
255    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
256
257    tracing::info!("Server started");
258    axum::serve(
259        listener,
260        <tower_http::normalize_path::NormalizePath<Router> as ServiceExt<
261            axum::http::Request<Body>,
262        >>::into_make_service(server),
263    )
264    .await
265    .unwrap();
266
267    Ok(())
268}