Skip to content

Commit 51eb8d5

Browse files
committed
Add tests for the rate limit middleware and add custom JWT rate limits
1 parent b9c8102 commit 51eb8d5

File tree

7 files changed

+839
-304
lines changed

7 files changed

+839
-304
lines changed

apps/webapp/app/env.server.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ const EnvironmentSchema = z.object({
3131
REMIX_APP_PORT: z.string().optional(),
3232
LOGIN_ORIGIN: z.string().default("http://localhost:3030"),
3333
APP_ORIGIN: z.string().default("http://localhost:3030"),
34-
ELECTRIC_ORIGIN: z.string(),
34+
ELECTRIC_ORIGIN: z.string().default("http://localhost:3060"),
3535
APP_ENV: z.string().default(process.env.NODE_ENV),
3636
SERVICE_NAME: z.string().default("trigger.dev webapp"),
3737
SECRET_STORE: SecretStoreOptionsSchema.default("DATABASE"),
@@ -105,6 +105,9 @@ const EnvironmentSchema = z.object({
105105
API_RATE_LIMIT_REJECTION_LOGS_ENABLED: z.string().default("1"),
106106
API_RATE_LIMIT_LIMITER_LOGS_ENABLED: z.string().default("0"),
107107

108+
API_RATE_LIMIT_JWT_WINDOW: z.string().default("1m"),
109+
API_RATE_LIMIT_JWT_TOKENS: z.coerce.number().int().default(60),
110+
108111
//Realtime rate limiting
109112
/**
110113
* @example "60s"

apps/webapp/app/services/apiAuth.server.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,18 @@ export async function authenticateApiKey(
103103

104104
export async function authenticateAuthorizationHeader(
105105
authorization: string,
106-
{ allowPublicKey = false }: { allowPublicKey?: boolean } = {}
106+
{
107+
allowPublicKey = false,
108+
allowJWT = false,
109+
}: { allowPublicKey?: boolean; allowJWT?: boolean } = {}
107110
): Promise<ApiAuthenticationResult | undefined> {
108111
const apiKey = getApiKeyFromHeader(authorization);
109112

110113
if (!apiKey) {
111114
return;
112115
}
113116

114-
return authenticateApiKey(apiKey, { allowPublicKey });
117+
return authenticateApiKey(apiKey, { allowPublicKey, allowJWT });
115118
}
116119

117120
export function isPublicApiKey(key: string) {

apps/webapp/app/services/apiRateLimit.server.ts

Lines changed: 12 additions & 300 deletions
Original file line numberDiff line numberDiff line change
@@ -1,303 +1,7 @@
1-
import { createCache, DefaultStatefulContext, Namespace, Cache as UnkeyCache } from "@unkey/cache";
2-
import { MemoryStore } from "@unkey/cache/stores";
3-
import { Ratelimit } from "@upstash/ratelimit";
4-
import { Request as ExpressRequest, Response as ExpressResponse, NextFunction } from "express";
5-
import { RedisOptions } from "ioredis";
6-
import { createHash } from "node:crypto";
7-
import { z } from "zod";
81
import { env } from "~/env.server";
92
import { authenticateAuthorizationHeader } from "./apiAuth.server";
10-
import { logger } from "./logger.server";
11-
import { createRedisRateLimitClient, Duration, RateLimiter } from "./rateLimiter.server";
12-
import { RedisCacheStore } from "./unkey/redisCacheStore.server";
13-
14-
const DurationSchema = z.custom<Duration>((value) => {
15-
if (typeof value !== "string") {
16-
throw new Error("Duration must be a string");
17-
}
18-
19-
return value as Duration;
20-
});
21-
22-
export const RateLimitFixedWindowConfig = z.object({
23-
type: z.literal("fixedWindow"),
24-
window: DurationSchema,
25-
tokens: z.number(),
26-
});
27-
28-
export type RateLimitFixedWindowConfig = z.infer<typeof RateLimitFixedWindowConfig>;
29-
30-
export const RateLimitSlidingWindowConfig = z.object({
31-
type: z.literal("slidingWindow"),
32-
window: DurationSchema,
33-
tokens: z.number(),
34-
});
35-
36-
export type RateLimitSlidingWindowConfig = z.infer<typeof RateLimitSlidingWindowConfig>;
37-
38-
export const RateLimitTokenBucketConfig = z.object({
39-
type: z.literal("tokenBucket"),
40-
refillRate: z.number(),
41-
interval: DurationSchema,
42-
maxTokens: z.number(),
43-
});
44-
45-
export type RateLimitTokenBucketConfig = z.infer<typeof RateLimitTokenBucketConfig>;
46-
47-
export const RateLimiterConfig = z.discriminatedUnion("type", [
48-
RateLimitFixedWindowConfig,
49-
RateLimitSlidingWindowConfig,
50-
RateLimitTokenBucketConfig,
51-
]);
52-
53-
export type RateLimiterConfig = z.infer<typeof RateLimiterConfig>;
54-
55-
type LimitConfigOverrideFunction = (authorizationValue: string) => Promise<unknown>;
56-
57-
type Options = {
58-
redis?: RedisOptions;
59-
keyPrefix: string;
60-
pathMatchers: (RegExp | string)[];
61-
pathWhiteList?: (RegExp | string)[];
62-
defaultLimiter: RateLimiterConfig;
63-
limiterConfigOverride?: LimitConfigOverrideFunction;
64-
limiterCache?: {
65-
fresh: number;
66-
stale: number;
67-
};
68-
log?: {
69-
requests?: boolean;
70-
rejections?: boolean;
71-
limiter?: boolean;
72-
};
73-
};
74-
75-
async function resolveLimitConfig(
76-
authorizationValue: string,
77-
hashedAuthorizationValue: string,
78-
defaultLimiter: RateLimiterConfig,
79-
cache: UnkeyCache<{ limiter: RateLimiterConfig }>,
80-
logsEnabled: boolean,
81-
limiterConfigOverride?: LimitConfigOverrideFunction
82-
): Promise<RateLimiterConfig> {
83-
if (!limiterConfigOverride) {
84-
return defaultLimiter;
85-
}
86-
87-
if (logsEnabled) {
88-
logger.info("RateLimiter: checking for override", {
89-
authorizationValue: hashedAuthorizationValue,
90-
defaultLimiter,
91-
});
92-
}
93-
94-
const cacheResult = await cache.limiter.swr(hashedAuthorizationValue, async (key) => {
95-
const override = await limiterConfigOverride(authorizationValue);
96-
97-
if (!override) {
98-
if (logsEnabled) {
99-
logger.info("RateLimiter: no override found", {
100-
authorizationValue,
101-
defaultLimiter,
102-
});
103-
}
104-
105-
return defaultLimiter;
106-
}
107-
108-
const parsedOverride = RateLimiterConfig.safeParse(override);
109-
110-
if (!parsedOverride.success) {
111-
logger.error("Error parsing rate limiter override", {
112-
override,
113-
errors: parsedOverride.error.errors,
114-
});
115-
116-
return defaultLimiter;
117-
}
118-
119-
if (logsEnabled && parsedOverride.data) {
120-
logger.info("RateLimiter: override found", {
121-
authorizationValue,
122-
defaultLimiter,
123-
override: parsedOverride.data,
124-
});
125-
}
126-
127-
return parsedOverride.data;
128-
});
129-
130-
return cacheResult.val ?? defaultLimiter;
131-
}
132-
133-
//returns an Express middleware that rate limits using the Bearer token in the Authorization header
134-
export function authorizationRateLimitMiddleware({
135-
redis,
136-
keyPrefix,
137-
defaultLimiter,
138-
pathMatchers,
139-
pathWhiteList = [],
140-
log = {
141-
rejections: true,
142-
requests: true,
143-
},
144-
limiterCache,
145-
limiterConfigOverride,
146-
}: Options) {
147-
const ctx = new DefaultStatefulContext();
148-
const memory = new MemoryStore({ persistentMap: new Map() });
149-
const redisCacheStore = new RedisCacheStore({
150-
connection: {
151-
keyPrefix: `cache:${keyPrefix}:rate-limit-cache:`,
152-
...redis,
153-
},
154-
});
155-
156-
// This cache holds the rate limit configuration for each org, so we don't have to fetch it every request
157-
const cache = createCache({
158-
limiter: new Namespace<RateLimiterConfig>(ctx, {
159-
stores: [memory, redisCacheStore],
160-
fresh: limiterCache?.fresh ?? 30_000,
161-
stale: limiterCache?.stale ?? 60_000,
162-
}),
163-
});
164-
165-
const redisClient = createRedisRateLimitClient(
166-
redis ?? {
167-
port: env.REDIS_PORT,
168-
host: env.REDIS_HOST,
169-
username: env.REDIS_USERNAME,
170-
password: env.REDIS_PASSWORD,
171-
enableAutoPipelining: true,
172-
...(env.REDIS_TLS_DISABLED === "true" ? {} : { tls: {} }),
173-
}
174-
);
175-
176-
return async (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
177-
if (log.requests) {
178-
logger.info(`RateLimiter (${keyPrefix}): request to ${req.path}`);
179-
}
180-
181-
// allow OPTIONS requests
182-
if (req.method.toUpperCase() === "OPTIONS") {
183-
return next();
184-
}
185-
186-
//first check if any of the pathMatchers match the request path
187-
const path = req.path;
188-
if (
189-
!pathMatchers.some((matcher) =>
190-
matcher instanceof RegExp ? matcher.test(path) : path === matcher
191-
)
192-
) {
193-
if (log.requests) {
194-
logger.info(`RateLimiter (${keyPrefix}): didn't match ${req.path}`);
195-
}
196-
return next();
197-
}
198-
199-
// Check if the path matches any of the whitelisted paths
200-
if (
201-
pathWhiteList.some((matcher) =>
202-
matcher instanceof RegExp ? matcher.test(path) : path === matcher
203-
)
204-
) {
205-
if (log.requests) {
206-
logger.info(`RateLimiter (${keyPrefix}): whitelisted ${req.path}`);
207-
}
208-
return next();
209-
}
210-
211-
if (log.requests) {
212-
logger.info(`RateLimiter (${keyPrefix}): matched ${req.path}`);
213-
}
214-
215-
const authorizationValue = req.headers.authorization;
216-
if (!authorizationValue) {
217-
if (log.requests) {
218-
logger.info(`RateLimiter (${keyPrefix}): no key`, { headers: req.headers, url: req.url });
219-
}
220-
res.setHeader("Content-Type", "application/problem+json");
221-
return res.status(401).send(
222-
JSON.stringify(
223-
{
224-
title: "Unauthorized",
225-
status: 401,
226-
type: "https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/401",
227-
detail: "No authorization header provided",
228-
error: "No authorization header provided",
229-
},
230-
null,
231-
2
232-
)
233-
);
234-
}
235-
236-
const hash = createHash("sha256");
237-
hash.update(authorizationValue);
238-
const hashedAuthorizationValue = hash.digest("hex");
239-
240-
const limiterConfig = await resolveLimitConfig(
241-
authorizationValue,
242-
hashedAuthorizationValue,
243-
defaultLimiter,
244-
cache,
245-
typeof log.limiter === "boolean" ? log.limiter : false,
246-
limiterConfigOverride
247-
);
248-
249-
const limiter =
250-
limiterConfig.type === "fixedWindow"
251-
? Ratelimit.fixedWindow(limiterConfig.tokens, limiterConfig.window)
252-
: limiterConfig.type === "tokenBucket"
253-
? Ratelimit.tokenBucket(
254-
limiterConfig.refillRate,
255-
limiterConfig.interval,
256-
limiterConfig.maxTokens
257-
)
258-
: Ratelimit.slidingWindow(limiterConfig.tokens, limiterConfig.window);
259-
260-
const rateLimiter = new RateLimiter({
261-
redisClient,
262-
keyPrefix,
263-
limiter,
264-
logSuccess: log.requests,
265-
logFailure: log.rejections,
266-
});
267-
268-
const { success, limit, reset, remaining } = await rateLimiter.limit(hashedAuthorizationValue);
269-
270-
const $remaining = Math.max(0, remaining); // remaining can be negative if the user has exceeded the limit, so clamp it to 0
271-
272-
res.set("x-ratelimit-limit", limit.toString());
273-
res.set("x-ratelimit-remaining", $remaining.toString());
274-
res.set("x-ratelimit-reset", reset.toString());
275-
276-
if (success) {
277-
return next();
278-
}
279-
280-
res.setHeader("Content-Type", "application/problem+json");
281-
const secondsUntilReset = Math.max(0, (reset - new Date().getTime()) / 1000);
282-
return res.status(429).send(
283-
JSON.stringify(
284-
{
285-
title: "Rate Limit Exceeded",
286-
status: 429,
287-
type: "https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/429",
288-
detail: `Rate limit exceeded ${$remaining}/${limit} requests remaining. Retry in ${secondsUntilReset} seconds.`,
289-
reset,
290-
limit,
291-
remaining,
292-
secondsUntilReset,
293-
error: `Rate limit exceeded ${$remaining}/${limit} requests remaining. Retry in ${secondsUntilReset} seconds.`,
294-
},
295-
null,
296-
2
297-
)
298-
);
299-
};
300-
}
3+
import { authorizationRateLimitMiddleware } from "./authorizationRateLimitMiddleware.server";
4+
import { Duration } from "./rateLimiter.server";
3015

3026
export const apiRateLimiter = authorizationRateLimitMiddleware({
3037
keyPrefix: "api",
@@ -312,16 +16,24 @@ export const apiRateLimiter = authorizationRateLimitMiddleware({
31216
stale: 60_000 * 20, // Date is stale after 20 minutes
31317
},
31418
limiterConfigOverride: async (authorizationValue) => {
315-
// TODO: we need to add an option to "allowJWT" auth and then handle this differently
31619
const authenticatedEnv = await authenticateAuthorizationHeader(authorizationValue, {
31720
allowPublicKey: true,
21+
allowJWT: true,
31822
});
31923

32024
if (!authenticatedEnv) {
32125
return;
32226
}
32327

324-
return authenticatedEnv.environment.organization.apiRateLimiterConfig;
28+
if (authenticatedEnv.type === "PUBLIC_JWT") {
29+
return {
30+
type: "fixedWindow",
31+
window: env.API_RATE_LIMIT_JWT_WINDOW,
32+
tokens: env.API_RATE_LIMIT_JWT_TOKENS,
33+
};
34+
} else {
35+
return authenticatedEnv.environment.organization.apiRateLimiterConfig;
36+
}
32537
},
32638
pathMatchers: [/^\/api/],
32739
// Allow /api/v1/tasks/:id/callback/:secret

0 commit comments

Comments
 (0)