Skip to content

Rate limit API requests and changed SQS reading speed #969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/webapp/app/entry.server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,4 @@ const sqsEventConsumer = singleton("sqsEventConsumer", getSharedSqsEventConsumer
export { wss } from "./v3/handleWebsockets.server";
export { socketIo } from "./v3/handleSocketIo.server";
export { registryProxy } from "./v3/registryProxy.server";
export { apiRateLimiter } from "./services/apiRateLimit.server";
15 changes: 14 additions & 1 deletion apps/webapp/app/env.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ const EnvironmentSchema = z.object({
AWS_SQS_SECRET_ACCESS_KEY: z.string().optional(),
/** Optional. Only used if you use the apps/proxy */
AWS_SQS_QUEUE_URL: z.string().optional(),
AWS_SQS_BATCH_SIZE: z.coerce.number().int().optional().default(10),
AWS_SQS_BATCH_SIZE: z.coerce.number().int().optional().default(1),
AWS_SQS_WAIT_TIME_MS: z.coerce.number().int().optional().default(100),
DISABLE_SSE: z.string().optional(),

// Redis options
Expand All @@ -68,6 +69,18 @@ const EnvironmentSchema = z.object({
TUNNEL_HOST: z.string().optional(),
TUNNEL_SECRET_KEY: z.string().optional(),

//API Rate limiting
/**
* @example "60s"
* @example "1m"
* @example "1h"
* @example "1d"
* @example "1000ms"
* @example "1000s"
*/
API_RATE_LIMIT_WINDOW: z.string().default("60s"),
API_RATE_LIMIT_MAX: z.coerce.number().int().default(600),

//v3
V3_ENABLED: z.string().default("false"),
OTLP_EXPORTER_TRACES_URL: z.string().optional(),
Expand Down
179 changes: 179 additions & 0 deletions apps/webapp/app/services/apiRateLimit.server.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import { Ratelimit } from "@upstash/ratelimit";
import { Request as ExpressRequest, Response as ExpressResponse, NextFunction } from "express";
import Redis, { RedisOptions } from "ioredis";
import { createHash } from "node:crypto";
import { env } from "~/env.server";
import { logger } from "./logger.server";

function createRedisRateLimitClient(
redisOptions: RedisOptions
): ConstructorParameters<typeof Ratelimit>[0]["redis"] {
const redis = new Redis(redisOptions);

return {
sadd: async <TData>(key: string, ...members: TData[]): Promise<number> => {
return redis.sadd(key, members as (string | number | Buffer)[]);
},
eval: <TArgs extends unknown[], TData = unknown>(
...args: [script: string, keys: string[], args: TArgs]
): Promise<TData> => {
const script = args[0];
const keys = args[1];
const argsArray = args[2];
return redis.eval(
script,
keys.length,
...keys,
...(argsArray as (string | Buffer | number)[])
) as Promise<TData>;
},
};
}

type Options = {
log?: {
requests?: boolean;
rejections?: boolean;
};
redis: RedisOptions;
keyPrefix: string;
pathMatchers: (RegExp | string)[];
limiter: ConstructorParameters<typeof Ratelimit>[0]["limiter"];
};

//returns an Express middleware that rate limits using the Bearer token in the Authorization header
export function authorizationRateLimitMiddleware({
redis,
keyPrefix,
limiter,
pathMatchers,
log = {
rejections: true,
requests: true,
},
}: Options) {
const rateLimiter = new Ratelimit({
redis: createRedisRateLimitClient(redis),
limiter: limiter,
ephemeralCache: new Map(),
analytics: false,
prefix: keyPrefix,
});

return async (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
if (log.requests) {
logger.info(`RateLimiter (${keyPrefix}): request to ${req.path}`);
}

//first check if any of the pathMatchers match the request path
const path = req.path;
if (
!pathMatchers.some((matcher) =>
matcher instanceof RegExp ? matcher.test(path) : path === matcher
)
) {
if (log.requests) {
logger.info(`RateLimiter (${keyPrefix}): didn't match ${req.path}`);
}
return next();
}

if (log.requests) {
logger.info(`RateLimiter (${keyPrefix}): matched ${req.path}`);
}

const authorizationValue = req.headers.authorization;
if (!authorizationValue) {
if (log.requests) {
logger.info(`RateLimiter (${keyPrefix}): no key`);
}
res.setHeader("Content-Type", "application/problem+json");
return res
.status(401)
.send(
JSON.stringify(
{
title: "Unauthorized",
status: 401,
type: "https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/401",
detail: "No authorization header provided",
},
null,
2
)
);
}

const hash = createHash("sha256");
hash.update(authorizationValue);
const hashedAuthorizationValue = hash.digest("hex");

const { success, pending, limit, reset, remaining } = await rateLimiter.limit(
hashedAuthorizationValue
);

res.set("x-ratelimit-limit", limit.toString());
res.set("x-ratelimit-remaining", remaining.toString());
res.set("x-ratelimit-reset", reset.toString());

if (success) {
if (log.requests) {
logger.info(`RateLimiter (${keyPrefix}): under rate limit`, {
limit,
reset,
remaining,
hashedAuthorizationValue,
});
}
return next();
}

if (log.rejections) {
logger.warn(`RateLimiter (${keyPrefix}): rate limit exceeded`, {
limit,
reset,
remaining,
pending,
hashedAuthorizationValue,
});
}

res.setHeader("Content-Type", "application/problem+json");
return res.status(429).send(
JSON.stringify(
{
title: "Rate Limit Exceeded",
status: 429,
type: "https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/429",
detail: `Rate limit exceeded ${remaining}/${limit} requests remaining. Retry after ${reset} seconds.`,
reset: reset,
limit: limit,
},
null,
2
)
);
};
}

type Duration = Parameters<typeof Ratelimit.slidingWindow>[1];

export const apiRateLimiter = authorizationRateLimitMiddleware({
keyPrefix: "ratelimit:api",
redis: {
port: env.REDIS_PORT,
host: env.REDIS_HOST,
username: env.REDIS_USERNAME,
password: env.REDIS_PASSWORD,
enableAutoPipelining: true,
...(env.REDIS_TLS_DISABLED === "true" ? {} : { tls: {} }),
},
limiter: Ratelimit.slidingWindow(env.API_RATE_LIMIT_MAX, env.API_RATE_LIMIT_WINDOW as Duration),
pathMatchers: [/^\/api/],
log: {
rejections: true,
requests: false,
},
});

export type RateLimitMiddleware = ReturnType<typeof authorizationRateLimitMiddleware>;
13 changes: 7 additions & 6 deletions apps/webapp/app/services/events/sqsEventConsumer.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import { Consumer } from "sqs-consumer";
import { PrismaClientOrTransaction, prisma } from "~/db.server";
import { logger, trace } from "../logger.server";
import { Message, SQSClient } from "@aws-sdk/client-sqs";
import { authenticateApiKey } from "../apiAuth.server";
import { SendEventBodySchema } from "@trigger.dev/core";
import { Consumer } from "sqs-consumer";
import { z } from "zod";
import { fromZodError } from "zod-validation-error";
import { IngestSendEvent } from "./ingestSendEvent.server";
import { PrismaClientOrTransaction, prisma } from "~/db.server";
import { env } from "~/env.server";
import { singleton } from "~/utils/singleton";
import { authenticateApiKey } from "../apiAuth.server";
import { logger, trace } from "../logger.server";
import { IngestSendEvent } from "./ingestSendEvent.server";

type SqsEventConsumerOptions = {
queueUrl: string;
Expand All @@ -17,6 +16,7 @@ type SqsEventConsumerOptions = {
region: string;
accessKeyId: string;
secretAccessKey: string;
pollingWaitTimeMs: number;
};

const messageSchema = SendEventBodySchema.extend({
Expand Down Expand Up @@ -137,6 +137,7 @@ export function getSharedSqsEventConsumer() {
const consumer = new SqsEventConsumer(undefined, {
queueUrl: env.AWS_SQS_QUEUE_URL,
batchSize: env.AWS_SQS_BATCH_SIZE,
pollingWaitTimeMs: env.AWS_SQS_WAIT_TIME_MS,
region: env.AWS_SQS_REGION,
accessKeyId: env.AWS_SQS_ACCESS_KEY_ID,
secretAccessKey: env.AWS_SQS_SECRET_ACCESS_KEY,
Expand Down
1 change: 1 addition & 0 deletions apps/webapp/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"@trigger.dev/yalt": "workspace:*",
"@types/pg": "8.6.6",
"@uiw/react-codemirror": "^4.19.5",
"@upstash/ratelimit": "^1.0.1",
"@whatwg-node/fetch": "^0.9.14",
"class-variance-authority": "^0.5.2",
"clsx": "^1.2.1",
Expand Down
6 changes: 4 additions & 2 deletions apps/webapp/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { broadcastDevReady, logDevReady } from "@remix-run/server-runtime";
import type { Server as IoServer } from "socket.io";
import type { Server as EngineServer } from "engine.io";
import { RegistryProxy } from "~/v3/registryProxy.server";
import { RateLimitMiddleware, apiRateLimiter } from "~/services/apiRateLimit.server";

const app = express();

Expand Down Expand Up @@ -39,6 +40,7 @@ if (process.env.HTTP_SERVER_DISABLED !== "true") {
const socketIo: { io: IoServer } | undefined = build.entry.module.socketIo;
const wss: WebSocketServer | undefined = build.entry.module.wss;
const registryProxy: RegistryProxy | undefined = build.entry.module.registryProxy;
const apiRateLimiter: RateLimitMiddleware = build.entry.module.apiRateLimiter;

if (registryProxy && process.env.ENABLE_REGISTRY_PROXY === "true") {
console.log(`🐳 Enabling container registry proxy to ${registryProxy.origin}`);
Expand Down Expand Up @@ -69,6 +71,8 @@ if (process.env.HTTP_SERVER_DISABLED !== "true") {
});

if (process.env.DASHBOARD_AND_API_DISABLED !== "true") {
app.use(apiRateLimiter);

app.all(
"*",
// @ts-ignore
Expand All @@ -84,8 +88,6 @@ if (process.env.HTTP_SERVER_DISABLED !== "true") {
});
}



const server = app.listen(port, () => {
console.log(`✅ server ready: http://localhost:${port} [NODE_ENV: ${MODE}]`);

Expand Down
37 changes: 31 additions & 6 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.