Skip to content

refactor: rename state to session, combine tool registration, and clearer dependencies #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions eslint.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export default defineConfig([
files,
rules: {
"@typescript-eslint/switch-exhaustiveness-check": "error",
"@typescript-eslint/no-non-null-assertion": "error",
},
},
// Ignore features specific to TypeScript resolved rules
Expand Down
19 changes: 10 additions & 9 deletions src/common/atlas/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,31 @@ export class ApiClient {
return this.accessToken?.token.access_token as string | undefined;
};

private authMiddleware = (apiClient: ApiClient): Middleware => ({
async onRequest({ request, schemaPath }) {
private authMiddleware: Middleware = {
onRequest: async ({ request, schemaPath }) => {
if (schemaPath.startsWith("/api/private/unauth") || schemaPath.startsWith("/api/oauth")) {
return undefined;
}

try {
const accessToken = await apiClient.getAccessToken();
const accessToken = await this.getAccessToken();
request.headers.set("Authorization", `Bearer ${accessToken}`);
return request;
} catch {
// ignore not availble tokens, API will return 401
}
},
});
private errorMiddleware = (): Middleware => ({
};

private readonly errorMiddleware: Middleware = {
async onResponse({ response }) {
if (!response.ok) {
throw await ApiClientError.fromResponse(response);
}
},
});
};

constructor(options?: ApiClientOptions) {
constructor(options: ApiClientOptions) {
const defaultOptions = {
baseUrl: "https://cloud.mongodb.com/",
userAgent: `AtlasMCP/${config.version} (${process.platform}; ${process.arch}; ${process.env.HOSTNAME || "unknown"})`,
Expand Down Expand Up @@ -107,9 +108,9 @@ export class ApiClient {
tokenPath: "/api/oauth/token",
},
});
this.client.use(this.authMiddleware(this));
this.client.use(this.authMiddleware);
}
this.client.use(this.errorMiddleware());
this.client.use(this.errorMiddleware);
}

