openclaw/src/memory/manager.mistral-provider.test.ts
nico-hoff 3eec79bd6c
feat(memory): add Ollama embedding provider (#26349)
Merged via /review-pr -> /prepare-pr -> /merge-pr.

Prepared head SHA: ac413865431614c352c3b29f2dfccc5593f0605a
Co-authored-by: nico-hoff <43175972+nico-hoff@users.noreply.github.com>
Co-authored-by: gumadeiras <5599352+gumadeiras@users.noreply.github.com>
Reviewed-by: @gumadeiras
2026-03-02 20:56:40 -05:00

197 lines
7.0 KiB
TypeScript

import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import type { OpenClawConfig } from "../config/config.js";
import { DEFAULT_OLLAMA_EMBEDDING_MODEL } from "./embeddings-ollama.js";
import type {
EmbeddingProvider,
EmbeddingProviderResult,
MistralEmbeddingClient,
OllamaEmbeddingClient,
OpenAiEmbeddingClient,
} from "./embeddings.js";
import { getMemorySearchManager, type MemoryIndexManager } from "./index.js";
const { createEmbeddingProviderMock } = vi.hoisted(() => ({
createEmbeddingProviderMock: vi.fn(),
}));
vi.mock("./embeddings.js", () => ({
createEmbeddingProvider: createEmbeddingProviderMock,
}));
vi.mock("./sqlite-vec.js", () => ({
loadSqliteVecExtension: async () => ({ ok: false, error: "sqlite-vec disabled in tests" }),
}));
function createProvider(id: string): EmbeddingProvider {
return {
id,
model: `${id}-model`,
embedQuery: async () => [0.1, 0.2, 0.3],
embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2, 0.3]),
};
}
function buildConfig(params: {
workspaceDir: string;
indexPath: string;
provider: "openai" | "mistral";
fallback?: "none" | "mistral" | "ollama";
}): OpenClawConfig {
return {
agents: {
defaults: {
workspace: params.workspaceDir,
memorySearch: {
provider: params.provider,
model: params.provider === "mistral" ? "mistral/mistral-embed" : "text-embedding-3-small",
fallback: params.fallback ?? "none",
store: { path: params.indexPath, vector: { enabled: false } },
sync: { watch: false, onSessionStart: false, onSearch: false },
query: { minScore: 0, hybrid: { enabled: false } },
},
},
list: [{ id: "main", default: true }],
},
} as OpenClawConfig;
}
describe("memory manager mistral provider wiring", () => {
let workspaceDir = "";
let indexPath = "";
let manager: MemoryIndexManager | null = null;
beforeEach(async () => {
createEmbeddingProviderMock.mockReset();
workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-memory-mistral-"));
indexPath = path.join(workspaceDir, "index.sqlite");
await fs.mkdir(path.join(workspaceDir, "memory"), { recursive: true });
await fs.writeFile(path.join(workspaceDir, "MEMORY.md"), "test");
});
afterEach(async () => {
if (manager) {
await manager.close();
manager = null;
}
if (workspaceDir) {
await fs.rm(workspaceDir, { recursive: true, force: true });
workspaceDir = "";
indexPath = "";
}
});
it("stores mistral client when mistral provider is selected", async () => {
const mistralClient: MistralEmbeddingClient = {
baseUrl: "https://api.mistral.ai/v1",
headers: { authorization: "Bearer test-key" },
model: "mistral-embed",
};
const providerResult: EmbeddingProviderResult = {
requestedProvider: "mistral",
provider: createProvider("mistral"),
mistral: mistralClient,
};
createEmbeddingProviderMock.mockResolvedValueOnce(providerResult);
const cfg = buildConfig({ workspaceDir, indexPath, provider: "mistral" });
const result = await getMemorySearchManager({ cfg, agentId: "main" });
if (!result.manager) {
throw new Error(`manager missing: ${result.error ?? "no error provided"}`);
}
manager = result.manager as unknown as MemoryIndexManager;
const internal = manager as unknown as { mistral?: MistralEmbeddingClient };
expect(internal.mistral).toBe(mistralClient);
});
it("stores mistral client after fallback activation", async () => {
const openAiClient: OpenAiEmbeddingClient = {
baseUrl: "https://api.openai.com/v1",
headers: { authorization: "Bearer openai-key" },
model: "text-embedding-3-small",
};
const mistralClient: MistralEmbeddingClient = {
baseUrl: "https://api.mistral.ai/v1",
headers: { authorization: "Bearer mistral-key" },
model: "mistral-embed",
};
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "openai",
provider: createProvider("openai"),
openAi: openAiClient,
} as EmbeddingProviderResult);
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "mistral",
provider: createProvider("mistral"),
mistral: mistralClient,
} as EmbeddingProviderResult);
const cfg = buildConfig({ workspaceDir, indexPath, provider: "openai", fallback: "mistral" });
const result = await getMemorySearchManager({ cfg, agentId: "main" });
if (!result.manager) {
throw new Error(`manager missing: ${result.error ?? "no error provided"}`);
}
manager = result.manager as unknown as MemoryIndexManager;
const internal = manager as unknown as {
activateFallbackProvider: (reason: string) => Promise<boolean>;
openAi?: OpenAiEmbeddingClient;
mistral?: MistralEmbeddingClient;
};
const activated = await internal.activateFallbackProvider("forced test");
expect(activated).toBe(true);
expect(internal.openAi).toBeUndefined();
expect(internal.mistral).toBe(mistralClient);
});
it("uses default ollama model when activating ollama fallback", async () => {
const openAiClient: OpenAiEmbeddingClient = {
baseUrl: "https://api.openai.com/v1",
headers: { authorization: "Bearer openai-key" },
model: "text-embedding-3-small",
};
const ollamaClient: OllamaEmbeddingClient = {
baseUrl: "http://127.0.0.1:11434",
headers: {},
model: DEFAULT_OLLAMA_EMBEDDING_MODEL,
embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2, 0.3]),
};
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "openai",
provider: createProvider("openai"),
openAi: openAiClient,
} as EmbeddingProviderResult);
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "ollama",
provider: createProvider("ollama"),
ollama: ollamaClient,
} as EmbeddingProviderResult);
const cfg = buildConfig({ workspaceDir, indexPath, provider: "openai", fallback: "ollama" });
const result = await getMemorySearchManager({ cfg, agentId: "main" });
if (!result.manager) {
throw new Error(`manager missing: ${result.error ?? "no error provided"}`);
}
manager = result.manager as unknown as MemoryIndexManager;
const internal = manager as unknown as {
activateFallbackProvider: (reason: string) => Promise<boolean>;
openAi?: OpenAiEmbeddingClient;
ollama?: OllamaEmbeddingClient;
};
const activated = await internal.activateFallbackProvider("forced ollama fallback");
expect(activated).toBe(true);
expect(internal.openAi).toBeUndefined();
expect(internal.ollama).toBe(ollamaClient);
const fallbackCall = createEmbeddingProviderMock.mock.calls[1]?.[0] as
| { provider?: string; model?: string }
| undefined;
expect(fallbackCall?.provider).toBe("ollama");
expect(fallbackCall?.model).toBe(DEFAULT_OLLAMA_EMBEDDING_MODEL);
});
});