Skip to main content

haste_fhir_client/
middleware.rs

1use std::{fmt::Debug, pin::Pin, sync::Arc};
2
3#[derive(Debug)]
4pub struct Context<CTX: Debug, Request: Debug, Response: Debug> {
5    pub ctx: CTX,
6    pub request: Request,
7    pub response: Option<Response>,
8}
9
10pub type MiddlewareOutput<Context, Error> =
11    Pin<Box<dyn Future<Output = Result<Context, Error>> + Send>>;
12pub type Next<State, Context, Error> =
13    Box<dyn Fn(State, Context) -> MiddlewareOutput<Context, Error> + Send + Sync>;
14
15pub type MiddlewareChainOld<State, CTX, Request, Response, Error> = Box<
16    dyn Fn(
17            State,
18            Context<CTX, Request, Response>,
19            Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
20        ) -> MiddlewareOutput<Context<CTX, Request, Response>, Error>
21        + Send
22        + Sync,
23>;
24
25pub trait MiddlewareChain<State, CTX: Debug, Request: Debug, Response: Debug, Error>:
26    Send + Sync
27{
28    fn call(
29        &self,
30        state: State,
31        ctx: Context<CTX, Request, Response>,
32        next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
33    ) -> MiddlewareOutput<Context<CTX, Request, Response>, Error>;
34}
35
36pub struct Middleware<State, CTX: Debug, Request: Debug, Response: Debug, Error> {
37    _state: std::marker::PhantomData<State>,
38    _phantom: std::marker::PhantomData<CTX>,
39    _execute: Arc<Next<State, Context<CTX, Request, Response>, Error>>,
40}
41
42impl<
43    State: 'static + Send + Sync,
44    CTX: 'static + Send + Sync + Debug,
45    Request: 'static + Send + Sync + Debug,
46    Response: 'static + Send + Sync + Debug,
47    Error: 'static + Send + Sync,
48> Middleware<State, CTX, Request, Response, Error>
49{
50    pub fn new(
51        mut middleware: Vec<Box<dyn MiddlewareChain<State, CTX, Request, Response, Error>>>,
52    ) -> Self {
53        middleware.reverse();
54        let next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>> = middleware
55            .into_iter()
56            .fold(
57            None,
58            |prev_next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
59             middleware: Box<dyn MiddlewareChain<State, CTX, Request, Response, Error>>| {
60                Some(Arc::new(Box::new(move |state, ctx| {
61                    middleware.call(state, ctx, prev_next.clone())
62                })))
63            },
64        );
65
66        Middleware {
67            _state: std::marker::PhantomData,
68            _phantom: std::marker::PhantomData,
69            _execute: next.unwrap(),
70        }
71    }
72
73    pub async fn call(
74        &self,
75        state: State,
76        ctx: CTX,
77        request: Request,
78    ) -> Result<Context<CTX, Request, Response>, Error> {
79        (self._execute)(
80            state,
81            Context {
82                ctx,
83                request,
84                response: None,
85            },
86        )
87        .await
88    }
89}
90
91#[cfg(test)]
92mod test {
93    use super::*;
94
95    struct MiddlewareChain1 {}
96    impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain1 {
97        fn call(
98            &self,
99            _state: (),
100            x: Context<(), usize, usize>,
101            _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
102        ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
103        {
104            Box::pin(async move {
105                let mut x = if let Some(next) = _next {
106                    let p = next((), x).await;
107                    p
108                } else {
109                    Ok(x)
110                }?;
111                println!("Middleware 1 executed");
112                x.response = x.response.map(|r| r + 1);
113                Ok(x)
114            })
115        }
116    }
117
118    struct MiddlewareChain2 {}
119    impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain2 {
120        fn call(
121            &self,
122            _state: (),
123            x: Context<(), usize, usize>,
124            _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
125        ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
126        {
127            Box::pin(async move {
128                let mut x = if let Some(next) = _next {
129                    let p = next((), x).await;
130                    p
131                } else {
132                    Ok(x)
133                }?;
134
135                println!("Middleware 2 executed {:?}", x.response);
136                x.response = x.response.map(|r| r + 2);
137                Ok(x)
138            })
139        }
140    }
141
142    struct MiddlewareChain3 {}
143    impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain3 {
144        fn call(
145            &self,
146            _state: (),
147            x: Context<(), usize, usize>,
148            _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
149        ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
150        {
151            Box::pin(async move {
152                let mut x = if let Some(next) = _next {
153                    let p = next((), x).await;
154                    p
155                } else {
156                    Ok(x)
157                }?;
158
159                x.response = x.response.map_or(Some(x.request + 3), |r| Some(r + 3));
160                Ok(x)
161            })
162        }
163    }
164
165    #[tokio::test]
166    async fn test_middleware() {
167        let test = Middleware::new(vec![
168            Box::new(MiddlewareChain1 {}),
169            Box::new(MiddlewareChain2 {}),
170            Box::new(MiddlewareChain3 {}),
171        ]);
172
173        let ret = test.call((), (), 42).await;
174        assert_eq!(Some(48), ret.unwrap().response);
175    }
176}