Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 01cbfe7

Browse files
fix: add guards against possible memory overflow in find and aggregate tools MCP-21 (#536)
1 parent c10955a commit 01cbfe7

File tree

17 files changed

+1116
-78
lines changed

17 files changed

+1116
-78
lines changed

‎src/common/config.ts‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import levenshtein from "ts-levenshtein";
99

1010
// From: https://github.com/mongodb-js/mongosh/blob/main/packages/cli-repl/src/arg-parser.ts
1111
const OPTIONS = {
12+
number: ["maxDocumentsPerQuery", "maxBytesPerQuery"],
1213
string: [
1314
"apiBaseUrl",
1415
"apiClientId",
@@ -98,6 +99,7 @@ const OPTIONS = {
9899

99100
interface Options {
100101
string: string[];
102+
number: string[];
101103
boolean: string[];
102104
array: string[];
103105
alias: Record<string, string>;
@@ -106,6 +108,7 @@ interface Options {
106108

107109
export const ALL_CONFIG_KEYS = new Set(
108110
(OPTIONS.string as readonly string[])
111+
.concat(OPTIONS.number)
109112
.concat(OPTIONS.array)
110113
.concat(OPTIONS.boolean)
111114
.concat(Object.keys(OPTIONS.alias))
@@ -175,6 +178,8 @@ export interface UserConfig extends CliOptions {
175178
loggers: Array<"stderr" | "disk" | "mcp">;
176179
idleTimeoutMs: number;
177180
notificationTimeoutMs: number;
181+
maxDocumentsPerQuery: number;
182+
maxBytesPerQuery: number;
178183
atlasTemporaryDatabaseUserLifetimeMs: number;
179184
}
180185

@@ -202,6 +207,8 @@ export const defaultUserConfig: UserConfig = {
202207
idleTimeoutMs: 10 * 60 * 1000, // 10 minutes
203208
notificationTimeoutMs: 9 * 60 * 1000, // 9 minutes
204209
httpHeaders: {},
210+
maxDocumentsPerQuery: 100, // By default, we only fetch a maximum 100 documents per query / aggregation
211+
maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation
205212
atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours
206213
};
207214

‎src/common/logger.ts‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export const LogId = {
4444
mongodbConnectFailure: mongoLogId(1_004_001),
4545
mongodbDisconnectFailure: mongoLogId(1_004_002),
4646
mongodbConnectTry: mongoLogId(1_004_003),
47+
mongodbCursorCloseError: mongoLogId(1_004_004),
4748

4849
toolUpdateFailure: mongoLogId(1_005_001),
4950
resourceUpdateFailure: mongoLogId(1_005_002),
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import { calculateObjectSize } from "bson";
2+
import type { AggregationCursor, FindCursor } from "mongodb";
3+
4+
export function getResponseBytesLimit(
5+
toolResponseBytesLimit: number | undefined | null,
6+
configuredMaxBytesPerQuery: unknown
7+
): {
8+
cappedBy: "config.maxBytesPerQuery" | "tool.responseBytesLimit" | undefined;
9+
limit: number;
10+
} {
11+
const configuredLimit: number = parseInt(String(configuredMaxBytesPerQuery), 10);
12+
13+
// Setting configured maxBytesPerQuery to negative, zero or nullish is
14+
// equivalent to disabling the max limit applied on documents
15+
const configuredLimitIsNotApplicable = Number.isNaN(configuredLimit) || configuredLimit <= 0;
16+
17+
// It's possible to have tool parameter responseBytesLimit as null or
18+
// negative values in which case we consider that no limit is to be
19+
// applied from tool call perspective unless we have a maxBytesPerQuery
20+
// configured.
21+
const toolResponseLimitIsNotApplicable = typeof toolResponseBytesLimit !== "number" || toolResponseBytesLimit <= 0;
22+
23+
if (configuredLimitIsNotApplicable) {
24+
return {
25+
cappedBy: toolResponseLimitIsNotApplicable ? undefined : "tool.responseBytesLimit",
26+
limit: toolResponseLimitIsNotApplicable ? 0 : toolResponseBytesLimit,
27+
};
28+
}
29+
30+
if (toolResponseLimitIsNotApplicable) {
31+
return { cappedBy: "config.maxBytesPerQuery", limit: configuredLimit };
32+
}
33+
34+
return {
35+
cappedBy: configuredLimit < toolResponseBytesLimit ? "config.maxBytesPerQuery" : "tool.responseBytesLimit",
36+
limit: Math.min(toolResponseBytesLimit, configuredLimit),
37+
};
38+
}
39+
40+
/**
41+
* This function attempts to put a guard rail against accidental memory overflow
42+
* on the MCP server.
43+
*
44+
* The cursor is iterated until we can predict that fetching next doc won't
45+
* exceed the derived limit on number of bytes for the tool call. The derived
46+
* limit takes into account the limit provided from the Tool's interface and the
47+
* configured maxBytesPerQuery for the server.
48+
*/
49+
export async function collectCursorUntilMaxBytesLimit<T = unknown>({
50+
cursor,
51+
toolResponseBytesLimit,
52+
configuredMaxBytesPerQuery,
53+
abortSignal,
54+
}: {
55+
cursor: FindCursor<T> | AggregationCursor<T>;
56+
toolResponseBytesLimit: number | undefined | null;
57+
configuredMaxBytesPerQuery: unknown;
58+
abortSignal?: AbortSignal;
59+
}): Promise<{ cappedBy: "config.maxBytesPerQuery" | "tool.responseBytesLimit" | undefined; documents: T[] }> {
60+
const { limit: maxBytesPerQuery, cappedBy } = getResponseBytesLimit(
61+
toolResponseBytesLimit,
62+
configuredMaxBytesPerQuery
63+
);
64+
65+
// It's possible to have no limit on the cursor response by setting both the
66+
// config.maxBytesPerQuery and tool.responseBytesLimit to nullish or
67+
// negative values.
68+
if (maxBytesPerQuery <= 0) {
69+
return {
70+
cappedBy,
71+
documents: await cursor.toArray(),
72+
};
73+
}
74+
75+
let wasCapped: boolean = false;
76+
let totalBytes = 0;
77+
const bufferedDocuments: T[] = [];
78+
while (true) {
79+
if (abortSignal?.aborted) {
80+
break;
81+
}
82+
83+
// If the cursor is empty then there is nothing for us to do anymore.
84+
const nextDocument = await cursor.tryNext();
85+
if (!nextDocument) {
86+
break;
87+
}
88+
89+
const nextDocumentSize = calculateObjectSize(nextDocument);
90+
if (totalBytes + nextDocumentSize >= maxBytesPerQuery) {
91+
wasCapped = true;
92+
break;
93+
}
94+
95+
totalBytes += nextDocumentSize;
96+
bufferedDocuments.push(nextDocument);
97+
}
98+
99+
return {
100+
cappedBy: wasCapped ? cappedBy : undefined,
101+
documents: bufferedDocuments,
102+
};
103+
}

‎src/helpers/constants.ts‎

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/**
2+
* A cap for the maxTimeMS used for FindCursor.countDocuments.
3+
*
4+
* The number is relatively smaller because we expect the count documents query
5+
* to be finished sooner if not by the time the batch of documents is retrieved
6+
* so that count documents query don't hold the final response back.
7+
*/
8+
export const QUERY_COUNT_MAX_TIME_MS_CAP: number = 10_000;
9+
10+
/**
11+
* A cap for the maxTimeMS used for counting resulting documents of an
12+
* aggregation.
13+
*/
14+
export const AGG_COUNT_MAX_TIME_MS_CAP: number = 60_000;
15+
16+
export const ONE_MB: number = 1 * 1024 * 1024;
17+
18+
/**
19+
* A map of applied limit on cursors to a text that is supposed to be sent as
20+
* response to LLM
21+
*/
22+
export const CURSOR_LIMITS_TO_LLM_TEXT = {
23+
"config.maxDocumentsPerQuery": "server's configured - maxDocumentsPerQuery",
24+
"config.maxBytesPerQuery": "server's configured - maxBytesPerQuery",
25+
"tool.responseBytesLimit": "tool's parameter - responseBytesLimit",
26+
} as const;

‎src/helpers/operationWithFallback.ts‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
type OperationCallback<OperationResult> = () => Promise<OperationResult>;
2+
3+
export async function operationWithFallback<OperationResult, FallbackValue>(
4+
performOperation: OperationCallback<OperationResult>,
5+
fallback: FallbackValue
6+
): Promise<OperationResult | FallbackValue> {
7+
try {
8+
return await performOperation();
9+
} catch {
10+
return fallback;
11+
}
12+
}

‎src/tools/mongodb/read/aggregate.ts‎

Lines changed: 135 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
import { z } from "zod";
2+
import type { AggregationCursor } from "mongodb";
23
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
4+
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
35
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
4-
import type { ToolArgs, OperationType } from "../../tool.js";
6+
import type { ToolArgs, OperationType,ToolExecutionContext } from "../../tool.js";
57
import { formatUntrustedData } from "../../tool.js";
68
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
7-
import { EJSON } from "bson";
9+
import { typeDocument,EJSON } from "bson";
810
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
11+
import { collectCursorUntilMaxBytesLimit } from "../../../helpers/collectCursorUntilMaxBytes.js";
12+
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
13+
import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js";
914
import { zEJSON } from "../../args.js";
15+
import { LogId } from "../../../common/logger.js";
1016

1117
export const AggregateArgs = {
1218
pipeline: z.array(zEJSON()).describe("An array of aggregation stages to execute"),
19+
responseBytesLimit: z.number().optional().default(ONE_MB).describe(`\
20+
The maximum number of bytes to return in the response. This value is capped by the server’s configured maxBytesPerQuery and cannot be exceeded. \
21+
Note to LLM: If the entire aggregation result is required, use the "export" tool instead of increasing this limit.\
22+
`),
1323
};
1424

1525
export class AggregateTool extends MongoDBToolBase {
@@ -21,32 +31,80 @@ export class AggregateTool extends MongoDBToolBase {
2131
};
2232
public operationType: OperationType = "read";
2333

24-
protected async execute({
25-
database,
26-
collection,
27-
pipeline,
28-
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
29-
const provider = await this.ensureConnected();
34+
protected async execute(
35+
{ database, collection, pipeline, responseBytesLimit }: ToolArgs<typeof this.argsShape>,
36+
{ signal }: ToolExecutionContext
37+
): Promise<CallToolResult> {
38+
let aggregationCursor: AggregationCursor | undefined = undefined;
39+
try {
40+
const provider = await this.ensureConnected();
3041

31-
this.assertOnlyUsesPermittedStages(pipeline);
42+
this.assertOnlyUsesPermittedStages(pipeline);
3243

33-
// Check if aggregate operation uses an index if enabled
34-
if (this.config.indexCheck) {
35-
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
36-
return provider
37-
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
38-
.explain("queryPlanner");
39-
});
40-
}
44+
// Check if aggregate operation uses an index if enabled
45+
if (this.config.indexCheck) {
46+
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
47+
return provider
48+
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
49+
.explain("queryPlanner");
50+
});
51+
}
4152

42-
const documents = await provider.aggregate(database, collection, pipeline).toArray();
53+
const cappedResultsPipeline = [...pipeline];
54+
if (this.config.maxDocumentsPerQuery > 0) {
55+
cappedResultsPipeline.push({ $limit: this.config.maxDocumentsPerQuery });
56+
}
57+
aggregationCursor = provider.aggregate(database, collection, cappedResultsPipeline);
4358

44-
return {
45-
content: formatUntrustedData(
46-
`The aggregation resulted in ${documents.length} documents.`,
47-
documents.length > 0 ? EJSON.stringify(documents) : undefined
48-
),
49-
};
59+
const [totalDocuments, cursorResults] = await Promise.all([
60+
this.countAggregationResultDocuments({ provider, database, collection, pipeline }),
61+
collectCursorUntilMaxBytesLimit({
62+
cursor: aggregationCursor,
63+
configuredMaxBytesPerQuery: this.config.maxBytesPerQuery,
64+
toolResponseBytesLimit: responseBytesLimit,
65+
abortSignal: signal,
66+
}),
67+
]);
68+
69+
// If the total number of documents that the aggregation would've
70+
// resulted in would be greater than the configured
71+
// maxDocumentsPerQuery then we know for sure that the results were
72+
// capped.
73+
const aggregationResultsCappedByMaxDocumentsLimit =
74+
this.config.maxDocumentsPerQuery > 0 &&
75+
!!totalDocuments &&
76+
totalDocuments > this.config.maxDocumentsPerQuery;
77+
78+
return {
79+
content: formatUntrustedData(
80+
this.generateMessage({
81+
aggResultsCount: totalDocuments,
82+
documents: cursorResults.documents,
83+
appliedLimits: [
84+
aggregationResultsCappedByMaxDocumentsLimit ? "config.maxDocumentsPerQuery" : undefined,
85+
cursorResults.cappedBy,
86+
].filter((limit): limit is keyof typeof CURSOR_LIMITS_TO_LLM_TEXT => !!limit),
87+
}),
88+
cursorResults.documents.length > 0 ? EJSON.stringify(cursorResults.documents) : undefined
89+
),
90+
};
91+
} finally {
92+
if (aggregationCursor) {
93+
void this.safeCloseCursor(aggregationCursor);
94+
}
95+
}
96+
}
97+
98+
private async safeCloseCursor(cursor: AggregationCursor<unknown>): Promise<void> {
99+
try {
100+
await cursor.close();
101+
} catch (error) {
102+
this.session.logger.warning({
103+
id: LogId.mongodbCursorCloseError,
104+
context: "aggregate tool",
105+
message: `Error when closing the cursor - ${error instanceof Error ? error.message : String(error)}`,
106+
});
107+
}
50108
}
51109

52110
private assertOnlyUsesPermittedStages(pipeline: Record<string, unknown>[]): void {
@@ -70,4 +128,57 @@ export class AggregateTool extends MongoDBToolBase {
70128
}
71129
}
72130
}
131+
132+
private async countAggregationResultDocuments({
133+
provider,
134+
database,
135+
collection,
136+
pipeline,
137+
}: {
138+
provider: NodeDriverServiceProvider;
139+
database: string;
140+
collection: string;
141+
pipeline: Document[];
142+
}): Promise<number | undefined> {
143+
const resultsCountAggregation = [...pipeline, { $count: "totalDocuments" }];
144+
return await operationWithFallback(async (): Promise<number | undefined> => {
145+
const aggregationResults = await provider
146+
.aggregate(database, collection, resultsCountAggregation)
147+
.maxTimeMS(AGG_COUNT_MAX_TIME_MS_CAP)
148+
.toArray();
149+
150+
const documentWithCount: unknown = aggregationResults.length === 1 ? aggregationResults[0] : undefined;
151+
const totalDocuments =
152+
documentWithCount &&
153+
typeof documentWithCount === "object" &&
154+
"totalDocuments" in documentWithCount &&
155+
typeof documentWithCount.totalDocuments === "number"
156+
? documentWithCount.totalDocuments
157+
: 0;
158+
159+
return totalDocuments;
160+
}, undefined);
161+
}
162+
163+
private generateMessage({
164+
aggResultsCount,
165+
documents,
166+
appliedLimits,
167+
}: {
168+
aggResultsCount: number | undefined;
169+
documents: unknown[];
170+
appliedLimits: (keyof typeof CURSOR_LIMITS_TO_LLM_TEXT)[];
171+
}): string {
172+
const appliedLimitText = appliedLimits.length
173+
? `\
174+
while respecting the applied limits of ${appliedLimits.map((limit) => CURSOR_LIMITS_TO_LLM_TEXT[limit]).join(", ")}. \
175+
Note to LLM: If the entire query result is required then use "export" tool to export the query results.\
176+
`
177+
: "";
178+
179+
return `\
180+
The aggregation resulted in ${aggResultsCount === undefined ? "indeterminable number of" : aggResultsCount} documents. \
181+
Returning ${documents.length} documents${appliedLimitText ? ` ${appliedLimitText}` : "."}\
182+
`;
183+
}
73184
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /