From 5a850f34bf95168a9b83f07c58c0fec2d329c818 Mon Sep 17 00:00:00 2001 From: Gustavo Madeira Santana Date: Sun, 15 Mar 2026 21:45:38 +0000 Subject: [PATCH] Media: add runtime selection policy --- src/extension-host/media-runtime-auto.ts | 67 +++-------------- .../media-runtime-policy.test.ts | 72 +++++++++++++++++++ src/extension-host/media-runtime-policy.ts | 68 ++++++++++++++++++ 3 files changed, 150 insertions(+), 57 deletions(-) create mode 100644 src/extension-host/media-runtime-policy.test.ts create mode 100644 src/extension-host/media-runtime-policy.ts diff --git a/src/extension-host/media-runtime-auto.ts b/src/extension-host/media-runtime-auto.ts index a9a62a7c384..35825eec33b 100644 --- a/src/extension-host/media-runtime-auto.ts +++ b/src/extension-host/media-runtime-auto.ts @@ -19,9 +19,9 @@ import { extractGeminiResponse } from "../media-understanding/output-extract.js" import type { MediaUnderstandingCapability } from "../media-understanding/types.js"; import { runExec } from "../process/exec.js"; import { - listExtensionHostMediaAutoRuntimeBackendIds, - resolveExtensionHostMediaRuntimeDefaultModel, -} from "./runtime-backend-catalog.js"; + resolveExtensionHostMediaProviderCandidates, + type ExtensionHostMediaActiveModel, +} from "./media-runtime-policy.js"; export type ActiveMediaModel = { provider: string; @@ -313,7 +313,7 @@ async function resolveKeyEntry(params: { agentDir?: string; providerRegistry: ProviderRegistry; capability: MediaUnderstandingCapability; - activeModel?: ActiveMediaModel; + activeModel?: ExtensionHostMediaActiveModel; }): Promise { const { cfg, agentDir, providerRegistry, capability } = params; const checkProvider = async ( @@ -341,53 +341,11 @@ async function resolveKeyEntry(params: { } }; - if (capability === "image") { - const activeProvider = params.activeModel?.provider?.trim(); - if (activeProvider) { - const activeEntry = await checkProvider(activeProvider, params.activeModel?.model); - if (activeEntry) { - return activeEntry; - } - } - for (const providerId of listExtensionHostMediaAutoRuntimeBackendIds("image")) { - const model = resolveExtensionHostMediaRuntimeDefaultModel({ - capability: "image", - backendId: providerId, - }); - const entry = await checkProvider(providerId, model); - if (entry) { - return entry; - } - } - return null; - } - - if (capability === "video") { - const activeProvider = params.activeModel?.provider?.trim(); - if (activeProvider) { - const activeEntry = await checkProvider(activeProvider, params.activeModel?.model); - if (activeEntry) { - return activeEntry; - } - } - for (const providerId of listExtensionHostMediaAutoRuntimeBackendIds("video")) { - const entry = await checkProvider(providerId, undefined); - if (entry) { - return entry; - } - } - return null; - } - - const activeProvider = params.activeModel?.provider?.trim(); - if (activeProvider) { - const activeEntry = await checkProvider(activeProvider, params.activeModel?.model); - if (activeEntry) { - return activeEntry; - } - } - for (const providerId of listExtensionHostMediaAutoRuntimeBackendIds("audio")) { - const entry = await checkProvider(providerId, undefined); + for (const candidate of resolveExtensionHostMediaProviderCandidates({ + capability, + activeModel: params.activeModel, + })) { + const entry = await checkProvider(candidate.provider, candidate.model); if (entry) { return entry; } @@ -472,12 +430,7 @@ export async function resolveAutoImageModel(params: { if (!provider) { return null; } - const model = - entry.model ?? - resolveExtensionHostMediaRuntimeDefaultModel({ - capability: "image", - backendId: provider, - }); + const model = entry.model; if (!model) { return null; } diff --git a/src/extension-host/media-runtime-policy.test.ts b/src/extension-host/media-runtime-policy.test.ts new file mode 100644 index 00000000000..4442fd5cd67 --- /dev/null +++ b/src/extension-host/media-runtime-policy.test.ts @@ -0,0 +1,72 @@ +import { describe, expect, it, vi } from "vitest"; + +vi.mock("./runtime-backend-catalog.js", () => ({ + listExtensionHostMediaAutoRuntimeBackendIds: vi.fn( + (capability: "audio" | "image" | "video") => + ({ + audio: ["deepgram", "openai"], + image: ["openai", "google"], + video: ["openai"], + })[capability], + ), + resolveExtensionHostMediaRuntimeDefaultModel: vi.fn( + (params: { capability: "audio" | "image" | "video"; backendId: string }) => + params.capability === "image" ? `${params.backendId}-default` : undefined, + ), +})); + +vi.mock("./media-runtime-registry.js", () => ({ + normalizeExtensionHostMediaProviderId: vi.fn((id: string) => + id.trim().toLowerCase() === "gemini" ? "google" : id.trim().toLowerCase(), + ), +})); + +import { resolveExtensionHostMediaProviderCandidates } from "./media-runtime-policy.js"; + +describe("media-runtime-policy", () => { + it("puts the active provider first and keeps the configured model", () => { + expect( + resolveExtensionHostMediaProviderCandidates({ + capability: "image", + activeModel: { + provider: "Google", + model: "gemini-2.5-flash", + }, + }), + ).toEqual([ + { provider: "google", model: "gemini-2.5-flash" }, + { provider: "openai", model: "openai-default" }, + ]); + }); + + it("uses catalog-backed defaults for fallback image providers", () => { + expect( + resolveExtensionHostMediaProviderCandidates({ + capability: "image", + activeModel: { + provider: "missing-provider", + model: "ignored", + }, + }), + ).toEqual([ + { provider: "missing-provider", model: "ignored" }, + { provider: "openai", model: "openai-default" }, + { provider: "google", model: "google-default" }, + ]); + }); + + it("keeps non-image fallback candidates model-free", () => { + expect( + resolveExtensionHostMediaProviderCandidates({ + capability: "audio", + activeModel: { + provider: "openai", + model: "gpt-4o-mini-transcribe", + }, + }), + ).toEqual([ + { provider: "openai", model: "gpt-4o-mini-transcribe" }, + { provider: "deepgram", model: undefined }, + ]); + }); +}); diff --git a/src/extension-host/media-runtime-policy.ts b/src/extension-host/media-runtime-policy.ts new file mode 100644 index 00000000000..ba9dd5e0e6d --- /dev/null +++ b/src/extension-host/media-runtime-policy.ts @@ -0,0 +1,68 @@ +import type { MediaUnderstandingCapability } from "../media-understanding/types.js"; +import { normalizeExtensionHostMediaProviderId } from "./media-runtime-registry.js"; +import { + listExtensionHostMediaAutoRuntimeBackendIds, + resolveExtensionHostMediaRuntimeDefaultModel, +} from "./runtime-backend-catalog.js"; + +export type ExtensionHostMediaActiveModel = { + provider: string; + model?: string; +}; + +export type ExtensionHostMediaProviderCandidate = { + provider: string; + model?: string; +}; + +function resolveExtensionHostMediaCandidateModel(params: { + capability: MediaUnderstandingCapability; + provider: string; + activeModel?: ExtensionHostMediaActiveModel; +}): string | undefined { + const activeProvider = params.activeModel?.provider?.trim(); + if ( + activeProvider && + normalizeExtensionHostMediaProviderId(activeProvider) === + normalizeExtensionHostMediaProviderId(params.provider) + ) { + return params.activeModel?.model; + } + return resolveExtensionHostMediaRuntimeDefaultModel({ + capability: params.capability, + backendId: params.provider, + }); +} + +export function resolveExtensionHostMediaProviderCandidates(params: { + capability: MediaUnderstandingCapability; + activeModel?: ExtensionHostMediaActiveModel; +}): readonly ExtensionHostMediaProviderCandidate[] { + const candidates: ExtensionHostMediaProviderCandidate[] = []; + const seen = new Set(); + + const pushCandidate = (provider: string | undefined): void => { + const normalized = provider?.trim() + ? normalizeExtensionHostMediaProviderId(provider) + : undefined; + if (!normalized || seen.has(normalized)) { + return; + } + seen.add(normalized); + candidates.push({ + provider: normalized, + model: resolveExtensionHostMediaCandidateModel({ + capability: params.capability, + provider: normalized, + activeModel: params.activeModel, + }), + }); + }; + + pushCandidate(params.activeModel?.provider); + for (const providerId of listExtensionHostMediaAutoRuntimeBackendIds(params.capability)) { + pushCandidate(providerId); + } + + return candidates; +}