haste_repository/pg/
rate_limit.rs1use std::{pin::Pin, sync::LazyLock};
2
3use crate::pg::{
4 PGConnection,
5 utilities::{commit_transaction, create_transaction},
6};
7use haste_rate_limit::{RateLimit, RateLimitError};
8use moka::future::{Cache, CacheBuilder};
9use sqlx::{Acquire, Postgres};
10
11#[derive(Clone)]
12enum RateLimitState {
13 Count(i32),
14 Max,
15}
16
17static MEMORY: LazyLock<Cache<String, RateLimitState>> = LazyLock::new(
18 || {
20 CacheBuilder::new(10_000)
21 .time_to_idle(std::time::Duration::from_secs(30))
22 .build()
23 },
24);
25
26fn _check_rate_limit_remote<'a, 'c, Connection: Acquire<'c, Database = Postgres> + Send + 'a>(
27 connection: Connection,
28 rate_key: &'a str,
29 max: i32,
30 points: i32,
31 window_in_seconds: i32,
32) -> impl Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a {
33 async move {
34 let mut conn = connection
35 .acquire()
36 .await
37 .map_err(|_e| RateLimitError::Error("could not acquire connection".to_string()))?;
38
39 let result: i32 = sqlx::query!(
40 "SELECT check_rate_limit($1, $2, $3, $4) as current_limit",
41 rate_key as &str,
42 max as i32,
43 points as i32,
44 window_in_seconds as i32,
45 )
46 .fetch_one(&mut *conn)
47 .await
48 .map_err(|_e| RateLimitError::Exceeded)?
49 .current_limit
50 .unwrap_or(0);
51
52 Ok(result)
53 }
54}
55
56async fn check_rate_limit_remote<'a>(
57 pg: PGConnection,
58 rate_key: &'a str,
59 max: i32,
60 points: i32,
61 window_in_seconds: i32,
62) -> Result<i32, haste_rate_limit::RateLimitError> {
63 match &pg {
64 PGConnection::Pool(_pool, _) => {
65 let tx = create_transaction(&pg, true)
66 .await
67 .map_err(|e| RateLimitError::Error(e.to_string()))?;
68 let res = {
69 let mut conn = tx.lock().await;
70 let res =
71 _check_rate_limit_remote(&mut *conn, rate_key, max, points, window_in_seconds)
72 .await?;
73 res
74 };
75 commit_transaction(tx)
76 .await
77 .map_err(|e| RateLimitError::Error(e.to_string()))?;
78 Ok(res)
79 }
80 PGConnection::Transaction(tx, _) => {
81 let mut tx = tx.lock().await;
82 let res = _check_rate_limit_remote(&mut *tx, rate_key, max, points, window_in_seconds)
83 .await?;
84 Ok(res)
85 }
86 }
87}
88
89fn check_rate_limit<'a>(
90 connection: PGConnection,
91 rate_key: &'a str,
92 max: i32,
93 points: i32,
94 window_in_seconds: i32,
95) -> impl Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a {
96 async move {
97 if let Some(current) = MEMORY.get(rate_key).await {
99 let cloned_key = rate_key.to_string();
100 tokio::spawn(async move {
103 let result = check_rate_limit_remote(
104 connection,
105 &cloned_key,
106 max,
107 points,
108 window_in_seconds,
109 )
110 .await;
111
112 if let Ok(points) = result {
113 MEMORY
114 .insert(cloned_key, RateLimitState::Count(points))
115 .await;
116 } else if let Err(e) = result {
117 match e {
118 RateLimitError::Exceeded => {
119 MEMORY.insert(cloned_key, RateLimitState::Max).await;
121 }
122 RateLimitError::Error(e) => {
123 println!("Error checking rate limit: {:?}", e);
124 }
125 }
126 }
127 });
128
129 match current {
130 RateLimitState::Count(current) => {
131 let current_score = current + points;
132
133 if current_score > max {
134 Err(RateLimitError::Exceeded)
135 } else {
136 MEMORY
137 .insert(rate_key.to_string(), RateLimitState::Count(current_score))
138 .await;
139 Ok(current_score)
140 }
141 }
142 RateLimitState::Max => Err(RateLimitError::Exceeded),
143 }
144 } else {
145 let result =
146 check_rate_limit_remote(connection, rate_key, max, points, window_in_seconds)
147 .await?;
148
149 MEMORY
150 .insert(rate_key.to_string(), RateLimitState::Count(result))
151 .await;
152
153 Ok(result)
154 }
155 }
156}
157
158impl RateLimit for PGConnection {
159 fn check<'a>(
162 &'a self,
163 rate_key: &'a str,
164 max: i32,
165 points: i32,
166 window_in_seconds: i32,
167 ) -> Pin<Box<dyn Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a>>
168 {
169 let connection = self.clone();
170 Box::pin(async move {
171 let res =
172 check_rate_limit(connection, rate_key, max, points, window_in_seconds).await?;
173 Ok(res)
174 })
175 }
176}