Skip to content

Commit 9523366

Browse files
committed
Handle revalidating JWT tokens
1 parent 3bf07ad commit 9523366

File tree

5 files changed

+136
-38
lines changed

5 files changed

+136
-38
lines changed

apps/webapp/app/routes/api.v1.usage.ingest.ts

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { ActionFunctionArgs } from "@remix-run/server-runtime";
22
import { MachinePresetName } from "@trigger.dev/core/v3";
33
import { z } from "zod";
44
import { prisma } from "~/db.server";
5-
import { validateJWTToken } from "~/services/apiAuth.server";
5+
import { validateJWTTokenAndRenew } from "~/services/apiAuth.server";
66
import { logger } from "~/services/logger.server";
77
import { workerQueue } from "~/services/worker.server";
88
import { machinePresetFromName } from "~/v3/machinePresets.server";
@@ -26,16 +26,12 @@ export async function action({ request }: ActionFunctionArgs) {
2626
return { status: 405, body: "Method Not Allowed" };
2727
}
2828

29-
const jwt = request.headers.get("x-trigger-jwt");
29+
const jwtResult = await validateJWTTokenAndRenew(request, JWTPayloadSchema);
3030

31-
if (!jwt) {
31+
if (!jwtResult) {
3232
return { status: 401, body: "Unauthorized" };
3333
}
3434

35-
logger.debug("Validating JWT", { jwt });
36-
37-
const jwtPayload = await validateJWTToken(jwt, JWTPayloadSchema);
38-
3935
const rawJson = await request.json();
4036

4137
const json = BodySchema.safeParse(rawJson);
@@ -46,16 +42,16 @@ export async function action({ request }: ActionFunctionArgs) {
4642
return { status: 400, body: "Bad Request" };
4743
}
4844

49-
const preset = machinePresetFromName(jwtPayload.machine_preset as MachinePresetName);
45+
const preset = machinePresetFromName(jwtResult.payload.machine_preset as MachinePresetName);
5046

51-
logger.debug("Validated JWT", { jwtPayload, json: json.data, preset });
47+
logger.debug("[/api/v1/usage/ingest] Reporting usage", { jwtResult, json: json.data, preset });
5248

53-
if (json.data.durationMs > 10) {
49+
if (json.data.durationMs > 0) {
5450
const costInCents = json.data.durationMs * preset.centsPerMs;
5551

5652
await prisma.taskRun.update({
5753
where: {
58-
id: jwtPayload.run_id,
54+
id: jwtResult.payload.run_id,
5955
},
6056
data: {
6157
usageDurationMs: {
@@ -71,7 +67,7 @@ export async function action({ request }: ActionFunctionArgs) {
7167
await reportUsageEvent({
7268
source: "webapp",
7369
type: "usage",
74-
subject: jwtPayload.org_id,
70+
subject: jwtResult.payload.org_id,
7571
data: {
7672
durationMs: json.data.durationMs,
7773
costInCents: String(costInCents),
@@ -81,7 +77,7 @@ export async function action({ request }: ActionFunctionArgs) {
8177
logger.error("Failed to report usage event, enqueing v3.reportUsage", { error: e });
8278

8379
await workerQueue.enqueue("v3.reportUsage", {
84-
orgId: jwtPayload.org_id,
80+
orgId: jwtResult.payload.org_id,
8581
data: {
8682
costInCents: String(costInCents),
8783
},
@@ -94,5 +90,8 @@ export async function action({ request }: ActionFunctionArgs) {
9490

9591
return new Response(null, {
9692
status: 200,
93+
headers: {
94+
"x-trigger-jwt": jwtResult.jwt,
95+
},
9796
});
9897
}

apps/webapp/app/services/apiAuth.server.ts

Lines changed: 104 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ import {
1313
import { prisma } from "~/db.server";
1414
import { json } from "@remix-run/server-runtime";
1515
import { findProjectByRef } from "~/models/project.server";
16-
import { SignJWT, jwtVerify } from "jose";
16+
import { SignJWT, jwtVerify, errors } from "jose";
1717
import { env } from "~/env.server";
18+
import { logger } from "./logger.server";
1819

1920
type Optional<T, K extends keyof T> = Prettify<Omit<T, K> & Partial<Pick<T, K>>>;
2021

@@ -213,40 +214,126 @@ export async function authenticatedEnvironmentForAuthentication(
213214
}
214215
}
215216

217+
const JWT_SECRET = new TextEncoder().encode(env.SESSION_SECRET);
218+
const JWT_ALGORITHM = "HS256";
219+
const DEFAULT_JWT_EXPIRATION_IN_MS = 1000 * 60 * 60; // 1 hour
220+
216221
export async function generateJWTTokenForEnvironment(
217222
environment: RuntimeEnvironment,
218223
payload: Record<string, string>
219224
) {
220-
const secret = new TextEncoder().encode(env.SESSION_SECRET);
221-
222-
const alg = "HS256";
223-
224225
const jwt = await new SignJWT({
225226
environment_id: environment.id,
226227
org_id: environment.organizationId,
227228
project_id: environment.projectId,
228229
...payload,
229230
})
230-
.setProtectedHeader({ alg })
231+
.setProtectedHeader({ alg: JWT_ALGORITHM })
231232
.setIssuedAt()
232233
.setIssuer("https://id.trigger.dev")
233234
.setAudience("https://api.trigger.dev")
234-
.setExpirationTime("24h")
235-
.sign(secret);
235+
.setExpirationTime(calculateJWTExpiration())
236+
.sign(JWT_SECRET);
236237

237238
return jwt;
238239
}
239240

240-
export async function validateJWTToken<T extends z.ZodTypeAny>(
241-
jwt: string,
241+
export async function validateJWTTokenAndRenew<T extends z.ZodTypeAny>(
242+
request: Request,
242243
payloadSchema: T
243-
): Promise<z.infer<T>> {
244-
const secret = new TextEncoder().encode(env.SESSION_SECRET);
244+
): Promise<{ payload: z.infer<T>; jwt: string } | undefined> {
245+
try {
246+
const jwt = request.headers.get("x-trigger-jwt");
247+
248+
if (!jwt) {
249+
logger.debug("Missing JWT token in request", {
250+
headers: Object.fromEntries(request.headers),
251+
});
252+
253+
return;
254+
}
255+
256+
const { payload: rawPayload } = await jwtVerify(jwt, JWT_SECRET, {
257+
issuer: "https://id.trigger.dev",
258+
audience: "https://api.trigger.dev",
259+
});
260+
261+
const payload = payloadSchema.safeParse(rawPayload);
262+
263+
if (!payload.success) {
264+
logger.error("Failed to validate JWT", { payload: rawPayload, issues: payload.error.issues });
265+
266+
return;
267+
}
268+
269+
const renewedJwt = await renewJWTToken(payload.data);
270+
271+
return {
272+
payload: payload.data,
273+
jwt: renewedJwt,
274+
};
275+
} catch (error) {
276+
if (error instanceof errors.JWTExpired) {
277+
// Now we need to try and renew the token using the API key auth
278+
const authenticatedEnv = await authenticateApiRequest(request);
279+
280+
if (!authenticatedEnv) {
281+
logger.error("Failed to renew JWT token, missing or invalid Authorization header", {
282+
error: error.message,
283+
});
284+
285+
return;
286+
}
287+
288+
const payload = payloadSchema.safeParse(error.payload);
245289

246-
const { payload, protectedHeader } = await jwtVerify(jwt, secret, {
247-
issuer: "https://id.trigger.dev",
248-
audience: "https://api.trigger.dev",
249-
});
290+
if (!payload.success) {
291+
logger.error("Failed to parse jwt payload after expired", {
292+
payload: error.payload,
293+
issues: payload.error.issues,
294+
});
295+
296+
return;
297+
}
298+
299+
const renewedJwt = await generateJWTTokenForEnvironment(authenticatedEnv.environment, {
300+
...payload.data,
301+
});
302+
303+
logger.debug("Renewed JWT token from Authorization header API Key", {
304+
environment: authenticatedEnv.environment,
305+
payload: payload.data,
306+
});
307+
308+
return {
309+
payload: payload.data,
310+
jwt: renewedJwt,
311+
};
312+
}
313+
314+
logger.error("Failed to validate JWT token", { error });
315+
}
316+
}
317+
318+
async function renewJWTToken(payload: Record<string, string>) {
319+
const jwt = await new SignJWT(payload)
320+
.setProtectedHeader({ alg: JWT_ALGORITHM })
321+
.setIssuedAt()
322+
.setIssuer("https://id.trigger.dev")
323+
.setAudience("https://api.trigger.dev")
324+
.setExpirationTime(calculateJWTExpiration())
325+
.sign(JWT_SECRET);
326+
327+
return jwt;
328+
}
329+
330+
function calculateJWTExpiration() {
331+
if (env.PROD_USAGE_HEARTBEAT_INTERVAL_MS) {
332+
return (
333+
(Date.now() + Math.max(DEFAULT_JWT_EXPIRATION_IN_MS, env.PROD_USAGE_HEARTBEAT_INTERVAL_MS)) /
334+
1000
335+
);
336+
}
250337

251-
return payloadSchema.parse(payload);
338+
return (Date.now() + DEFAULT_JWT_EXPIRATION_IN_MS) / 1000;
252339
}

apps/webapp/app/v3/openMeter.server.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ export async function reportUsageEvent(event: UsageEvent) {
2929

3030
const url = `${env.USAGE_OPEN_METER_BASE_URL}/api/v1/events`;
3131

32-
logger.debug("Reporting usage event", { url, body });
32+
logger.debug("Reporting usage event to OpenMeter", { url, body });
3333

3434
const response = await fetch(url, {
3535
method: "POST",
@@ -42,6 +42,6 @@ export async function reportUsageEvent(event: UsageEvent) {
4242
});
4343

4444
if (!response.ok) {
45-
throw new Error(`Failed to report usage event: ${response.status} ${response.statusText}`);
45+
logger.error(`Failed to report usage event: ${response.status} ${response.statusText}`);
4646
}
4747
}

packages/core/src/v3/usage/usageClient.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import { apiClientManager } from "../apiClientManager-api";
2+
13
export type UsageClientOptions = {
24
token: string;
35
baseUrl: string;
@@ -10,20 +12,29 @@ export type UsageEvent = {
1012
export class UsageClient {
1113
constructor(
1214
private readonly url: string,
13-
private readonly jwt: string
15+
private jwt: string
1416
) {}
1517

1618
async sendUsageEvent(event: UsageEvent): Promise<void> {
1719
try {
18-
await fetch(this.url, {
20+
const response = await fetch(this.url, {
1921
method: "POST",
2022
body: JSON.stringify(event),
2123
headers: {
2224
"content-type": "application/json",
2325
"x-trigger-jwt": this.jwt,
2426
accept: "application/json",
27+
authorization: `Bearer ${apiClientManager.accessToken}`, // this is used to renew the JWT
2528
},
2629
});
30+
31+
if (response.ok) {
32+
const renewedJwt = response.headers.get("x-trigger-jwt");
33+
34+
if (renewedJwt) {
35+
this.jwt = renewedJwt;
36+
}
37+
}
2738
} catch (error) {
2839
console.error(`Failed to send usage event: ${error}`);
2940
}

references/v3-catalog/src/trigger/longRunning.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ import { logger, task, wait } from "@trigger.dev/sdk/v3";
33
export const longRunning = task({
44
id: "long-running",
55
run: async (payload: { message: string }, { ctx }) => {
6-
logger.info("Long running payloadddd", { payload });
6+
logger.info("Long running", { payload });
77

8-
// Wait for 3 minutes
9-
await new Promise((resolve) => setTimeout(resolve, 5000));
8+
await new Promise((resolve) => setTimeout(resolve, 20000));
9+
10+
await wait.for({ seconds: 10 });
1011

11-
await wait.for({ seconds: 5 });
12+
await new Promise((resolve) => setTimeout(resolve, 20000));
1213
},
1314
});
1415

0 commit comments

Comments
 (0)