Skip to content

Commit 26472d3

Browse files
committed
rework suspend restore
1 parent e979dd9 commit 26472d3

File tree

16 files changed

+170
-124
lines changed

16 files changed

+170
-124
lines changed

apps/supervisor/src/env.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ const Env = z.object({
4545
// Used by the resource monitor
4646
OVERRIDE_CPU_TOTAL: z.coerce.number().optional(),
4747
OVERRIDE_MEMORY_TOTAL_GB: z.coerce.number().optional(),
48+
49+
// Kubernetes specific settings
50+
KUBERNETES_FORCE_ENABLED: BoolEnv.default(false),
51+
KUBERNETES_NAMESPACE: z.string().default("default"),
52+
EPHEMERAL_STORAGE_SIZE_LIMIT: z.string().default("10Gi"),
53+
EPHEMERAL_STORAGE_SIZE_REQUEST: z.string().default("2Gi"),
4854
});
4955

5056
export const env = Env.parse(stdEnv);

apps/supervisor/src/index.ts

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ import {
1313
} from "./resourceMonitor.js";
1414
import { KubernetesWorkloadManager } from "./workloadManager/kubernetes.js";
1515
import { DockerWorkloadManager } from "./workloadManager/docker.js";
16-
import { HttpServer, CheckpointClient } from "@trigger.dev/core/v3/serverOnly";
16+
import {
17+
HttpServer,
18+
CheckpointClient,
19+
isKubernetesEnvironment,
20+
} from "@trigger.dev/core/v3/serverOnly";
1721
import { createK8sApi, RUNTIME_ENV } from "./clients/kubernetes.js";
1822

