diff --git a/src/plugins/contracts/catalog.contract.test.ts b/src/plugins/contracts/catalog.contract.test.ts index 4339b6edec4..a87e632ac45 100644 --- a/src/plugins/contracts/catalog.contract.test.ts +++ b/src/plugins/contracts/catalog.contract.test.ts @@ -5,36 +5,57 @@ import { expectCodexMissingAuthHint, } from "../provider-runtime.test-support.js"; import { - providerContractPluginIds, + resolveProviderContractPluginIdsForProvider, resolveProviderContractProvidersForPluginIds, uniqueProviderContractProviders, } from "./registry.js"; -const resolvePluginProvidersMock = vi.fn(); -const resolveOwningPluginIdsForProviderMock = vi.fn(); -const resolveNonBundledProviderPluginIdsMock = vi.fn(); +type ResolvePluginProviders = typeof import("../providers.js").resolvePluginProviders; +type ResolveOwningPluginIdsForProvider = + typeof import("../providers.js").resolveOwningPluginIdsForProvider; +type ResolveNonBundledProviderPluginIds = + typeof import("../providers.js").resolveNonBundledProviderPluginIds; + +const resolvePluginProvidersMock = vi.hoisted(() => + vi.fn((_) => uniqueProviderContractProviders), +); +const resolveOwningPluginIdsForProviderMock = vi.hoisted(() => + vi.fn((params) => + resolveProviderContractPluginIdsForProvider(params.provider), + ), +); +const resolveNonBundledProviderPluginIdsMock = vi.hoisted(() => + vi.fn((_) => [] as string[]), +); vi.mock("../providers.js", () => ({ - resolvePluginProviders: (...args: unknown[]) => resolvePluginProvidersMock(...args), - resolveOwningPluginIdsForProvider: (...args: unknown[]) => - resolveOwningPluginIdsForProviderMock(...args), - resolveNonBundledProviderPluginIds: (...args: unknown[]) => - resolveNonBundledProviderPluginIdsMock(...args), + resolvePluginProviders: (params: unknown) => resolvePluginProvidersMock(params as never), + resolveOwningPluginIdsForProvider: (params: unknown) => + resolveOwningPluginIdsForProviderMock(params as never), + resolveNonBundledProviderPluginIds: (params: unknown) => + resolveNonBundledProviderPluginIdsMock(params as never), })); -const { - augmentModelCatalogWithProviderPlugins, - buildProviderMissingAuthMessageWithPlugin, - resetProviderRuntimeHookCacheForTest, - resolveProviderBuiltInModelSuppression, -} = await import("../provider-runtime.js"); +let augmentModelCatalogWithProviderPlugins: typeof import("../provider-runtime.js").augmentModelCatalogWithProviderPlugins; +let buildProviderMissingAuthMessageWithPlugin: typeof import("../provider-runtime.js").buildProviderMissingAuthMessageWithPlugin; +let resetProviderRuntimeHookCacheForTest: typeof import("../provider-runtime.js").resetProviderRuntimeHookCacheForTest; +let resolveProviderBuiltInModelSuppression: typeof import("../provider-runtime.js").resolveProviderBuiltInModelSuppression; describe("provider catalog contract", () => { - beforeEach(() => { + beforeEach(async () => { + vi.resetModules(); + ({ + augmentModelCatalogWithProviderPlugins, + buildProviderMissingAuthMessageWithPlugin, + resetProviderRuntimeHookCacheForTest, + resolveProviderBuiltInModelSuppression, + } = await import("../provider-runtime.js")); resetProviderRuntimeHookCacheForTest(); resolveOwningPluginIdsForProviderMock.mockReset(); - resolveOwningPluginIdsForProviderMock.mockReturnValue(providerContractPluginIds); + resolveOwningPluginIdsForProviderMock.mockImplementation((params) => + resolveProviderContractPluginIdsForProvider(params.provider), + ); resolveNonBundledProviderPluginIdsMock.mockReset(); resolveNonBundledProviderPluginIdsMock.mockReturnValue([]); diff --git a/src/plugins/contracts/registry.ts b/src/plugins/contracts/registry.ts index adedfe57d0c..f33571b8008 100644 --- a/src/plugins/contracts/registry.ts +++ b/src/plugins/contracts/registry.ts @@ -177,6 +177,19 @@ export function requireProviderContractProvider(providerId: string): ProviderPlu return provider; } +export function resolveProviderContractPluginIdsForProvider( + providerId: string, +): string[] | undefined { + const pluginIds = [ + ...new Set( + providerContractRegistry + .filter((entry) => entry.provider.id === providerId) + .map((entry) => entry.pluginId), + ), + ]; + return pluginIds.length > 0 ? pluginIds : undefined; +} + export function resolveProviderContractProvidersForPluginIds( pluginIds: readonly string[], ): ProviderPlugin[] {