Skip to content

Commit 113413b

Browse files
committed
manual checkpoints
1 parent 899ca3f commit 113413b

File tree

6 files changed

+229
-4
lines changed

6 files changed

+229
-4
lines changed

apps/coordinator/src/index.ts

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,6 +1492,95 @@ class TaskCoordinator {
14921492
});
14931493
}
14941494
}
1495+
case "/checkpoint/manual": {
1496+
try {
1497+
const body = await getTextBody(req);
1498+
const json = safeJsonParse(body);
1499+
1500+
if (typeof json !== "object" || !json) {
1501+
return reply.text("Invalid body", 400);
1502+
}
1503+
1504+
if (!("runId" in json) || typeof json.runId !== "string") {
1505+
return reply.text("Missing or invalid: runId", 400);
1506+
}
1507+
1508+
let restoreAtUnixTimeMs: number | undefined;
1509+
if ("restoreAtUnixTimeMs" in json && typeof json.restoreAtUnixTimeMs === "number") {
1510+
restoreAtUnixTimeMs = json.restoreAtUnixTimeMs;
1511+
}
1512+
1513+
let keepRunAlive = false;
1514+
if ("keepRunAlive" in json && typeof json.keepRunAlive === "boolean") {
1515+
keepRunAlive = json.keepRunAlive;
1516+
}
1517+
1518+
const { runId } = json;
1519+
1520+
if (!runId) {
1521+
return reply.text("Missing runId", 400);
1522+
}
1523+
1524+
const runSocket = await this.#getRunSocket(runId);
1525+
if (!runSocket) {
1526+
return reply.text("Run socket not found", 404);
1527+
}
1528+
1529+
const { data } = runSocket;
1530+
1531+
console.log("Manual checkpoint", data);
1532+
1533+
const checkpoint = await this.#checkpointer.checkpointAndPush({
1534+
runId: data.runId,
1535+
projectRef: data.projectRef,
1536+
deploymentVersion: data.deploymentVersion,
1537+
attemptNumber: data.attemptNumber ? parseInt(data.attemptNumber) : undefined,
1538+
});
1539+
1540+
if (!checkpoint) {
1541+
return reply.text("Failed to checkpoint", 500);
1542+
}
1543+
1544+
if (!data.attemptFriendlyId) {
1545+
return reply.text("Socket data missing attemptFriendlyId", 500);
1546+
}
1547+
1548+
const ack = await this.#platformSocket?.sendWithAck("CHECKPOINT_CREATED", {
1549+
version: "v1",
1550+
runId,
1551+
attemptFriendlyId: data.attemptFriendlyId,
1552+
docker: checkpoint.docker,
1553+
location: checkpoint.location,
1554+
reason: {
1555+
type: "MANUAL",
1556+
restoreAtUnixTimeMs,
1557+
},
1558+
});
1559+
1560+
if (ack?.keepRunAlive || keepRunAlive) {
1561+
return reply.json({
1562+
message: `keeping run ${runId} alive after checkpoint`,
1563+
checkpoint,
1564+
requestJson: json,
1565+
});
1566+
}
1567+
1568+
runSocket.emit("REQUEST_EXIT", {
1569+
version: "v1",
1570+
});
1571+
1572+
return reply.json({
1573+
message: `checkpoint created for run ${runId}`,
1574+
checkpoint,
1575+
requestJson: json,
1576+
});
1577+
} catch (error) {
1578+
return reply.json({
1579+
message: `error`,
1580+
error,
1581+
});
1582+
}
1583+
}
14951584
default: {
14961585
return reply.empty(404);
14971586
}

apps/webapp/app/database-types.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,11 @@ export const RuntimeEnvironmentType = {
6969
DEVELOPMENT: "DEVELOPMENT",
7070
PREVIEW: "PREVIEW",
7171
} as const satisfies Record<RuntimeEnvironmentTypeType, RuntimeEnvironmentTypeType>;
72+
73+
export function isTaskRunAttemptStatus(value: string): value is keyof typeof TaskRunAttemptStatus {
74+
return Object.values(TaskRunAttemptStatus).includes(value as keyof typeof TaskRunAttemptStatus);
75+
}
76+
77+
export function isTaskRunStatus(value: string): value is keyof typeof TaskRunStatus {
78+
return Object.values(TaskRunStatus).includes(value as keyof typeof TaskRunStatus);
79+
}