public async getIpInfo(): Promise<{
Expand Down
5 changes: 4 additions & 1 deletion src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ function getLogPath(): string {
// to SNAKE_UPPER_CASE.
function getEnvConfig(): Partial<UserConfig> {
function setValue(obj: Record<string, unknown>, path: string[], value: string): void {
const currentField = path.shift()!;
const currentField = path.shift();
if (!currentField) {
return;
}
if (path.length === 0) {
const numberValue = Number(value);
if (!isNaN(numberValue)) {
Expand Down
27 changes: 19 additions & 8 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
#!/usr/bin/env node

import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
import { Server } from "./server.js";
import logger from "./logger.js";
import { mongoLogId } from "mongodb-log-writer";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import config from "./config.js";
import { Session } from "./session.js";
import { Server } from "./server.js";

export async function runServer() {
const server = new Server();
try {
const session = new Session();
const mcpServer = new McpServer({
name: "MongoDB Atlas",
version: config.version,
});

const server = new Server({
mcpServer,
session,
});

const transport = new StdioServerTransport();
await server.connect(transport);
}

runServer().catch((error) => {
logger.emergency(mongoLogId(1_000_004), "server", `Fatal error running server: ${error}`);
await server.connect(transport);
} catch (error: unknown) {
logger.emergency(mongoLogId(1_000_004), "server", `Fatal error running server: ${error as string}`);

process.exit(1);
});
}
41 changes: 21 additions & 20 deletions src/server.ts
Original file line number Diff line number Diff line change
@@ -1,39 +1,40 @@
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import defaultState, { State } from "./state.js";
import { Session } from "./session.js";
import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js";
import { registerAtlasTools } from "./tools/atlas/tools.js";
import { registerMongoDBTools } from "./tools/mongodb/index.js";
import config from "./config.js";
import { AtlasTools } from "./tools/atlas/tools.js";
import { MongoDbTools } from "./tools/mongodb/tools.js";
import logger, { initializeLogger } from "./logger.js";
import { mongoLogId } from "mongodb-log-writer";

export class Server {
state: State = defaultState;
private server?: McpServer;
public readonly session: Session;
private readonly mcpServer: McpServer;

constructor({ mcpServer, session }: { mcpServer: McpServer; session: Session }) {
this.mcpServer = mcpServer;
this.session = session;
}

async connect(transport: Transport) {
this.server = new McpServer({
name: "MongoDB Atlas",
version: config.version,
});
this.mcpServer.server.registerCapabilities({ logging: {} });

this.server.server.registerCapabilities({ logging: {} });
this.registerTools();

registerAtlasTools(this.server, this.state);
registerMongoDBTools(this.server, this.state);
await initializeLogger(this.mcpServer);

await initializeLogger(this.server);
await this.server.connect(transport);
await this.mcpServer.connect(transport);

logger.info(mongoLogId(1_000_004), "server", `Server started with transport ${transport.constructor.name}`);
}

async close(): Promise<void> {
try {
await this.state.serviceProvider?.close(true);
} catch {
// Ignore errors during service provider close
await this.session.close();
await this.mcpServer.close();
}

private registerTools() {
for (const tool of [...AtlasTools, ...MongoDbTools]) {
new tool(this.session).register(this.mcpServer);
}
await this.server?.close();
}
}
18 changes: 13 additions & 5 deletions src/state.ts → src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver
import { ApiClient } from "./common/atlas/apiClient.js";
import config from "./config.js";

export class State {
export class Session {
serviceProvider?: NodeDriverServiceProvider;
apiClient?: ApiClient;

ensureApiClient(): asserts this is { apiClient: ApiClient } {
ensureAuthenticated(): asserts this is { apiClient: ApiClient } {
if (!this.apiClient) {
if (!config.apiClientId || !config.apiClientSecret) {
throw new Error(
Expand All @@ -23,7 +23,15 @@ export class State {
});
}
}
}

const defaultState = new State();
export default defaultState;
async close(): Promise<void> {
if (this.serviceProvider) {
try {
await this.serviceProvider.close(true);
} catch (error) {
console.error("Error closing service provider:", error);
}
this.serviceProvider = undefined;
}
}
}
6 changes: 3 additions & 3 deletions src/tools/atlas/atlasTool.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { ToolBase } from "../tool.js";
import { State } from "../../state.js";
import { Session } from "../../session.js";

export abstract class AtlasToolBase extends ToolBase {
constructor(state: State) {
super(state);
constructor(protected readonly session: Session) {
super(session);
}
}
6 changes: 3 additions & 3 deletions src/tools/atlas/createAccessList.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export class CreateAccessListTool extends AtlasToolBase {
comment,
currentIpAddress,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
this.state.ensureApiClient();
this.session.ensureAuthenticated();

if (!ipAddresses?.length && !cidrBlocks?.length && !currentIpAddress) {
throw new Error("One of ipAddresses, cidrBlocks, currentIpAddress must be provided.");
Expand All @@ -39,7 +39,7 @@ export class CreateAccessListTool extends AtlasToolBase {
}));

if (currentIpAddress) {
const currentIp = await this.state.apiClient.getIpInfo();
const currentIp = await this.session.apiClient.getIpInfo();
const input = {
groupId: projectId,
ipAddress: currentIp.currentIpv4Address,
Expand All @@ -56,7 +56,7 @@ export class CreateAccessListTool extends AtlasToolBase {

const inputs = [...ipInputs, ...cidrInputs];

await this.state.apiClient.createProjectIpAccessList({
await this.session.apiClient.createProjectIpAccessList({
params: {
path: {
groupId: projectId,
Expand Down
4 changes: 2 additions & 2 deletions src/tools/atlas/createDBUser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export class CreateDBUserTool extends AtlasToolBase {
roles,
clusters,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
this.state.ensureApiClient();
this.session.ensureAuthenticated();

const input = {
groupId: projectId,
Expand All @@ -53,7 +53,7 @@ export class CreateDBUserTool extends AtlasToolBase {
: undefined,
} as CloudDatabaseUser;

await this.state.apiClient.createDatabaseUser({
await this.session.apiClient.createDatabaseUser({
params: {
path: {
groupId: projectId,
Expand Down
4 changes: 2 additions & 2 deletions src/tools/atlas/createFreeCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export class CreateFreeClusterTool extends AtlasToolBase {
};

protected async execute({ projectId, name, region }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
this.state.ensureApiClient();
this.session.ensureAuthenticated();

const input = {
groupId: projectId,
Expand All @@ -38,7 +38,7 @@ export class CreateFreeClusterTool extends AtlasToolBase {
terminationProtectionEnabled: false,
} as unknown as ClusterDescription20240805;

await this.state.apiClient.createCluster({
await this.session.apiClient.createCluster({
params: {
path: {
groupId: projectId,
Expand Down
4 changes: 2 additions & 2 deletions src/tools/atlas/inspectAccessList.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ export class InspectAccessListTool extends AtlasToolBase {
};

protected async execute({ projectId }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
this.state.ensureApiClient();
this.session.ensureAuthenticated();

const accessList = await this.state.apiClient.listProjectIpAccessLists({
const accessList = await this.session.apiClient.listProjectIpAccessLists({
params: {
path: {
groupId: projectId,
Expand Down
4 changes: 2 additions & 2 deletions src/tools/atlas/inspectCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ export class InspectClusterTool extends AtlasToolBase {
};

protected async execute({ projectId, clusterName }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
this.state.ensureApiClient();
this.session.ensureAuthenticated();

const cluster = await this.state.apiClient.getCluster({
const cluster = await this.session.apiClient.getCluster({
params: {
path: {
groupId: projectId,
Expand Down
8 changes: 4 additions & 4 deletions src/tools/atlas/listClusters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ export class ListClustersTool extends AtlasToolBase {
};

protected async execute({ projectId }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
this.state.ensureApiClient();
this.session.ensureAuthenticated();

if (!projectId) {
const data = await this.state.apiClient.listClustersForAllProjects();
const data = await this.session.apiClient.listClustersForAllProjects();

return this.formatAllClustersTable(data);
} else {
const project = await this.state.apiClient.getProject({
const project = await this.session.apiClient.getProject({
params: {
path: {
groupId: projectId,
Expand All @@ -31,7 +31,7 @@ export class ListClustersTool extends AtlasToolBase {
throw new Error(`Project with ID "${projectId}" not found.`);
}

const data = await this.state.apiClient.listClusters({
const data = await this.session.apiClient.listClusters({
params: {
path: {
groupId: project.id || "",
Expand Down
4 changes: 2 additions & 2 deletions src/tools/atlas/listDBUsers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ export class ListDBUsersTool extends AtlasToolBase {
};

protected async execute({ projectId }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
this.state.ensureApiClient();
this.session.ensureAuthenticated();

const data = await this.state.apiClient.listDatabaseUsers({
const data = await this.session.apiClient.listDatabaseUsers({
params: {
path: {
groupId: projectId,
Expand Down
4 changes: 2 additions & 2 deletions src/tools/atlas/listProjects.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ export class ListProjectsTool extends AtlasToolBase {
protected argsShape = {};

protected async execute(): Promise<CallToolResult> {
this.state.ensureApiClient();
this.session.ensureAuthenticated();

const data = await this.state.apiClient.listProjects();
const data = await this.session.apiClient.listProjects();

if (!data?.results?.length) {
throw new Error("No projects found in your MongoDB Atlas account.");
Expand Down
29 changes: 10 additions & 19 deletions src/tools/atlas/tools.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import { ToolBase } from "../tool.js";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { State } from "../../state.js";
import { ListClustersTool } from "./listClusters.js";
import { ListProjectsTool } from "./listProjects.js";
import { InspectClusterTool } from "./inspectCluster.js";
Expand All @@ -10,19 +7,13 @@ import { InspectAccessListTool } from "./inspectAccessList.js";
import { ListDBUsersTool } from "./listDBUsers.js";
import { CreateDBUserTool } from "./createDBUser.js";

export function registerAtlasTools(server: McpServer, state: State) {
const tools: ToolBase[] = [
new ListClustersTool(state),
new ListProjectsTool(state),
new InspectClusterTool(state),
new CreateFreeClusterTool(state),
new CreateAccessListTool(state),
new InspectAccessListTool(state),
new ListDBUsersTool(state),
new CreateDBUserTool(state),
];

for (const tool of tools) {
tool.register(server);
}
}
export const AtlasTools = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

ListClustersTool,
ListProjectsTool,
InspectClusterTool,
CreateFreeClusterTool,
CreateAccessListTool,
InspectAccessListTool,
ListDBUsersTool,
CreateDBUserTool,
];
2 changes: 1 addition & 1 deletion src/tools/mongodb/connect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export class ConnectTool extends MongoDBToolBase {
throw new MongoDBError(ErrorCodes.InvalidParams, "Invalid connection options");
}

await this.connectToMongoDB(connectionString, this.state);
await this.connectToMongoDB(connectionString);

return {
content: [{ type: "text", text: `Successfully connected to ${connectionString}.` }],
Expand Down
Loading