From a96ef12061769ce3247462a14fb56cc209831e64 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sat, 7 Mar 2026 17:36:42 +0000 Subject: [PATCH] refactor(memory): dedupe local embedding init concurrency fixtures --- src/memory/embeddings.test.ts | 123 ++++++++++------------------------ 1 file changed, 37 insertions(+), 86 deletions(-) diff --git a/src/memory/embeddings.test.ts b/src/memory/embeddings.test.ts index 027673c7099..df22885fefd 100644 --- a/src/memory/embeddings.test.ts +++ b/src/memory/embeddings.test.ts @@ -516,20 +516,32 @@ describe("local embedding ensureContext concurrency", () => { vi.doUnmock("./node-llama.js"); }); - it("loads the model only once when embedBatch is called concurrently", async () => { + async function setupLocalProviderWithMockedInit(params?: { + initializationDelayMs?: number; + failFirstGetLlama?: boolean; + }) { const getLlamaSpy = vi.fn(); const loadModelSpy = vi.fn(); const createContextSpy = vi.fn(); + let shouldFail = params?.failFirstGetLlama ?? false; const nodeLlamaModule = await import("./node-llama.js"); vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({ getLlama: async (...args: unknown[]) => { getLlamaSpy(...args); - await new Promise((r) => setTimeout(r, 50)); + if (shouldFail) { + shouldFail = false; + throw new Error("transient init failure"); + } + if (params?.initializationDelayMs) { + await new Promise((r) => setTimeout(r, params.initializationDelayMs)); + } return { loadModel: async (...modelArgs: unknown[]) => { loadModelSpy(...modelArgs); - await new Promise((r) => setTimeout(r, 50)); + if (params?.initializationDelayMs) { + await new Promise((r) => setTimeout(r, params.initializationDelayMs)); + } return { createEmbeddingContext: async () => { createContextSpy(); @@ -548,7 +560,6 @@ describe("local embedding ensureContext concurrency", () => { } as never); const { createEmbeddingProvider } = await import("./embeddings.js"); - const result = await createEmbeddingProvider({ config: {} as never, provider: "local", @@ -556,7 +567,20 @@ describe("local embedding ensureContext concurrency", () => { fallback: "none", }); - const provider = requireProvider(result); + return { + provider: requireProvider(result), + getLlamaSpy, + loadModelSpy, + createContextSpy, + }; + } + + it("loads the model only once when embedBatch is called concurrently", async () => { + const { provider, getLlamaSpy, loadModelSpy, createContextSpy } = + await setupLocalProviderWithMockedInit({ + initializationDelayMs: 50, + }); + const results = await Promise.all([ provider.embedBatch(["text1"]), provider.embedBatch(["text2"]), @@ -576,49 +600,11 @@ describe("local embedding ensureContext concurrency", () => { }); it("retries initialization after a transient ensureContext failure", async () => { - const getLlamaSpy = vi.fn(); - const loadModelSpy = vi.fn(); - const createContextSpy = vi.fn(); + const { provider, getLlamaSpy, loadModelSpy, createContextSpy } = + await setupLocalProviderWithMockedInit({ + failFirstGetLlama: true, + }); - let failFirstGetLlama = true; - const nodeLlamaModule = await import("./node-llama.js"); - vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({ - getLlama: async (...args: unknown[]) => { - getLlamaSpy(...args); - if (failFirstGetLlama) { - failFirstGetLlama = false; - throw new Error("transient init failure"); - } - return { - loadModel: async (...modelArgs: unknown[]) => { - loadModelSpy(...modelArgs); - return { - createEmbeddingContext: async () => { - createContextSpy(); - return { - getEmbeddingFor: vi.fn().mockResolvedValue({ - vector: new Float32Array([1, 0, 0, 0]), - }), - }; - }, - }; - }, - }; - }, - resolveModelFile: async () => "/fake/model.gguf", - LlamaLogLevel: { error: 0 }, - } as never); - - const { createEmbeddingProvider } = await import("./embeddings.js"); - - const result = await createEmbeddingProvider({ - config: {} as never, - provider: "local", - model: "", - fallback: "none", - }); - - const provider = requireProvider(result); await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure"); const recovered = await provider.embedBatch(["second"]); @@ -631,46 +617,11 @@ describe("local embedding ensureContext concurrency", () => { }); it("shares initialization when embedQuery and embedBatch start concurrently", async () => { - const getLlamaSpy = vi.fn(); - const loadModelSpy = vi.fn(); - const createContextSpy = vi.fn(); + const { provider, getLlamaSpy, loadModelSpy, createContextSpy } = + await setupLocalProviderWithMockedInit({ + initializationDelayMs: 50, + }); - const nodeLlamaModule = await import("./node-llama.js"); - vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({ - getLlama: async (...args: unknown[]) => { - getLlamaSpy(...args); - await new Promise((r) => setTimeout(r, 50)); - return { - loadModel: async (...modelArgs: unknown[]) => { - loadModelSpy(...modelArgs); - await new Promise((r) => setTimeout(r, 50)); - return { - createEmbeddingContext: async () => { - createContextSpy(); - return { - getEmbeddingFor: vi.fn().mockResolvedValue({ - vector: new Float32Array([1, 0, 0, 0]), - }), - }; - }, - }; - }, - }; - }, - resolveModelFile: async () => "/fake/model.gguf", - LlamaLogLevel: { error: 0 }, - } as never); - - const { createEmbeddingProvider } = await import("./embeddings.js"); - - const result = await createEmbeddingProvider({ - config: {} as never, - provider: "local", - model: "", - fallback: "none", - }); - - const provider = requireProvider(result); const [queryA, batch, queryB] = await Promise.all([ provider.embedQuery("query-a"), provider.embedBatch(["batch-a", "batch-b"]),