apps/webapp/app/v3/services/createCheckpoint.server.ts

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { CoordinatorToPlatformMessages } from "@trigger.dev/core/v3";
1+
import { CoordinatorToPlatformMessages, ManualCheckpointMetadata } from "@trigger.dev/core/v3";
22
import type { InferSocketMessageSchema } from "@trigger.dev/core/v3/zodSocket";
33
import type { Checkpoint, CheckpointRestoreEvent } from "@trigger.dev/database";
44
import { logger } from "~/services/logger.server";
@@ -101,6 +101,19 @@ export class CreateCheckpointService extends BaseService {
101101
// setTimeout(resolve, waitSeconds * 1000);
102102
// });
103103

104+
let metadata: string;
105+
106+
if (params.reason.type === "MANUAL") {
107+
metadata = JSON.stringify({
108+
...params.reason,
109+
attemptId: attempt.id,
110+
previousAttemptStatus: attempt.status,
111+
previousRunStatus: attempt.taskRun.status,
112+
} satisfies ManualCheckpointMetadata);
113+
} else {
114+
metadata = JSON.stringify(params.reason);
115+
}
116+
104117
const checkpoint = await this._prisma.checkpoint.create({
105118
data: {
106119
friendlyId: generateFriendlyId("checkpoint"),
@@ -112,7 +125,7 @@ export class CreateCheckpointService extends BaseService {
112125
location: params.location,
113126
type: params.docker ? "DOCKER" : "KUBERNETES",
114127
reason: params.reason.type,
115-
metadata: JSON.stringify(params.reason),
128+
metadata,
116129
imageRef,
117130
},
118131
});
@@ -138,7 +151,17 @@ export class CreateCheckpointService extends BaseService {
138151
let checkpointEvent: CheckpointRestoreEvent | undefined;
139152

140153
switch (reason.type) {
154+
case "MANUAL":
141155
case "WAIT_FOR_DURATION": {
156+
let restoreAtUnixTimeMs: number;
157+
158+
if (reason.type === "MANUAL") {
159+
// Restore immediately if not specified, useful for live migration
160+
restoreAtUnixTimeMs = reason.restoreAtUnixTimeMs ?? Date.now();
161+
} else {
162+
restoreAtUnixTimeMs = reason.now + reason.ms;
163+
}
164+
142165
checkpointEvent = await eventService.checkpoint({
143166
checkpointId: checkpoint.id,
144167
});
@@ -151,7 +174,7 @@ export class CreateCheckpointService extends BaseService {
151174
resumableAttemptId: attempt.id,
152175
checkpointEventId: checkpointEvent.id,
153176
},
154-
reason.now + reason.ms
177+
restoreAtUnixTimeMs
155178
);
156179

157180
return {

apps/webapp/app/v3/services/createCheckpointRestoreEvent.server.ts

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
import type { CheckpointRestoreEvent, CheckpointRestoreEventType } from "@trigger.dev/database";
1+
import type {
2+
Checkpoint,
3+
CheckpointRestoreEvent,
4+
CheckpointRestoreEventType,
5+
} from "@trigger.dev/database";
26
import { logger } from "~/services/logger.server";
37
import { BaseService } from "./baseService.server";
8+
import { ManualCheckpointMetadata } from "@trigger.dev/core/v3";
9+
import { isTaskRunAttemptStatus, isTaskRunStatus, TaskRunAttemptStatus } from "~/database-types";
10+
import { safeJsonParse } from "~/utils/json";
411

512
interface CheckpointRestoreEventCallParams {
613
checkpointId: string;
@@ -39,6 +46,13 @@ export class CreateCheckpointRestoreEventService extends BaseService {
3946
return;
4047
}
4148

49+
if (params.type === "RESTORE" && checkpoint.reason === "MANUAL") {
50+
const manualRestoreSuccess = await this.#handleManualCheckpointRestore(checkpoint);
51+
if (!manualRestoreSuccess) {
52+
return;
53+
}
54+
}
55+
4256
logger.debug(`Creating checkpoint/restore event`, { params });
4357

4458
let taskRunDependencyId: string | undefined;
@@ -99,4 +113,81 @@ export class CreateCheckpointRestoreEventService extends BaseService {
99113

100114
return checkpointEvent;
101115
}
116+
117+
async #handleManualCheckpointRestore(checkpoint: Checkpoint): Promise<boolean> {
118+
const json = checkpoint.metadata ? safeJsonParse(checkpoint.metadata) : undefined;
119+
120+
// We need to restore the previous run and attempt status as saved in the metadata
121+
const metadata = ManualCheckpointMetadata.safeParse(json);
122+
123+
if (!metadata.success) {
124+
logger.error("Invalid metadata", { metadata });
125+
return false;
126+
}
127+
128+
const { attemptId, previousAttemptStatus, previousRunStatus } = metadata.data;
129+
130+
if (!isTaskRunAttemptStatus(previousAttemptStatus)) {
131+
logger.error("Invalid previous attempt status", { previousAttemptStatus });
132+
return false;
133+
}
134+
135+
if (!isTaskRunStatus(previousRunStatus)) {
136+
logger.error("Invalid previous run status", { previousRunStatus });
137+
return false;
138+
}
139+
140+
try {
141+
const updatedAttempt = await this._prisma.taskRunAttempt.update({
142+
where: {
143+
id: attemptId,
144+
},
145+
data: {
146+
status: previousAttemptStatus,
147+
taskRun: {
148+
update: {
149+
data: {
150+
status: previousRunStatus,
151+
},
152+
},
153+
},
154+
},
155+
select: {
156+
id: true,
157+
status: true,
158+
taskRun: {
159+
select: {
160+
id: true,
161+
status: true,
162+
},
163+
},
164+
},
165+
});
166+
167+
logger.debug("Set post resume statuses after manual checkpoint", {
168+
run: {
169+
id: updatedAttempt.taskRun.id,
170+
status: updatedAttempt.taskRun.status,
171+
},
172+
attempt: {
173+
id: updatedAttempt.id,
174+
status: updatedAttempt.status,
175+
},
176+
});
177+
178+
return true;
179+
} catch (error) {
180+
logger.error("Failed to set post resume statuses", {
181+
error:
182+
error instanceof Error
183+
? {
184+
name: error.name,
185+
message: error.message,
186+
stack: error.stack,
187+
}
188+
: error,
189+
});
190+
return false;
191+
}
192+
}
102193
}

packages/core/src/v3/schemas/messages.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,11 @@ export const CoordinatorToPlatformMessages = {
479479
type: z.literal("RETRYING_AFTER_FAILURE"),
480480
attemptNumber: z.number(),
481481
}),
482+
z.object({
483+
type: z.literal("MANUAL"),
484+
/** If unspecified it will be restored immediately, e.g. for live migration */
485+
restoreAtUnixTimeMs: z.number().optional(),
486+
}),
482487
]),
483488
}),
484489
callback: z.object({

packages/core/src/v3/schemas/schemas.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,12 @@ export const TaskRunExecutionLazyAttemptPayload = z.object({
250250
});
251251

252252
export type TaskRunExecutionLazyAttemptPayload = z.infer<typeof TaskRunExecutionLazyAttemptPayload>;
253+
254+
export const ManualCheckpointMetadata = z.object({
255+
/** NOT a friendly ID */
256+
attemptId: z.string(),
257+
previousRunStatus: z.string(),
258+
previousAttemptStatus: z.string(),
259+
});
260+
261+
export type ManualCheckpointMetadata = z.infer<typeof ManualCheckpointMetadata>;

0 commit comments

Comments
 (0)