1use crate::{
2 auth_n::{self, certificates::get_certification_provider},
3 fhir_client::ServerCTX,
4 fhir_http::{HTTPBody, HTTPRequest, http_request_to_fhir_request},
5 mcp,
6 middleware::{
7 errors::{log_operationoutcome_errors, operation_outcome_error_handle},
8 security_headers::SecurityHeaderLayer,
9 },
10 openapi,
11 services::{AppState, ConfigError, create_services, get_pool},
12 static_assets::{create_static_server, root_asset_route},
13};
14use axum::{
15 Extension, Router, ServiceExt,
16 body::Body,
17 extract::{DefaultBodyLimit, OriginalUri, Path, State},
18 http::Request,
19 http::{HeaderName, HeaderValue, Method, Uri},
20 middleware::from_fn,
21 response::{IntoResponse, Response},
22 routing::{any, get, post},
23};
24use haste_config::get_config;
25use haste_fhir_client::FHIRClient;
26use haste_fhir_operation_error::OperationOutcomeError;
27use haste_fhir_search::SearchEngine;
28use haste_fhir_terminology::FHIRTerminology;
29use haste_jwt::{ProjectId, TenantId, claims::UserTokenClaims};
30use haste_repository::{Repository, types::SupportedFHIRVersions};
31use sentry::integrations::tower::NewSentryLayer;
32use serde::Deserialize;
33use std::{collections::HashMap, sync::Arc};
34use tower::{Layer, ServiceBuilder};
35use tower_http::normalize_path::NormalizePath;
36use tower_http::{
37 compression::CompressionLayer,
38 cors::{Any, CorsLayer},
39 normalize_path::NormalizePathLayer,
40 set_header::SetResponseHeaderLayer,
41 trace::TraceLayer,
42};
43use tower_sessions::{
44 Expiry, SessionManagerLayer,
45 cookie::{SameSite, time::Duration},
46};
47use tower_sessions_sqlx_store::PostgresStore;
48use tracing::{Instrument, Level, span};
49
50const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION");
51
52#[derive(Deserialize)]
53struct FHIRHandlerPath {
54 tenant: TenantId,
55 project: ProjectId,
56 fhir_version: SupportedFHIRVersions,
57 fhir_location: Option<String>,
58}
59
60#[derive(Deserialize)]
61struct FHIRRootHandlerPath {
62 tenant: TenantId,
63 project: ProjectId,
64 fhir_version: SupportedFHIRVersions,
65}
66
67async fn fhir_handler<
68 Repo: Repository + Send + Sync + 'static,
69 Search: SearchEngine + Send + Sync + 'static,
70 Terminology: FHIRTerminology + Send + Sync + 'static,
71>(
72 claims: Arc<UserTokenClaims>,
73 method: Method,
74 uri: Uri,
75 path: FHIRHandlerPath,
76 state: Arc<AppState<Repo, Search, Terminology>>,
77 body: String,
78) -> Result<Response, OperationOutcomeError> {
79 let fhir_location = path.fhir_location.unwrap_or_default();
80 let method_str = method.to_string();
81 let span = span!(Level::ERROR, "FHIR-HTTP", method_str, fhir_location);
82 async {
83 let http_req = HTTPRequest::new(
84 method,
85 fhir_location,
86 HTTPBody::String(body),
87 uri.query()
88 .map(|q| {
89 url::form_urlencoded::parse(q.as_bytes())
90 .into_owned()
91 .collect()
92 })
93 .unwrap_or_else(HashMap::new),
94 );
95
96 let fhir_request = http_request_to_fhir_request(SupportedFHIRVersions::R4, http_req)?;
97
98 let ctx = Arc::new(ServerCTX::new(
99 path.tenant,
100 path.project,
101 path.fhir_version,
102 claims.clone(),
103 state.fhir_client.clone(),
104 state.rate_limit.clone(),
105 ));
106
107 let response = state.fhir_client.request(ctx, fhir_request).await?;
108
109 let http_response = response.into_response();
110 Ok(http_response)
111 }
112 .instrument(span)
113 .await
114}
115
116async fn fhir_root_handler<
117 Repo: Repository + Send + Sync + 'static,
118 Search: SearchEngine + Send + Sync + 'static,
119 Terminology: FHIRTerminology + Send + Sync + 'static,
120>(
121 method: Method,
122 Extension(user): Extension<Arc<UserTokenClaims>>,
123 OriginalUri(uri): OriginalUri,
124 Path(path): Path<FHIRRootHandlerPath>,
125 State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
126 body: String,
127) -> Result<Response, OperationOutcomeError> {
128 fhir_handler(
129 user,
130 method,
131 uri,
132 FHIRHandlerPath {
133 tenant: path.tenant,
134 project: path.project,
135 fhir_version: path.fhir_version,
136 fhir_location: None,
137 },
138 state,
139 body,
140 )
141 .await
142}
143
144async fn fhir_type_handler<
145 Repo: Repository + Send + Sync + 'static,
146 Search: SearchEngine + Send + Sync + 'static,
147 Terminology: FHIRTerminology + Send + Sync + 'static,
148>(
149 method: Method,
150 Extension(user): Extension<Arc<UserTokenClaims>>,
151 OriginalUri(uri): OriginalUri,
152 Path(path): Path<FHIRHandlerPath>,
153 State(state): State<Arc<AppState<Repo, Search, Terminology>>>,
154 body: String,
155) -> Result<Response, OperationOutcomeError> {
156 fhir_handler(user, method, uri, path, state, body).await
157}
158
159pub async fn server() -> Result<NormalizePath<Router>, OperationOutcomeError> {
160 let config = get_config("environment".into());
161 get_certification_provider();
163
164 let pool = get_pool(config.as_ref()).await;
165 let session_store = PostgresStore::new(pool.clone());
166 session_store.migrate().await.map_err(ConfigError::from)?;
167
168 let max_body_size = config
169 .get(crate::ServerEnvironmentVariables::MaxRequestBodySize)
170 .ok()
171 .and_then(|s| s.parse::<usize>().ok())
172 .unwrap_or(4 * 1024 * 1024);
173 let shared_state = create_services(config).await?;
174
175 let fhir_router = Router::new()
176 .route("/{fhir_version}", any(fhir_root_handler))
177 .route("/{fhir_version}/{*fhir_location}", any(fhir_type_handler));
178
179 let protected_resources_router = Router::new()
180 .nest("/fhir", fhir_router)
181 .route("/mcp", post(mcp::route::mcp_handler))
182 .layer(
183 ServiceBuilder::new()
184 .layer(axum::middleware::from_fn_with_state(
185 shared_state.clone(),
186 auth_n::middleware::basic_auth::basic_auth_middleware,
187 ))
188 .layer(axum::middleware::from_fn_with_state(
189 shared_state.clone(),
190 auth_n::middleware::jwt::token_verifcation,
191 ))
192 .layer(axum::middleware::from_fn(
193 auth_n::middleware::project_access::project_access,
194 )),
195 );
196
197 let project_router = Router::new().merge(protected_resources_router).nest(
198 "/oidc",
199 auth_n::oidc::routes::create_router(shared_state.clone()),
200 );
201
202 let tenant_router = Router::new()
203 .nest("/{project}/api/v1", project_router)
204 .layer(
205 ServiceBuilder::new()
207 .layer(from_fn(operation_outcome_error_handle))
208 .layer(from_fn(log_operationoutcome_errors)),
209 );
210
211 let discovery_2_0_document_router = Router::new()
212 .route(
213 "/openid-configuration/w/{tenant}/{project}/{*resource}",
214 get(auth_n::oidc::routes::discovery::openid_configuration),
215 )
216 .route(
217 "/openid-configuration/w/{tenant}/{project}",
218 get(auth_n::oidc::routes::discovery::openid_configuration),
219 )
220 .route(
221 "/oauth-protected-resource/w/{tenant}/{project}/{*resource}",
222 get(auth_n::oidc::routes::discovery::oauth_protected_resource),
223 );
224
225 let app = Router::new()
226 .nest("/.well-known", discovery_2_0_document_router)
227 .nest(
228 "/auth",
229 auth_n::global::routes::create_router(shared_state.clone()),
230 )
231 .route("/openapi.json", get(openapi::openapi_document_handler))
232 .nest("/w/{tenant}", tenant_router)
233 .layer(
234 ServiceBuilder::new()
235 .layer(NewSentryLayer::<Request<Body>>::new_from_top())
236 .layer(TraceLayer::new_for_http())
237 .layer(DefaultBodyLimit::max(max_body_size))
239 .layer(CompressionLayer::new())
240 .layer(SecurityHeaderLayer::new())
241 .layer(SetResponseHeaderLayer::overriding(
242 HeaderName::from_static("x-api-version"),
243 HeaderValue::from_static(SERVER_VERSION),
244 ))
245 .layer(
246 SessionManagerLayer::new(session_store)
247 .with_secure(true)
248 .with_same_site(SameSite::None)
249 .with_expiry(Expiry::OnInactivity(Duration::days(3))),
250 )
251 .layer(
252 CorsLayer::new()
253 .allow_methods(Any)
255 .allow_origin(Any)
257 .allow_headers(Any),
258 ),
259 )
260 .with_state(shared_state)
261 .nest(root_asset_route().to_str().unwrap(), create_static_server());
262
263 Ok(NormalizePathLayer::trim_trailing_slash().layer(app))
264}
265
266pub async fn serve(port: u16) -> Result<(), OperationOutcomeError> {
267 let server = server().await?;
268
269 let addr = format!("0.0.0.0:{}", port);
270 let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
271
272 tracing::info!("Server started");
273 axum::serve(
274 listener,
275 <tower_http::normalize_path::NormalizePath<Router> as ServiceExt<
276 axum::http::Request<Body>,
277 >>::into_make_service(server),
278 )
279 .await
280 .unwrap();
281
282 Ok(())
283}