fix: restore memory search output dimensionality

This commit is contained in:
Marc J Saint-jour 2026-03-12 20:04:16 -04:00
parent c2421ec120
commit 9fe29139a0

View File

@ -5,8 +5,6 @@ import {
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
import { parseGeminiAuth } from "../infra/gemini-auth.js";
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import type { EmbeddingInput } from "./embedding-inputs.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
@ -19,7 +17,6 @@ export type GeminiEmbeddingClient = {
model: string;
modelPath: string;
apiKeys: string[];
outputDimensionality?: number;
};
const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
@ -27,111 +24,6 @@ export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
"text-embedding-004": 2048,
};
// --- gemini-embedding-2-preview support ---
export const GEMINI_EMBEDDING_2_MODELS = new Set([
"gemini-embedding-2-preview",
// Add the GA model name here once released.
]);
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
export type GeminiTaskType =
| "RETRIEVAL_QUERY"
| "RETRIEVAL_DOCUMENT"
| "SEMANTIC_SIMILARITY"
| "CLASSIFICATION"
| "CLUSTERING"
| "QUESTION_ANSWERING"
| "FACT_VERIFICATION";
export type GeminiTextPart = { text: string };
export type GeminiInlinePart = {
inlineData: { mimeType: string; data: string };
};
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
export type GeminiEmbeddingRequest = {
content: { parts: GeminiPart[] };
taskType: GeminiTaskType;
outputDimensionality?: number;
model?: string;
};
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
export function buildGeminiTextEmbeddingRequest(params: {
text: string;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiTextEmbeddingRequest {
return buildGeminiEmbeddingRequest({
input: { text: params.text },
taskType: params.taskType,
outputDimensionality: params.outputDimensionality,
modelPath: params.modelPath,
});
}
export function buildGeminiEmbeddingRequest(params: {
input: EmbeddingInput;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiEmbeddingRequest {
const request: GeminiEmbeddingRequest = {
content: {
parts: params.input.parts?.map((part) =>
part.type === "text"
? ({ text: part.text } satisfies GeminiTextPart)
: ({
inlineData: { mimeType: part.mimeType, data: part.data },
} satisfies GeminiInlinePart),
) ?? [{ text: params.input.text }],
},
taskType: params.taskType,
};
if (params.modelPath) {
request.model = params.modelPath;
}
if (params.outputDimensionality != null) {
request.outputDimensionality = params.outputDimensionality;
}
return request;
}
/**
* Returns true if the given model name is a gemini-embedding-2 variant that
* supports `outputDimensionality` and extended task types.
*/
export function isGeminiEmbedding2Model(model: string): boolean {
return GEMINI_EMBEDDING_2_MODELS.has(model);
}
/**
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
* Returns `undefined` for older models (they don't support the param).
*/
export function resolveGeminiOutputDimensionality(
model: string,
requested?: number,
): number | undefined {
if (!isGeminiEmbedding2Model(model)) {
return undefined;
}
if (requested == null) {
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
}
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
if (!valid.includes(requested)) {
throw new Error(
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
);
}
return requested;
}
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
const trimmed = resolveMemorySecretInputString({
value: remoteApiKey,
@ -146,7 +38,7 @@ function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
return trimmed;
}
export function normalizeGeminiModel(model: string): string {
function normalizeGeminiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) {
return DEFAULT_GEMINI_EMBEDDING_MODEL;
@ -161,46 +53,6 @@ export function normalizeGeminiModel(model: string): string {
return withoutPrefix;
}
async function fetchGeminiEmbeddingPayload(params: {
client: GeminiEmbeddingClient;
endpoint: string;
body: unknown;
}): Promise<{
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
}> {
return await executeWithApiKeyRotation({
provider: "google",
apiKeys: params.client.apiKeys,
execute: async (apiKey) => {
const authHeaders = parseGeminiAuth(apiKey);
const headers = {
...authHeaders.headers,
...params.client.headers,
};
return await withRemoteHttpResponse({
url: params.endpoint,
ssrfPolicy: params.client.ssrfPolicy,
init: {
method: "POST",
headers,
body: JSON.stringify(params.body),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
return (await res.json()) as {
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
};
},
});
},
});
}
function normalizeGeminiBaseUrl(raw: string): string {
const trimmed = raw.replace(/\/+$/, "");
const openAiIndex = trimmed.indexOf("/openai");
@ -221,53 +73,76 @@ export async function createGeminiEmbeddingProvider(
const baseUrl = client.baseUrl.replace(/\/$/, "");
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
const isV2 = isGeminiEmbedding2Model(client.model);
const outputDimensionality = client.outputDimensionality;
const fetchWithGeminiAuth = async (apiKey: string, endpoint: string, body: unknown) => {
const authHeaders = parseGeminiAuth(apiKey);
const headers = {
...authHeaders.headers,
...client.headers,
};
const payload = await withRemoteHttpResponse({
url: endpoint,
ssrfPolicy: client.ssrfPolicy,
init: {
method: "POST",
headers,
body: JSON.stringify(body),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
return (await res.json()) as {
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
};
},
});
return payload;
};
const embedQuery = async (text: string): Promise<number[]> => {
if (!text.trim()) {
return [];
}
const payload = await fetchGeminiEmbeddingPayload({
client,
endpoint: embedUrl,
body: buildGeminiTextEmbeddingRequest({
text,
taskType: options.taskType ?? "RETRIEVAL_QUERY",
outputDimensionality: isV2 ? outputDimensionality : undefined,
}),
const payload = await executeWithApiKeyRotation({
provider: "google",
apiKeys: client.apiKeys,
execute: (apiKey) =>
fetchWithGeminiAuth(apiKey, embedUrl, {
content: { parts: [{ text }] },
taskType: "RETRIEVAL_QUERY",
...(typeof options.outputDimensionality === "number"
? { outputDimensionality: options.outputDimensionality }
: {}),
}),
});
return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []);
};
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
if (inputs.length === 0) {
return [];
}
const payload = await fetchGeminiEmbeddingPayload({
client,
endpoint: batchUrl,
body: {
requests: inputs.map((input) =>
buildGeminiEmbeddingRequest({
input,
modelPath: client.modelPath,
taskType: options.taskType ?? "RETRIEVAL_DOCUMENT",
outputDimensionality: isV2 ? outputDimensionality : undefined,
}),
),
},
});
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? []));
return payload.embedding?.values ?? [];
};
const embedBatch = async (texts: string[]): Promise<number[][]> => {
return await embedBatchInputs(
texts.map((text) => ({
text,
})),
);
if (texts.length === 0) {
return [];
}
const requests = texts.map((text) => ({
model: client.modelPath,
content: { parts: [{ text }] },
taskType: "RETRIEVAL_DOCUMENT",
...(typeof options.outputDimensionality === "number"
? { outputDimensionality: options.outputDimensionality }
: {}),
}));
const payload = await executeWithApiKeyRotation({
provider: "google",
apiKeys: client.apiKeys,
execute: (apiKey) =>
fetchWithGeminiAuth(apiKey, batchUrl, {
requests,
}),
});
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
return texts.map((_, index) => embeddings[index]?.values ?? []);
};
return {
@ -277,7 +152,6 @@ export async function createGeminiEmbeddingProvider(
maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model],
embedQuery,
embedBatch,
embedBatchInputs,
},
client,
};
@ -315,18 +189,13 @@ export async function resolveGeminiEmbeddingClient(
});
const model = normalizeGeminiModel(options.model);
const modelPath = buildGeminiModelPath(model);
const outputDimensionality = resolveGeminiOutputDimensionality(
model,
options.outputDimensionality,
);
debugEmbeddingsLog("memory embeddings: gemini client", {
rawBaseUrl,
baseUrl,
model,
modelPath,
outputDimensionality,
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
});
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality };
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys };
}