haste_fhir_client/
middleware.rs

1use std::{pin::Pin, sync::Arc};
2
3pub struct Context<CTX, Request, Response> {
4    pub ctx: CTX,
5    pub request: Request,
6    pub response: Option<Response>,
7}
8
9pub type MiddlewareOutput<Context, Error> =
10    Pin<Box<dyn Future<Output = Result<Context, Error>> + Send>>;
11pub type Next<State, Context, Error> =
12    Box<dyn Fn(State, Context) -> MiddlewareOutput<Context, Error> + Send + Sync>;
13
14pub type MiddlewareChainOld<State, CTX, Request, Response, Error> = Box<
15    dyn Fn(
16            State,
17            Context<CTX, Request, Response>,
18            Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
19        ) -> MiddlewareOutput<Context<CTX, Request, Response>, Error>
20        + Send
21        + Sync,
22>;
23
24pub trait MiddlewareChain<State, CTX, Request, Response, Error>: Send + Sync {
25    fn call(
26        &self,
27        state: State,
28        ctx: Context<CTX, Request, Response>,
29        next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
30    ) -> MiddlewareOutput<Context<CTX, Request, Response>, Error>;
31}
32
33pub struct Middleware<State, CTX, Request, Response, Error> {
34    _state: std::marker::PhantomData<State>,
35    _phantom: std::marker::PhantomData<CTX>,
36    _execute: Arc<Next<State, Context<CTX, Request, Response>, Error>>,
37}
38
39impl<
40    State: 'static + Send + Sync,
41    CTX: 'static + Send + Sync,
42    Request: 'static + Send + Sync,
43    Response: 'static + Send + Sync,
44    Error: 'static + Send + Sync,
45> Middleware<State, CTX, Request, Response, Error>
46{
47    pub fn new(
48        mut middleware: Vec<Box<dyn MiddlewareChain<State, CTX, Request, Response, Error>>>,
49    ) -> Self {
50        middleware.reverse();
51        let next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>> = middleware
52            .into_iter()
53            .fold(
54            None,
55            |prev_next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
56             middleware: Box<dyn MiddlewareChain<State, CTX, Request, Response, Error>>| {
57                Some(Arc::new(Box::new(move |state, ctx| {
58                    middleware.call(state, ctx, prev_next.clone())
59                })))
60            },
61        );
62
63        Middleware {
64            _state: std::marker::PhantomData,
65            _phantom: std::marker::PhantomData,
66            _execute: next.unwrap(),
67        }
68    }
69
70    pub async fn call(
71        &self,
72        state: State,
73        ctx: CTX,
74        request: Request,
75    ) -> Result<Context<CTX, Request, Response>, Error> {
76        (self._execute)(
77            state,
78            Context {
79                ctx,
80                request,
81                response: None,
82            },
83        )
84        .await
85    }
86}
87
88#[cfg(test)]
89mod test {
90    use super::*;
91
92    struct MiddlewareChain1 {}
93    impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain1 {
94        fn call(
95            &self,
96            _state: (),
97            x: Context<(), usize, usize>,
98            _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
99        ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
100        {
101            Box::pin(async move {
102                let mut x = if let Some(next) = _next {
103                    let p = next((), x).await;
104                    p
105                } else {
106                    Ok(x)
107                }?;
108                println!("Middleware 1 executed");
109                x.response = x.response.map(|r| r + 1);
110                Ok(x)
111            })
112        }
113    }
114
115    struct MiddlewareChain2 {}
116    impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain2 {
117        fn call(
118            &self,
119            _state: (),
120            x: Context<(), usize, usize>,
121            _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
122        ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
123        {
124            Box::pin(async move {
125                let mut x = if let Some(next) = _next {
126                    let p = next((), x).await;
127                    p
128                } else {
129                    Ok(x)
130                }?;
131
132                println!("Middleware 2 executed {:?}", x.response);
133                x.response = x.response.map(|r| r + 2);
134                Ok(x)
135            })
136        }
137    }
138
139    struct MiddlewareChain3 {}
140    impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain3 {
141        fn call(
142            &self,
143            _state: (),
144            x: Context<(), usize, usize>,
145            _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
146        ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
147        {
148            Box::pin(async move {
149                let mut x = if let Some(next) = _next {
150                    let p = next((), x).await;
151                    p
152                } else {
153                    Ok(x)
154                }?;
155
156                x.response = x.response.map_or(Some(x.request + 3), |r| Some(r + 3));
157                Ok(x)
158            })
159        }
160    }
161
162    #[tokio::test]
163    async fn test_middleware() {
164        let test = Middleware::new(vec![
165            Box::new(MiddlewareChain1 {}),
166            Box::new(MiddlewareChain2 {}),
167            Box::new(MiddlewareChain3 {}),
168        ]);
169
170        let ret = test.call((), (), 42).await;
171        assert_eq!(Some(48), ret.unwrap().response);
172    }
173}