openclaw/src/memory/manager.mistral-provider.test.ts
2026-03-18 15:36:32 +00:00

208 lines
7.4 KiB
TypeScript

import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import type { OpenClawConfig } from "../config/config.js";
import { DEFAULT_OLLAMA_EMBEDDING_MODEL } from "./embeddings-ollama.js";
import type {
EmbeddingProvider,
EmbeddingProviderResult,
MistralEmbeddingClient,
OllamaEmbeddingClient,
OpenAiEmbeddingClient,
} from "./embeddings.js";
import type { MemoryIndexManager } from "./index.js";
const { createEmbeddingProviderMock } = vi.hoisted(() => ({
createEmbeddingProviderMock: vi.fn(),
}));
vi.mock("./embeddings.js", () => ({
createEmbeddingProvider: createEmbeddingProviderMock,
}));
vi.mock("./sqlite-vec.js", () => ({
loadSqliteVecExtension: async () => ({ ok: false, error: "sqlite-vec disabled in tests" }),
}));
type MemoryIndexModule = typeof import("./index.js");
let getMemorySearchManager: MemoryIndexModule["getMemorySearchManager"];
let closeAllMemorySearchManagers: MemoryIndexModule["closeAllMemorySearchManagers"];
function createProvider(id: string): EmbeddingProvider {
return {
id,
model: `${id}-model`,
embedQuery: async () => [0.1, 0.2, 0.3],
embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2, 0.3]),
};
}
function buildConfig(params: {
workspaceDir: string;
indexPath: string;
provider: "openai" | "mistral";
fallback?: "none" | "mistral" | "ollama";
}): OpenClawConfig {
return {
agents: {
defaults: {
workspace: params.workspaceDir,
memorySearch: {
provider: params.provider,
model: params.provider === "mistral" ? "mistral/mistral-embed" : "text-embedding-3-small",
fallback: params.fallback ?? "none",
store: { path: params.indexPath, vector: { enabled: false } },
sync: { watch: false, onSessionStart: false, onSearch: false },
query: { minScore: 0, hybrid: { enabled: false } },
},
},
list: [{ id: "main", default: true }],
},
} as OpenClawConfig;
}
describe("memory manager mistral provider wiring", () => {
let workspaceDir = "";
let indexPath = "";
let manager: MemoryIndexManager | null = null;
beforeAll(async () => {
({ getMemorySearchManager, closeAllMemorySearchManagers } = await import("./index.js"));
});
beforeEach(async () => {
vi.clearAllMocks();
createEmbeddingProviderMock.mockReset();
workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-memory-mistral-"));
indexPath = path.join(workspaceDir, "index.sqlite");
await fs.mkdir(path.join(workspaceDir, "memory"), { recursive: true });
await fs.writeFile(path.join(workspaceDir, "MEMORY.md"), "test");
});
afterEach(async () => {
if (manager) {
await manager.close();
manager = null;
}
await closeAllMemorySearchManagers();
if (workspaceDir) {
await fs.rm(workspaceDir, { recursive: true, force: true });
workspaceDir = "";
indexPath = "";
}
});
it("stores mistral client when mistral provider is selected", async () => {
const mistralClient: MistralEmbeddingClient = {
baseUrl: "https://api.mistral.ai/v1",
headers: { authorization: "Bearer test-key" },
model: "mistral-embed",
};
const providerResult: EmbeddingProviderResult = {
requestedProvider: "mistral",
provider: createProvider("mistral"),
mistral: mistralClient,
};
createEmbeddingProviderMock.mockResolvedValueOnce(providerResult);
const cfg = buildConfig({ workspaceDir, indexPath, provider: "mistral" });
const result = await getMemorySearchManager({ cfg, agentId: "main" });
if (!result.manager) {
throw new Error(`manager missing: ${result.error ?? "no error provided"}`);
}
manager = result.manager as unknown as MemoryIndexManager;
const internal = manager as unknown as { mistral?: MistralEmbeddingClient };
expect(internal.mistral).toBe(mistralClient);
});
it("stores mistral client after fallback activation", async () => {
const openAiClient: OpenAiEmbeddingClient = {
baseUrl: "https://api.openai.com/v1",
headers: { authorization: "Bearer openai-key" },
model: "text-embedding-3-small",
};
const mistralClient: MistralEmbeddingClient = {
baseUrl: "https://api.mistral.ai/v1",
headers: { authorization: "Bearer mistral-key" },
model: "mistral-embed",
};
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "openai",
provider: createProvider("openai"),
openAi: openAiClient,
} as EmbeddingProviderResult);
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "mistral",
provider: createProvider("mistral"),
mistral: mistralClient,
} as EmbeddingProviderResult);
const cfg = buildConfig({ workspaceDir, indexPath, provider: "openai", fallback: "mistral" });
const result = await getMemorySearchManager({ cfg, agentId: "main" });
if (!result.manager) {
throw new Error(`manager missing: ${result.error ?? "no error provided"}`);
}
manager = result.manager as unknown as MemoryIndexManager;
const internal = manager as unknown as {
activateFallbackProvider: (reason: string) => Promise<boolean>;
openAi?: OpenAiEmbeddingClient;
mistral?: MistralEmbeddingClient;
};
const activated = await internal.activateFallbackProvider("forced test");
expect(activated).toBe(true);
expect(internal.openAi).toBeUndefined();
expect(internal.mistral).toBe(mistralClient);
});
it("uses default ollama model when activating ollama fallback", async () => {
const openAiClient: OpenAiEmbeddingClient = {
baseUrl: "https://api.openai.com/v1",
headers: { authorization: "Bearer openai-key" },
model: "text-embedding-3-small",
};
const ollamaClient: OllamaEmbeddingClient = {
baseUrl: "http://127.0.0.1:11434",
headers: {},
model: DEFAULT_OLLAMA_EMBEDDING_MODEL,
embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2, 0.3]),
};
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "openai",
provider: createProvider("openai"),
openAi: openAiClient,
} as EmbeddingProviderResult);
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "ollama",
provider: createProvider("ollama"),
ollama: ollamaClient,
} as EmbeddingProviderResult);
const cfg = buildConfig({ workspaceDir, indexPath, provider: "openai", fallback: "ollama" });
const result = await getMemorySearchManager({ cfg, agentId: "main" });
if (!result.manager) {
throw new Error(`manager missing: ${result.error ?? "no error provided"}`);
}
manager = result.manager as unknown as MemoryIndexManager;
const internal = manager as unknown as {
activateFallbackProvider: (reason: string) => Promise<boolean>;
openAi?: OpenAiEmbeddingClient;
ollama?: OllamaEmbeddingClient;
};
const activated = await internal.activateFallbackProvider("forced ollama fallback");
expect(activated).toBe(true);
expect(internal.openAi).toBeUndefined();
expect(internal.ollama).toBe(ollamaClient);
const fallbackCall = createEmbeddingProviderMock.mock.calls[1]?.[0] as
| { provider?: string; model?: string }
| undefined;
expect(fallbackCall?.provider).toBe("ollama");
expect(fallbackCall?.model).toBe(DEFAULT_OLLAMA_EMBEDDING_MODEL);
});
});