Skip to content

Commit ce6e01e

Browse files
committed
Add support for subscribing to run tags
1 parent d2c69c7 commit ce6e01e

File tree

23 files changed

+532
-154
lines changed

23 files changed

+532
-154
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import { type ActionFunctionArgs, json } from "@remix-run/server-runtime";
2+
import { z } from "zod";
3+
import { $replica } from "~/db.server";
4+
import { permittedToReadRun, permittedToReadRunTag } from "~/services/accessControl.server";
5+
import { authenticateApiRequest } from "~/services/apiAuth.server";
6+
import { logger } from "~/services/logger.server";
7+
import { realtimeClient } from "~/services/realtimeClientGlobal.server";
8+
import { makeApiCors } from "~/utils/apiCors";
9+
10+
const ParamsSchema = z.object({
11+
tagName: z.string(),
12+
});
13+
14+
export async function loader({ request, params }: ActionFunctionArgs) {
15+
const apiCors = makeApiCors(request);
16+
17+
if (request.method.toUpperCase() === "OPTIONS") {
18+
return apiCors(json({}));
19+
}
20+
21+
// Authenticate the request
22+
const authenticationResult = await authenticateApiRequest(request, { allowJWT: true });
23+
24+
if (!authenticationResult) {
25+
return apiCors(json({ error: "Invalid or Missing API Key" }, { status: 401 }));
26+
}
27+
28+
const parsedParams = ParamsSchema.safeParse(params);
29+
30+
if (!parsedParams.success) {
31+
return apiCors(
32+
json(
33+
{ error: "Invalid request parameters", issues: parsedParams.error.issues },
34+
{ status: 400 }
35+
)
36+
);
37+
}
38+
39+
if (!permittedToReadRunTag(authenticationResult, parsedParams.data.tagName)) {
40+
return apiCors(json({ error: "Unauthorized" }, { status: 403 }));
41+
}
42+
43+
try {
44+
return realtimeClient.streamRunsWhere(
45+
request.url,
46+
authenticationResult.environment,
47+
`"runTags" @> ARRAY['${parsedParams.data.tagName}']`,
48+
apiCors
49+
);
50+
} catch (error) {
51+
if (error instanceof Response) {
52+
// Error responses from longPollingFetch
53+
return apiCors(error);
54+
} else if (error instanceof TypeError) {
55+
// Unexpected errors
56+
logger.error("Unexpected error in loader:", { error: error.message });
57+
return apiCors(new Response("An unexpected error occurred", { status: 500 }));
58+
} else {
59+
// Unknown errors
60+
logger.error("Unknown error occurred in loader, not Error", { error: JSON.stringify(error) });
61+
return apiCors(new Response("An unknown error occurred", { status: 500 }));
62+
}
63+
}
64+
}

apps/webapp/app/services/accessControl.server.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,39 @@ export function permittedToReadRun(
4242
return false;
4343
}
4444

45+
export function permittedToReadRunTag(
46+
authenticationResult: ApiAuthenticationResult,
47+
tagName: string
48+
): boolean {
49+
if (authenticationResult.type === "PRIVATE") {
50+
return true;
51+
}
52+
53+
if (authenticationResult.type === "PUBLIC") {
54+
return true;
55+
}
56+
57+
if (!authenticationResult.claims) {
58+
return false;
59+
}
60+
61+
const parsedClaims = ClaimsSchema.safeParse(authenticationResult.claims);
62+
63+
if (!parsedClaims.success) {
64+
return false;
65+
}
66+
67+
if (parsedClaims.data.permissions?.includes("read:runs")) {
68+
return true;
69+
}
70+
71+
if (parsedClaims.data.permissions?.includes(`read:tags:${tagName}`)) {
72+
return true;
73+
}
74+
75+
return false;
76+
}
77+
4578
export function permittedToReadBatch(
4679
authenticationResult: ApiAuthenticationResult,
4780
batchId: string

apps/webapp/test/realtimeClient.test.ts

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,82 @@ describe("RealtimeClient", () => {
134134
expect(liveResponse2.status).toBe(429);
135135
}
136136
);
137+
138+
containerWithElectricTest(
139+
"Should support subscribing to a run tag",
140+
{ timeout: 30_000 },
141+
async ({ redis, electricOrigin, prisma }) => {
142+
const client = new RealtimeClient({
143+
electricOrigin,
144+
keyPrefix: "test:realtime",
145+
redis: redis.options,
146+
expiryTimeInSeconds: 5,
147+
cachedLimitProvider: {
148+
async getCachedLimit() {
149+
return 1;
150+
},
151+
},
152+
});
153+
154+
const organization = await prisma.organization.create({
155+
data: {
156+
title: "test-org",
157+
slug: "test-org",
158+
},
159+
});
160+
161+
const project = await prisma.project.create({
162+
data: {
163+
name: "test-project",
164+
slug: "test-project",
165+
organizationId: organization.id,
166+
externalRef: "test-project",
167+
},
168+
});
169+
170+
const environment = await prisma.runtimeEnvironment.create({
171+
data: {
172+
projectId: project.id,
173+
organizationId: organization.id,
174+
slug: "test",
175+
type: "DEVELOPMENT",
176+
shortcode: "1234",
177+
apiKey: "tr_dev_1234",
178+
pkApiKey: "pk_test_1234",
179+
},
180+
});
181+
182+
const run = await prisma.taskRun.create({
183+
data: {
184+
taskIdentifier: "test-task",
185+
friendlyId: "run_1234",
186+
payload: "{}",
187+
payloadType: "application/json",
188+
traceId: "trace_1234",
189+
spanId: "span_1234",
190+
queue: "test-queue",
191+
projectId: project.id,
192+
runtimeEnvironmentId: environment.id,
193+
runTags: ["test:tag:1234", "test:tag:5678"],
194+
},
195+
});
196+
197+
const response = await client.streamRunsWhere(
198+
"http://localhost:3000?offset=-1",
199+
environment,
200+
`"runTags" @> ARRAY['test:tag:1234']`
201+
);
202+
203+
const responseBody = await response.json();
204+
205+
const headers = Object.fromEntries(response.headers.entries());
206+
207+
const shapeId = headers["electric-shape-id"];
208+
const chunkOffset = headers["electric-chunk-last-offset"];
209+
210+
expect(response.status).toBe(200);
211+
expect(shapeId).toBeDefined();
212+
expect(chunkOffset).toBe("0_0");
213+
}
214+
);
137215
});

