Skip to content

Commit 3d1aaa6

Browse files
committed
Refactoring OAuthErrors
This makes it possible to parse them from JSON, using OAUTH_ERRORS Invalidating credentials & retrying when server OAuth errors occur Updated existing tests Added some initial test coverage refactored to avoid recursion as recommended
1 parent 048bc4f commit 3d1aaa6

File tree

6 files changed

+500
-88
lines changed

6 files changed

+500
-88
lines changed

src/client/auth.test.ts

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
refreshAuthorization,
66
registerClient,
77
} from "./auth.js";
8+
import {ServerError} from "../server/auth/errors.js";
89

910
// Mock fetch globally
1011
const mockFetch = jest.fn();
@@ -112,25 +113,23 @@ describe("OAuth Authorization", () => {
112113
});
113114

114115
it("throws on non-404 errors", async () => {
115-
mockFetch.mockResolvedValueOnce({
116-
ok: false,
117-
status: 500,
118-
});
116+
mockFetch.mockResolvedValueOnce(new Response(null, { status: 500 }));
119117

120118
await expect(
121119
discoverOAuthMetadata("https://auth.example.com")
122120
).rejects.toThrow("HTTP 500");
123121
});
124122

125123
it("validates metadata schema", async () => {
126-
mockFetch.mockResolvedValueOnce({
127-
ok: true,
128-
status: 200,
129-
json: async () => ({
130-
// Missing required fields
131-
issuer: "https://auth.example.com",
132-
}),
133-
});
124+
mockFetch.mockResolvedValueOnce(
125+
Response.json(
126+
{
127+
// Missing required fields
128+
issuer: "https://auth.example.com",
129+
},
130+
{ status: 200 }
131+
)
132+
);
134133

135134
await expect(
136135
discoverOAuthMetadata("https://auth.example.com")
@@ -321,10 +320,12 @@ describe("OAuth Authorization", () => {
321320
});
322321

323322
it("throws on error response", async () => {
324-
mockFetch.mockResolvedValueOnce({
325-
ok: false,
326-
status: 400,
327-
});
323+
mockFetch.mockResolvedValueOnce(
324+
Response.json(
325+
new ServerError("Token exchange failed").toResponseObject(),
326+
{ status: 400 }
327+
)
328+
);
328329

329330
await expect(
330331
exchangeAuthorization("https://auth.example.com", {
@@ -403,10 +404,12 @@ describe("OAuth Authorization", () => {
403404
});
404405

405406
it("throws on error response", async () => {
406-
mockFetch.mockResolvedValueOnce({
407-
ok: false,
408-
status: 400,
409-
});
407+
mockFetch.mockResolvedValueOnce(
408+
Response.json(
409+
new ServerError("Token refresh failed").toResponseObject(),
410+
{ status: 400 }
411+
)
412+
);
410413

411414
await expect(
412415
refreshAuthorization("https://auth.example.com", {
@@ -491,10 +494,12 @@ describe("OAuth Authorization", () => {
491494
});
492495

493496
it("throws on error response", async () => {
494-
mockFetch.mockResolvedValueOnce({
495-
ok: false,
496-
status: 400,
497-
});
497+
mockFetch.mockResolvedValueOnce(
498+
Response.json(
499+
new ServerError("Dynamic client registration failed").toResponseObject(),
500+
{ status: 400 }
501+
)
502+
);
498503

499504
await expect(
500505
registerClient("https://auth.example.com", {

src/client/auth.ts

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,26 @@
11
import pkceChallenge from "pkce-challenge";
2-
import { LATEST_PROTOCOL_VERSION } from "../types.js";
3-
import type { OAuthClientMetadata, OAuthClientInformation, OAuthTokens, OAuthMetadata, OAuthClientInformationFull } from "../shared/auth.js";
4-
import { OAuthClientInformationFullSchema, OAuthMetadataSchema, OAuthTokensSchema } from "../shared/auth.js";
2+
import {LATEST_PROTOCOL_VERSION} from "../types.js";
3+
import type {
4+
OAuthClientInformation,
5+
OAuthClientInformationFull,
6+
OAuthClientMetadata,
7+
OAuthMetadata,
8+
OAuthTokens
9+
} from "../shared/auth.js";
10+
import {
11+
OAuthClientInformationFullSchema,
12+
OAuthErrorResponseSchema,
13+
OAuthMetadataSchema,
14+
OAuthTokensSchema
15+
} from "../shared/auth.js";
16+
import {
17+
InvalidClientError,
18+
InvalidGrantError,
19+
OAUTH_ERRORS,
20+
OAuthError,
21+
ServerError,
22+
UnauthorizedClientError
23+
} from "../server/auth/errors.js";
524

625
/**
726
* Implements an end-to-end OAuth client to be used with one MCP server.
@@ -66,6 +85,13 @@ export interface OAuthClientProvider {
6685
* the authorization result.
6786
*/
6887
codeVerifier(): string | Promise<string>;
88+
89+
/**
90+
* If implemented, provides a way for the client to invalidate (e.g. delete) the specified
91+
* credentials, in the case where the server has indicated that they are no longer valid.
92+
* This avoids requiring the user to intervene manually.
93+
*/
94+
invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise<void>;
6995
}
7096

7197
export type AuthResult = "AUTHORIZED" | "REDIRECT";
@@ -76,6 +102,33 @@ export class UnauthorizedError extends Error {
76102
}
77103
}
78104

105+
/**
106+
* Parses an OAuth error response from a string or Response object.
107+
*
108+
* If the input is a standard OAuth2.0 error response, it will be parsed according to the spec
109+
* and an instance of the appropriate OAuthError subclass will be returned.
110+
* If parsing fails, it falls back to a generic ServerError that includes
111+
* the response status (if available) and original content.
112+
*
113+
* @param input - A Response object or string containing the error response
114+
* @returns A Promise that resolves to an OAuthError instance
115+
*/
116+
export async function parseErrorResponse(input: Response | string): Promise<OAuthError> {
117+
const statusCode = input instanceof Response ? input.status : undefined;
118+
const body = input instanceof Response ? await input.text() : input;
119+
120+
try {
121+
const result = OAuthErrorResponseSchema.parse(JSON.parse(body));
122+
const { error, error_description, error_uri } = result;
123+
const errorClass = OAUTH_ERRORS[error] || ServerError;
124+
return new errorClass(error_description || '', error_uri);
125+
} catch (error) {
126+
// Not a valid OAuth error response, but try to inform the user of the raw data anyway
127+
const errorMessage = `${statusCode ? `HTTP ${statusCode}: ` : ''}Invalid OAuth error response: ${error}. Raw body: ${body}`;
128+
return new ServerError(errorMessage);
129+
}
130+
}
131+
79132
/**
80133
* Orchestrates the full auth flow with a server.
81134
*
@@ -84,7 +137,30 @@ export class UnauthorizedError extends Error {
84137
*/
85138
export async function auth(
86139
provider: OAuthClientProvider,
87-
{ serverUrl, authorizationCode }: { serverUrl: string | URL, authorizationCode?: string }): Promise<AuthResult> {
140+
options: { serverUrl: string | URL, authorizationCode?: string },
141+
): Promise<AuthResult> {
142+
try {
143+
return await authInternal(provider, options);
144+
} catch (error) {
145+
// Handle recoverable error types by invalidating credentials and retrying
146+
if (error instanceof InvalidClientError || error instanceof UnauthorizedClientError) {
147+
await provider.invalidateCredentials?.('all');
148+
return await authInternal(provider, options);
149+
} else if (error instanceof InvalidGrantError) {
150+
await provider.invalidateCredentials?.('tokens');
151+
return await authInternal(provider, options);
152+
}
153+
154+
// Throw otherwise
155+
throw error
156+
}
157+
}
158+
159+
async function authInternal(
160+
provider: OAuthClientProvider,
161+
options: { serverUrl: string | URL, authorizationCode?: string },
162+
): Promise<AuthResult> {
163+
const { serverUrl, authorizationCode } = options;
88164
const metadata = await discoverOAuthMetadata(serverUrl);
89165

90166
// Handle client registration if needed
@@ -119,7 +195,7 @@ export async function auth(
119195
});
120196

121197
await provider.saveTokens(tokens);
122-
return "AUTHORIZED";
198+
return "AUTHORIZED"
123199
}
124200

125201
const tokens = await provider.tokens();
@@ -135,14 +211,20 @@ export async function auth(
135211
});
136212

137213
await provider.saveTokens(newTokens);
138-
return "AUTHORIZED";
214+
return "AUTHORIZED"
139215
} catch (error) {
140-
console.error("Could not refresh OAuth tokens:", error);
216+
// If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry.
217+
if (!(error instanceof OAuthError) || error instanceof ServerError) {
218+
console.error("Could not refresh OAuth tokens:", error);
219+
} else {
220+
console.warn(`OAuth token refresh failed: ${JSON.stringify(error.toResponseObject())}`);
221+
throw error;
222+
}
141223
}
142224
}
143225

144226
// Start new authorization flow
145-
const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, {
227+
const {authorizationUrl, codeVerifier} = await startAuthorization(serverUrl, {
146228
metadata,
147229
clientInformation,
148230
redirectUrl: provider.redirectUrl,
@@ -151,7 +233,7 @@ export async function auth(
151233

152234
await provider.saveCodeVerifier(codeVerifier);
153235
await provider.redirectToAuthorization(authorizationUrl);
154-
return "REDIRECT";
236+
return "REDIRECT"
155237
}
156238

157239
/**
@@ -316,7 +398,7 @@ export async function exchangeAuthorization(
316398
});
317399

318400
if (!response.ok) {
319-
throw new Error(`Token exchange failed: HTTP ${response.status}`);
401+
throw await parseErrorResponse(response);
320402
}
321403

322404
return OAuthTokensSchema.parse(await response.json());
@@ -375,7 +457,7 @@ export async function refreshAuthorization(
375457
});
376458

377459
if (!response.ok) {
378-
throw new Error(`Token refresh failed: HTTP ${response.status}`);
460+
throw await parseErrorResponse(response);
379461
}
380462

381463
return OAuthTokensSchema.parse(await response.json());
@@ -415,7 +497,7 @@ export async function registerClient(
415497
});
416498

417499
if (!response.ok) {
418-
throw new Error(`Dynamic client registration failed: HTTP ${response.status}`);
500+
throw await parseErrorResponse(response);
419501
}
420502

421503
return OAuthClientInformationFullSchema.parse(await response.json());

0 commit comments

Comments
 (0)