Skip to content

Commit 2965234

Browse files
committed
Cancel runs and attempts, in prod and dev
1 parent 766b1ee commit 2965234

29 files changed

+784
-129
lines changed

apps/coordinator/package.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"version": "0.0.1",
55
"description": "",
66
"main": "dist/index.cjs",
7-
"type": "module",
87
"scripts": {
98
"build": "npm run build:bundle",
109
"build:bundle": "esbuild src/index.ts --bundle --outfile=dist/index.mjs --platform=node --format=esm --target=esnext --banner:js=\"const require = createRequire(import.meta.url);\"",
@@ -31,4 +30,4 @@
3130
"tsx": "^4.7.0",
3231
"typescript": "^5.3.3"
3332
}
34-
}
33+
}

apps/coordinator/src/index.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,16 @@ class TaskCoordinator {
301301

302302
taskSocket.emit("RESUME_AFTER_DURATION", message);
303303
},
304+
REQUEST_ATTEMPT_CANCELLATION: async (message) => {
305+
const taskSocket = await this.#getAttemptSocket(message.attemptId);
306+
307+
if (!taskSocket) {
308+
logger.log("Socket for attempt not found", { attemptId: message.attemptId });
309+
return;
310+
}
311+
312+
taskSocket.emit("REQUEST_ATTEMPT_CANCELLATION", message);
313+
},
304314
},
305315
});
306316

@@ -385,11 +395,21 @@ class TaskCoordinator {
385395

386396
if (!executionAck) {
387397
logger.error("no execution ack", { attemptId: socket.data.attemptId });
398+
399+
socket.emit("REQUEST_EXIT", {
400+
version: "v1",
401+
});
402+
388403
return;
389404
}
390405

391406
if (!executionAck.success) {
392407
logger.error("execution unsuccessful", { attemptId: socket.data.attemptId });
408+
409+
socket.emit("REQUEST_EXIT", {
410+
version: "v1",
411+
});
412+
393413
return;
394414
}
395415

apps/coordinator/tsconfig.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
"strict": true,
1010
"skipLibCheck": true,
1111
"paths": {
12-
"@trigger.dev/core/v3": ["../core/src/v3"],
13-
"@trigger.dev/core/v3/*": ["../core/src/v3/*"],
14-
"@trigger.dev/core-apps": ["../core-apps/src"],
15-
"@trigger.dev/core-apps/*": ["../core-apps/src/*"]
12+
"@trigger.dev/core/v3": ["../../packages/core/src/v3"],
13+
"@trigger.dev/core/v3/*": ["../../packages/core/src/v3/*"],
14+
"@trigger.dev/core-apps": ["../../packages/core-apps/src"],
15+
"@trigger.dev/core-apps/*": ["../../packages/core-apps/src/*"]
1616
}
1717
}
1818
}

apps/docker-provider/package.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"version": "0.0.1",
55
"description": "",
66
"main": "dist/index.cjs",
7-
"type": "module",
87
"scripts": {
98
"build": "npm run build:bundle",
109
"build:bundle": "esbuild src/index.ts --bundle --outfile=dist/index.mjs --platform=node --format=esm --target=esnext --banner:js=\"const require = createRequire(import.meta.url);\"",
@@ -29,4 +28,4 @@
2928
"tsx": "^4.7.0",
3029
"typescript": "^5.3.3"
3130
}
32-
}
31+
}

apps/docker-provider/tsconfig.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
"strict": true,
88
"skipLibCheck": true,
99
"paths": {
10-
"@trigger.dev/core/v3": ["../core/src/v3"],
11-
"@trigger.dev/core/v3/*": ["../core/src/v3/*"],
12-
"@trigger.dev/core-apps": ["../core-apps/src"],
13-
"@trigger.dev/core-apps/*": ["../core-apps/src/*"]
10+
"@trigger.dev/core/v3": ["../../packages/core/src/v3"],
11+
"@trigger.dev/core/v3/*": ["../../packages/core/src/v3/*"],
12+
"@trigger.dev/core-apps": ["../../packages/core-apps/src"],
13+
"@trigger.dev/core-apps/*": ["../../packages/core-apps/src/*"]
1414
}
1515
}
1616
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import type { ActionFunctionArgs } from "@remix-run/server-runtime";
2+
import { json } from "@remix-run/server-runtime";
3+
import { z } from "zod";
4+
import { prisma } from "~/db.server";
5+
import { authenticateApiRequest } from "~/services/apiAuth.server";
6+
import { CancelTaskRunService } from "~/v3/services/cancelTaskRun.server";
7+
8+
const ParamsSchema = z.object({
9+
runParam: z.string(),
10+
});
11+
12+
export async function action({ request, params }: ActionFunctionArgs) {
13+
// Ensure this is a POST request
14+
if (request.method.toUpperCase() !== "POST") {
15+
return { status: 405, body: "Method Not Allowed" };
16+
}
17+
18+
// Authenticate the request
19+
const authenticationResult = await authenticateApiRequest(request);
20+
21+
if (!authenticationResult) {
22+
return json({ error: "Invalid or Missing API Key" }, { status: 401 });
23+
}
24+
25+
const parsed = ParamsSchema.safeParse(params);
26+
27+
if (!parsed.success) {
28+
return json({ error: "Invalid or Missing runId" }, { status: 400 });
29+
}
30+
31+
const { runParam } = parsed.data;
32+
33+
const taskRun = await prisma.taskRun.findUnique({
34+
where: {
35+
friendlyId: runParam,
36+
},
37+
});
38+
39+
if (!taskRun) {
40+
return json({ error: "Run not found" }, { status: 404 });
41+
}
42+
43+
const service = new CancelTaskRunService();
44+
45+
try {
46+
await service.call(taskRun);
47+
} catch (error) {
48+
return json({ error: "Internal Server Error" }, { status: 500 });
49+
}
50+
51+
return json({ message: "Run cancelled" }, { status: 200 });
52+
}

