Skip to content

Commit 2f5fc17

Browse files
Merge pull request #166 from allenzhou101/request-auth-info
Include Authorization Info in Tool Calls (and request handlers generally)
2 parents dd254f9 + b30a77e commit 2f5fc17

File tree

6 files changed

+76
-23
lines changed

6 files changed

+76
-23
lines changed

src/inMemory.test.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { InMemoryTransport } from "./inMemory.js";
22
import { JSONRPCMessage } from "./types.js";
3+
import { AuthInfo } from "./server/auth/types.js";
34

45
describe("InMemoryTransport", () => {
56
let clientTransport: InMemoryTransport;
@@ -35,6 +36,32 @@ describe("InMemoryTransport", () => {
3536
expect(receivedMessage).toEqual(message);
3637
});
3738

39+
test("should send message with auth info from client to server", async () => {
40+
const message: JSONRPCMessage = {
41+
jsonrpc: "2.0",
42+
method: "test",
43+
id: 1,
44+
};
45+
46+
const authInfo: AuthInfo = {
47+
token: "test-token",
48+
clientId: "test-client",
49+
scopes: ["read", "write"],
50+
expiresAt: Date.now() / 1000 + 3600,
51+
};
52+
53+
let receivedMessage: JSONRPCMessage | undefined;
54+
let receivedAuthInfo: AuthInfo | undefined;
55+
serverTransport.onmessage = (msg, extra) => {
56+
receivedMessage = msg;
57+
receivedAuthInfo = extra?.authInfo;
58+
};
59+
60+
await clientTransport.send(message, { authInfo });
61+
expect(receivedMessage).toEqual(message);
62+
expect(receivedAuthInfo).toEqual(authInfo);
63+
});
64+
3865
test("should send message from server to client", async () => {
3966
const message: JSONRPCMessage = {
4067
jsonrpc: "2.0",

src/inMemory.ts

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import { Transport } from "./shared/transport.js";
2-
import { JSONRPCMessage } from "./types.js";
2+
import { JSONRPCMessage, RequestId } from "./types.js";
3+
import { AuthInfo } from "./server/auth/types.js";
4+
5+
interface QueuedMessage {
6+
message: JSONRPCMessage;
7+
extra?: { authInfo?: AuthInfo };
8+
}
39

410
/**
511
* In-memory transport for creating clients and servers that talk to each other within the same process.
612
*/
713
export class InMemoryTransport implements Transport {
814
private _otherTransport?: InMemoryTransport;
9-
private _messageQueue: JSONRPCMessage[] = [];
15+
private _messageQueue: QueuedMessage[] = [];
1016

1117
onclose?: () => void;
1218
onerror?: (error: Error) => void;
13-
onmessage?: (message: JSONRPCMessage) => void;
19+
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;
1420
sessionId?: string;
1521

1622
/**
@@ -27,10 +33,8 @@ export class InMemoryTransport implements Transport {
2733
async start(): Promise<void> {
2834
// Process any messages that were queued before start was called
2935
while (this._messageQueue.length > 0) {
30-
const message = this._messageQueue.shift();
31-
if (message) {
32-
this.onmessage?.(message);
33-
}
36+
const queuedMessage = this._messageQueue.shift()!;
37+
this.onmessage?.(queuedMessage.message, queuedMessage.extra);
3438
}
3539
}
3640

@@ -41,15 +45,19 @@ export class InMemoryTransport implements Transport {
4145
this.onclose?.();
4246
}
4347

44-
async send(message: JSONRPCMessage): Promise<void> {
48+
/**
49+
* Sends a message with optional auth info.
50+
* This is useful for testing authentication scenarios.
51+
*/
52+
async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId, authInfo?: AuthInfo }): Promise<void> {
4553
if (!this._otherTransport) {
4654
throw new Error("Not connected");
4755
}
4856

4957
if (this._otherTransport.onmessage) {
50-
this._otherTransport.onmessage(message);
58+
this._otherTransport.onmessage(message, { authInfo: options?.authInfo });
5159
} else {
52-
this._otherTransport._messageQueue.push(message);
60+
this._otherTransport._messageQueue.push({ message, extra: { authInfo: options?.authInfo } });
5361
}
5462
}
5563
}

src/server/auth/types.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,10 @@ export interface AuthInfo {
2121
* When the token expires (in seconds since epoch).
2222
*/
2323
expiresAt?: number;
24+
25+
/**
26+
* Additional data associated with the token.
27+
* This field should be used for any additional data that needs to be attached to the auth info.
28+
*/
29+
extra?: Record<string, unknown>;
2430
}

src/server/sse.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { Transport } from "../shared/transport.js";
44
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
55
import getRawBody from "raw-body";
66
import contentType from "content-type";
7+
import { AuthInfo } from "./auth/types.js";
78
import { URL } from 'url';
89

910
const MAXIMUM_MESSAGE_SIZE = "4mb";
@@ -19,7 +20,7 @@ export class SSEServerTransport implements Transport {
1920

2021
onclose?: () => void;
2122
onerror?: (error: Error) => void;
22-
onmessage?: (message: JSONRPCMessage) => void;
23+
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;
2324

2425
/**
2526
* Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`.
@@ -76,7 +77,7 @@ export class SSEServerTransport implements Transport {
7677
* This should be called when a POST request is made to send a message to the server.
7778
*/
7879
async handlePostMessage(
79-
req: IncomingMessage,
80+
req: IncomingMessage & { auth?: AuthInfo },
8081
res: ServerResponse,
8182
parsedBody?: unknown,
8283
): Promise<void> {
@@ -85,6 +86,7 @@ export class SSEServerTransport implements Transport {
8586
res.writeHead(500).end(message);
8687
throw new Error(message);
8788
}
89+
const authInfo: AuthInfo | undefined = req.auth;
8890

8991
let body: string | unknown;
9092
try {
@@ -104,7 +106,7 @@ export class SSEServerTransport implements Transport {
104106
}
105107

106108
try {
107-
await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body);
109+
await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { authInfo });
108110
} catch {
109111
res.writeHead(400).end(`Invalid message: ${body}`);
110112
return;
@@ -116,7 +118,7 @@ export class SSEServerTransport implements Transport {
116118
/**
117119
* Handle a client message, regardless of how it arrived. This can be used to inform the server of messages that arrive via a means different than HTTP POST.
118120
*/
119-
async handleMessage(message: unknown): Promise<void> {
121+
async handleMessage(message: unknown, extra?: { authInfo?: AuthInfo }): Promise<void> {
120122
let parsedMessage: JSONRPCMessage;
121123
try {
122124
parsedMessage = JSONRPCMessageSchema.parse(message);
@@ -125,7 +127,7 @@ export class SSEServerTransport implements Transport {
125127
throw error;
126128
}
127129

128-
this.onmessage?.(parsedMessage);
130+
this.onmessage?.(parsedMessage, extra);
129131
}
130132

131133
async close(): Promise<void> {

src/shared/protocol.ts

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {
2323
ServerCapabilities,
2424
} from "../types.js";
2525
import { Transport } from "./transport.js";
26+
import { AuthInfo } from "../server/auth/types.js";
2627

2728
/**
2829
* Callback for progress notifications.
@@ -109,6 +110,11 @@ export type RequestHandlerExtra<SendRequestT extends Request,
109110
*/
110111
signal: AbortSignal;
111112

113+
/**
114+
* Information about a validated access token, provided to request handlers.
115+
*/
116+
authInfo?: AuthInfo;
117+
112118
/**
113119
* The session ID from the transport, if available.
114120
*/
@@ -274,11 +280,11 @@ export abstract class Protocol<
274280
this._onerror(error);
275281
};
276282

277-
this._transport.onmessage = (message) => {
283+
this._transport.onmessage = (message, extra) => {
278284
if (isJSONRPCResponse(message) || isJSONRPCError(message)) {
279285
this._onresponse(message);
280286
} else if (isJSONRPCRequest(message)) {
281-
this._onrequest(message);
287+
this._onrequest(message, extra);
282288
} else if (isJSONRPCNotification(message)) {
283289
this._onnotification(message);
284290
} else {
@@ -326,7 +332,7 @@ export abstract class Protocol<
326332
);
327333
}
328334

329-
private _onrequest(request: JSONRPCRequest): void {
335+
private _onrequest(request: JSONRPCRequest, extra?: { authInfo?: AuthInfo }): void {
330336
const handler =
331337
this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler;
332338

@@ -351,20 +357,20 @@ export abstract class Protocol<
351357
const abortController = new AbortController();
352358
this._requestHandlerAbortControllers.set(request.id, abortController);
353359

354-
// Create extra object with both abort signal and sessionId from transport
355-
const extra: RequestHandlerExtra<SendRequestT, SendNotificationT> = {
360+
const fullExtra: RequestHandlerExtra<SendRequestT, SendNotificationT> = {
356361
signal: abortController.signal,
357362
sessionId: this._transport?.sessionId,
358363
sendNotification:
359364
(notification) =>
360365
this.notification(notification, { relatedRequestId: request.id }),
361366
sendRequest: (r, resultSchema, options?) =>
362-
this.request(r, resultSchema, { ...options, relatedRequestId: request.id })
367+
this.request(r, resultSchema, { ...options, relatedRequestId: request.id }),
368+
authInfo: extra?.authInfo,
363369
};
364370

365371
// Starting with Promise.resolve() puts any synchronous errors into the monad as well.
366372
Promise.resolve()
367-
.then(() => handler(request, extra))
373+
.then(() => handler(request, fullExtra))
368374
.then(
369375
(result) => {
370376
if (abortController.signal.aborted) {

src/shared/transport.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { AuthInfo } from "../server/auth/types.js";
12
import { JSONRPCMessage, RequestId } from "../types.js";
23

34
/**
@@ -41,8 +42,11 @@ export interface Transport {
4142

4243
/**
4344
* Callback for when a message (request or response) is received over the connection.
45+
*
46+
* Includes the authInfo if the transport is authenticated.
47+
*
4448
*/
45-
onmessage?: (message: JSONRPCMessage) => void;
49+
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;
4650

4751
/**
4852
* The session ID generated for this connection.

0 commit comments

Comments
 (0)