diff --git a/src/extension-host/media-runtime-policy.test.ts b/src/extension-host/media-runtime-policy.test.ts index 4442fd5cd67..6321d3604a1 100644 --- a/src/extension-host/media-runtime-policy.test.ts +++ b/src/extension-host/media-runtime-policy.test.ts @@ -1,6 +1,63 @@ import { describe, expect, it, vi } from "vitest"; vi.mock("./runtime-backend-catalog.js", () => ({ + listExtensionHostMediaRuntimeBackendCatalogEntries: vi.fn(() => [ + { + id: "capability.runtime-backend:media.audio:deepgram", + family: "capability.runtime-backend", + subsystemId: "media.audio", + backendId: "deepgram", + source: "builtin", + defaultRank: 0, + selectorKeys: ["deepgram"], + capabilities: ["audio"], + metadata: { autoSelectable: true }, + }, + { + id: "capability.runtime-backend:media.audio:openai", + family: "capability.runtime-backend", + subsystemId: "media.audio", + backendId: "openai", + source: "builtin", + defaultRank: 1, + selectorKeys: ["openai"], + capabilities: ["audio"], + metadata: { autoSelectable: true }, + }, + { + id: "capability.runtime-backend:media.image:openai", + family: "capability.runtime-backend", + subsystemId: "media.image", + backendId: "openai", + source: "builtin", + defaultRank: 0, + selectorKeys: ["openai"], + capabilities: ["image"], + metadata: { autoSelectable: true, defaultModel: "openai-default" }, + }, + { + id: "capability.runtime-backend:media.image:google", + family: "capability.runtime-backend", + subsystemId: "media.image", + backendId: "google", + source: "builtin", + defaultRank: 1, + selectorKeys: ["google", "gemini"], + capabilities: ["image"], + metadata: { autoSelectable: true, defaultModel: "google-default" }, + }, + { + id: "capability.runtime-backend:media.video:openai", + family: "capability.runtime-backend", + subsystemId: "media.video", + backendId: "openai", + source: "builtin", + defaultRank: 0, + selectorKeys: ["openai"], + capabilities: ["video"], + metadata: { autoSelectable: true }, + }, + ]), listExtensionHostMediaAutoRuntimeBackendIds: vi.fn( (capability: "audio" | "image" | "video") => ({ diff --git a/src/extension-host/media-runtime-policy.ts b/src/extension-host/media-runtime-policy.ts index ba9dd5e0e6d..9115fb7ae97 100644 --- a/src/extension-host/media-runtime-policy.ts +++ b/src/extension-host/media-runtime-policy.ts @@ -1,9 +1,10 @@ import type { MediaUnderstandingCapability } from "../media-understanding/types.js"; import { normalizeExtensionHostMediaProviderId } from "./media-runtime-registry.js"; import { - listExtensionHostMediaAutoRuntimeBackendIds, + listExtensionHostMediaRuntimeBackendCatalogEntries, resolveExtensionHostMediaRuntimeDefaultModel, } from "./runtime-backend-catalog.js"; +import { resolveExtensionHostRuntimeBackendIdsByPolicy } from "./runtime-backend-policy.js"; export type ExtensionHostMediaActiveModel = { provider: string; @@ -15,6 +16,18 @@ export type ExtensionHostMediaProviderCandidate = { model?: string; }; +function resolveExtensionHostMediaRuntimeSubsystemId( + capability: MediaUnderstandingCapability, +): "media.audio" | "media.image" | "media.video" { + if (capability === "audio") { + return "media.audio"; + } + if (capability === "video") { + return "media.video"; + } + return "media.image"; +} + function resolveExtensionHostMediaCandidateModel(params: { capability: MediaUnderstandingCapability; provider: string; @@ -39,16 +52,16 @@ export function resolveExtensionHostMediaProviderCandidates(params: { activeModel?: ExtensionHostMediaActiveModel; }): readonly ExtensionHostMediaProviderCandidate[] { const candidates: ExtensionHostMediaProviderCandidate[] = []; - const seen = new Set(); - - const pushCandidate = (provider: string | undefined): void => { - const normalized = provider?.trim() - ? normalizeExtensionHostMediaProviderId(provider) - : undefined; - if (!normalized || seen.has(normalized)) { - return; - } - seen.add(normalized); + const preferredProvider = params.activeModel?.provider?.trim() + ? normalizeExtensionHostMediaProviderId(params.activeModel.provider) + : undefined; + for (const provider of resolveExtensionHostRuntimeBackendIdsByPolicy({ + entries: listExtensionHostMediaRuntimeBackendCatalogEntries(), + subsystemId: resolveExtensionHostMediaRuntimeSubsystemId(params.capability), + preferredBackendId: preferredProvider, + include: (entry) => entry.metadata?.autoSelectable === true, + })) { + const normalized = normalizeExtensionHostMediaProviderId(provider); candidates.push({ provider: normalized, model: resolveExtensionHostMediaCandidateModel({ @@ -57,11 +70,6 @@ export function resolveExtensionHostMediaProviderCandidates(params: { activeModel: params.activeModel, }), }); - }; - - pushCandidate(params.activeModel?.provider); - for (const providerId of listExtensionHostMediaAutoRuntimeBackendIds(params.capability)) { - pushCandidate(providerId); } return candidates; diff --git a/src/extension-host/runtime-backend-policy.test.ts b/src/extension-host/runtime-backend-policy.test.ts new file mode 100644 index 00000000000..e029769ae45 --- /dev/null +++ b/src/extension-host/runtime-backend-policy.test.ts @@ -0,0 +1,92 @@ +import { describe, expect, it } from "vitest"; +import { resolveExtensionHostRuntimeBackendIdsByPolicy } from "./runtime-backend-policy.js"; + +const entries = [ + { + id: "capability.runtime-backend:media.image:openai", + family: "capability.runtime-backend", + subsystemId: "media.image", + backendId: "openai", + source: "builtin", + defaultRank: 0, + selectorKeys: ["openai"], + capabilities: ["image"], + metadata: { autoSelectable: true }, + }, + { + id: "capability.runtime-backend:media.image:google", + family: "capability.runtime-backend", + subsystemId: "media.image", + backendId: "google", + source: "builtin", + defaultRank: 1, + selectorKeys: ["google", "gemini"], + capabilities: ["image"], + metadata: { autoSelectable: true }, + }, + { + id: "capability.runtime-backend:media.image:custom", + family: "capability.runtime-backend", + subsystemId: "media.image", + backendId: "custom", + source: "builtin", + defaultRank: 2, + selectorKeys: ["custom"], + capabilities: ["image"], + metadata: { autoSelectable: false }, + }, + { + id: "capability.runtime-backend:tts:edge", + family: "capability.runtime-backend", + subsystemId: "tts", + backendId: "edge", + source: "builtin", + defaultRank: 0, + selectorKeys: ["edge"], + capabilities: ["tts.synthesis"], + }, + { + id: "capability.runtime-backend:tts:openai", + family: "capability.runtime-backend", + subsystemId: "tts", + backendId: "openai", + source: "builtin", + defaultRank: 1, + selectorKeys: ["openai"], + capabilities: ["tts.synthesis", "tts.telephony"], + }, +] as const; + +describe("runtime-backend-policy", () => { + it("resolves the default-ranked filtered chain when no preferred backend is provided", () => { + expect( + resolveExtensionHostRuntimeBackendIdsByPolicy({ + entries, + subsystemId: "media.image", + include: (entry) => entry.metadata?.autoSelectable === true, + }), + ).toEqual(["openai", "google"]); + }); + + it("keeps the preferred backend first even when it is outside the filtered chain", () => { + expect( + resolveExtensionHostRuntimeBackendIdsByPolicy({ + entries, + subsystemId: "media.image", + preferredBackendId: "missing-provider", + include: (entry) => entry.metadata?.autoSelectable === true, + }), + ).toEqual(["missing-provider", "openai", "google"]); + }); + + it("falls back to an explicit backend id when no filtered default exists", () => { + expect( + resolveExtensionHostRuntimeBackendIdsByPolicy({ + entries, + subsystemId: "tts", + include: () => false, + fallbackBackendId: "edge", + }), + ).toEqual(["edge"]); + }); +}); diff --git a/src/extension-host/runtime-backend-policy.ts b/src/extension-host/runtime-backend-policy.ts new file mode 100644 index 00000000000..f2a87a6d997 --- /dev/null +++ b/src/extension-host/runtime-backend-policy.ts @@ -0,0 +1,47 @@ +import { + resolveExtensionHostDefaultRuntimeBackendIdByArbitration, + resolveExtensionHostRuntimeBackendFallbackChainByArbitration, +} from "./runtime-backend-arbitration.js"; +import type { + ExtensionHostRuntimeBackendCatalogEntry, + ExtensionHostRuntimeBackendSubsystemId, +} from "./runtime-backend-catalog.js"; + +type ExtensionHostRuntimeBackendPolicyPredicate = ( + entry: ExtensionHostRuntimeBackendCatalogEntry, +) => boolean; + +export function resolveExtensionHostRuntimeBackendIdsByPolicy(params: { + entries: readonly ExtensionHostRuntimeBackendCatalogEntry[]; + subsystemId: ExtensionHostRuntimeBackendSubsystemId; + preferredBackendId?: string; + include?: ExtensionHostRuntimeBackendPolicyPredicate; + fallbackBackendId?: string; +}): readonly string[] { + const preferredBackendId = params.preferredBackendId?.trim(); + if (preferredBackendId) { + return resolveExtensionHostRuntimeBackendFallbackChainByArbitration({ + entries: params.entries, + subsystemId: params.subsystemId, + preferredBackendId, + include: params.include, + }); + } + + const defaultBackendId = resolveExtensionHostDefaultRuntimeBackendIdByArbitration({ + entries: params.entries, + subsystemId: params.subsystemId, + include: params.include, + fallbackBackendId: params.fallbackBackendId, + }); + if (!defaultBackendId) { + return []; + } + + return resolveExtensionHostRuntimeBackendFallbackChainByArbitration({ + entries: params.entries, + subsystemId: params.subsystemId, + preferredBackendId: defaultBackendId, + include: params.include, + }); +} diff --git a/src/extension-host/tts-runtime-policy.ts b/src/extension-host/tts-runtime-policy.ts index 3a27c395be0..68242a0ff6d 100644 --- a/src/extension-host/tts-runtime-policy.ts +++ b/src/extension-host/tts-runtime-policy.ts @@ -1,12 +1,9 @@ import type { TtsProvider } from "../config/types.tts.js"; -import { - resolveExtensionHostDefaultRuntimeBackendIdByArbitration, - resolveExtensionHostRuntimeBackendFallbackChainByArbitration, -} from "./runtime-backend-arbitration.js"; import { listExtensionHostTtsRuntimeBackendCatalogEntries, type ExtensionHostRuntimeBackendCatalogEntry, } from "./runtime-backend-catalog.js"; +import { resolveExtensionHostRuntimeBackendIdsByPolicy } from "./runtime-backend-policy.js"; import type { ResolvedTtsConfig } from "./tts-config.js"; import { isExtensionHostTtsProviderConfigured } from "./tts-runtime-registry.js"; @@ -18,19 +15,19 @@ function isConfiguredTtsRuntimeBackend( } export function resolveExtensionHostDefaultTtsProvider(config: ResolvedTtsConfig): TtsProvider { - return (resolveExtensionHostDefaultRuntimeBackendIdByArbitration({ + return (resolveExtensionHostRuntimeBackendIdsByPolicy({ entries: listExtensionHostTtsRuntimeBackendCatalogEntries(), subsystemId: "tts", include: (entry) => isConfiguredTtsRuntimeBackend(config, entry), fallbackBackendId: "edge", - }) ?? "edge") as TtsProvider; + })[0] ?? "edge") as TtsProvider; } export function resolveExtensionHostTtsFallbackProviders(params: { config: ResolvedTtsConfig; preferredProvider: TtsProvider; }): readonly TtsProvider[] { - return resolveExtensionHostRuntimeBackendFallbackChainByArbitration({ + return resolveExtensionHostRuntimeBackendIdsByPolicy({ entries: listExtensionHostTtsRuntimeBackendCatalogEntries(), subsystemId: "tts", preferredBackendId: params.preferredProvider,