1923
class ManagedSupervisor {
@@ -25,7 +29,7 @@ class ManagedSupervisor {
2529
private readonly resourceMonitor: ResourceMonitor;
2630
private readonly checkpointClient?: CheckpointClient;
2731

28-
private readonly isKubernetes = RUNTIME_ENV === "kubernetes";
32+
private readonly isKubernetes = isKubernetesEnvironment(env.KUBERNETES_FORCE_ENABLED);
2933
private readonly warmStartUrl = env.TRIGGER_WARM_START_URL;
3034

3135
constructor() {
@@ -94,6 +98,7 @@ class ManagedSupervisor {
9498
this.checkpointClient = new CheckpointClient({
9599
apiUrl: new URL(env.TRIGGER_CHECKPOINT_URL),
96100
workerClient: this.workerSession.httpClient,
101+
orchestrator: this.isKubernetes ? "KUBERNETES" : "DOCKER",
97102
});
98103
}
99104

@@ -127,7 +132,9 @@ class ManagedSupervisor {
127132
return;
128133
}
129134

130-
if (message.checkpoint) {
135+
const { checkpoint, ...rest } = message;
136+
137+
if (checkpoint) {
131138
this.logger.log("[ManagedWorker] Restoring run", { runId: message.run.id });
132139

133140
if (!this.checkpointClient) {
@@ -139,7 +146,10 @@ class ManagedSupervisor {
139146
const didRestore = await this.checkpointClient.restoreRun({
140147
runFriendlyId: message.run.friendlyId,
141148
snapshotFriendlyId: message.snapshot.friendlyId,
142-
checkpoint: message.checkpoint,
149+
body: {
150+
...rest,
151+
checkpoint,
152+
},
143153
});
144154

145155
if (didRestore) {

apps/supervisor/src/workloadManager/kubernetes.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,14 @@ import type { EnvironmentType, MachinePreset } from "@trigger.dev/core/v3";
99
import { env } from "../env.js";
1010
import { type K8sApi, createK8sApi, type k8s } from "../clients/kubernetes.js";
1111

12-
const POD_EPHEMERAL_STORAGE_SIZE_LIMIT = process.env.POD_EPHEMERAL_STORAGE_SIZE_LIMIT || "10Gi";
13-
const POD_EPHEMERAL_STORAGE_SIZE_REQUEST = process.env.POD_EPHEMERAL_STORAGE_SIZE_REQUEST || "2Gi";
14-
1512
type ResourceQuantities = {
1613
[K in "cpu" | "memory" | "ephemeral-storage"]?: string;
1714
};
1815

1916
export class KubernetesWorkloadManager implements WorkloadManager {
2017
private readonly logger = new SimpleStructuredLogger("kubernetes-workload-provider");
2118
private k8s: K8sApi;
22-
private namespace = "default";
19+
private namespace = env.KUBERNETES_NAMESPACE;
2320

2421
constructor(private opts: WorkloadManagerOptions) {
2522
this.k8s = createK8sApi();
@@ -205,13 +202,13 @@ export class KubernetesWorkloadManager implements WorkloadManager {
205202

206203
get #defaultResourceRequests(): ResourceQuantities {
207204
return {
208-
"ephemeral-storage": POD_EPHEMERAL_STORAGE_SIZE_REQUEST,
205+
"ephemeral-storage": env.EPHEMERAL_STORAGE_SIZE_REQUEST,
209206
};
210207
}
211208

212209
get #defaultResourceLimits(): ResourceQuantities {
213210
return {
214-
"ephemeral-storage": POD_EPHEMERAL_STORAGE_SIZE_LIMIT,
211+
"ephemeral-storage": env.EPHEMERAL_STORAGE_SIZE_LIMIT,
215212
};
216213
}
217214

apps/supervisor/src/workloadServer/index.ts

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ export class WorkloadServer extends EventEmitter<WorkloadServerEvents> {
9494
this.websocketServer = this.createWebsocketServer();
9595
}
9696

97-
private runnerIdFromRequest(req: IncomingMessage): string | undefined {
98-
const value = req.headers[WORKLOAD_HEADERS.RUNNER_ID];
97+
private headerValueFromRequest(req: IncomingMessage, headerName: string): string | undefined {
98+
const value = req.headers[headerName];
9999

100100
if (Array.isArray(value)) {
101101
return value[0];
@@ -104,6 +104,22 @@ export class WorkloadServer extends EventEmitter<WorkloadServerEvents> {
104104
return value;
105105
}
106106

107+
private runnerIdFromRequest(req: IncomingMessage): string | undefined {
108+
return this.headerValueFromRequest(req, WORKLOAD_HEADERS.RUNNER_ID);
109+
}
110+
111+
private deploymentIdFromRequest(req: IncomingMessage): string | undefined {
112+
return this.headerValueFromRequest(req, WORKLOAD_HEADERS.DEPLOYMENT_ID);
113+
}
114+
115+
private deploymentVersionFromRequest(req: IncomingMessage): string | undefined {
116+
return this.headerValueFromRequest(req, WORKLOAD_HEADERS.DEPLOYMENT_VERSION);
117+
}
118+
119+
private projectRefFromRequest(req: IncomingMessage): string | undefined {
120+
return this.headerValueFromRequest(req, WORKLOAD_HEADERS.PROJECT_REF);
121+
}
122+
107123
private createHttpServer({ host, port }: { host: string; port: number }) {
108124
return new HttpServer({ port, host })
109125
.route(
@@ -213,8 +229,10 @@ export class WorkloadServer extends EventEmitter<WorkloadServerEvents> {
213229
}
214230

215231
const runnerId = this.runnerIdFromRequest(req);
232+
const deploymentVersion = this.deploymentVersionFromRequest(req);
233+
const projectRef = this.projectRefFromRequest(req);
216234

217-
if (!runnerId) {
235+
if (!runnerId || !deploymentVersion || !projectRef) {
218236
console.error("Invalid headers for suspend request", {
219237
...params,
220238
headers: req.headers,
@@ -241,8 +259,13 @@ export class WorkloadServer extends EventEmitter<WorkloadServerEvents> {
241259
const suspendResult = await this.checkpointClient.suspendRun({
242260
runFriendlyId: params.runFriendlyId,
243261
snapshotFriendlyId: params.snapshotFriendlyId,
244-
containerId: runnerId,
245-
runnerId,
262+
body: {
263+
runnerId,
264+
runId: params.runFriendlyId,
265+
snapshotId: params.snapshotFriendlyId,
266+
projectRef,
267+
deploymentVersion,
268+
},
246269
});
247270

248271
if (!suspendResult) {

packages/cli-v3/src/entryPoints/managed-run-controller.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,10 @@ class ManagedRunController {
149149

150150
this.httpClient = new WorkloadHttpClient({
151151
workerApiUrl: this.workerApiUrl,
152-
deploymentId: env.TRIGGER_DEPLOYMENT_ID,
153152
runnerId: env.TRIGGER_RUNNER_ID,
153+
deploymentId: env.TRIGGER_DEPLOYMENT_ID,
154+
deploymentVersion: env.TRIGGER_DEPLOYMENT_VERSION,
155+
projectRef: env.TRIGGER_PROJECT_REF,
154156
});
155157

156158
if (env.TRIGGER_WARM_START_URL) {

packages/core/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
"nanoid": "^3.3.4",
192192
"socket.io": "4.7.4",
193193
"socket.io-client": "4.7.5",
194+
"std-env": "^3.8.1",
194195
"superjson": "^2.2.1",
195196
"tinyexec": "^0.3.2",
196197
"zod": "3.23.8",

packages/core/src/utils.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ export function assertExhaustive(x: never): never {
22
throw new Error("Unexpected object: " + x);
33
}
44

5-
export async function tryCatch<T, E = Error>(promise: Promise<T>): Promise<[T, null] | [null, E]> {
5+
export async function tryCatch<T, E = Error>(promise: Promise<T>): Promise<[null, T] | [E, null]> {
66
try {
77
const data = await promise;
8-
return [data, null];
8+
return [null, data];
99
} catch (error) {
10-
return [null, error as E];
10+
return [error as E, null];
1111
}
1212
}

packages/core/src/v3/runEngineWorker/consts.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ export const WORKER_HEADERS = {
88
export const WORKLOAD_HEADERS = {
99
DEPLOYMENT_ID: "x-trigger-workload-deployment-id",
1010
RUNNER_ID: "x-trigger-workload-runner-id",
11+
DEPLOYMENT_VERSION: "x-trigger-workload-deployment-version",
12+
PROJECT_REF: "x-trigger-workload-project-ref",
1113
};
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
export type WorkloadClientCommonOptions = {
22
workerApiUrl: string;
3-
deploymentId: string;
43
runnerId: string;
4+
deploymentId: string;
5+
deploymentVersion: string;
6+
projectRef: string;
57
};

packages/core/src/v3/runEngineWorker/workload/util.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ export function getDefaultWorkloadHeaders(
88
return createHeaders({
99
[WORKLOAD_HEADERS.DEPLOYMENT_ID]: options.deploymentId,
1010
[WORKLOAD_HEADERS.RUNNER_ID]: options.runnerId,
11+
[WORKLOAD_HEADERS.DEPLOYMENT_VERSION]: options.deploymentVersion,
12+
[WORKLOAD_HEADERS.PROJECT_REF]: options.projectRef,
1113
});
1214
}

packages/core/src/v3/schemas/checkpoints.ts

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { CheckpointType } from "./runEngine.js";
1+
import { CheckpointType, DequeuedMessage } from "./runEngine.js";
22
import z from "zod";
33

44
const CallbackUrl = z
@@ -8,20 +8,12 @@ const CallbackUrl = z
88

99
export const CheckpointServiceSuspendRequestBody = z.object({
1010
type: CheckpointType,
11-
containerId: z.string(),
12-
simulate: z.boolean().optional(),
13-
leaveRunning: z.boolean().optional(),
11+
runId: z.string(),
12+
snapshotId: z.string(),
13+
runnerId: z.string(),
14+
projectRef: z.string(),
15+
deploymentVersion: z.string(),
1416
reason: z.string().optional(),
15-
callbacks: z
16-
.object({
17-
/** These headers will sent to all callbacks */
18-
headers: z.record(z.string()).optional(),
19-
/** This will be hit before suspending the container. Suspension will proceed unless we receive an error response. */
20-
preSuspend: CallbackUrl.optional(),
21-
/** This will be hit after suspending or failure to suspend the container */
22-
completion: CallbackUrl.optional(),
23-
})
24-
.optional(),
2517
});
2618

2719
export type CheckpointServiceSuspendRequestBody = z.infer<
@@ -39,16 +31,7 @@ export type CheckpointServiceSuspendResponseBody = z.infer<
3931
typeof CheckpointServiceSuspendResponseBody
4032
>;
4133

42-
export const CheckpointServiceRestoreRequestBody = z.discriminatedUnion("type", [
43-
z.object({
44-
type: z.literal(CheckpointType.Enum.DOCKER),
45-
containerId: z.string(),
46-
}),
47-
z.object({
48-
type: z.literal(CheckpointType.Enum.KUBERNETES),
49-
containerId: z.string(),
50-
}),
51-
]);
34+
export const CheckpointServiceRestoreRequestBody = DequeuedMessage.required({ checkpoint: true });
5235

5336
export type CheckpointServiceRestoreRequestBody = z.infer<
5437
typeof CheckpointServiceRestoreRequestBody

packages/core/src/v3/schemas/runEngine.ts

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -123,52 +123,6 @@ export const ExecutionResult = z.object({
123123

124124
export type ExecutionResult = z.infer<typeof ExecutionResult>;
125125

126-
/** This is sent to a Worker when a run is dequeued (a new run or continuing run) */
127-
export const DequeuedMessage = z.object({
128-
version: z.literal("1"),
129-
snapshot: ExecutionSnapshot,
130-
dequeuedAt: z.coerce.date(),
131-
image: z.string().optional(),
132-
checkpoint: z
133-
.object({
134-
id: z.string(),
135-
type: z.string(),
136-
location: z.string(),
137-
reason: z.string().nullish(),
138-
})
139-
.optional(),
140-
completedWaitpoints: z.array(CompletedWaitpoint),
141-
backgroundWorker: z.object({
142-
id: z.string(),
143-
friendlyId: z.string(),
144-
version: z.string(),
145-
}),
146-
deployment: z.object({
147-
id: z.string().optional(),
148-
friendlyId: z.string().optional(),
149-
}),
150-
run: z.object({
151-
id: z.string(),
152-
friendlyId: z.string(),
153-
isTest: z.boolean(),
154-
machine: MachinePreset,
155-
attemptNumber: z.number(),
156-
masterQueue: z.string(),
157-
traceContext: z.record(z.unknown()),
158-
}),
159-
environment: z.object({
160-
id: z.string(),
161-
type: EnvironmentType,
162-
}),
163-
organization: z.object({
164-
id: z.string(),
165-
}),
166-
project: z.object({
167-
id: z.string(),
168-
}),
169-
});
170-
export type DequeuedMessage = z.infer<typeof DequeuedMessage>;
171-
172126
/** The response to the Worker when starting an attempt */
173127
export const StartRunAttemptResult = ExecutionResult.and(
174128
z.object({
@@ -256,3 +210,51 @@ export const MachineResources = z.object({
256210
memory: z.number(),
257211
});
258212
export type MachineResources = z.infer<typeof MachineResources>;
213+
214+
export const DequeueMessageCheckpoint = z.object({
215+
id: z.string(),
216+
type: CheckpointType,
217+
location: z.string(),
218+
imageRef: z.string(),
219+
reason: z.string().nullish(),
220+
});
221+
export type DequeueMessageCheckpoint = z.infer<typeof DequeueMessageCheckpoint>;
222+
223+
/** This is sent to a Worker when a run is dequeued (a new run or continuing run) */
224+
export const DequeuedMessage = z.object({
225+
version: z.literal("1"),
226+
snapshot: ExecutionSnapshot,
227+
dequeuedAt: z.coerce.date(),
228+
image: z.string().optional(),
229+
checkpoint: DequeueMessageCheckpoint.optional(),
230+
completedWaitpoints: z.array(CompletedWaitpoint),
231+
backgroundWorker: z.object({
232+
id: z.string(),
233+
friendlyId: z.string(),
234+
version: z.string(),
235+
}),
236+
deployment: z.object({
237+
id: z.string().optional(),
238+
friendlyId: z.string().optional(),
239+
}),
240+
run: z.object({
241+
id: z.string(),
242+
friendlyId: z.string(),
243+
isTest: z.boolean(),
244+
machine: MachinePreset,
245+
attemptNumber: z.number(),
246+
masterQueue: z.string(),
247+
traceContext: z.record(z.unknown()),
248+
}),
249+
environment: z.object({
250+
id: z.string(),
251+
type: EnvironmentType,
252+
}),
253+
organization: z.object({
254+
id: z.string(),
255+
}),
256+
project: z.object({
257+
id: z.string(),
258+
}),
259+
});
260+
export type DequeuedMessage = z.infer<typeof DequeuedMessage>;

0 commit comments

Comments
 (0)