Skip to main content

haste_server/
server.rs

1use crate::{
2    auth_n::{self, certificates::get_certification_provider, middleware::jwt::User},
3    fhir_client::ServerCTX,
4    fhir_http::{HTTPBody, HTTPRequest, http_request_to_fhir_request},
5    mcp,
6    middleware::{
7        errors::{log_operationoutcome_errors, operation_outcome_error_handle},
8        security_headers::SecurityHeaderLayer,
9    },
10    openapi,
11    services::{AppState, ConfigError, create_services, get_pool},
12    static_assets::{create_static_server, root_asset_route},
13};
14use axum::{
15    Extension, Router, ServiceExt,
16    body::Body,
17    extract::{DefaultBodyLimit, OriginalUri, Path, State},
18    http::Request,
19    http::{HeaderName, HeaderValue, Method, Uri},
20    middleware::from_fn,
21    response::{IntoResponse, Response},
22    routing::{any, get, post},
23};
24use axum_client_ip::ClientIpSource;
25use haste_config::get_config;
26use haste_fhir_client::FHIRClient;
27use haste_fhir_model::r4::generated::terminology::IssueType;
28use haste_fhir_operation_error::OperationOutcomeError;
29use haste_fhir_search::SearchEngine;
30use haste_fhir_terminology::FHIRTerminology;
31use haste_jwt::{ProjectId, TenantId};
32use haste_repository::{Repository, types::SupportedFHIRVersions};
33use sentry::integrations::tower::NewSentryLayer;
34use serde::Deserialize;
35use std::{collections::HashMap, sync::Arc};
36use std::{net::SocketAddr, str::FromStr};
37use tower::{Layer, ServiceBuilder};
38use tower_http::normalize_path::NormalizePath;
39use tower_http::{
40    compression::CompressionLayer,
41    cors::{Any, CorsLayer},
42    normalize_path::NormalizePathLayer,
43    set_header::SetResponseHeaderLayer,
44    trace::TraceLayer,
45};
46use tower_sessions::{
47    Expiry, SessionManagerLayer,
48    cookie::{SameSite, time::Duration},
49};
50use tower_sessions_sqlx_store::PostgresStore;
51use tracing::{Instrument, Level, span};
52
53const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION");
54
55#[derive(Deserialize)]
56struct FHIRHandlerPath {
57    tenant: TenantId,
58    project: ProjectId,
59    fhir_version: SupportedFHIRVersions,
60    fhir_location: Option<String>,
61}
62
63#[derive(Deserialize)]
64struct FHIRRootHandlerPath {
65    tenant: TenantId,
66    project: ProjectId,
67    fhir_version: SupportedFHIRVersions,
68}
69
70async fn fhir_handler<
71    Repo: Repository + Send + Sync + 'static,
72    Search: SearchEngine + Send + Sync + 'static,
73    Terminology: FHIRTerminology + Send + Sync + 'static,
74>(
75    user: Arc<User>,
76    method: Method,
77    uri: Uri,
78    path: FHIRHandlerPath,
79    state: Arc<AppState<Repo, Search, Terminology>>,
80    body: String,
81) -> Result<Response, OperationOutcomeError> {
82    let fhir_location = path.fhir_location.unwrap_or_default();
83    let method_str = method.to_string();
84    let span = span!(Level::ERROR, "FHIR-HTTP", method_str, fhir_location);
85    async {
86        let http_req = HTTPRequest::new(
87            method,
88            fhir_location,
89            HTTPBody::String(body),
90            uri.query()
91                .map(|q| {
92                    url::form_urlencoded::parse(q.as_bytes())
93                        .into_owned()
94                        .collect()
95                })
96                .unwrap_or_else(HashMap::new),
97        );
98
99        let fhir_request = http_request_to_fhir_request(SupportedFHIRVersions::R4, http_req)?;
100
101        let ctx = Arc::new(ServerCTX::new(
102            path.tenant,
103            path.project,
104            path.fhir_version,
105            user.clone(),
106            state.fhir_client.clone(),
107            state.rate_limit.clone(),
108        ));
109
110        let response = state.fhir_client.request(ctx, fhir_request).await?;
111
112        let http_response = response.into_response();
113        Ok(http_response)
114    }
115    .instrument(span)
116    .await
117}
118
119async fn fhir_root_handler<
120    Repo: Repository + Send + Sync + 'static,
121    Search: SearchEngine + Send + Sync + 'static,
122    Terminology: FHIRTerminology + Send + Sync + 'static,
123>(
124    method: Method,
125    Extension(user): Extension<Arc<User>>,
126    OriginalUri(uri): OriginalUri,
127    Path(path): Path<FHIRRootHandlerPath>,
128    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
129    body: String,
130) -> Result<Response, OperationOutcomeError> {
131    fhir_handler(
132        user,
133        method,
134        uri,
135        FHIRHandlerPath {
136            tenant: path.tenant,
137            project: path.project,
138            fhir_version: path.fhir_version,
139            fhir_location: None,
140        },
141        state,
142        body,
143    )
144    .await
145}
146
147async fn fhir_type_handler<
148    Repo: Repository + Send + Sync + 'static,
149    Search: SearchEngine + Send + Sync + 'static,
150    Terminology: FHIRTerminology + Send + Sync + 'static,
151>(
152    method: Method,
153    Extension(user): Extension<Arc<User>>,
154    OriginalUri(uri): OriginalUri,
155    Path(path): Path<FHIRHandlerPath>,
156    State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
157    body: String,
158) -> Result<Response, OperationOutcomeError> {
159    fhir_handler(user, method, uri, path, state, body).await
160}
161
162pub async fn server() -> Result<NormalizePath<Router>, OperationOutcomeError> {
163    let config = get_config("environment".into());
164
165    // Setup Ip Source for determining client ip for login etc..
166    let ip_source = config
167        .get(crate::ServerEnvironmentVariables::IpSource)
168        .and_then(|ip_source| {
169            ClientIpSource::from_str(&ip_source).map_err(|_e| {
170                OperationOutcomeError::fatal(
171                    IssueType::Exception(None),
172                    format!("Invalid IP_SOURCE value: {}", ip_source),
173                )
174            })
175        })
176        .unwrap_or_else(|_e| ClientIpSource::ConnectInfo);
177
178    // Varify instantiates.
179    get_certification_provider();
180
181    let pool = get_pool(config.as_ref()).await;
182    let session_store = PostgresStore::new(pool.clone());
183    session_store.migrate().await.map_err(ConfigError::from)?;
184
185    let max_body_size = config
186        .get(crate::ServerEnvironmentVariables::MaxRequestBodySize)
187        .ok()
188        .and_then(|s| s.parse::<usize>().ok())
189        .unwrap_or(4 * 1024 * 1024);
190    let shared_state = create_services(config).await?;
191
192    let fhir_router = Router::new()
193        .route("/{fhir_version}", any(fhir_root_handler))
194        .route("/{fhir_version}/{*fhir_location}", any(fhir_type_handler));
195
196    let protected_resources_router = Router::new()
197        .nest("/fhir", fhir_router)
198        .route("/mcp", post(mcp::route::mcp_handler))
199        .layer(
200            ServiceBuilder::new()
201                .layer(axum::middleware::from_fn_with_state(
202                    shared_state.clone(),
203                    auth_n::middleware::basic_auth::basic_auth_middleware,
204                ))
205                .layer(axum::middleware::from_fn_with_state(
206                    shared_state.clone(),
207                    auth_n::middleware::jwt::token_verifcation,
208                ))
209                .layer(axum::middleware::from_fn(
210                    auth_n::middleware::project_access::project_access,
211                )),
212        );
213
214    let project_router = Router::new().merge(protected_resources_router).nest(
215        "/oidc",
216        auth_n::oidc::routes::create_router(shared_state.clone()),
217    );
218
219    let tenant_router = Router::new()
220        .nest("/auth", auth_n::tenant::routes::create_router())
221        .nest("/{project}/api/v1", project_router)
222        .layer(
223            // Relies on tenant for html so moving operation outcome error handling to here.
224            ServiceBuilder::new()
225                .layer(from_fn(operation_outcome_error_handle))
226                .layer(from_fn(log_operationoutcome_errors)),
227        );
228
229    let discovery_2_0_document_router = Router::new()
230        .route(
231            "/openid-configuration/w/{tenant}/{project}/{*resource}",
232            get(auth_n::oidc::routes::discovery::openid_configuration),
233        )
234        .route(
235            "/openid-configuration/w/{tenant}/{project}",
236            get(auth_n::oidc::routes::discovery::openid_configuration),
237        )
238        .route(
239            "/oauth-protected-resource/w/{tenant}/{project}/{*resource}",
240            get(auth_n::oidc::routes::discovery::oauth_protected_resource),
241        );
242
243    let app = Router::new()
244        .nest("/.well-known", discovery_2_0_document_router)
245        .nest(
246            "/auth",
247            auth_n::global::routes::create_router(shared_state.clone()),
248        )
249        .route("/openapi.json", get(openapi::openapi_document_handler))
250        .nest("/w/{tenant}", tenant_router)
251        .layer(
252            ServiceBuilder::new()
253                .layer(ip_source.into_extension())
254                .layer(NewSentryLayer::<Request<Body>>::new_from_top())
255                .layer(TraceLayer::new_for_http())
256                // 4mb by default.
257                .layer(DefaultBodyLimit::max(max_body_size))
258                .layer(CompressionLayer::new())
259                .layer(SecurityHeaderLayer::new())
260                .layer(SetResponseHeaderLayer::overriding(
261                    HeaderName::from_static("x-api-version"),
262                    HeaderValue::from_static(SERVER_VERSION),
263                ))
264                .layer(
265                    SessionManagerLayer::new(session_store)
266                        .with_secure(true)
267                        .with_same_site(SameSite::None)
268                        .with_expiry(Expiry::OnInactivity(Duration::days(3))),
269                )
270                .layer(
271                    CorsLayer::new()
272                        // allow `GET` and `POST` when accessing the resource
273                        .allow_methods(Any)
274                        // allow requests from any origin
275                        .allow_origin(Any)
276                        .allow_headers(Any),
277                ),
278        )
279        .with_state(shared_state)
280        .nest(root_asset_route().to_str().unwrap(), create_static_server());
281
282    Ok(NormalizePathLayer::trim_trailing_slash().layer(app))
283}
284
285pub async fn serve(port: u16) -> Result<(), OperationOutcomeError> {
286    let server = server().await?;
287
288    let addr = format!("0.0.0.0:{}", port);
289    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
290
291    tracing::info!("Server started");
292    axum::serve(
293        listener,
294        <tower_http::normalize_path::NormalizePath<Router> as ServiceExt<
295            axum::http::Request<Body>,
296        >>::into_make_service_with_connect_info::<SocketAddr>(server),
297    )
298    .await
299    .unwrap();
300
301    Ok(())
302}