Skip to content

Commit b608fc9

Browse files
committed
fix: use union types and reuse argument definitions
1 parent a0bbbe9 commit b608fc9

File tree

5 files changed

+89
-53
lines changed

5 files changed

+89
-53
lines changed

eslint.config.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export default defineConfig([
1313
files: ["src/**/*.ts"],
1414
rules: {
1515
"@typescript-eslint/no-non-null-assertion": "error",
16+
"@typescript-eslint/switch-exhaustiveness-check": "error",
1617
},
1718
},
1819
globalIgnores(["node_modules", "dist"]),

src/tools/mongodb/metadata/explain.ts

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbToo
33
import { ToolArgs } from "../../tool.js";
44
import { z } from "zod";
55
import { ExplainVerbosity, Document } from "mongodb";
6+
import { AggregateArgs } from "../read/aggregate.js";
7+
import { FindArgs } from "../read/find.js";
8+
import { CountArgs } from "../read/count.js";
69

710
export class ExplainTool extends MongoDBToolBase {
811
protected name = "explain";
@@ -11,42 +14,65 @@ export class ExplainTool extends MongoDBToolBase {
1114

1215
protected argsShape = {
1316
...DbOperationArgs,
14-
method: z.enum(["aggregate", "find"]).describe("The method to run"),
15-
methodArguments: z
16-
.object({
17-
aggregatePipeline: z
18-
.array(z.object({}).passthrough())
19-
.optional()
20-
.describe("aggregate - array of aggregation stages to execute"),
21-
22-
findQuery: z.object({}).passthrough().optional().describe("find - The query to run"),
23-
findProjection: z.object({}).passthrough().optional().describe("find - The projection to apply"),
24-
})
25-
.describe("The arguments for the method"),
17+
method: z
18+
.array(
19+
z.union([
20+
z.object({
21+
name: z.literal("aggregate"),
22+
arguments: z.object(AggregateArgs),
23+
}),
24+
z.object({
25+
name: z.literal("find"),
26+
arguments: z.object(FindArgs),
27+
}),
28+
z.object({
29+
name: z.literal("count"),
30+
arguments: z.object(CountArgs),
31+
}),
32+
])
33+
)
34+
.describe("The method and its arguments to run"),
2635
};
2736

2837
protected operationType: DbOperationType = "metadata";
2938

39+
static readonly defaultVerbosity = ExplainVerbosity.queryPlanner;
40+
3041
protected async execute({
3142
database,
3243
collection,
33-
method,
34-
methodArguments,
44+
method: methods,
3545
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
3646
const provider = this.ensureConnected();
47+
const method = methods[0];
48+
49+
if (!method) {
50+
throw new Error("No method provided");
51+
}
3752

3853
let result: Document;
39-
switch (method) {
54+
switch (method.name) {
4055
case "aggregate": {
41-
result = await provider.aggregate(database, collection).explain();
56+
const { pipeline, limit } = method.arguments;
57+
result = await provider
58+
.aggregate(database, collection, pipeline)
59+
.limit(limit)
60+
.explain(ExplainTool.defaultVerbosity);
4261
break;
4362
}
4463
case "find": {
45-
const query = methodArguments.findQuery ?? {};
46-
const projection = methodArguments.findProjection ?? {};
64+
const { filter, ...rest } = method.arguments;
4765
result = await provider
48-
.find(database, collection, query, { projection })
49-
.explain(ExplainVerbosity.queryPlanner);
66+
.find(database, collection, filter, { ...rest })
67+
.explain(ExplainTool.defaultVerbosity);
68+
break;
69+
}
70+
case "count": {
71+
const { query } = method.arguments;
72+
// This helper doesn't have explain() command but does have the argument explain
73+
result = (await provider.count(database, collection, query, {
74+
explain: ExplainTool.defaultVerbosity,
75+
})) as unknown as Document;
5076
break;
5177
}
5278
default:

src/tools/mongodb/read/aggregate.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
11
import { z } from "zod";
22
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
3-
import { DbOperationType, MongoDBToolBase } from "../mongodbTool.js";
3+
import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js";
44
import { ToolArgs } from "../../tool.js";
55

6+
export const AggregateArgs = {
7+
pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"),
8+
limit: z.number().optional().default(10).describe("The maximum number of documents to return"),
9+
};
10+
611
export class AggregateTool extends MongoDBToolBase {
712
protected name = "aggregate";
813
protected description = "Run an aggregation against a MongoDB collection";
914
protected argsShape = {
10-
collection: z.string().describe("Collection name"),
11-
database: z.string().describe("Database name"),
12-
pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"),
13-
limit: z.number().optional().default(10).describe("The maximum number of documents to return"),
15+
...DbOperationArgs,
16+
...AggregateArgs,
1417
};
1518
protected operationType: DbOperationType = "read";
1619

1720
protected async execute({
1821
database,
1922
collection,
2023
pipeline,
24+
limit,
2125
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
2226
const provider = this.ensureConnected();
23-
const documents = await provider.aggregate(database, collection, pipeline).toArray();
27+
const documents = await provider.aggregate(database, collection, pipeline).limit(limit).toArray();
2428

2529
const content: Array<{ text: string; type: "text" }> = [
2630
{

src/tools/mongodb/read/count.ts

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,22 @@ import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbToo
33
import { ToolArgs } from "../../tool.js";
44
import { z } from "zod";
55

6+
export const CountArgs = {
7+
query: z
8+
.object({})
9+
.passthrough()
10+
.optional()
11+
.describe(
12+
"The query filter to count documents. Matches the syntax of the filter argument of db.collection.count()"
13+
),
14+
};
15+
616
export class CountTool extends MongoDBToolBase {
717
protected name = "count";
818
protected description = "Gets the number of documents in a MongoDB collection";
919
protected argsShape = {
1020
...DbOperationArgs,
11-
query: z
12-
.object({})
13-
.passthrough()
14-
.optional()
15-
.describe(
16-
"The query filter to count documents. Matches the syntax of the filter argument of db.collection.count()"
17-
),
21+
...CountArgs,
1822
};
1923

2024
protected operationType: DbOperationType = "metadata";

src/tools/mongodb/read/find.ts

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,33 @@
11
import { z } from "zod";
22
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
3-
import { DbOperationType, MongoDBToolBase } from "../mongodbTool.js";
3+
import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js";
44
import { ToolArgs } from "../../tool.js";
55
import { SortDirection } from "mongodb";
66

7+
export const FindArgs = {
8+
filter: z
9+
.object({})
10+
.passthrough()
11+
.optional()
12+
.describe("The query filter, matching the syntax of the query argument of db.collection.find()"),
13+
projection: z
14+
.object({})
15+
.passthrough()
16+
.optional()
17+
.describe("The projection, matching the syntax of the projection argument of db.collection.find()"),
18+
limit: z.number().optional().default(10).describe("The maximum number of documents to return"),
19+
sort: z
20+
.record(z.string(), z.custom<SortDirection>())
21+
.optional()
22+
.describe("A document, describing the sort order, matching the syntax of the sort argument of cursor.sort()"),
23+
};
24+
725
export class FindTool extends MongoDBToolBase {
826
protected name = "find";
927
protected description = "Run a find query against a MongoDB collection";
1028
protected argsShape = {
11-
collection: z.string().describe("Collection name"),
12-
database: z.string().describe("Database name"),
13-
filter: z
14-
.object({})
15-
.passthrough()
16-
.optional()
17-
.describe("The query filter, matching the syntax of the query argument of db.collection.find()"),
18-
projection: z
19-
.object({})
20-
.passthrough()
21-
.optional()
22-
.describe("The projection, matching the syntax of the projection argument of db.collection.find()"),
23-
limit: z.number().optional().default(10).describe("The maximum number of documents to return"),
24-
sort: z
25-
.record(z.string(), z.custom<SortDirection>())
26-
.optional()
27-
.describe(
28-
"A document, describing the sort order, matching the syntax of the sort argument of cursor.sort()"
29-
),
29+
...DbOperationArgs,
30+
...FindArgs,
3031
};
3132
protected operationType: DbOperationType = "read";
3233

0 commit comments

Comments
 (0)