Skip to content
Merged
1 change: 1 addition & 0 deletions src/common/logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export const LogId = {
toolUpdateFailure: mongoLogId(1_005_001),
resourceUpdateFailure: mongoLogId(1_005_002),
updateToolMetadata: mongoLogId(1_005_003),
toolValidationError: mongoLogId(1_005_004),

streamableHttpTransportStarted: mongoLogId(1_006_001),
streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002),
Expand Down
47 changes: 6 additions & 41 deletions src/common/search/embeddingsProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import { embedMany } from "ai";
import type { UserConfig } from "../config.js";
import assert from "assert";
import { createFetch } from "@mongodb-js/devtools-proxy-support";
import { z } from "zod";
import {
type EmbeddingParameters,
type VoyageEmbeddingParameters,
type VoyageModels,
zVoyageAPIParameters,
} from "../../tools/mongodb/mongodbSchemas.js";

type EmbeddingsInput = string;
type Embeddings = number[] | unknown[];
export type EmbeddingParameters = {
inputType: "query" | "document";
};

export interface EmbeddingsProvider<
SupportedModels extends string,
Expand All @@ -23,40 +25,6 @@ export interface EmbeddingsProvider<
): Promise<Embeddings[]>;
}

export const zVoyageModels = z
.enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"])
.default("voyage-3-large");

// Zod does not undestand JS boxed numbers (like Int32) as integer literals,
// so we preprocess them to unwrap them so Zod understands them.
function unboxNumber(v: unknown): number {
if (v && typeof v === "object" && typeof v.valueOf === "function") {
const n = Number(v.valueOf());
if (!Number.isNaN(n)) return n;
}
return v as number;
}

export const zVoyageEmbeddingParameters = z.object({
outputDimension: z
.preprocess(
unboxNumber,
z.union([z.literal(256), z.literal(512), z.literal(1024), z.literal(2048), z.literal(4096)])
)
.optional()
.default(1024),
outputDtype: z.enum(["float", "int8", "uint8", "binary", "ubinary"]).optional().default("float"),
});

const zVoyageAPIParameters = zVoyageEmbeddingParameters
.extend({
inputType: z.enum(["query", "document"]),
})
.strip();

type VoyageModels = z.infer<typeof zVoyageModels>;
type VoyageEmbeddingParameters = z.infer<typeof zVoyageEmbeddingParameters> & EmbeddingParameters;

class VoyageEmbeddingsProvider implements EmbeddingsProvider<VoyageModels, VoyageEmbeddingParameters> {
private readonly voyage: VoyageProvider;

Expand Down Expand Up @@ -105,6 +73,3 @@ export function getEmbeddingsProvider(

return undefined;
}

export const zSupportedEmbeddingParameters = zVoyageEmbeddingParameters.extend({ model: zVoyageModels });
export type SupportedEmbeddingParameters = z.infer<typeof zSupportedEmbeddingParameters>;
3 changes: 2 additions & 1 deletion src/common/search/vectorSearchEmbeddingsManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import type { ConnectionManager } from "../connectionManager.js";
import z from "zod";
import { ErrorCodes, MongoDBError } from "../errors.js";
import { getEmbeddingsProvider } from "./embeddingsProvider.js";
import type { EmbeddingParameters, SupportedEmbeddingParameters } from "./embeddingsProvider.js";
import type { EmbeddingParameters } from "../../tools/mongodb/mongodbSchemas.js";
import { formatUntrustedData } from "../../tools/tool.js";
import type { Similarity } from "../schemas.js";
import type { SupportedEmbeddingParameters } from "../../tools/mongodb/mongodbSchemas.js";

export const quantizationEnum = z.enum(["none", "scalar", "binary"]);
export type Quantization = z.infer<typeof quantizationEnum>;
Expand Down
90 changes: 90 additions & 0 deletions src/helpers/assertVectorSearchFilterFieldsAreIndexed.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Based on -

import type z from "zod";
import { ErrorCodes, MongoDBError } from "../common/errors.js";
import type { VectorSearchStage } from "../tools/mongodb/mongodbSchemas.js";
import { type CompositeLogger, LogId } from "../common/logger.js";

// https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#mongodb-vector-search-pre-filter
const ALLOWED_LOGICAL_OPERATORS = ["$not", "$nor", "$and", "$or"];

export type VectorSearchIndex = {
name: string;
latestDefinition: {
fields: Array<
| {
type: "vector";
}
| {
type: "filter";
path: string;
}
>;
};
};

export function assertVectorSearchFilterFieldsAreIndexed({
searchIndexes,
pipeline,
logger,
}: {
searchIndexes: VectorSearchIndex[];
pipeline: Record<string, unknown>[];
logger: CompositeLogger;
}): void {
const searchIndexesWithFilterFields = searchIndexes.reduce<Record<string, string[]>>(
(indexFieldMap, searchIndex) => {
const filterFields = searchIndex.latestDefinition.fields
.map<string | undefined>((field) => {
return field.type === "filter" ? field.path : undefined;
})
.filter((filterField) => filterField !== undefined);

indexFieldMap[searchIndex.name] = filterFields;
return indexFieldMap;
},
{}
);
for (const stage of pipeline) {
if ("$vectorSearch" in stage) {
const { $vectorSearch: vectorSearchStage } = stage as z.infer<typeof VectorSearchStage>;
const allowedFilterFields = searchIndexesWithFilterFields[vectorSearchStage.index];
if (!allowedFilterFields) {
logger.warning({
id: LogId.toolValidationError,
context: "aggregate tool",
message: `Could not assert if filter fields are indexed - No filter fields found for index ${vectorSearchStage.index}`,
});
return;
}

const filterFieldsInStage = collectFieldsFromVectorSearchFilter(vectorSearchStage.filter);
const filterFieldsNotIndexed = filterFieldsInStage.filter((field) => !allowedFilterFields.includes(field));
if (filterFieldsNotIndexed.length) {
throw new MongoDBError(
ErrorCodes.AtlasVectorSearchInvalidQuery,
`Vector search stage contains filter on fields that are not indexed by index ${vectorSearchStage.index} - ${filterFieldsNotIndexed.join(", ")}`
);
}
}
}
}

export function collectFieldsFromVectorSearchFilter(filter: unknown): string[] {
if (!filter || typeof filter !== "object" || !Object.keys(filter).length) {
return [];
}

const collectedFields = Object.entries(filter).reduce<string[]>((collectedFields, [maybeField, fieldMQL]) => {
if (ALLOWED_LOGICAL_OPERATORS.includes(maybeField) && Array.isArray(fieldMQL)) {
return fieldMQL.flatMap((mql) => collectFieldsFromVectorSearchFilter(mql));
}

if (!ALLOWED_LOGICAL_OPERATORS.includes(maybeField)) {
collectedFields.push(maybeField);
}
return collectedFields;
}, []);

return Array.from(new Set(collectedFields));
}
2 changes: 1 addition & 1 deletion src/tools/mongodb/create/insertMany.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import { type ToolArgs, type OperationType, formatUntrustedData } from "../../tool.js";
import { zEJSON } from "../../args.js";
import { type Document } from "bson";
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
import { zSupportedEmbeddingParameters } from "../mongodbSchemas.js";
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";

const zSupportedEmbeddingParametersWithInput = zSupportedEmbeddingParameters.extend({
Expand Down
86 changes: 86 additions & 0 deletions src/tools/mongodb/mongodbSchemas.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import z from "zod";
import { zEJSON } from "../args.js";

export const zVoyageModels = z
.enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"])
.default("voyage-3-large");

// Zod does not undestand JS boxed numbers (like Int32) as integer literals,
// so we preprocess them to unwrap them so Zod understands them.
function unboxNumber(v: unknown): number {
if (v && typeof v === "object" && typeof v.valueOf === "function") {
const n = Number(v.valueOf());
if (!Number.isNaN(n)) return n;
}
return v as number;
}

export const zVoyageEmbeddingParameters = z.object({
outputDimension: z
.preprocess(
unboxNumber,
z.union([z.literal(256), z.literal(512), z.literal(1024), z.literal(2048), z.literal(4096)])
)
.optional()
.default(1024),
outputDtype: z.enum(["float", "int8", "uint8", "binary", "ubinary"]).optional().default("float"),
});

export const zVoyageAPIParameters = zVoyageEmbeddingParameters
.extend({
inputType: z.enum(["query", "document"]),
})
.strip();

export type VoyageModels = z.infer<typeof zVoyageModels>;
export type VoyageEmbeddingParameters = z.infer<typeof zVoyageEmbeddingParameters> & EmbeddingParameters;

export type EmbeddingParameters = {
inputType: "query" | "document";
};

export const zSupportedEmbeddingParameters = zVoyageEmbeddingParameters.extend({ model: zVoyageModels });
export type SupportedEmbeddingParameters = z.infer<typeof zSupportedEmbeddingParameters>;

export const AnyVectorSearchStage = zEJSON();
export const VectorSearchStage = z.object({
$vectorSearch: z
.object({
exact: z
.boolean()
.optional()
.default(false)
.describe(
"When true, uses an ENN algorithm, otherwise uses ANN. Using ENN is not compatible with numCandidates, in that case, numCandidates must be left empty."
),
index: z.string().describe("Name of the index, as retrieved from the `collection-indexes` tool."),
path: z
.string()
.describe(
"Field, in dot notation, where to search. There must be a vector search index for that field. Note to LLM: When unsure, use the 'collection-indexes' tool to validate that the field is indexed with a vector search index."
),
queryVector: z
.union([z.string(), z.array(z.number())])
.describe(
"The content to search for. The embeddingParameters field is mandatory if the queryVector is a string, in that case, the tool generates the embedding automatically using the provided configuration."
),
numCandidates: z
.number()
.int()
.positive()
.optional()
.describe("Number of candidates for the ANN algorithm. Mandatory when exact is false."),
limit: z.number().int().positive().optional().default(10),
filter: zEJSON()
.optional()
.describe(
"MQL filter that can only use filter fields from the index definition. Note to LLM: If unsure, use the `collection-indexes` tool to learn which fields can be used for filtering."
),
embeddingParameters: zSupportedEmbeddingParameters
.optional()
.describe(
"The embedding model and its parameters to use to generate embeddings before searching. It is mandatory if queryVector is a string value. Note to LLM: If unsure, ask the user before providing one."
),
})
.passthrough(),
});
59 changes: 13 additions & 46 deletions src/tools/mongodb/read/aggregate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,15 @@ import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
import { collectCursorUntilMaxBytesLimit } from "../../../helpers/collectCursorUntilMaxBytes.js";
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js";
import { zEJSON } from "../../args.js";
import { LogId } from "../../../common/logger.js";
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";

const AnyStage = zEJSON();
const VectorSearchStage = z.object({
$vectorSearch: z
.object({
exact: z
.boolean()
.optional()
.default(false)
.describe(
"When true, uses an ENN algorithm, otherwise uses ANN. Using ENN is not compatible with numCandidates, in that case, numCandidates must be left empty."
),
index: z.string().describe("Name of the index, as retrieved from the `collection-indexes` tool."),
path: z
.string()
.describe(
"Field, in dot notation, where to search. There must be a vector search index for that field. Note to LLM: When unsure, use the 'collection-indexes' tool to validate that the field is indexed with a vector search index."
),
queryVector: z
.union([z.string(), z.array(z.number())])
.describe(
"The content to search for. The embeddingParameters field is mandatory if the queryVector is a string, in that case, the tool generates the embedding automatically using the provided configuration."
),
numCandidates: z
.number()
.int()
.positive()
.optional()
.describe("Number of candidates for the ANN algorithm. Mandatory when exact is false."),
limit: z.number().int().positive().optional().default(10),
filter: zEJSON()
.optional()
.describe(
"MQL filter that can only use filter fields from the index definition. Note to LLM: If unsure, use the `collection-indexes` tool to learn which fields can be used for filtering."
),
embeddingParameters: zSupportedEmbeddingParameters
.optional()
.describe(
"The embedding model and its parameters to use to generate embeddings before searching. It is mandatory if queryVector is a string value. Note to LLM: If unsure, ask the user before providing one."
),
})
.passthrough(),
});
import { AnyVectorSearchStage, VectorSearchStage } from "../mongodbSchemas.js";
import {
assertVectorSearchFilterFieldsAreIndexed,
type VectorSearchIndex,
} from "../../../helpers/assertVectorSearchFilterFieldsAreIndexed.js";

export const AggregateArgs = {
pipeline: z.array(z.union([AnyStage, VectorSearchStage])).describe(
pipeline: z.array(z.union([AnyVectorSearchStage, VectorSearchStage])).describe(
`An array of aggregation stages to execute.
\`$vectorSearch\` **MUST** be the first stage of the pipeline, or the first stage of a \`$unionWith\` subpipeline.
### Usage Rules for \`$vectorSearch\`
Expand Down Expand Up @@ -97,6 +57,13 @@ export class AggregateTool extends MongoDBToolBase {
try {
const provider = await this.ensureConnected();
await this.assertOnlyUsesPermittedStages(pipeline);
if (await this.session.isSearchSupported()) {
assertVectorSearchFilterFieldsAreIndexed({
searchIndexes: (await provider.getSearchIndexes(database, collection)) as VectorSearchIndex[],
pipeline,
logger: this.session.logger,
});
}

// Check if aggregate operation uses an index if enabled
if (this.config.indexCheck) {
Expand Down
Loading
Loading