diff --git a/src/agents/model-selection.ts b/src/agents/model-selection.ts index 0f8f5568618..2013d7a7828 100644 --- a/src/agents/model-selection.ts +++ b/src/agents/model-selection.ts @@ -60,45 +60,6 @@ export function legacyModelKey(provider: string, model: string): string | null { return rawKey === canonicalKey ? null : rawKey; } -export function normalizeProviderId(provider: string): string { - const normalized = provider.trim().toLowerCase(); - if (normalized === "z.ai" || normalized === "z-ai") { - return "zai"; - } - if (normalized === "opencode-zen") { - return "opencode"; - } - if (normalized === "opencode-go-auth") { - return "opencode-go"; - } - if (normalized === "qwen") { - return "qwen-portal"; - } - if (normalized === "kimi-code") { - return "kimi-coding"; - } - if (normalized === "bedrock" || normalized === "aws-bedrock") { - return "amazon-bedrock"; - } - // Backward compatibility for older provider naming. - if (normalized === "bytedance" || normalized === "doubao") { - return "volcengine"; - } - return normalized; -} - -/** Normalize provider ID for auth lookup. Coding-plan variants share auth with base. */ -export function normalizeProviderIdForAuth(provider: string): string { - const normalized = normalizeProviderId(provider); - if (normalized === "volcengine-plan") { - return "volcengine"; - } - if (normalized === "byteplus-plan") { - return "byteplus"; - } - return normalized; -} - export function findNormalizedProviderValue( entries: Record | undefined, provider: string, diff --git a/src/agents/provider-id.test.ts b/src/agents/provider-id.test.ts new file mode 100644 index 00000000000..9fee7012c8c --- /dev/null +++ b/src/agents/provider-id.test.ts @@ -0,0 +1,24 @@ +import { describe, expect, it } from "vitest"; +import { normalizeProviderId, normalizeProviderIdForAuth } from "./provider-id.js"; + +describe("normalizeProviderId", () => { + it("applies provider aliases without pulling heavier model-selection dependencies", () => { + expect(normalizeProviderId("Anthropic")).toBe("anthropic"); + expect(normalizeProviderId("Z.ai")).toBe("zai"); + expect(normalizeProviderId("z-ai")).toBe("zai"); + expect(normalizeProviderId("OpenCode-Zen")).toBe("opencode"); + expect(normalizeProviderId("qwen")).toBe("qwen-portal"); + expect(normalizeProviderId("kimi-code")).toBe("kimi-coding"); + expect(normalizeProviderId("bedrock")).toBe("amazon-bedrock"); + expect(normalizeProviderId("aws-bedrock")).toBe("amazon-bedrock"); + expect(normalizeProviderId("doubao")).toBe("volcengine"); + }); +}); + +describe("normalizeProviderIdForAuth", () => { + it("maps coding-plan variants back to their base auth providers", () => { + expect(normalizeProviderIdForAuth("volcengine-plan")).toBe("volcengine"); + expect(normalizeProviderIdForAuth("byteplus-plan")).toBe("byteplus"); + expect(normalizeProviderIdForAuth("anthropic")).toBe("anthropic"); + }); +}); diff --git a/src/agents/provider-id.ts b/src/agents/provider-id.ts new file mode 100644 index 00000000000..6259e878264 --- /dev/null +++ b/src/agents/provider-id.ts @@ -0,0 +1,38 @@ +export function normalizeProviderId(provider: string): string { + const normalized = provider.trim().toLowerCase(); + if (normalized === "z.ai" || normalized === "z-ai") { + return "zai"; + } + if (normalized === "opencode-zen") { + return "opencode"; + } + if (normalized === "opencode-go-auth") { + return "opencode-go"; + } + if (normalized === "qwen") { + return "qwen-portal"; + } + if (normalized === "kimi-code") { + return "kimi-coding"; + } + if (normalized === "bedrock" || normalized === "aws-bedrock") { + return "amazon-bedrock"; + } + // Backward compatibility for older provider naming. + if (normalized === "bytedance" || normalized === "doubao") { + return "volcengine"; + } + return normalized; +} + +/** Normalize provider ID for auth lookup. Coding-plan variants share auth with base. */ +export function normalizeProviderIdForAuth(provider: string): string { + const normalized = normalizeProviderId(provider); + if (normalized === "volcengine-plan") { + return "volcengine"; + } + if (normalized === "byteplus-plan") { + return "byteplus"; + } + return normalized; +} diff --git a/src/extension-host/provider-discovery.test.ts b/src/extension-host/provider-discovery.test.ts new file mode 100644 index 00000000000..dec3d6a96d0 --- /dev/null +++ b/src/extension-host/provider-discovery.test.ts @@ -0,0 +1,107 @@ +import { describe, expect, it } from "vitest"; +import type { ModelProviderConfig } from "../config/types.js"; +import type { ProviderDiscoveryOrder, ProviderPlugin } from "../plugins/types.js"; +import { + groupExtensionHostDiscoveryProvidersByOrder, + normalizeExtensionHostDiscoveryResult, + resolveExtensionHostDiscoveryProviders, +} from "./provider-discovery.js"; + +function makeProvider(params: { + id: string; + label?: string; + order?: ProviderDiscoveryOrder; + discovery?: boolean; +}): ProviderPlugin { + return { + id: params.id, + label: params.label ?? params.id, + auth: [], + ...(params.discovery === false + ? {} + : { + discovery: { + ...(params.order ? { order: params.order } : {}), + run: async () => null, + }, + }), + }; +} + +function makeModelProviderConfig(overrides?: Partial): ModelProviderConfig { + return { + baseUrl: "http://127.0.0.1:8000/v1", + models: [], + ...overrides, + }; +} + +describe("resolveExtensionHostDiscoveryProviders", () => { + it("keeps only providers with discovery handlers", () => { + expect( + resolveExtensionHostDiscoveryProviders([ + makeProvider({ id: "simple" }), + makeProvider({ id: "hidden", discovery: false }), + ]).map((provider) => provider.id), + ).toEqual(["simple"]); + }); +}); + +describe("groupExtensionHostDiscoveryProvidersByOrder", () => { + it("groups providers by declared order and sorts labels within each group", () => { + const grouped = groupExtensionHostDiscoveryProvidersByOrder([ + makeProvider({ id: "late-b", label: "Zulu" }), + makeProvider({ id: "late-a", label: "Alpha" }), + makeProvider({ id: "paired", label: "Paired", order: "paired" }), + makeProvider({ id: "profile", label: "Profile", order: "profile" }), + makeProvider({ id: "simple", label: "Simple", order: "simple" }), + ]); + + expect(grouped.simple.map((provider) => provider.id)).toEqual(["simple"]); + expect(grouped.profile.map((provider) => provider.id)).toEqual(["profile"]); + expect(grouped.paired.map((provider) => provider.id)).toEqual(["paired"]); + expect(grouped.late.map((provider) => provider.id)).toEqual(["late-a", "late-b"]); + }); +}); + +describe("normalizeExtensionHostDiscoveryResult", () => { + it("maps a single provider result to the provider id", () => { + const provider = makeProvider({ id: "Ollama" }); + const normalized = normalizeExtensionHostDiscoveryResult({ + provider, + result: { + provider: makeModelProviderConfig({ + baseUrl: "http://127.0.0.1:11434", + api: "ollama", + }), + }, + }); + + expect(normalized).toEqual({ + ollama: { + baseUrl: "http://127.0.0.1:11434", + api: "ollama", + models: [], + }, + }); + }); + + it("normalizes keys for multi-provider discovery results", () => { + const normalized = normalizeExtensionHostDiscoveryResult({ + provider: makeProvider({ id: "ignored" }), + result: { + providers: { + " VLLM ": makeModelProviderConfig(), + "": makeModelProviderConfig({ baseUrl: "http://ignored" }), + }, + }, + }); + + expect(normalized).toEqual({ + vllm: { + baseUrl: "http://127.0.0.1:8000/v1", + models: [], + }, + }); + }); +}); diff --git a/src/extension-host/provider-discovery.ts b/src/extension-host/provider-discovery.ts new file mode 100644 index 00000000000..0eb59b14d5a --- /dev/null +++ b/src/extension-host/provider-discovery.ts @@ -0,0 +1,61 @@ +import { normalizeProviderId } from "../agents/provider-id.js"; +import type { ModelProviderConfig } from "../config/types.js"; +import type { ProviderDiscoveryOrder, ProviderPlugin } from "../plugins/types.js"; + +const DISCOVERY_ORDER: readonly ProviderDiscoveryOrder[] = ["simple", "profile", "paired", "late"]; + +export function resolveExtensionHostDiscoveryProviders( + providers: ProviderPlugin[], +): ProviderPlugin[] { + return providers.filter((provider) => provider.discovery); +} + +export function groupExtensionHostDiscoveryProvidersByOrder( + providers: ProviderPlugin[], +): Record { + const grouped = { + simple: [], + profile: [], + paired: [], + late: [], + } as Record; + + for (const provider of providers) { + const order = provider.discovery?.order ?? "late"; + grouped[order].push(provider); + } + + for (const order of DISCOVERY_ORDER) { + grouped[order].sort((a, b) => a.label.localeCompare(b.label)); + } + + return grouped; +} + +export function normalizeExtensionHostDiscoveryResult(params: { + provider: ProviderPlugin; + result: + | { provider: ModelProviderConfig } + | { providers: Record } + | null + | undefined; +}): Record { + const result = params.result; + if (!result) { + return {}; + } + + if ("provider" in result) { + return { [normalizeProviderId(params.provider.id)]: result.provider }; + } + + const normalized: Record = {}; + for (const [key, value] of Object.entries(result.providers)) { + const normalizedKey = normalizeProviderId(key); + if (!normalizedKey || !value) { + continue; + } + normalized[normalizedKey] = value; + } + return normalized; +} diff --git a/src/plugins/provider-discovery.ts b/src/plugins/provider-discovery.ts index e249bf6e45a..1fb95d221a5 100644 --- a/src/plugins/provider-discovery.ts +++ b/src/plugins/provider-discovery.ts @@ -1,6 +1,10 @@ -import { normalizeProviderId } from "../agents/model-selection.js"; import type { OpenClawConfig } from "../config/config.js"; import type { ModelProviderConfig } from "../config/types.js"; +import { + groupExtensionHostDiscoveryProvidersByOrder, + normalizeExtensionHostDiscoveryResult, + resolveExtensionHostDiscoveryProviders, +} from "../extension-host/provider-discovery.js"; import { resolvePluginProviders } from "./providers.js"; import type { ProviderDiscoveryOrder, ProviderPlugin } from "./types.js"; @@ -51,24 +55,7 @@ export function normalizePluginDiscoveryResult(params: { | null | undefined; }): Record { - const result = params.result; - if (!result) { - return {}; - } - - if ("provider" in result) { - return { [normalizeProviderId(params.provider.id)]: result.provider }; - } - - const normalized: Record = {}; - for (const [key, value] of Object.entries(result.providers)) { - const normalizedKey = normalizeProviderId(key); - if (!normalizedKey || !value) { - continue; - } - normalized[normalizedKey] = value; - } - return normalized; + return normalizeExtensionHostDiscoveryResult(params); } export function runProviderCatalog(params: {