fix: restore memory search output dimensionality
This commit is contained in:
parent
c2421ec120
commit
9fe29139a0
@ -5,8 +5,6 @@ import {
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import { parseGeminiAuth } from "../infra/gemini-auth.js";
|
||||
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
@ -19,7 +17,6 @@ export type GeminiEmbeddingClient = {
|
||||
model: string;
|
||||
modelPath: string;
|
||||
apiKeys: string[];
|
||||
outputDimensionality?: number;
|
||||
};
|
||||
|
||||
const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
|
||||
@ -27,111 +24,6 @@ export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
|
||||
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-004": 2048,
|
||||
};
|
||||
|
||||
// --- gemini-embedding-2-preview support ---
|
||||
|
||||
export const GEMINI_EMBEDDING_2_MODELS = new Set([
|
||||
"gemini-embedding-2-preview",
|
||||
// Add the GA model name here once released.
|
||||
]);
|
||||
|
||||
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
|
||||
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
|
||||
|
||||
export type GeminiTaskType =
|
||||
| "RETRIEVAL_QUERY"
|
||||
| "RETRIEVAL_DOCUMENT"
|
||||
| "SEMANTIC_SIMILARITY"
|
||||
| "CLASSIFICATION"
|
||||
| "CLUSTERING"
|
||||
| "QUESTION_ANSWERING"
|
||||
| "FACT_VERIFICATION";
|
||||
|
||||
export type GeminiTextPart = { text: string };
|
||||
export type GeminiInlinePart = {
|
||||
inlineData: { mimeType: string; data: string };
|
||||
};
|
||||
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
|
||||
export type GeminiEmbeddingRequest = {
|
||||
content: { parts: GeminiPart[] };
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
model?: string;
|
||||
};
|
||||
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
|
||||
|
||||
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
|
||||
export function buildGeminiTextEmbeddingRequest(params: {
|
||||
text: string;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiTextEmbeddingRequest {
|
||||
return buildGeminiEmbeddingRequest({
|
||||
input: { text: params.text },
|
||||
taskType: params.taskType,
|
||||
outputDimensionality: params.outputDimensionality,
|
||||
modelPath: params.modelPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function buildGeminiEmbeddingRequest(params: {
|
||||
input: EmbeddingInput;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiEmbeddingRequest {
|
||||
const request: GeminiEmbeddingRequest = {
|
||||
content: {
|
||||
parts: params.input.parts?.map((part) =>
|
||||
part.type === "text"
|
||||
? ({ text: part.text } satisfies GeminiTextPart)
|
||||
: ({
|
||||
inlineData: { mimeType: part.mimeType, data: part.data },
|
||||
} satisfies GeminiInlinePart),
|
||||
) ?? [{ text: params.input.text }],
|
||||
},
|
||||
taskType: params.taskType,
|
||||
};
|
||||
if (params.modelPath) {
|
||||
request.model = params.modelPath;
|
||||
}
|
||||
if (params.outputDimensionality != null) {
|
||||
request.outputDimensionality = params.outputDimensionality;
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given model name is a gemini-embedding-2 variant that
|
||||
* supports `outputDimensionality` and extended task types.
|
||||
*/
|
||||
export function isGeminiEmbedding2Model(model: string): boolean {
|
||||
return GEMINI_EMBEDDING_2_MODELS.has(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
|
||||
* Returns `undefined` for older models (they don't support the param).
|
||||
*/
|
||||
export function resolveGeminiOutputDimensionality(
|
||||
model: string,
|
||||
requested?: number,
|
||||
): number | undefined {
|
||||
if (!isGeminiEmbedding2Model(model)) {
|
||||
return undefined;
|
||||
}
|
||||
if (requested == null) {
|
||||
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
|
||||
}
|
||||
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
|
||||
if (!valid.includes(requested)) {
|
||||
throw new Error(
|
||||
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
|
||||
);
|
||||
}
|
||||
return requested;
|
||||
}
|
||||
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
const trimmed = resolveMemorySecretInputString({
|
||||
value: remoteApiKey,
|
||||
@ -146,7 +38,7 @@ function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
export function normalizeGeminiModel(model: string): string {
|
||||
function normalizeGeminiModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
@ -161,46 +53,6 @@ export function normalizeGeminiModel(model: string): string {
|
||||
return withoutPrefix;
|
||||
}
|
||||
|
||||
async function fetchGeminiEmbeddingPayload(params: {
|
||||
client: GeminiEmbeddingClient;
|
||||
endpoint: string;
|
||||
body: unknown;
|
||||
}): Promise<{
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
}> {
|
||||
return await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
apiKeys: params.client.apiKeys,
|
||||
execute: async (apiKey) => {
|
||||
const authHeaders = parseGeminiAuth(apiKey);
|
||||
const headers = {
|
||||
...authHeaders.headers,
|
||||
...params.client.headers,
|
||||
};
|
||||
return await withRemoteHttpResponse({
|
||||
url: params.endpoint,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(params.body),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
|
||||
}
|
||||
return (await res.json()) as {
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
};
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function normalizeGeminiBaseUrl(raw: string): string {
|
||||
const trimmed = raw.replace(/\/+$/, "");
|
||||
const openAiIndex = trimmed.indexOf("/openai");
|
||||
@ -221,53 +73,76 @@ export async function createGeminiEmbeddingProvider(
|
||||
const baseUrl = client.baseUrl.replace(/\/$/, "");
|
||||
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
|
||||
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
|
||||
const isV2 = isGeminiEmbedding2Model(client.model);
|
||||
const outputDimensionality = client.outputDimensionality;
|
||||
|
||||
const fetchWithGeminiAuth = async (apiKey: string, endpoint: string, body: unknown) => {
|
||||
const authHeaders = parseGeminiAuth(apiKey);
|
||||
const headers = {
|
||||
...authHeaders.headers,
|
||||
...client.headers,
|
||||
};
|
||||
const payload = await withRemoteHttpResponse({
|
||||
url: endpoint,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
|
||||
}
|
||||
return (await res.json()) as {
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
};
|
||||
},
|
||||
});
|
||||
return payload;
|
||||
};
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: embedUrl,
|
||||
body: buildGeminiTextEmbeddingRequest({
|
||||
text,
|
||||
taskType: options.taskType ?? "RETRIEVAL_QUERY",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
const payload = await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
apiKeys: client.apiKeys,
|
||||
execute: (apiKey) =>
|
||||
fetchWithGeminiAuth(apiKey, embedUrl, {
|
||||
content: { parts: [{ text }] },
|
||||
taskType: "RETRIEVAL_QUERY",
|
||||
...(typeof options.outputDimensionality === "number"
|
||||
? { outputDimensionality: options.outputDimensionality }
|
||||
: {}),
|
||||
}),
|
||||
});
|
||||
return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []);
|
||||
};
|
||||
|
||||
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
|
||||
if (inputs.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: batchUrl,
|
||||
body: {
|
||||
requests: inputs.map((input) =>
|
||||
buildGeminiEmbeddingRequest({
|
||||
input,
|
||||
modelPath: client.modelPath,
|
||||
taskType: options.taskType ?? "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
),
|
||||
},
|
||||
});
|
||||
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
|
||||
return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? []));
|
||||
return payload.embedding?.values ?? [];
|
||||
};
|
||||
|
||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
||||
return await embedBatchInputs(
|
||||
texts.map((text) => ({
|
||||
text,
|
||||
})),
|
||||
);
|
||||
if (texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const requests = texts.map((text) => ({
|
||||
model: client.modelPath,
|
||||
content: { parts: [{ text }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
...(typeof options.outputDimensionality === "number"
|
||||
? { outputDimensionality: options.outputDimensionality }
|
||||
: {}),
|
||||
}));
|
||||
const payload = await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
apiKeys: client.apiKeys,
|
||||
execute: (apiKey) =>
|
||||
fetchWithGeminiAuth(apiKey, batchUrl, {
|
||||
requests,
|
||||
}),
|
||||
});
|
||||
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
|
||||
return texts.map((_, index) => embeddings[index]?.values ?? []);
|
||||
};
|
||||
|
||||
return {
|
||||
@ -277,7 +152,6 @@ export async function createGeminiEmbeddingProvider(
|
||||
maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery,
|
||||
embedBatch,
|
||||
embedBatchInputs,
|
||||
},
|
||||
client,
|
||||
};
|
||||
@ -315,18 +189,13 @@ export async function resolveGeminiEmbeddingClient(
|
||||
});
|
||||
const model = normalizeGeminiModel(options.model);
|
||||
const modelPath = buildGeminiModelPath(model);
|
||||
const outputDimensionality = resolveGeminiOutputDimensionality(
|
||||
model,
|
||||
options.outputDimensionality,
|
||||
);
|
||||
debugEmbeddingsLog("memory embeddings: gemini client", {
|
||||
rawBaseUrl,
|
||||
baseUrl,
|
||||
model,
|
||||
modelPath,
|
||||
outputDimensionality,
|
||||
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
|
||||
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
|
||||
});
|
||||
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality };
|
||||
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys };
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user