Skip to main content

haste_repository/pg/
rate_limit.rs

1use std::{pin::Pin, sync::LazyLock};
2
3use crate::pg::{
4    PGConnection,
5    utilities::{commit_transaction, create_transaction},
6};
7use haste_rate_limit::{RateLimit, RateLimitError};
8use moka::future::{Cache, CacheBuilder};
9use sqlx::{Acquire, Postgres};
10
11#[derive(Clone)]
12enum RateLimitState {
13    Count(i32),
14    Max,
15}
16
17static MEMORY: LazyLock<Cache<String, RateLimitState>> = LazyLock::new(
18    // Cache entries live for 30 seconds, after which they will be automatically evicted.
19    || {
20        CacheBuilder::new(10_000)
21            .time_to_idle(std::time::Duration::from_secs(30))
22            .build()
23    },
24);
25
26fn _check_rate_limit_remote<'a, 'c, Connection: Acquire<'c, Database = Postgres> + Send + 'a>(
27    connection: Connection,
28    rate_key: &'a str,
29    max: i32,
30    points: i32,
31    window_in_seconds: i32,
32) -> impl Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a {
33    async move {
34        let mut conn = connection
35            .acquire()
36            .await
37            .map_err(|_e| RateLimitError::Error("could not acquire connection".to_string()))?;
38
39        let result: i32 = sqlx::query!(
40            "SELECT check_rate_limit($1, $2, $3, $4) as current_limit",
41            rate_key as &str,
42            max as i32,
43            points as i32,
44            window_in_seconds as i32,
45        )
46        .fetch_one(&mut *conn)
47        .await
48        .map_err(|_e| RateLimitError::Exceeded)?
49        .current_limit
50        .unwrap_or(0);
51
52        Ok(result)
53    }
54}
55
56async fn check_rate_limit_remote<'a>(
57    pg: PGConnection,
58    rate_key: &'a str,
59    max: i32,
60    points: i32,
61    window_in_seconds: i32,
62) -> Result<i32, haste_rate_limit::RateLimitError> {
63    match &pg {
64        PGConnection::Pool(_pool, _) => {
65            let tx = create_transaction(&pg, true)
66                .await
67                .map_err(|e| RateLimitError::Error(e.to_string()))?;
68            let res = {
69                let mut conn = tx.lock().await;
70                let res =
71                    _check_rate_limit_remote(&mut *conn, rate_key, max, points, window_in_seconds)
72                        .await?;
73                res
74            };
75            commit_transaction(tx)
76                .await
77                .map_err(|e| RateLimitError::Error(e.to_string()))?;
78            Ok(res)
79        }
80        PGConnection::Transaction(tx, _) => {
81            let mut tx = tx.lock().await;
82            let res = _check_rate_limit_remote(&mut *tx, rate_key, max, points, window_in_seconds)
83                .await?;
84            Ok(res)
85        }
86    }
87}
88
89fn check_rate_limit<'a>(
90    connection: PGConnection,
91    rate_key: &'a str,
92    max: i32,
93    points: i32,
94    window_in_seconds: i32,
95) -> impl Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a {
96    async move {
97        // First check in-memory cache
98        if let Some(current) = MEMORY.get(rate_key).await {
99            let cloned_key = rate_key.to_string();
100            // Run background task to update the cache asynchronously without blocking the main request flow.
101            // This allows us to have a fast response time while still keeping the cache reasonably up to date.
102            tokio::spawn(async move {
103                let result = check_rate_limit_remote(
104                    connection,
105                    &cloned_key,
106                    max,
107                    points,
108                    window_in_seconds,
109                )
110                .await;
111
112                if let Ok(points) = result {
113                    MEMORY
114                        .insert(cloned_key, RateLimitState::Count(points))
115                        .await;
116                } else if let Err(e) = result {
117                    match e {
118                        RateLimitError::Exceeded => {
119                            // If the rate limit is exceeded, we can set the in-memory cache to max to prevent further requests from hitting the database until the cache expires.
120                            MEMORY.insert(cloned_key, RateLimitState::Max).await;
121                        }
122                        RateLimitError::Error(e) => {
123                            println!("Error checking rate limit: {:?}", e);
124                        }
125                    }
126                }
127            });
128
129            match current {
130                RateLimitState::Count(current) => {
131                    let current_score = current + points;
132
133                    if current_score > max {
134                        Err(RateLimitError::Exceeded)
135                    } else {
136                        MEMORY
137                            .insert(rate_key.to_string(), RateLimitState::Count(current_score))
138                            .await;
139                        Ok(current_score)
140                    }
141                }
142                RateLimitState::Max => Err(RateLimitError::Exceeded),
143            }
144        } else {
145            let result =
146                check_rate_limit_remote(connection, rate_key, max, points, window_in_seconds)
147                    .await?;
148
149            MEMORY
150                .insert(rate_key.to_string(), RateLimitState::Count(result))
151                .await;
152
153            Ok(result)
154        }
155    }
156}
157
158impl RateLimit for PGConnection {
159    /// Returns the current points after the operation.
160    /// Note use of box and pin so can satisfy dynamic dispatch requirements.
161    fn check<'a>(
162        &'a self,
163        rate_key: &'a str,
164        max: i32,
165        points: i32,
166        window_in_seconds: i32,
167    ) -> Pin<Box<dyn Future<Output = Result<i32, haste_rate_limit::RateLimitError>> + Send + 'a>>
168    {
169        let connection = self.clone();
170        Box::pin(async move {
171            let res =
172                check_rate_limit(connection, rate_key, max, points, window_in_seconds).await?;
173            Ok(res)
174        })
175    }
176}