diff --git a/extensions/discord/src/monitor/message-handler.inbound-context.test.ts b/extensions/discord/src/monitor/message-handler.inbound-context.test.ts index 6eb378e7bbb..29d49887d36 100644 --- a/extensions/discord/src/monitor/message-handler.inbound-context.test.ts +++ b/extensions/discord/src/monitor/message-handler.inbound-context.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from "vitest"; -import { inboundCtxCapture as capture } from "../../../../src/channels/plugins/contracts/inbound-contract-dispatch-mock.js"; +import { inboundCtxCapture as capture } from "../../../../src/channels/plugins/contracts/inbound-testkit.js"; import { expectChannelInboundContextContract as expectInboundContextContract } from "../../../../src/channels/plugins/contracts/suites.js"; import type { DiscordMessagePreflightContext } from "./message-handler.preflight.js"; import { processDiscordMessage } from "./message-handler.process.js"; diff --git a/extensions/signal/src/monitor/event-handler.mention-gating.test.ts b/extensions/signal/src/monitor/event-handler.mention-gating.test.ts index 60222d4a7ab..ffcdb5baba6 100644 --- a/extensions/signal/src/monitor/event-handler.mention-gating.test.ts +++ b/extensions/signal/src/monitor/event-handler.mention-gating.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it, vi } from "vitest"; import type { MsgContext } from "../../../../src/auto-reply/templating.js"; -import { buildDispatchInboundCaptureMock } from "../../../../src/channels/plugins/contracts/dispatch-inbound-capture.js"; +import { buildDispatchInboundCaptureMock } from "../../../../src/channels/plugins/contracts/inbound-testkit.js"; import type { OpenClawConfig } from "../../../../src/config/types.js"; import { createBaseSignalEventHandlerDeps, diff --git a/src/channels/plugins/contracts/dispatch-inbound-capture.ts b/src/channels/plugins/contracts/dispatch-inbound-capture.ts deleted file mode 100644 index cd7b0bd5fdb..00000000000 --- a/src/channels/plugins/contracts/dispatch-inbound-capture.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { vi } from "vitest"; - -export function buildDispatchInboundCaptureMock>( - actual: T, - setCtx: (ctx: unknown) => void, -) { - const dispatchInboundMessage = vi.fn(async (params: { ctx: unknown }) => { - setCtx(params.ctx); - return { queuedFinal: false, counts: { tool: 0, block: 0, final: 0 } }; - }); - - return { - ...actual, - dispatchInboundMessage, - dispatchInboundMessageWithDispatcher: dispatchInboundMessage, - dispatchInboundMessageWithBufferedDispatcher: dispatchInboundMessage, - }; -} diff --git a/src/channels/plugins/contracts/inbound-contract-capture.ts b/src/channels/plugins/contracts/inbound-contract-capture.ts deleted file mode 100644 index b74164c7a79..00000000000 --- a/src/channels/plugins/contracts/inbound-contract-capture.ts +++ /dev/null @@ -1,20 +0,0 @@ -import type { MsgContext } from "../../../auto-reply/templating.js"; -import { buildDispatchInboundCaptureMock } from "./dispatch-inbound-capture.js"; - -export type InboundContextCapture = { - ctx: MsgContext | undefined; -}; - -export function createInboundContextCapture(): InboundContextCapture { - return { ctx: undefined }; -} - -export async function buildDispatchInboundContextCapture( - importOriginal: >() => Promise, - capture: InboundContextCapture, -) { - const actual = await importOriginal(); - return buildDispatchInboundCaptureMock(actual, (ctx) => { - capture.ctx = ctx as MsgContext; - }); -} diff --git a/src/channels/plugins/contracts/inbound-contract-dispatch-mock.ts b/src/channels/plugins/contracts/inbound-contract-dispatch-mock.ts deleted file mode 100644 index 05698d628c5..00000000000 --- a/src/channels/plugins/contracts/inbound-contract-dispatch-mock.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { vi } from "vitest"; -import { createInboundContextCapture } from "./inbound-contract-capture.js"; -import { buildDispatchInboundContextCapture } from "./inbound-contract-capture.js"; - -export const inboundCtxCapture = createInboundContextCapture(); - -vi.mock("../../../auto-reply/dispatch.js", async (importOriginal) => { - return await buildDispatchInboundContextCapture(importOriginal, inboundCtxCapture); -}); diff --git a/src/channels/plugins/contracts/inbound-testkit.ts b/src/channels/plugins/contracts/inbound-testkit.ts new file mode 100644 index 00000000000..b3241572f56 --- /dev/null +++ b/src/channels/plugins/contracts/inbound-testkit.ts @@ -0,0 +1,39 @@ +import { vi } from "vitest"; +import type { MsgContext } from "../../../auto-reply/templating.js"; + +export type InboundContextCapture = { + ctx: MsgContext | undefined; +}; + +export function createInboundContextCapture(): InboundContextCapture { + return { ctx: undefined }; +} + +export function buildDispatchInboundCaptureMock>( + actual: T, + setCtx: (ctx: unknown) => void, +) { + const dispatchInboundMessage = vi.fn(async (params: { ctx: unknown }) => { + setCtx(params.ctx); + return { queuedFinal: false, counts: { tool: 0, block: 0, final: 0 } }; + }); + + return { + ...actual, + dispatchInboundMessage, + dispatchInboundMessageWithDispatcher: dispatchInboundMessage, + dispatchInboundMessageWithBufferedDispatcher: dispatchInboundMessage, + }; +} + +export async function buildDispatchInboundContextCapture( + importOriginal: >() => Promise, + capture: InboundContextCapture, +) { + const actual = await importOriginal(); + return buildDispatchInboundCaptureMock(actual, (ctx) => { + capture.ctx = ctx as MsgContext; + }); +} + +export const inboundCtxCapture = createInboundContextCapture(); diff --git a/src/channels/plugins/contracts/inbound.contract.test.ts b/src/channels/plugins/contracts/inbound.contract.test.ts index e90e5090e6b..eadb1913544 100644 --- a/src/channels/plugins/contracts/inbound.contract.test.ts +++ b/src/channels/plugins/contracts/inbound.contract.test.ts @@ -8,7 +8,7 @@ import { createInboundSlackTestContext } from "../../../../extensions/slack/src/ import type { SlackMessageEvent } from "../../../../extensions/slack/src/types.js"; import type { MsgContext } from "../../../auto-reply/templating.js"; import type { OpenClawConfig } from "../../../config/config.js"; -import { inboundCtxCapture } from "./inbound-contract-dispatch-mock.js"; +import { inboundCtxCapture } from "./inbound-testkit.js"; import { expectChannelInboundContextContract } from "./suites.js"; const signalCapture = vi.hoisted(() => ({ ctx: undefined as MsgContext | undefined })); diff --git a/src/plugins/contracts/auth-choice.contract.test.ts b/src/plugins/contracts/auth-choice.contract.test.ts index 7f3f6535e54..fc301051065 100644 --- a/src/plugins/contracts/auth-choice.contract.test.ts +++ b/src/plugins/contracts/auth-choice.contract.test.ts @@ -6,13 +6,12 @@ import { readAuthProfilesForAgent, requireOpenClawAgentDir, setupAuthTestEnv, -} from "../../../test/helpers/auth-wizard.js"; +} from "../../commands/test-wizard-helpers.js"; import { clearRuntimeAuthProfileStoreSnapshots } from "../../agents/auth-profiles/store.js"; import { applyAuthChoiceLoadedPluginProvider } from "../../plugins/provider-auth-choice.js"; -import { createCapturedPluginRegistration } from "../../test-utils/plugin-registration.js"; import { buildProviderPluginMethodChoice } from "../provider-wizard.js"; -import type { OpenClawPluginApi, ProviderPlugin } from "../types.js"; import { requireProviderContractProvider, uniqueProviderContractProviders } from "./registry.js"; +import { registerProviders, requireProvider } from "./testkit.js"; type ResolvePluginProviders = typeof import("../../plugins/provider-auth-choice.runtime.js").resolvePluginProviders; @@ -67,22 +66,6 @@ type StoredAuthProfile = { const qwenPortalPlugin = (await import("../../../extensions/qwen-portal-auth/index.js")).default; -function registerProviders(...plugins: Array<{ register(api: OpenClawPluginApi): void }>) { - const captured = createCapturedPluginRegistration(); - for (const plugin of plugins) { - plugin.register(captured.api); - } - return captured.providers; -} - -function requireProvider(providers: ProviderPlugin[], providerId: string) { - const provider = providers.find((entry) => entry.id === providerId); - if (!provider) { - throw new Error(`provider ${providerId} missing`); - } - return provider; -} - describe("provider auth-choice contract", () => { const lifecycle = createAuthTestLifecycle([ "OPENCLAW_STATE_DIR", diff --git a/src/plugins/contracts/auth.contract.test.ts b/src/plugins/contracts/auth.contract.test.ts index 1b8c809f9df..4842bef5e76 100644 --- a/src/plugins/contracts/auth.contract.test.ts +++ b/src/plugins/contracts/auth.contract.test.ts @@ -4,14 +4,13 @@ import { replaceRuntimeAuthProfileStoreSnapshots, } from "../../agents/auth-profiles/store.js"; import { createNonExitingRuntime } from "../../runtime.js"; -import { createCapturedPluginRegistration } from "../../test-utils/plugin-registration.js"; import type { WizardMultiSelectParams, WizardPrompter, WizardProgress, WizardSelectParams, } from "../../wizard/prompts.js"; -import type { OpenClawPluginApi, ProviderPlugin } from "../types.js"; +import { registerProviders, requireProvider } from "./testkit.js"; type LoginOpenAICodexOAuth = (typeof import("../../plugins/provider-openai-codex-oauth.js"))["loginOpenAICodexOAuth"]; @@ -78,22 +77,6 @@ function buildAuthContext() { }; } -function registerProviders(...plugins: Array<{ register(api: OpenClawPluginApi): void }>) { - const captured = createCapturedPluginRegistration(); - for (const plugin of plugins) { - plugin.register(captured.api); - } - return captured.providers; -} - -function requireProvider(providers: ProviderPlugin[], providerId: string) { - const provider = providers.find((entry) => entry.id === providerId); - if (!provider) { - throw new Error(`provider ${providerId} missing`); - } - return provider; -} - describe("provider auth contract", () => { afterEach(() => { loginOpenAICodexOAuthMock.mockReset(); diff --git a/src/plugins/contracts/discovery.contract.test.ts b/src/plugins/contracts/discovery.contract.test.ts index 072e657616e..0a334a619a1 100644 --- a/src/plugins/contracts/discovery.contract.test.ts +++ b/src/plugins/contracts/discovery.contract.test.ts @@ -5,9 +5,8 @@ import { } from "../../agents/auth-profiles/store.js"; import { QWEN_OAUTH_MARKER } from "../../agents/model-auth-markers.js"; import type { ModelDefinitionConfig } from "../../config/types.models.js"; -import { createCapturedPluginRegistration } from "../../test-utils/plugin-registration.js"; import { runProviderCatalog } from "../provider-discovery.js"; -import type { OpenClawPluginApi, ProviderPlugin } from "../types.js"; +import { registerProviders, requireProvider } from "./testkit.js"; const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn()); const buildOllamaProviderMock = vi.hoisted(() => vi.fn()); @@ -60,22 +59,6 @@ const cloudflareAiGatewayPlugin = ( await import("../../../extensions/cloudflare-ai-gateway/index.js") ).default; -function registerProviders(...plugins: Array<{ register(api: OpenClawPluginApi): void }>) { - const captured = createCapturedPluginRegistration(); - for (const plugin of plugins) { - plugin.register(captured.api); - } - return captured.providers; -} - -function requireProvider(providers: ProviderPlugin[], providerId: string) { - const provider = providers.find((entry) => entry.id === providerId); - if (!provider) { - throw new Error(`provider ${providerId} missing`); - } - return provider; -} - function createModelConfig(id: string, name = id): ModelDefinitionConfig { return { id, diff --git a/src/plugins/contracts/loader.contract.test.ts b/src/plugins/contracts/loader.contract.test.ts index cdac689af52..dde3ef19c19 100644 --- a/src/plugins/contracts/loader.contract.test.ts +++ b/src/plugins/contracts/loader.contract.test.ts @@ -2,15 +2,8 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { withBundledPluginAllowlistCompat } from "../bundled-compat.js"; import { __testing as providerTesting } from "../providers.js"; import { resolvePluginWebSearchProviders } from "../web-search-providers.js"; -import { providerContractPluginIds, webSearchProviderContractRegistry } from "./registry.js"; - -function uniqueSortedPluginIds(values: string[]) { - return [...new Set(values)].toSorted((left, right) => left.localeCompare(right)); -} - -function normalizeProviderContractPluginId(pluginId: string) { - return pluginId === "kimi-coding" ? "kimi" : pluginId; -} +import { providerContractCompatPluginIds, webSearchProviderContractRegistry } from "./registry.js"; +import { uniqueSortedStrings } from "./testkit.js"; describe("plugin loader contract", () => { beforeEach(() => { @@ -18,9 +11,7 @@ describe("plugin loader contract", () => { }); it("keeps bundled provider compatibility wired to the provider registry", () => { - const providerPluginIds = uniqueSortedPluginIds( - providerContractPluginIds.map(normalizeProviderContractPluginId), - ); + const providerPluginIds = uniqueSortedStrings(providerContractCompatPluginIds); const compatPluginIds = providerTesting.resolveBundledProviderCompatPluginIds({ config: { plugins: { @@ -38,16 +29,12 @@ describe("plugin loader contract", () => { pluginIds: compatPluginIds, }); - expect(uniqueSortedPluginIds(compatPluginIds)).toEqual( - expect.arrayContaining(providerPluginIds), - ); + expect(uniqueSortedStrings(compatPluginIds)).toEqual(expect.arrayContaining(providerPluginIds)); expect(compatConfig?.plugins?.allow).toEqual(expect.arrayContaining(providerPluginIds)); }); it("keeps vitest bundled provider enablement wired to the provider registry", () => { - const providerPluginIds = uniqueSortedPluginIds( - providerContractPluginIds.map(normalizeProviderContractPluginId), - ); + const providerPluginIds = uniqueSortedStrings(providerContractCompatPluginIds); const compatConfig = providerTesting.withBundledProviderVitestCompat({ config: undefined, pluginIds: providerPluginIds, @@ -61,19 +48,19 @@ describe("plugin loader contract", () => { }); it("keeps bundled web search loading scoped to the web search registry", () => { - const webSearchPluginIds = uniqueSortedPluginIds( + const webSearchPluginIds = uniqueSortedStrings( webSearchProviderContractRegistry.map((entry) => entry.pluginId), ); const providers = resolvePluginWebSearchProviders({}); - expect(uniqueSortedPluginIds(providers.map((provider) => provider.pluginId))).toEqual( + expect(uniqueSortedStrings(providers.map((provider) => provider.pluginId))).toEqual( webSearchPluginIds, ); }); it("keeps bundled web search allowlist compatibility wired to the web search registry", () => { - const webSearchPluginIds = uniqueSortedPluginIds( + const webSearchPluginIds = uniqueSortedStrings( webSearchProviderContractRegistry.map((entry) => entry.pluginId), ); @@ -86,7 +73,7 @@ describe("plugin loader contract", () => { }, }); - expect(uniqueSortedPluginIds(providers.map((provider) => provider.pluginId))).toEqual( + expect(uniqueSortedStrings(providers.map((provider) => provider.pluginId))).toEqual( webSearchPluginIds, ); }); diff --git a/src/plugins/contracts/registry.ts b/src/plugins/contracts/registry.ts index 8247b8b273d..8ab7422c1e2 100644 --- a/src/plugins/contracts/registry.ts +++ b/src/plugins/contracts/registry.ts @@ -160,6 +160,10 @@ export const providerContractPluginIds = [ ...new Set(providerContractRegistry.map((entry) => entry.pluginId)), ].toSorted((left, right) => left.localeCompare(right)); +export const providerContractCompatPluginIds = providerContractPluginIds.map((pluginId) => + pluginId === "kimi-coding" ? "kimi" : pluginId, +); + export function requireProviderContractProvider(providerId: string): ProviderPlugin { const provider = uniqueProviderContractProviders.find((entry) => entry.id === providerId); if (!provider) { diff --git a/src/plugins/contracts/testkit.ts b/src/plugins/contracts/testkit.ts new file mode 100644 index 00000000000..e3f98c70759 --- /dev/null +++ b/src/plugins/contracts/testkit.ts @@ -0,0 +1,26 @@ +import { createCapturedPluginRegistration } from "../../test-utils/plugin-registration.js"; +import type { OpenClawPluginApi, ProviderPlugin } from "../types.js"; + +type RegistrablePlugin = { + register(api: OpenClawPluginApi): void; +}; + +export function registerProviders(...plugins: RegistrablePlugin[]) { + const captured = createCapturedPluginRegistration(); + for (const plugin of plugins) { + plugin.register(captured.api); + } + return captured.providers; +} + +export function requireProvider(providers: ProviderPlugin[], providerId: string) { + const provider = providers.find((entry) => entry.id === providerId); + if (!provider) { + throw new Error(`provider ${providerId} missing`); + } + return provider; +} + +export function uniqueSortedStrings(values: readonly string[]) { + return [...new Set(values)].toSorted((left, right) => left.localeCompare(right)); +}