apps/webapp/app/v3/authenticatedSocketConnection.server.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { AuthenticatedEnvironment } from "~/services/apiAuth.server";
1010
import { logger } from "~/services/logger.server";
1111
import { DevQueueConsumer } from "./marqs/devQueueConsumer.server";
1212
import type { WebSocket, MessageEvent, CloseEvent, ErrorEvent } from "ws";
13+
import { env } from "~/env.server";
1314

1415
export class AuthenticatedSocketConnection {
1516
public id: string;
@@ -26,6 +27,10 @@ export class AuthenticatedSocketConnection {
2627
schema: serverWebsocketMessages,
2728
sender: async (message) => {
2829
return new Promise((resolve, reject) => {
30+
if (!ws.OPEN) {
31+
return reject(new Error("Websocket is not open"));
32+
}
33+
2934
ws.send(JSON.stringify(message), {}, (err) => {
3035
if (err) {
3136
reject(err);
@@ -84,6 +89,8 @@ export class AuthenticatedSocketConnection {
8489
}
8590

8691
async #handleClose(ev: CloseEvent) {
92+
logger.debug("[AuthenticatedSocketConnection] Websocket closed", { ev });
93+
8794
await this._consumer.stop();
8895

8996
this.onClose.post(ev);

apps/webapp/app/v3/eventRepository.server.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export type TraceAttributes = Partial<
4444
CreatableEvent,
4545
| "attemptId"
4646
| "isError"
47+
| "isCancelled"
4748
| "runId"
4849
| "runIsTest"
4950
| "output"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import { z } from "zod";
2+
import { singleton } from "~/utils/singleton";
3+
import { ZodPubSub, ZodSubscriber } from "../utils/zodPubSub.server";
4+
import { env } from "~/env.server";
5+
6+
const messageCatalog = {
7+
CANCEL_ATTEMPT: z.object({
8+
version: z.literal("v1").default("v1"),
9+
backgroundWorkerId: z.string(),
10+
attemptId: z.string(),
11+
taskRunId: z.string(),
12+
}),
13+
};
14+
15+
export type DevSubscriber = ZodSubscriber<typeof messageCatalog>;
16+
17+
export const devPubSub = singleton("devPubSub", initializeDevPubSub);
18+
19+
function initializeDevPubSub() {
20+
return new ZodPubSub({
21+
redis: {
22+
port: env.REDIS_PORT,
23+
host: env.REDIS_HOST,
24+
username: env.REDIS_USERNAME,
25+
password: env.REDIS_PASSWORD,
26+
enableAutoPipelining: true,
27+
...(env.REDIS_TLS_DISABLED === "true" ? {} : { tls: {} }),
28+
},
29+
schema: {
30+
CANCEL_ATTEMPT: z.object({
31+
version: z.literal("v1").default("v1"),
32+
backgroundWorkerId: z.string(),
33+
attemptId: z.string(),
34+
taskRunId: z.string(),
35+
}),
36+
},
37+
});
38+
}

apps/webapp/app/v3/marqs/devQueueConsumer.server.ts

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import { marqs } from "../marqs.server";
1717
import { CancelAttemptService } from "../services/cancelAttempt.server";
1818
import { CompleteAttemptService } from "../services/completeAttempt.server";
1919
import { attributesFromAuthenticatedEnv } from "../tracer.server";
20+
import { DevSubscriber, devPubSub } from "./devPubSub.server";
2021

2122
const tracer = trace.getTracer("devQueueConsumer");
2223

@@ -36,9 +37,11 @@ export type DevQueueConsumerOptions = {
3637

3738
export class DevQueueConsumer {
3839
private _backgroundWorkers: Map<string, BackgroundWorkerWithTasks> = new Map();
40+
private _backgroundWorkerSubscriber: Map<string, DevSubscriber> = new Map();
3941
private _deprecatedWorkers: Map<string, BackgroundWorkerWithTasks> = new Map();
4042
private _enabled = false;
41-
private _options: Required<DevQueueConsumerOptions>;
43+
private _maximumItemsPerTrace: number;
44+
private _traceTimeoutSeconds: number;
4245
private _perTraceCountdown: number | undefined;
4346
private _lastNewTrace: Date | undefined;
4447
private _currentSpanContext: Context | undefined;
@@ -51,12 +54,10 @@ export class DevQueueConsumer {
5154
constructor(
5255
public env: AuthenticatedEnvironment,
5356
private _sender: ZodMessageSender<typeof serverWebsocketMessages>,
54-
options: DevQueueConsumerOptions = {}
57+
private _options: DevQueueConsumerOptions = {}
5558
) {
56-
this._options = {
57-
maximumItemsPerTrace: options.maximumItemsPerTrace ?? 1_000, // 1k items per trace
58-
traceTimeoutSeconds: options.traceTimeoutSeconds ?? 60, // 60 seconds
59-
};
59+
this._traceTimeoutSeconds = _options.traceTimeoutSeconds ?? 60;
60+
this._maximumItemsPerTrace = _options.maximumItemsPerTrace ?? 1_000;
6061
}
6162

6263
// This method is called when a background worker is deprecated and will no longer be used unless a run is locked to it
@@ -87,6 +88,21 @@ export class DevQueueConsumer {
8788

8889
logger.debug("Registered background worker", { backgroundWorker: backgroundWorker.id });
8990

91+
const subscriber = await devPubSub.subscribe(`backgroundWorker:${backgroundWorker.id}:*`);
92+
93+
subscriber.on("CANCEL_ATTEMPT", async (message) => {
94+
await this._sender.send("BACKGROUND_WORKER_MESSAGE", {
95+
backgroundWorkerId: backgroundWorker.friendlyId,
96+
data: {
97+
type: "CANCEL_ATTEMPT",
98+
taskAttemptId: message.attemptId,
99+
taskRunId: message.taskRunId,
100+
},
101+
});
102+
});
103+
104+
this._backgroundWorkerSubscriber.set(backgroundWorker.id, subscriber);
105+
90106
// Start reading from the queue if we haven't already
91107
this.#enable();
92108
}
@@ -133,6 +149,16 @@ export class DevQueueConsumer {
133149

134150
// We need to cancel all the in progress task run attempts and ack the messages so they will stop processing
135151
await this.#cancelInProgressAttempts(reason);
152+
153+
// We need to unsubscribe from the background worker channels
154+
for (const [id, subscriber] of this._backgroundWorkerSubscriber) {
155+
logger.debug("Unsubscribing from background worker channel", { id });
156+
157+
await subscriber.stopListening();
158+
this._backgroundWorkerSubscriber.delete(id);
159+
160+
logger.debug("Unsubscribed from background worker channel", { id });
161+
}
136162
}
137163

138164
async #cancelInProgressAttempts(reason: string) {
@@ -144,6 +170,10 @@ export class DevQueueConsumer {
144170

145171
this._inProgressAttempts.clear();
146172

173+
logger.debug("Cancelling in progress attempts", {
174+
attempts: Array.from(inProgressAttempts.keys()),
175+
});
176+
147177
for (const [attemptId, messageId] of inProgressAttempts) {
148178
await this.#cancelInProgressAttempt(attemptId, messageId, service, cancelledAt, reason);
149179
}
@@ -156,6 +186,8 @@ export class DevQueueConsumer {
156186
cancelledAt: Date,
157187
reason: string
158188
) {
189+
logger.debug("Cancelling in progress attempt", { attemptId, messageId });
190+
159191
try {
160192
await cancelAttemptService.call(attemptId, messageId, cancelledAt, reason, this.env);
161193
} catch (e) {
@@ -189,7 +221,7 @@ export class DevQueueConsumer {
189221
// Check if the trace has expired
190222
if (
191223
this._perTraceCountdown === 0 ||
192-
Date.now() - this._lastNewTrace!.getTime() > this._options.traceTimeoutSeconds * 1000 ||
224+
Date.now() - this._lastNewTrace!.getTime() > this._traceTimeoutSeconds * 1000 ||
193225
this._currentSpanContext === undefined ||
194226
this._endSpanInNextIteration
195227
) {
@@ -366,6 +398,7 @@ export class DevQueueConsumer {
366398
backgroundWorkerTaskId: backgroundTask.id,
367399
status: "EXECUTING" as const,
368400
queueId: queue.id,
401+
runtimeEnvironmentId: this.env.id,
369402
},
370403
});
371404

@@ -442,6 +475,11 @@ export class DevQueueConsumer {
442475
},
443476
});
444477

478+
logger.debug("Saving the in progress attempt", {
479+
taskRunAttempt: taskRunAttempt.id,
480+
messageId: message.messageId,
481+
});
482+
445483
this._inProgressAttempts.set(taskRunAttempt.friendlyId, message.messageId);
446484
} catch (e) {
447485
if (e instanceof Error) {

0 commit comments

Comments
 (0)