1use std::{pin::Pin, sync::Arc};
2
3pub struct Context<CTX, Request, Response> {
4 pub ctx: CTX,
5 pub request: Request,
6 pub response: Option<Response>,
7}
8
9pub type MiddlewareOutput<Context, Error> =
10 Pin<Box<dyn Future<Output = Result<Context, Error>> + Send>>;
11pub type Next<State, Context, Error> =
12 Box<dyn Fn(State, Context) -> MiddlewareOutput<Context, Error> + Send + Sync>;
13
14pub type MiddlewareChainOld<State, CTX, Request, Response, Error> = Box<
15 dyn Fn(
16 State,
17 Context<CTX, Request, Response>,
18 Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
19 ) -> MiddlewareOutput<Context<CTX, Request, Response>, Error>
20 + Send
21 + Sync,
22>;
23
24pub trait MiddlewareChain<State, CTX, Request, Response, Error>: Send + Sync {
25 fn call(
26 &self,
27 state: State,
28 ctx: Context<CTX, Request, Response>,
29 next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
30 ) -> MiddlewareOutput<Context<CTX, Request, Response>, Error>;
31}
32
33pub struct Middleware<State, CTX, Request, Response, Error> {
34 _state: std::marker::PhantomData<State>,
35 _phantom: std::marker::PhantomData<CTX>,
36 _execute: Arc<Next<State, Context<CTX, Request, Response>, Error>>,
37}
38
39impl<
40 State: 'static + Send + Sync,
41 CTX: 'static + Send + Sync,
42 Request: 'static + Send + Sync,
43 Response: 'static + Send + Sync,
44 Error: 'static + Send + Sync,
45> Middleware<State, CTX, Request, Response, Error>
46{
47 pub fn new(
48 mut middleware: Vec<Box<dyn MiddlewareChain<State, CTX, Request, Response, Error>>>,
49 ) -> Self {
50 middleware.reverse();
51 let next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>> = middleware
52 .into_iter()
53 .fold(
54 None,
55 |prev_next: Option<Arc<Next<State, Context<CTX, Request, Response>, Error>>>,
56 middleware: Box<dyn MiddlewareChain<State, CTX, Request, Response, Error>>| {
57 Some(Arc::new(Box::new(move |state, ctx| {
58 middleware.call(state, ctx, prev_next.clone())
59 })))
60 },
61 );
62
63 Middleware {
64 _state: std::marker::PhantomData,
65 _phantom: std::marker::PhantomData,
66 _execute: next.unwrap(),
67 }
68 }
69
70 pub async fn call(
71 &self,
72 state: State,
73 ctx: CTX,
74 request: Request,
75 ) -> Result<Context<CTX, Request, Response>, Error> {
76 (self._execute)(
77 state,
78 Context {
79 ctx,
80 request,
81 response: None,
82 },
83 )
84 .await
85 }
86}
87
88#[cfg(test)]
89mod test {
90 use super::*;
91
92 struct MiddlewareChain1 {}
93 impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain1 {
94 fn call(
95 &self,
96 _state: (),
97 x: Context<(), usize, usize>,
98 _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
99 ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
100 {
101 Box::pin(async move {
102 let mut x = if let Some(next) = _next {
103 let p = next((), x).await;
104 p
105 } else {
106 Ok(x)
107 }?;
108 println!("Middleware 1 executed");
109 x.response = x.response.map(|r| r + 1);
110 Ok(x)
111 })
112 }
113 }
114
115 struct MiddlewareChain2 {}
116 impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain2 {
117 fn call(
118 &self,
119 _state: (),
120 x: Context<(), usize, usize>,
121 _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
122 ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
123 {
124 Box::pin(async move {
125 let mut x = if let Some(next) = _next {
126 let p = next((), x).await;
127 p
128 } else {
129 Ok(x)
130 }?;
131
132 println!("Middleware 2 executed {:?}", x.response);
133 x.response = x.response.map(|r| r + 2);
134 Ok(x)
135 })
136 }
137 }
138
139 struct MiddlewareChain3 {}
140 impl MiddlewareChain<(), (), usize, usize, String> for MiddlewareChain3 {
141 fn call(
142 &self,
143 _state: (),
144 x: Context<(), usize, usize>,
145 _next: Option<Arc<Next<(), Context<(), usize, usize>, String>>>,
146 ) -> Pin<Box<dyn Future<Output = Result<Context<(), usize, usize>, String>> + Send>>
147 {
148 Box::pin(async move {
149 let mut x = if let Some(next) = _next {
150 let p = next((), x).await;
151 p
152 } else {
153 Ok(x)
154 }?;
155
156 x.response = x.response.map_or(Some(x.request + 3), |r| Some(r + 3));
157 Ok(x)
158 })
159 }
160 }
161
162 #[tokio::test]
163 async fn test_middleware() {
164 let test = Middleware::new(vec![
165 Box::new(MiddlewareChain1 {}),
166 Box::new(MiddlewareChain2 {}),
167 Box::new(MiddlewareChain3 {}),
168 ]);
169
170 let ret = test.call((), (), 42).await;
171 assert_eq!(Some(48), ret.unwrap().response);
172 }
173}