From dc8253a84d9595972b3dfc4f559fbec3c8978ad9 Mon Sep 17 00:00:00 2001 From: huangcj <43933609+SubtleSpark@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:09:03 +0800 Subject: [PATCH] fix(memory): serialize local embedding initialization to avoid duplicate model loads (#15639) Merged via squash. Prepared head SHA: a085fc21a8ba7163fffdb5de640dd4dc1ff5a88e Co-authored-by: SubtleSpark <43933609+SubtleSpark@users.noreply.github.com> Co-authored-by: gumadeiras <5599352+gumadeiras@users.noreply.github.com> Reviewed-by: @gumadeiras --- CHANGELOG.md | 1 + src/memory/embeddings.test.ts | 181 ++++++++++++++++++++++++++++++++++ src/memory/embeddings.ts | 35 +++++-- 3 files changed, 207 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fb849832b4..8dd5008de32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,6 +82,7 @@ Docs: https://docs.openclaw.ai - Agents/Compaction continuity: expand staged-summary merge instructions to preserve active task status, batch progress, latest user request, and follow-up commitments so compaction handoffs retain in-flight work context. (#8903) thanks @joetomasone. - Gateway/status self version reporting: make Gateway self version in `openclaw status` prefer runtime `VERSION` (while preserving explicit `OPENCLAW_VERSION` override), preventing stale post-upgrade app version output. (#32655) thanks @liuxiaopai-ai. - Memory/QMD index isolation: set `QMD_CONFIG_DIR` alongside `XDG_CONFIG_HOME` so QMD config state stays per-agent despite upstream XDG handling bugs, preventing cross-agent collection indexing and excess disk/CPU usage. (#27028) thanks @HenryLoenwind. +- Memory/local embedding initialization hardening: add regression coverage for transient initialization retry and mixed `embedQuery` + `embedBatch` concurrent startup to lock single-flight initialization behavior. (#15639) thanks @SubtleSpark. - CLI/Coding-agent reliability: switch default `claude-cli` non-interactive args to `--permission-mode bypassPermissions`, auto-normalize legacy `--dangerously-skip-permissions` backend overrides to the modern permission-mode form, align coding-agent + live-test docs with the non-PTY Claude path, and emit session system-event heartbeat notices when CLI watchdog no-output timeouts terminate runs. Related to #28261. Landed from contributor PRs #28610 and #31149. Thanks @niceysam, @cryptomaltese and @vincentkoc. - ACP/ACPX session bootstrap: retry with `sessions new` when `sessions ensure` returns no session identifiers so ACP spawns avoid `NO_SESSION`/`ACP_TURN_FAILED` failures on affected agents. Related to #28786. Landed from contributor PR #31338. Thanks @Sid-Qin and @vincentkoc. - LINE/auth boundary hardening synthesis: enforce strict LINE webhook authn/z boundary semantics across pairing-store account scoping, DM/group allowlist separation, fail-closed webhook auth/runtime behavior, and replay/duplication controls (including in-flight replay reservation and post-success dedupe marking). (from #26701, #26683, #25978, #17593, #16619, #31990, #26047, #30584, #18777) Thanks @bmendonca3, @davidahmann, @harshang03, @haosenwang1018, @liuxiaopai-ai, @coygeek, and @Takhoffman. diff --git a/src/memory/embeddings.test.ts b/src/memory/embeddings.test.ts index 57e4410f821..91cfb567a37 100644 --- a/src/memory/embeddings.test.ts +++ b/src/memory/embeddings.test.ts @@ -471,6 +471,187 @@ describe("local embedding normalization", () => { }); }); +describe("local embedding ensureContext concurrency", () => { + afterEach(() => { + vi.resetAllMocks(); + vi.resetModules(); + vi.unstubAllGlobals(); + vi.doUnmock("./node-llama.js"); + }); + + it("loads the model only once when embedBatch is called concurrently", async () => { + const getLlamaSpy = vi.fn(); + const loadModelSpy = vi.fn(); + const createContextSpy = vi.fn(); + + const nodeLlamaModule = await import("./node-llama.js"); + vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({ + getLlama: async (...args: unknown[]) => { + getLlamaSpy(...args); + await new Promise((r) => setTimeout(r, 50)); + return { + loadModel: async (...modelArgs: unknown[]) => { + loadModelSpy(...modelArgs); + await new Promise((r) => setTimeout(r, 50)); + return { + createEmbeddingContext: async () => { + createContextSpy(); + return { + getEmbeddingFor: vi.fn().mockResolvedValue({ + vector: new Float32Array([1, 0, 0, 0]), + }), + }; + }, + }; + }, + }; + }, + resolveModelFile: async () => "/fake/model.gguf", + LlamaLogLevel: { error: 0 }, + } as never); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "local", + model: "", + fallback: "none", + }); + + const provider = requireProvider(result); + const results = await Promise.all([ + provider.embedBatch(["text1"]), + provider.embedBatch(["text2"]), + provider.embedBatch(["text3"]), + provider.embedBatch(["text4"]), + ]); + + expect(results).toHaveLength(4); + for (const embeddings of results) { + expect(embeddings).toHaveLength(1); + expect(embeddings[0]).toHaveLength(4); + } + + expect(getLlamaSpy).toHaveBeenCalledTimes(1); + expect(loadModelSpy).toHaveBeenCalledTimes(1); + expect(createContextSpy).toHaveBeenCalledTimes(1); + }); + + it("retries initialization after a transient ensureContext failure", async () => { + const getLlamaSpy = vi.fn(); + const loadModelSpy = vi.fn(); + const createContextSpy = vi.fn(); + + let failFirstGetLlama = true; + const nodeLlamaModule = await import("./node-llama.js"); + vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({ + getLlama: async (...args: unknown[]) => { + getLlamaSpy(...args); + if (failFirstGetLlama) { + failFirstGetLlama = false; + throw new Error("transient init failure"); + } + return { + loadModel: async (...modelArgs: unknown[]) => { + loadModelSpy(...modelArgs); + return { + createEmbeddingContext: async () => { + createContextSpy(); + return { + getEmbeddingFor: vi.fn().mockResolvedValue({ + vector: new Float32Array([1, 0, 0, 0]), + }), + }; + }, + }; + }, + }; + }, + resolveModelFile: async () => "/fake/model.gguf", + LlamaLogLevel: { error: 0 }, + } as never); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "local", + model: "", + fallback: "none", + }); + + const provider = requireProvider(result); + await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure"); + + const recovered = await provider.embedBatch(["second"]); + expect(recovered).toHaveLength(1); + expect(recovered[0]).toHaveLength(4); + + expect(getLlamaSpy).toHaveBeenCalledTimes(2); + expect(loadModelSpy).toHaveBeenCalledTimes(1); + expect(createContextSpy).toHaveBeenCalledTimes(1); + }); + + it("shares initialization when embedQuery and embedBatch start concurrently", async () => { + const getLlamaSpy = vi.fn(); + const loadModelSpy = vi.fn(); + const createContextSpy = vi.fn(); + + const nodeLlamaModule = await import("./node-llama.js"); + vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({ + getLlama: async (...args: unknown[]) => { + getLlamaSpy(...args); + await new Promise((r) => setTimeout(r, 50)); + return { + loadModel: async (...modelArgs: unknown[]) => { + loadModelSpy(...modelArgs); + await new Promise((r) => setTimeout(r, 50)); + return { + createEmbeddingContext: async () => { + createContextSpy(); + return { + getEmbeddingFor: vi.fn().mockResolvedValue({ + vector: new Float32Array([1, 0, 0, 0]), + }), + }; + }, + }; + }, + }; + }, + resolveModelFile: async () => "/fake/model.gguf", + LlamaLogLevel: { error: 0 }, + } as never); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "local", + model: "", + fallback: "none", + }); + + const provider = requireProvider(result); + const [queryA, batch, queryB] = await Promise.all([ + provider.embedQuery("query-a"), + provider.embedBatch(["batch-a", "batch-b"]), + provider.embedQuery("query-b"), + ]); + + expect(queryA).toHaveLength(4); + expect(batch).toHaveLength(2); + expect(queryB).toHaveLength(4); + expect(batch[0]).toHaveLength(4); + expect(batch[1]).toHaveLength(4); + + expect(getLlamaSpy).toHaveBeenCalledTimes(1); + expect(loadModelSpy).toHaveBeenCalledTimes(1); + expect(createContextSpy).toHaveBeenCalledTimes(1); + }); +}); + describe("FTS-only fallback when no provider available", () => { it("returns null provider with reason when auto mode finds no providers", async () => { vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue( diff --git a/src/memory/embeddings.ts b/src/memory/embeddings.ts index 9682c08582a..faf1c795b95 100644 --- a/src/memory/embeddings.ts +++ b/src/memory/embeddings.ts @@ -111,19 +111,34 @@ async function createLocalEmbeddingProvider( let llama: Llama | null = null; let embeddingModel: LlamaModel | null = null; let embeddingContext: LlamaEmbeddingContext | null = null; + let initPromise: Promise | null = null; - const ensureContext = async () => { - if (!llama) { - llama = await getLlama({ logLevel: LlamaLogLevel.error }); + const ensureContext = async (): Promise => { + if (embeddingContext) { + return embeddingContext; } - if (!embeddingModel) { - const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined); - embeddingModel = await llama.loadModel({ modelPath: resolved }); + if (initPromise) { + return initPromise; } - if (!embeddingContext) { - embeddingContext = await embeddingModel.createEmbeddingContext(); - } - return embeddingContext; + initPromise = (async () => { + try { + if (!llama) { + llama = await getLlama({ logLevel: LlamaLogLevel.error }); + } + if (!embeddingModel) { + const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined); + embeddingModel = await llama.loadModel({ modelPath: resolved }); + } + if (!embeddingContext) { + embeddingContext = await embeddingModel.createEmbeddingContext(); + } + return embeddingContext; + } catch (err) { + initPromise = null; + throw err; + } + })(); + return initPromise; }; return {