diff --git a/src/memory/embeddings-gemini.ts b/src/memory/embeddings-gemini.ts index ab028241ed8..5786ead73fd 100644 --- a/src/memory/embeddings-gemini.ts +++ b/src/memory/embeddings-gemini.ts @@ -5,8 +5,6 @@ import { import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; import { parseGeminiAuth } from "../infra/gemini-auth.js"; import type { SsrFPolicy } from "../infra/net/ssrf.js"; -import type { EmbeddingInput } from "./embedding-inputs.js"; -import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; import { debugEmbeddingsLog } from "./embeddings-debug.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js"; @@ -19,7 +17,6 @@ export type GeminiEmbeddingClient = { model: string; modelPath: string; apiKeys: string[]; - outputDimensionality?: number; }; const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"; @@ -27,111 +24,6 @@ export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001"; const GEMINI_MAX_INPUT_TOKENS: Record = { "text-embedding-004": 2048, }; - -// --- gemini-embedding-2-preview support --- - -export const GEMINI_EMBEDDING_2_MODELS = new Set([ - "gemini-embedding-2-preview", - // Add the GA model name here once released. -]); - -const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072; -const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const; - -export type GeminiTaskType = - | "RETRIEVAL_QUERY" - | "RETRIEVAL_DOCUMENT" - | "SEMANTIC_SIMILARITY" - | "CLASSIFICATION" - | "CLUSTERING" - | "QUESTION_ANSWERING" - | "FACT_VERIFICATION"; - -export type GeminiTextPart = { text: string }; -export type GeminiInlinePart = { - inlineData: { mimeType: string; data: string }; -}; -export type GeminiPart = GeminiTextPart | GeminiInlinePart; -export type GeminiEmbeddingRequest = { - content: { parts: GeminiPart[] }; - taskType: GeminiTaskType; - outputDimensionality?: number; - model?: string; -}; -export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest; - -/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */ -export function buildGeminiTextEmbeddingRequest(params: { - text: string; - taskType: GeminiTaskType; - outputDimensionality?: number; - modelPath?: string; -}): GeminiTextEmbeddingRequest { - return buildGeminiEmbeddingRequest({ - input: { text: params.text }, - taskType: params.taskType, - outputDimensionality: params.outputDimensionality, - modelPath: params.modelPath, - }); -} - -export function buildGeminiEmbeddingRequest(params: { - input: EmbeddingInput; - taskType: GeminiTaskType; - outputDimensionality?: number; - modelPath?: string; -}): GeminiEmbeddingRequest { - const request: GeminiEmbeddingRequest = { - content: { - parts: params.input.parts?.map((part) => - part.type === "text" - ? ({ text: part.text } satisfies GeminiTextPart) - : ({ - inlineData: { mimeType: part.mimeType, data: part.data }, - } satisfies GeminiInlinePart), - ) ?? [{ text: params.input.text }], - }, - taskType: params.taskType, - }; - if (params.modelPath) { - request.model = params.modelPath; - } - if (params.outputDimensionality != null) { - request.outputDimensionality = params.outputDimensionality; - } - return request; -} - -/** - * Returns true if the given model name is a gemini-embedding-2 variant that - * supports `outputDimensionality` and extended task types. - */ -export function isGeminiEmbedding2Model(model: string): boolean { - return GEMINI_EMBEDDING_2_MODELS.has(model); -} - -/** - * Validate and return the `outputDimensionality` for gemini-embedding-2 models. - * Returns `undefined` for older models (they don't support the param). - */ -export function resolveGeminiOutputDimensionality( - model: string, - requested?: number, -): number | undefined { - if (!isGeminiEmbedding2Model(model)) { - return undefined; - } - if (requested == null) { - return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS; - } - const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS; - if (!valid.includes(requested)) { - throw new Error( - `Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`, - ); - } - return requested; -} function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined { const trimmed = resolveMemorySecretInputString({ value: remoteApiKey, @@ -146,7 +38,7 @@ function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined { return trimmed; } -export function normalizeGeminiModel(model: string): string { +function normalizeGeminiModel(model: string): string { const trimmed = model.trim(); if (!trimmed) { return DEFAULT_GEMINI_EMBEDDING_MODEL; @@ -161,46 +53,6 @@ export function normalizeGeminiModel(model: string): string { return withoutPrefix; } -async function fetchGeminiEmbeddingPayload(params: { - client: GeminiEmbeddingClient; - endpoint: string; - body: unknown; -}): Promise<{ - embedding?: { values?: number[] }; - embeddings?: Array<{ values?: number[] }>; -}> { - return await executeWithApiKeyRotation({ - provider: "google", - apiKeys: params.client.apiKeys, - execute: async (apiKey) => { - const authHeaders = parseGeminiAuth(apiKey); - const headers = { - ...authHeaders.headers, - ...params.client.headers, - }; - return await withRemoteHttpResponse({ - url: params.endpoint, - ssrfPolicy: params.client.ssrfPolicy, - init: { - method: "POST", - headers, - body: JSON.stringify(params.body), - }, - onResponse: async (res) => { - if (!res.ok) { - const text = await res.text(); - throw new Error(`gemini embeddings failed: ${res.status} ${text}`); - } - return (await res.json()) as { - embedding?: { values?: number[] }; - embeddings?: Array<{ values?: number[] }>; - }; - }, - }); - }, - }); -} - function normalizeGeminiBaseUrl(raw: string): string { const trimmed = raw.replace(/\/+$/, ""); const openAiIndex = trimmed.indexOf("/openai"); @@ -221,53 +73,76 @@ export async function createGeminiEmbeddingProvider( const baseUrl = client.baseUrl.replace(/\/$/, ""); const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`; const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`; - const isV2 = isGeminiEmbedding2Model(client.model); - const outputDimensionality = client.outputDimensionality; + + const fetchWithGeminiAuth = async (apiKey: string, endpoint: string, body: unknown) => { + const authHeaders = parseGeminiAuth(apiKey); + const headers = { + ...authHeaders.headers, + ...client.headers, + }; + const payload = await withRemoteHttpResponse({ + url: endpoint, + ssrfPolicy: client.ssrfPolicy, + init: { + method: "POST", + headers, + body: JSON.stringify(body), + }, + onResponse: async (res) => { + if (!res.ok) { + const text = await res.text(); + throw new Error(`gemini embeddings failed: ${res.status} ${text}`); + } + return (await res.json()) as { + embedding?: { values?: number[] }; + embeddings?: Array<{ values?: number[] }>; + }; + }, + }); + return payload; + }; const embedQuery = async (text: string): Promise => { if (!text.trim()) { return []; } - const payload = await fetchGeminiEmbeddingPayload({ - client, - endpoint: embedUrl, - body: buildGeminiTextEmbeddingRequest({ - text, - taskType: options.taskType ?? "RETRIEVAL_QUERY", - outputDimensionality: isV2 ? outputDimensionality : undefined, - }), + const payload = await executeWithApiKeyRotation({ + provider: "google", + apiKeys: client.apiKeys, + execute: (apiKey) => + fetchWithGeminiAuth(apiKey, embedUrl, { + content: { parts: [{ text }] }, + taskType: "RETRIEVAL_QUERY", + ...(typeof options.outputDimensionality === "number" + ? { outputDimensionality: options.outputDimensionality } + : {}), + }), }); - return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []); - }; - - const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise => { - if (inputs.length === 0) { - return []; - } - const payload = await fetchGeminiEmbeddingPayload({ - client, - endpoint: batchUrl, - body: { - requests: inputs.map((input) => - buildGeminiEmbeddingRequest({ - input, - modelPath: client.modelPath, - taskType: options.taskType ?? "RETRIEVAL_DOCUMENT", - outputDimensionality: isV2 ? outputDimensionality : undefined, - }), - ), - }, - }); - const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : []; - return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? [])); + return payload.embedding?.values ?? []; }; const embedBatch = async (texts: string[]): Promise => { - return await embedBatchInputs( - texts.map((text) => ({ - text, - })), - ); + if (texts.length === 0) { + return []; + } + const requests = texts.map((text) => ({ + model: client.modelPath, + content: { parts: [{ text }] }, + taskType: "RETRIEVAL_DOCUMENT", + ...(typeof options.outputDimensionality === "number" + ? { outputDimensionality: options.outputDimensionality } + : {}), + })); + const payload = await executeWithApiKeyRotation({ + provider: "google", + apiKeys: client.apiKeys, + execute: (apiKey) => + fetchWithGeminiAuth(apiKey, batchUrl, { + requests, + }), + }); + const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : []; + return texts.map((_, index) => embeddings[index]?.values ?? []); }; return { @@ -277,7 +152,6 @@ export async function createGeminiEmbeddingProvider( maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model], embedQuery, embedBatch, - embedBatchInputs, }, client, }; @@ -315,18 +189,13 @@ export async function resolveGeminiEmbeddingClient( }); const model = normalizeGeminiModel(options.model); const modelPath = buildGeminiModelPath(model); - const outputDimensionality = resolveGeminiOutputDimensionality( - model, - options.outputDimensionality, - ); debugEmbeddingsLog("memory embeddings: gemini client", { rawBaseUrl, baseUrl, model, modelPath, - outputDimensionality, embedEndpoint: `${baseUrl}/${modelPath}:embedContent`, batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`, }); - return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality }; + return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys }; }