1
- import { createCache , DefaultStatefulContext , Namespace , Cache as UnkeyCache } from "@unkey/cache" ;
2
- import { MemoryStore } from "@unkey/cache/stores" ;
3
- import { Ratelimit } from "@upstash/ratelimit" ;
4
- import { Request as ExpressRequest , Response as ExpressResponse , NextFunction } from "express" ;
5
- import { RedisOptions } from "ioredis" ;
6
- import { createHash } from "node:crypto" ;
7
- import { z } from "zod" ;
8
1
import { env } from "~/env.server" ;
9
2
import { authenticateAuthorizationHeader } from "./apiAuth.server" ;
10
- import { logger } from "./logger.server" ;
11
- import { createRedisRateLimitClient , Duration , RateLimiter } from "./rateLimiter.server" ;
12
- import { RedisCacheStore } from "./unkey/redisCacheStore.server" ;
13
-
14
- const DurationSchema = z . custom < Duration > ( ( value ) => {
15
- if ( typeof value !== "string" ) {
16
- throw new Error ( "Duration must be a string" ) ;
17
- }
18
-
19
- return value as Duration ;
20
- } ) ;
21
-
22
- export const RateLimitFixedWindowConfig = z . object ( {
23
- type : z . literal ( "fixedWindow" ) ,
24
- window : DurationSchema ,
25
- tokens : z . number ( ) ,
26
- } ) ;
27
-
28
- export type RateLimitFixedWindowConfig = z . infer < typeof RateLimitFixedWindowConfig > ;
29
-
30
- export const RateLimitSlidingWindowConfig = z . object ( {
31
- type : z . literal ( "slidingWindow" ) ,
32
- window : DurationSchema ,
33
- tokens : z . number ( ) ,
34
- } ) ;
35
-
36
- export type RateLimitSlidingWindowConfig = z . infer < typeof RateLimitSlidingWindowConfig > ;
37
-
38
- export const RateLimitTokenBucketConfig = z . object ( {
39
- type : z . literal ( "tokenBucket" ) ,
40
- refillRate : z . number ( ) ,
41
- interval : DurationSchema ,
42
- maxTokens : z . number ( ) ,
43
- } ) ;
44
-
45
- export type RateLimitTokenBucketConfig = z . infer < typeof RateLimitTokenBucketConfig > ;
46
-
47
- export const RateLimiterConfig = z . discriminatedUnion ( "type" , [
48
- RateLimitFixedWindowConfig ,
49
- RateLimitSlidingWindowConfig ,
50
- RateLimitTokenBucketConfig ,
51
- ] ) ;
52
-
53
- export type RateLimiterConfig = z . infer < typeof RateLimiterConfig > ;
54
-
55
- type LimitConfigOverrideFunction = ( authorizationValue : string ) => Promise < unknown > ;
56
-
57
- type Options = {
58
- redis ?: RedisOptions ;
59
- keyPrefix : string ;
60
- pathMatchers : ( RegExp | string ) [ ] ;
61
- pathWhiteList ?: ( RegExp | string ) [ ] ;
62
- defaultLimiter : RateLimiterConfig ;
63
- limiterConfigOverride ?: LimitConfigOverrideFunction ;
64
- limiterCache ?: {
65
- fresh : number ;
66
- stale : number ;
67
- } ;
68
- log ?: {
69
- requests ?: boolean ;
70
- rejections ?: boolean ;
71
- limiter ?: boolean ;
72
- } ;
73
- } ;
74
-
75
- async function resolveLimitConfig (
76
- authorizationValue : string ,
77
- hashedAuthorizationValue : string ,
78
- defaultLimiter : RateLimiterConfig ,
79
- cache : UnkeyCache < { limiter : RateLimiterConfig } > ,
80
- logsEnabled : boolean ,
81
- limiterConfigOverride ?: LimitConfigOverrideFunction
82
- ) : Promise < RateLimiterConfig > {
83
- if ( ! limiterConfigOverride ) {
84
- return defaultLimiter ;
85
- }
86
-
87
- if ( logsEnabled ) {
88
- logger . info ( "RateLimiter: checking for override" , {
89
- authorizationValue : hashedAuthorizationValue ,
90
- defaultLimiter,
91
- } ) ;
92
- }
93
-
94
- const cacheResult = await cache . limiter . swr ( hashedAuthorizationValue , async ( key ) => {
95
- const override = await limiterConfigOverride ( authorizationValue ) ;
96
-
97
- if ( ! override ) {
98
- if ( logsEnabled ) {
99
- logger . info ( "RateLimiter: no override found" , {
100
- authorizationValue,
101
- defaultLimiter,
102
- } ) ;
103
- }
104
-
105
- return defaultLimiter ;
106
- }
107
-
108
- const parsedOverride = RateLimiterConfig . safeParse ( override ) ;
109
-
110
- if ( ! parsedOverride . success ) {
111
- logger . error ( "Error parsing rate limiter override" , {
112
- override,
113
- errors : parsedOverride . error . errors ,
114
- } ) ;
115
-
116
- return defaultLimiter ;
117
- }
118
-
119
- if ( logsEnabled && parsedOverride . data ) {
120
- logger . info ( "RateLimiter: override found" , {
121
- authorizationValue,
122
- defaultLimiter,
123
- override : parsedOverride . data ,
124
- } ) ;
125
- }
126
-
127
- return parsedOverride . data ;
128
- } ) ;
129
-
130
- return cacheResult . val ?? defaultLimiter ;
131
- }
132
-
133
- //returns an Express middleware that rate limits using the Bearer token in the Authorization header
134
- export function authorizationRateLimitMiddleware ( {
135
- redis,
136
- keyPrefix,
137
- defaultLimiter,
138
- pathMatchers,
139
- pathWhiteList = [ ] ,
140
- log = {
141
- rejections : true ,
142
- requests : true ,
143
- } ,
144
- limiterCache,
145
- limiterConfigOverride,
146
- } : Options ) {
147
- const ctx = new DefaultStatefulContext ( ) ;
148
- const memory = new MemoryStore ( { persistentMap : new Map ( ) } ) ;
149
- const redisCacheStore = new RedisCacheStore ( {
150
- connection : {
151
- keyPrefix : `cache:${ keyPrefix } :rate-limit-cache:` ,
152
- ...redis ,
153
- } ,
154
- } ) ;
155
-
156
- // This cache holds the rate limit configuration for each org, so we don't have to fetch it every request
157
- const cache = createCache ( {
158
- limiter : new Namespace < RateLimiterConfig > ( ctx , {
159
- stores : [ memory , redisCacheStore ] ,
160
- fresh : limiterCache ?. fresh ?? 30_000 ,
161
- stale : limiterCache ?. stale ?? 60_000 ,
162
- } ) ,
163
- } ) ;
164
-
165
- const redisClient = createRedisRateLimitClient (
166
- redis ?? {
167
- port : env . REDIS_PORT ,
168
- host : env . REDIS_HOST ,
169
- username : env . REDIS_USERNAME ,
170
- password : env . REDIS_PASSWORD ,
171
- enableAutoPipelining : true ,
172
- ...( env . REDIS_TLS_DISABLED === "true" ? { } : { tls : { } } ) ,
173
- }
174
- ) ;
175
-
176
- return async ( req : ExpressRequest , res : ExpressResponse , next : NextFunction ) => {
177
- if ( log . requests ) {
178
- logger . info ( `RateLimiter (${ keyPrefix } ): request to ${ req . path } ` ) ;
179
- }
180
-
181
- // allow OPTIONS requests
182
- if ( req . method . toUpperCase ( ) === "OPTIONS" ) {
183
- return next ( ) ;
184
- }
185
-
186
- //first check if any of the pathMatchers match the request path
187
- const path = req . path ;
188
- if (
189
- ! pathMatchers . some ( ( matcher ) =>
190
- matcher instanceof RegExp ? matcher . test ( path ) : path === matcher
191
- )
192
- ) {
193
- if ( log . requests ) {
194
- logger . info ( `RateLimiter (${ keyPrefix } ): didn't match ${ req . path } ` ) ;
195
- }
196
- return next ( ) ;
197
- }
198
-
199
- // Check if the path matches any of the whitelisted paths
200
- if (
201
- pathWhiteList . some ( ( matcher ) =>
202
- matcher instanceof RegExp ? matcher . test ( path ) : path === matcher
203
- )
204
- ) {
205
- if ( log . requests ) {
206
- logger . info ( `RateLimiter (${ keyPrefix } ): whitelisted ${ req . path } ` ) ;
207
- }
208
- return next ( ) ;
209
- }
210
-
211
- if ( log . requests ) {
212
- logger . info ( `RateLimiter (${ keyPrefix } ): matched ${ req . path } ` ) ;
213
- }
214
-
215
- const authorizationValue = req . headers . authorization ;
216
- if ( ! authorizationValue ) {
217
- if ( log . requests ) {
218
- logger . info ( `RateLimiter (${ keyPrefix } ): no key` , { headers : req . headers , url : req . url } ) ;
219
- }
220
- res . setHeader ( "Content-Type" , "application/problem+json" ) ;
221
- return res . status ( 401 ) . send (
222
- JSON . stringify (
223
- {
224
- title : "Unauthorized" ,
225
- status : 401 ,
226
- type : "https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/401" ,
227
- detail : "No authorization header provided" ,
228
- error : "No authorization header provided" ,
229
- } ,
230
- null ,
231
- 2
232
- )
233
- ) ;
234
- }
235
-
236
- const hash = createHash ( "sha256" ) ;
237
- hash . update ( authorizationValue ) ;
238
- const hashedAuthorizationValue = hash . digest ( "hex" ) ;
239
-
240
- const limiterConfig = await resolveLimitConfig (
241
- authorizationValue ,
242
- hashedAuthorizationValue ,
243
- defaultLimiter ,
244
- cache ,
245
- typeof log . limiter === "boolean" ? log . limiter : false ,
246
- limiterConfigOverride
247
- ) ;
248
-
249
- const limiter =
250
- limiterConfig . type === "fixedWindow"
251
- ? Ratelimit . fixedWindow ( limiterConfig . tokens , limiterConfig . window )
252
- : limiterConfig . type === "tokenBucket"
253
- ? Ratelimit . tokenBucket (
254
- limiterConfig . refillRate ,
255
- limiterConfig . interval ,
256
- limiterConfig . maxTokens
257
- )
258
- : Ratelimit . slidingWindow ( limiterConfig . tokens , limiterConfig . window ) ;
259
-
260
- const rateLimiter = new RateLimiter ( {
261
- redisClient,
262
- keyPrefix,
263
- limiter,
264
- logSuccess : log . requests ,
265
- logFailure : log . rejections ,
266
- } ) ;
267
-
268
- const { success, limit, reset, remaining } = await rateLimiter . limit ( hashedAuthorizationValue ) ;
269
-
270
- const $remaining = Math . max ( 0 , remaining ) ; // remaining can be negative if the user has exceeded the limit, so clamp it to 0
271
-
272
- res . set ( "x-ratelimit-limit" , limit . toString ( ) ) ;
273
- res . set ( "x-ratelimit-remaining" , $remaining . toString ( ) ) ;
274
- res . set ( "x-ratelimit-reset" , reset . toString ( ) ) ;
275
-
276
- if ( success ) {
277
- return next ( ) ;
278
- }
279
-
280
- res . setHeader ( "Content-Type" , "application/problem+json" ) ;
281
- const secondsUntilReset = Math . max ( 0 , ( reset - new Date ( ) . getTime ( ) ) / 1000 ) ;
282
- return res . status ( 429 ) . send (
283
- JSON . stringify (
284
- {
285
- title : "Rate Limit Exceeded" ,
286
- status : 429 ,
287
- type : "https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/429" ,
288
- detail : `Rate limit exceeded ${ $remaining } /${ limit } requests remaining. Retry in ${ secondsUntilReset } seconds.` ,
289
- reset,
290
- limit,
291
- remaining,
292
- secondsUntilReset,
293
- error : `Rate limit exceeded ${ $remaining } /${ limit } requests remaining. Retry in ${ secondsUntilReset } seconds.` ,
294
- } ,
295
- null ,
296
- 2
297
- )
298
- ) ;
299
- } ;
300
- }
3
+ import { authorizationRateLimitMiddleware } from "./authorizationRateLimitMiddleware.server" ;
4
+ import { Duration } from "./rateLimiter.server" ;
301
5
302
6
export const apiRateLimiter = authorizationRateLimitMiddleware ( {
303
7
keyPrefix : "api" ,
@@ -312,16 +16,24 @@ export const apiRateLimiter = authorizationRateLimitMiddleware({
312
16
stale : 60_000 * 20 , // Date is stale after 20 minutes
313
17
} ,
314
18
limiterConfigOverride : async ( authorizationValue ) => {
315
- // TODO: we need to add an option to "allowJWT" auth and then handle this differently
316
19
const authenticatedEnv = await authenticateAuthorizationHeader ( authorizationValue , {
317
20
allowPublicKey : true ,
21
+ allowJWT : true ,
318
22
} ) ;
319
23
320
24
if ( ! authenticatedEnv ) {
321
25
return ;
322
26
}
323
27
324
- return authenticatedEnv . environment . organization . apiRateLimiterConfig ;
28
+ if ( authenticatedEnv . type === "PUBLIC_JWT" ) {
29
+ return {
30
+ type : "fixedWindow" ,
31
+ window : env . API_RATE_LIMIT_JWT_WINDOW ,
32
+ tokens : env . API_RATE_LIMIT_JWT_TOKENS ,
33
+ } ;
34
+ } else {
35
+ return authenticatedEnv . environment . organization . apiRateLimiterConfig ;
36
+ }
325
37
} ,
326
38
pathMatchers : [ / ^ \/ a p i / ] ,
327
39
// Allow /api/v1/tasks/:id/callback/:secret
0 commit comments