Skip to main content

haste_repository/pg/
rate_limit.rs

1use std::pin::Pin;
2
3use crate::pg::{
4    PGConnection,
5    utilities::{commit_transaction, create_transaction},
6};
7use haste_rate_limit::{RateLimit, RateLimitError};
8use sqlx::{Acquire, Postgres};
9
10fn check_rate_limit<'a, 'c, Connection: Acquire<'c, Database = Postgres> + Send + 'a>(
11    connection: Connection,
12    rate_key: &'a str,
13    max: i32,
14    points: i32,
15    window_in_seconds: i32,
16) -> impl Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a {
17    async move {
18        let mut conn = connection
19            .acquire()
20            .await
21            .map_err(|_e| RateLimitError::Error("could not acquire connection".to_string()))?;
22
23        let result: i32 = sqlx::query!(
24            "SELECT check_rate_limit($1, $2, $3, $4) as current_limit",
25            rate_key as &str,
26            max as i32,
27            points as i32,
28            window_in_seconds as i32,
29        )
30        .fetch_one(&mut *conn)
31        .await
32        .map_err(|_e| RateLimitError::Exceeded)?
33        .current_limit
34        .unwrap_or(0);
35
36        Ok(result)
37    }
38}
39
40impl RateLimit for PGConnection {
41    /// Returns the current points after the operation.
42    /// Note use of box and pin so can satisfy dynamic dispatch requirements.
43    fn check<'a>(
44        &'a self,
45        rate_key: &'a str,
46        max: i32,
47        points: i32,
48        window_in_seconds: i32,
49    ) -> Pin<Box<dyn Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a>>
50    {
51        Box::pin(async move {
52            match self {
53                PGConnection::Pool(_pool, _) => {
54                    let tx = create_transaction(self, true)
55                        .await
56                        .map_err(|e| RateLimitError::Error(e.to_string()))?;
57                    let res = {
58                        let mut conn = tx.lock().await;
59                        let res =
60                            check_rate_limit(&mut *conn, rate_key, max, points, window_in_seconds)
61                                .await?;
62                        res
63                    };
64                    commit_transaction(tx)
65                        .await
66                        .map_err(|e| RateLimitError::Error(e.to_string()))?;
67                    Ok(res)
68                }
69                PGConnection::Transaction(tx, _) => {
70                    let mut tx = tx.lock().await;
71                    let res = check_rate_limit(&mut *tx, rate_key, max, points, window_in_seconds)
72                        .await?;
73                    Ok(res)
74                }
75            }
76        })
77    }
78}