docker/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ services:
6060
- 6379:6379
6161

6262
electric:
63-
image: electricsql/electric:0.7.3
63+
image: electricsql/electric:0.7.5
6464
restart: always
6565
environment:
6666
DATABASE_URL: postgresql://postgres:postgres@database:5432/postgres?sslmode=disable

internal-packages/testcontainers/src/utils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ export async function createElectricContainer(
5555
network.getName()
5656
)}:5432/${postgresContainer.getDatabase()}?sslmode=disable`;
5757

58-
const container = await new GenericContainer("electricsql/electric:0.7.3")
58+
const container = await new GenericContainer("electricsql/electric:0.7.5")
5959
.withExposedPorts(3000)
6060
.withNetwork(network)
6161
.withEnvironment({

packages/core/src/v3/apiClient/index.ts

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ import {
5454
UpdateEnvironmentVariableParams,
5555
} from "./types.js";
5656
import { generateJWT } from "../jwt.js";
57-
import { TriggerJwtOptions } from "../types/tasks.js";
57+
import { AnyRunTypes, TriggerJwtOptions } from "../types/tasks.js";
5858

5959
export type {
6060
CreateEnvironmentVariableParams,
@@ -590,15 +590,22 @@ export class ApiClient {
590590
);
591591
}
592592

593-
subscribeToRunChanges<TPayload = any, TOutput = any>(runId: string) {
594-
return runShapeStream<TPayload, TOutput>(`${this.baseUrl}/realtime/v1/runs/${runId}`, {
593+
subscribeToRunChanges<TRunTypes extends AnyRunTypes>(runId: string) {
594+
return runShapeStream<TRunTypes>(`${this.baseUrl}/realtime/v1/runs/${runId}`, {
595595
closeOnComplete: true,
596596
headers: this.#getRealtimeHeaders(),
597597
});
598598
}
599599

600-
subscribeToBatchChanges<TPayload = any, TOutput = any>(batchId: string) {
601-
return runShapeStream<TPayload, TOutput>(`${this.baseUrl}/realtime/v1/batches/${batchId}`, {
600+
subscribeToRunTag<TRunTypes extends AnyRunTypes>(tag: string) {
601+
return runShapeStream<TRunTypes>(`${this.baseUrl}/realtime/v1/tags/${tag}`, {
602+
closeOnComplete: false,
603+
headers: this.#getRealtimeHeaders(),
604+
});
605+
}
606+
607+
subscribeToBatchChanges<TRunTypes extends AnyRunTypes>(batchId: string) {
608+
return runShapeStream<TRunTypes>(`${this.baseUrl}/realtime/v1/batches/${batchId}`, {
602609
closeOnComplete: false,
603610
headers: this.#getRealtimeHeaders(),
604611
});

packages/core/src/v3/apiClient/runStream.ts

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,47 @@
11
import { DeserializedJson } from "../../schemas/json.js";
22
import { RunStatus, SubscribeRunRawShape } from "../schemas/api.js";
33
import { SerializedError } from "../schemas/common.js";
4-
import { AnyTask, TaskOutput, TaskPayload } from "../types/tasks.js";
4+
import { AnyRunTypes, AnyTask, InferRunTypes } from "../types/tasks.js";
55
import {
66
conditionallyImportAndParsePacket,
77
IOPacket,
88
parsePacket,
99
} from "../utils/ioSerialization.js";
1010
import { AsyncIterableStream, createAsyncIterableStream, zodShapeStream } from "./stream.js";
1111

12-
export type RunShape<TPayload = any, TOutput = any> = {
13-
id: string;
14-
createdAt: Date;
15-
updatedAt: Date;
16-
taskIdentifier: string;
17-
number: number;
18-
status: RunStatus;
19-
durationMs: number;
20-
costInCents: number;
21-
baseCostInCents: number;
22-
payload: TPayload;
23-
tags: string[];
24-
idempotencyKey?: string;
25-
expiredAt?: Date;
26-
ttl?: string;
27-
finishedAt?: Date;
28-
startedAt?: Date;
29-
delayedUntil?: Date;
30-
queuedAt?: Date;
31-
metadata?: Record<string, DeserializedJson>;
32-
error?: SerializedError;
33-
output?: TOutput;
34-
isTest: boolean;
35-
};
12+
export type RunShape<TRunTypes extends AnyRunTypes> = TRunTypes extends AnyRunTypes
13+
? {
14+
id: string;
15+
taskIdentifier: TRunTypes["taskIdentifier"];
16+
payload: TRunTypes["payload"];
17+
output?: TRunTypes["output"];
18+
createdAt: Date;
19+
updatedAt: Date;
20+
number: number;
21+
status: RunStatus;
22+
durationMs: number;
23+
costInCents: number;
24+
baseCostInCents: number;
25+
tags: string[];
26+
idempotencyKey?: string;
27+
expiredAt?: Date;
28+
ttl?: string;
29+
finishedAt?: Date;
30+
startedAt?: Date;
31+
delayedUntil?: Date;
32+
queuedAt?: Date;
33+
metadata?: Record<string, DeserializedJson>;
34+
error?: SerializedError;
35+
isTest: boolean;
36+
}
37+
: never;
3638

37-
export type AnyRunShape = RunShape<any, any>;
39+
export type AnyRunShape = RunShape<AnyRunTypes>;
3840

39-
export type TaskRunShape<TTask extends AnyTask> = RunShape<TaskPayload<TTask>, TaskOutput<TTask>>;
41+
export type TaskRunShape<TTask extends AnyTask> = RunShape<InferRunTypes<TTask>>;
4042

41-
export type RunStreamCallback<TPayload = any, TOutput = any> = (
42-
run: RunShape<TPayload, TOutput>
43+
export type RunStreamCallback<TRunTypes extends AnyRunTypes> = (
44+
run: RunShape<TRunTypes>
4345
) => void | Promise<void>;
4446

4547
export type RunShapeStreamOptions = {
@@ -48,19 +50,19 @@ export type RunShapeStreamOptions = {
4850
closeOnComplete?: boolean;
4951
};
5052

51-
export function runShapeStream<TPayload = any, TOutput = any>(
53+
export function runShapeStream<TRunTypes extends AnyRunTypes>(
5254
url: string,
5355
options?: RunShapeStreamOptions
54-
): RunSubscription<TPayload, TOutput> {
55-
const subscription = new RunSubscription<TPayload, TOutput>(url, options);
56+
): RunSubscription<TRunTypes> {
57+
const subscription = new RunSubscription<TRunTypes>(url, options);
5658

5759
return subscription.init();
5860
}
5961

60-
export class RunSubscription<TPayload = any, TOutput = any> {
62+
export class RunSubscription<TRunTypes extends AnyRunTypes> {
6163
private abortController: AbortController;
6264
private unsubscribeShape: () => void;
63-
private stream: AsyncIterableStream<RunShape<TPayload, TOutput>>;
65+
private stream: AsyncIterableStream<RunShape<TRunTypes>>;
6466
private packetCache = new Map<string, any>();
6567

6668
constructor(
@@ -117,15 +119,15 @@ export class RunSubscription<TPayload = any, TOutput = any> {
117119
this.unsubscribeShape?.();
118120
}
119121

120-
[Symbol.asyncIterator](): AsyncIterator<RunShape<TPayload, TOutput>> {
122+
[Symbol.asyncIterator](): AsyncIterator<RunShape<TRunTypes>> {
121123
return this.stream[Symbol.asyncIterator]();
122124
}
123125

124-
getReader(): ReadableStreamDefaultReader<RunShape<TPayload, TOutput>> {
126+
getReader(): ReadableStreamDefaultReader<RunShape<TRunTypes>> {
125127
return this.stream.getReader();
126128
}
127129

128-
private async transformRunShape(row: SubscribeRunRawShape): Promise<RunShape> {
130+
private async transformRunShape(row: SubscribeRunRawShape): Promise<RunShape<TRunTypes>> {
129131
const payloadPacket = row.payloadType
130132
? ({ data: row.payload ?? undefined, dataType: row.payloadType } satisfies IOPacket)
131133
: undefined;
@@ -183,7 +185,7 @@ export class RunSubscription<TPayload = any, TOutput = any> {
183185
error: row.error ?? undefined,
184186
isTest: row.isTest,
185187
metadata,
186-
};
188+
} as RunShape<TRunTypes>;
187189
}
188190
}
189191

0 commit comments

Comments
 (0)