1use std::{fmt::Debug, pin::Pin, sync::Arc};
2
3#[derive(Debug)]
4pub struct Context<CTX: Debug, Request: Debug, Response: Debug> {
5 pub ctx: CTX,
6 pub request: Request,
7 pub response: Option<Response>,
8}
9
10pub type MiddlewareOutput<Context, Error> =
11 Pin<Box<dyn Future<Output = Result<Context, Error>> + Send>>;
12pub type Next<State, Context, Error> =
13 Box<dyn Fn(State, Context) -> MiddlewareOutput<Context, Error> + Send + Sync>;
14
15pub type MiddlewareChainOld<State, CTX, Request, Response, Error> = Box<
16 dyn Fn(
17 State,
18 Context<CTX, Request, Response>,
19 Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
20 ) -> MiddlewareOutput<Context<CTX, Request, Response>, Error>
21 + Send
22 + Sync,
23>;
24
25pub trait MiddlewareChain<State, CTX: Debug, Request: Debug, Response: Debug, Error>:
26 Send + Sync
27{
28 fn call(
29 &self,
30 state: State,
31 ctx: Context<CTX, Request, Response>,
32 next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
33 ) -> MiddlewareOutput<Context<CTX, Request, Response>, Error>;
34}
35
36pub struct Middleware<State, CTX: Debug, Request: Debug, Response: Debug, Error> {
37 _state: std::marker::PhantomData<State>,
38 _phantom: std::marker::PhantomData<CTX>,
39 _execute: Arc<Next<State, Context<CTX, Request, Response>, Error>>,
40}
41
42impl<
43 State: 'static + Send + Sync,
44 CTX: 'static + Send + Sync + Debug,
45 Request: 'static + Send + Sync + Debug,
46 Response: 'static + Send + Sync + Debug,
47 Error: 'static + Send + Sync,
48> Middleware<State, CTX, Request, Response, Error>
49{
50 pub fn new(
51 mut middleware: Vec<Box<dyn MiddlewareChain<State, CTX, Request, Response, Error>>>,
52 ) -> Self {
53 middleware.reverse();
54 let next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>> = middleware
55 .into_iter()
56 .fold(
57 None,
58 |prev_next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
59 middleware: Box<dyn MiddlewareChain<State, CTX, Request, Response, Error>>| {
60 Some(Arc::new(Box::new(move |state, ctx| {
61 middleware.call(state, ctx, prev_next.clone())
62 })))
63 },
64 );
65
66 Middleware {
67 _state: std::marker::PhantomData,
68 _phantom: std::marker::PhantomData,
69 _execute: next.unwrap(),
70 }
71 }
72
73 pub async fn call(
74 &self,
75 state: State,
76 ctx: CTX,
77 request: Request,
78 ) -> Result<Context<CTX, Request, Response>, Error> {
79 (self._execute)(
80 state,
81 Context {
82 ctx,
83 request,
84 response: None,
85 },
86 )
87 .await
88 }
89}
90
91#[cfg(test)]
92mod test {
93 use super::*;
94
95 struct MiddlewareChain1 {}
96 impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain1 {
97 fn call(
98 &self,
99 _state: (),
100 x: Context<(), usize, usize>,
101 _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
102 ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
103 {
104 Box::pin(async move {
105 let mut x = if let Some(next) = _next {
106 let p = next((), x).await;
107 p
108 } else {
109 Ok(x)
110 }?;
111 println!("Middleware 1 executed");
112 x.response = x.response.map(|r| r + 1);
113 Ok(x)
114 })
115 }
116 }
117
118 struct MiddlewareChain2 {}
119 impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain2 {
120 fn call(
121 &self,
122 _state: (),
123 x: Context<(), usize, usize>,
124 _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
125 ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
126 {
127 Box::pin(async move {
128 let mut x = if let Some(next) = _next {
129 let p = next((), x).await;
130 p
131 } else {
132 Ok(x)
133 }?;
134
135 println!("Middleware 2 executed {:?}", x.response);
136 x.response = x.response.map(|r| r + 2);
137 Ok(x)
138 })
139 }
140 }
141
142 struct MiddlewareChain3 {}
143 impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain3 {
144 fn call(
145 &self,
146 _state: (),
147 x: Context<(), usize, usize>,
148 _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
149 ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
150 {
151 Box::pin(async move {
152 let mut x = if let Some(next) = _next {
153 let p = next((), x).await;
154 p
155 } else {
156 Ok(x)
157 }?;
158
159 x.response = x.response.map_or(Some(x.request + 3), |r| Some(r + 3));
160 Ok(x)
161 })
162 }
163 }
164
165 #[tokio::test]
166 async fn test_middleware() {
167 let test = Middleware::new(vec![
168 Box::new(MiddlewareChain1 {}),
169 Box::new(MiddlewareChain2 {}),
170 Box::new(MiddlewareChain3 {}),
171 ]);
172
173 let ret = test.call((), (), 42).await;
174 assert_eq!(Some(48), ret.unwrap().response);
175 }
176}