haste_repository/pg/
rate_limit.rs1use std::pin::Pin;
2
3use crate::pg::{
4 PGConnection,
5 utilities::{commit_transaction, create_transaction},
6};
7use haste_rate_limit::{RateLimit, RateLimitError};
8use sqlx::{Acquire, Postgres};
9
10fn check_rate_limit<'a, 'c, Connection: Acquire<'c, Database = Postgres> + Send + 'a>(
11 connection: Connection,
12 rate_key: &'a str,
13 max: i32,
14 points: i32,
15 window_in_seconds: i32,
16) -> impl Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a {
17 async move {
18 let mut conn = connection
19 .acquire()
20 .await
21 .map_err(|_e| RateLimitError::Error("could not acquire connection".to_string()))?;
22
23 let result: i32 = sqlx::query!(
24 "SELECT check_rate_limit($1, $2, $3, $4) as current_limit",
25 rate_key as &str,
26 max as i32,
27 points as i32,
28 window_in_seconds as i32,
29 )
30 .fetch_one(&mut *conn)
31 .await
32 .map_err(|_e| RateLimitError::Exceeded)?
33 .current_limit
34 .unwrap_or(0);
35
36 Ok(result)
37 }
38}
39
40impl RateLimit for PGConnection {
41 fn check<'a>(
44 &'a self,
45 rate_key: &'a str,
46 max: i32,
47 points: i32,
48 window_in_seconds: i32,
49 ) -> Pin<Box<dyn Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a>>
50 {
51 Box::pin(async move {
52 match self {
53 PGConnection::Pool(_pool, _) => {
54 let tx = create_transaction(self, true)
55 .await
56 .map_err(|e| RateLimitError::Error(e.to_string()))?;
57 let res = {
58 let mut conn = tx.lock().await;
59 let res =
60 check_rate_limit(&mut *conn, rate_key, max, points, window_in_seconds)
61 .await?;
62 res
63 };
64 commit_transaction(tx)
65 .await
66 .map_err(|e| RateLimitError::Error(e.to_string()))?;
67 Ok(res)
68 }
69 PGConnection::Transaction(tx, _) => {
70 let mut tx = tx.lock().await;
71 let res = check_rate_limit(&mut *tx, rate_key, max, points, window_in_seconds)
72 .await?;
73 Ok(res)
74 }
75 }
76 })
77 }
78}