Skip to main content

haste_server/
server.rs

1use crate::{
2    auth_n::{self, certificates::get_certification_provider},
3    fhir_client::ServerCTX,
4    fhir_http::{HTTPBody, HTTPRequest, http_request_to_fhir_request},
5    mcp,
6    middleware::{
7        errors::{log_operationoutcome_errors, operation_outcome_error_handle},
8        security_headers::SecurityHeaderLayer,
9    },
10    openapi,
11    services::{AppState, ConfigError, create_services, get_pool},
12    static_assets::{create_static_server, root_asset_route},
13};
14use axum::{
15    Extension, Router, ServiceExt,
16    body::Body,
17    extract::{DefaultBodyLimit, OriginalUri, Path, State},
18    http::Request,
19    http::{HeaderName, HeaderValue, Method, Uri},
20    middleware::from_fn,
21    response::{IntoResponse, Response},
22    routing::{any, get, post},
23};
24use haste_config::get_config;
25use haste_fhir_client::FHIRClient;
26use haste_fhir_operation_error::OperationOutcomeError;
27use haste_fhir_search::SearchEngine;
28use haste_fhir_terminology::FHIRTerminology;
29use haste_jwt::{ProjectId, TenantId, claims::UserTokenClaims};
30use haste_repository::{Repository, types::SupportedFHIRVersions};
31use sentry::integrations::tower::NewSentryLayer;
32use serde::Deserialize;
33use std::{collections::HashMap, sync::Arc};
34use tower::{Layer, ServiceBuilder};
35use tower_http::normalize_path::NormalizePath;
36use tower_http::{
37    compression::CompressionLayer,
38    cors::{Any, CorsLayer},
39    normalize_path::NormalizePathLayer,
40    set_header::SetResponseHeaderLayer,
41    trace::TraceLayer,
42};
43use tower_sessions::{
44    Expiry, SessionManagerLayer,
45    cookie::{SameSite, time::Duration},
46};
47use tower_sessions_sqlx_store::PostgresStore;
48use tracing::{Instrument, Level, span};
49
50const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION");
51
52#[derive(Deserialize)]
53struct FHIRHandlerPath {
54    tenant: TenantId,
55    project: ProjectId,
56    fhir_version: SupportedFHIRVersions,
57    fhir_location: Option<String>,
58}
59
60#[derive(Deserialize)]
61struct FHIRRootHandlerPath {
62    tenant: TenantId,
63    project: ProjectId,
64    fhir_version: SupportedFHIRVersions,
65}
66
67async fn fhir_handler<
68    Repo: Repository + Send + Sync + 'static,
69    Search: SearchEngine + Send + Sync + 'static,
70    Terminology: FHIRTerminology + Send + Sync + 'static,
71>(
72    claims: Arc<UserTokenClaims>,
73    method: Method,
74    uri: Uri,
75    path: FHIRHandlerPath,
76    state: Arc<AppState<Repo, Search, Terminology>>,
77    body: String,
78) -> Result<Response, OperationOutcomeError> {
79    let fhir_location = path.fhir_location.unwrap_or_default();
80    let method_str = method.to_string();
81    let span = span!(Level::ERROR, "FHIR-HTTP", method_str, fhir_location);
82    async {
83        let http_req = HTTPRequest::new(
84            method,
85            fhir_location,
86            HTTPBody::String(body),
87            uri.query()
88                .map(|q| {
89                    url::form_urlencoded::parse(q.as_bytes())
90                        .into_owned()
91                        .collect()
92                })
93                .unwrap_or_else(HashMap::new),
94        );
95
96        let fhir_request = http_request_to_fhir_request(SupportedFHIRVersions::R4, http_req)?;
97
98        let ctx = Arc::new(ServerCTX::new(
99            path.tenant,
100            path.project,
101            path.fhir_version,
102            claims.clone(),
103            state.fhir_client.clone(),
104            state.rate_limit.clone(),
105        ));
106
107        let response = state.fhir_client.request(ctx, fhir_request).await?;
108
109        let http_response = response.into_response();
110        Ok(http_response)
111    }
112    .instrument(span)
113    .await
114}
115
116async fn fhir_root_handler<
117    Repo: Repository + Send + Sync + 'static,
118    Search: SearchEngine + Send + Sync + 'static,
119    Terminology: FHIRTerminology + Send + Sync + 'static,
120>(
121    method: Method,
122    Extension(user): Extension<Arc<UserTokenClaims>>,
123    OriginalUri(uri): OriginalUri,
124    Path(path): Path<FHIRRootHandlerPath>,
125    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
126    body: String,
127) -> Result<Response, OperationOutcomeError> {
128    fhir_handler(
129        user,
130        method,
131        uri,
132        FHIRHandlerPath {
133            tenant: path.tenant,
134            project: path.project,
135            fhir_version: path.fhir_version,
136            fhir_location: None,
137        },
138        state,
139        body,
140    )
141    .await
142}
143
144async fn fhir_type_handler<
145    Repo: Repository + Send + Sync + 'static,
146    Search: SearchEngine + Send + Sync + 'static,
147    Terminology: FHIRTerminology + Send + Sync + 'static,
148>(
149    method: Method,
150    Extension(user): Extension<Arc<UserTokenClaims>>,
151    OriginalUri(uri): OriginalUri,
152    Path(path): Path<FHIRHandlerPath>,
153    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
154    body: String,
155) -> Result<Response, OperationOutcomeError> {
156    fhir_handler(user, method, uri, path, state, body).await
157}
158
159pub async fn server() -> Result<NormalizePath<Router>, OperationOutcomeError> {
160    let config = get_config("environment".into());
161    // Varify instantiates.
162    get_certification_provider();
163
164    let pool = get_pool(config.as_ref()).await;
165    let session_store = PostgresStore::new(pool.clone());
166    session_store.migrate().await.map_err(ConfigError::from)?;
167
168    let max_body_size = config
169        .get(crate::ServerEnvironmentVariables::MaxRequestBodySize)
170        .ok()
171        .and_then(|s| s.parse::<usize>().ok())
172        .unwrap_or(4 * 1024 * 1024);
173    let shared_state = create_services(config).await?;
174
175    let fhir_router = Router::new()
176        .route("/{fhir_version}", any(fhir_root_handler))
177        .route("/{fhir_version}/{*fhir_location}", any(fhir_type_handler));
178
179    let protected_resources_router = Router::new()
180        .nest("/fhir", fhir_router)
181        .route("/mcp", post(mcp::route::mcp_handler))
182        .layer(
183            ServiceBuilder::new()
184                .layer(axum::middleware::from_fn_with_state(
185                    shared_state.clone(),
186                    auth_n::middleware::basic_auth::basic_auth_middleware,
187                ))
188                .layer(axum::middleware::from_fn_with_state(
189                    shared_state.clone(),
190                    auth_n::middleware::jwt::token_verifcation,
191                ))
192                .layer(axum::middleware::from_fn(
193                    auth_n::middleware::project_access::project_access,
194                )),
195        );
196
197    let project_router = Router::new().merge(protected_resources_router).nest(
198        "/oidc",
199        auth_n::oidc::routes::create_router(shared_state.clone()),
200    );
201
202    let tenant_router = Router::new()
203        .nest("/{project}/api/v1", project_router)
204        .layer(
205            // Relies on tenant for html so moving operation outcome error handling to here.
206            ServiceBuilder::new()
207                .layer(from_fn(operation_outcome_error_handle))
208                .layer(from_fn(log_operationoutcome_errors)),
209        );
210
211    let discovery_2_0_document_router = Router::new()
212        .route(
213            "/openid-configuration/w/{tenant}/{project}/{*resource}",
214            get(auth_n::oidc::routes::discovery::openid_configuration),
215        )
216        .route(
217            "/openid-configuration/w/{tenant}/{project}",
218            get(auth_n::oidc::routes::discovery::openid_configuration),
219        )
220        .route(
221            "/oauth-protected-resource/w/{tenant}/{project}/{*resource}",
222            get(auth_n::oidc::routes::discovery::oauth_protected_resource),
223        );
224
225    let app = Router::new()
226        .nest("/.well-known", discovery_2_0_document_router)
227        .nest(
228            "/auth",
229            auth_n::global::routes::create_router(shared_state.clone()),
230        )
231        .route("/openapi.json", get(openapi::openapi_document_handler))
232        .nest("/w/{tenant}", tenant_router)
233        .layer(
234            ServiceBuilder::new()
235                .layer(NewSentryLayer::<Request<Body>>::new_from_top())
236                .layer(TraceLayer::new_for_http())
237                // 4mb by default.
238                .layer(DefaultBodyLimit::max(max_body_size))
239                .layer(CompressionLayer::new())
240                .layer(SecurityHeaderLayer::new())
241                .layer(SetResponseHeaderLayer::overriding(
242                    HeaderName::from_static("x-api-version"),
243                    HeaderValue::from_static(SERVER_VERSION),
244                ))
245                .layer(
246                    SessionManagerLayer::new(session_store)
247                        .with_secure(true)
248                        .with_same_site(SameSite::None)
249                        .with_expiry(Expiry::OnInactivity(Duration::days(3))),
250                )
251                .layer(
252                    CorsLayer::new()
253                        // allow `GET` and `POST` when accessing the resource
254                        .allow_methods(Any)
255                        // allow requests from any origin
256                        .allow_origin(Any)
257                        .allow_headers(Any),
258                ),
259        )
260        .with_state(shared_state)
261        .nest(root_asset_route().to_str().unwrap(), create_static_server());
262
263    Ok(NormalizePathLayer::trim_trailing_slash().layer(app))
264}
265
266pub async fn serve(port: u16) -> Result<(), OperationOutcomeError> {
267    let server = server().await?;
268
269    let addr = format!("0.0.0.0:{}", port);
270    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
271
272    tracing::info!("Server started");
273    axum::serve(
274        listener,
275        <tower_http::normalize_path::NormalizePath<Router> as ServiceExt<
276            axum::http::Request<Body>,
277        >>::into_make_service(server),
278    )
279    .await
280    .unwrap();
281
282    Ok(())
283}