diff --git a/src/extension-host/contributions/acp-runtime-backend-registry.test.ts b/src/extension-host/contributions/acp-runtime-backend-registry.test.ts new file mode 100644 index 00000000000..a178d987c96 --- /dev/null +++ b/src/extension-host/contributions/acp-runtime-backend-registry.test.ts @@ -0,0 +1,85 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { AcpRuntimeError } from "../../acp/runtime/errors.js"; +import type { AcpRuntime } from "../../acp/runtime/types.js"; +import { + __testing, + getExtensionHostAcpRuntimeBackend, + registerExtensionHostAcpRuntimeBackend, + requireExtensionHostAcpRuntimeBackend, + unregisterExtensionHostAcpRuntimeBackend, +} from "./acp-runtime-backend-registry.js"; + +function createRuntimeStub(): AcpRuntime { + return { + ensureSession: vi.fn(async (input) => ({ + sessionKey: input.sessionKey, + backend: "stub", + runtimeSessionName: `${input.sessionKey}:runtime`, + })), + runTurn: vi.fn(async function* () {}), + cancel: vi.fn(async () => {}), + close: vi.fn(async () => {}), + }; +} + +describe("extension host acp runtime backend registry", () => { + beforeEach(() => { + __testing.resetExtensionHostAcpRuntimeBackendsForTests(); + }); + + it("registers and resolves backends by id", () => { + const runtime = createRuntimeStub(); + registerExtensionHostAcpRuntimeBackend({ id: "acpx", runtime }); + + const backend = getExtensionHostAcpRuntimeBackend("acpx"); + expect(backend?.id).toBe("acpx"); + expect(backend?.runtime).toBe(runtime); + }); + + it("prefers a healthy backend when resolving without explicit id", () => { + registerExtensionHostAcpRuntimeBackend({ + id: "unhealthy", + runtime: createRuntimeStub(), + healthy: () => false, + }); + registerExtensionHostAcpRuntimeBackend({ + id: "healthy", + runtime: createRuntimeStub(), + healthy: () => true, + }); + + expect(getExtensionHostAcpRuntimeBackend()?.id).toBe("healthy"); + }); + + it("throws typed errors for missing or unavailable backends", () => { + expect(() => requireExtensionHostAcpRuntimeBackend()).toThrowError(AcpRuntimeError); + + registerExtensionHostAcpRuntimeBackend({ + id: "acpx", + runtime: createRuntimeStub(), + healthy: () => false, + }); + + try { + requireExtensionHostAcpRuntimeBackend("acpx"); + throw new Error("expected requireExtensionHostAcpRuntimeBackend to throw"); + } catch (error) { + expect(error).toBeInstanceOf(AcpRuntimeError); + expect((error as AcpRuntimeError).code).toBe("ACP_BACKEND_UNAVAILABLE"); + } + }); + + it("shares backend state globally for cross-loader access", () => { + const runtime = createRuntimeStub(); + const sharedState = __testing.getExtensionHostAcpRuntimeRegistryGlobalStateForTests(); + + sharedState.backendsById.set("acpx", { + id: "acpx", + runtime, + }); + + expect(getExtensionHostAcpRuntimeBackend("acpx")?.runtime).toBe(runtime); + unregisterExtensionHostAcpRuntimeBackend("acpx"); + expect(getExtensionHostAcpRuntimeBackend("acpx")).toBeNull(); + }); +}); diff --git a/src/extension-host/contributions/acp-runtime-backend-registry.ts b/src/extension-host/contributions/acp-runtime-backend-registry.ts new file mode 100644 index 00000000000..68b16b7b9ba --- /dev/null +++ b/src/extension-host/contributions/acp-runtime-backend-registry.ts @@ -0,0 +1,124 @@ +import { AcpRuntimeError } from "../../acp/runtime/errors.js"; +import type { AcpRuntime } from "../../acp/runtime/types.js"; + +export type ExtensionHostAcpRuntimeBackend = { + id: string; + runtime: AcpRuntime; + healthy?: () => boolean; +}; + +type ExtensionHostAcpRuntimeRegistryGlobalState = { + backendsById: Map; +}; + +const ACP_RUNTIME_REGISTRY_STATE_KEY = Symbol.for("openclaw.acpRuntimeRegistryState"); + +function createExtensionHostAcpRuntimeRegistryGlobalState(): ExtensionHostAcpRuntimeRegistryGlobalState { + return { + backendsById: new Map(), + }; +} + +function resolveExtensionHostAcpRuntimeRegistryGlobalState(): ExtensionHostAcpRuntimeRegistryGlobalState { + const runtimeGlobal = globalThis as typeof globalThis & { + [ACP_RUNTIME_REGISTRY_STATE_KEY]?: ExtensionHostAcpRuntimeRegistryGlobalState; + }; + if (!runtimeGlobal[ACP_RUNTIME_REGISTRY_STATE_KEY]) { + runtimeGlobal[ACP_RUNTIME_REGISTRY_STATE_KEY] = + createExtensionHostAcpRuntimeRegistryGlobalState(); + } + return runtimeGlobal[ACP_RUNTIME_REGISTRY_STATE_KEY]; +} + +const EXTENSION_HOST_ACP_BACKENDS_BY_ID = + resolveExtensionHostAcpRuntimeRegistryGlobalState().backendsById; + +function normalizeBackendId(id: string | undefined): string { + return id?.trim().toLowerCase() || ""; +} + +function isBackendHealthy(backend: ExtensionHostAcpRuntimeBackend): boolean { + if (!backend.healthy) { + return true; + } + try { + return backend.healthy(); + } catch { + return false; + } +} + +export function registerExtensionHostAcpRuntimeBackend( + backend: ExtensionHostAcpRuntimeBackend, +): void { + const id = normalizeBackendId(backend.id); + if (!id) { + throw new Error("ACP runtime backend id is required"); + } + if (!backend.runtime) { + throw new Error(`ACP runtime backend "${id}" is missing runtime implementation`); + } + EXTENSION_HOST_ACP_BACKENDS_BY_ID.set(id, { + ...backend, + id, + }); +} + +export function unregisterExtensionHostAcpRuntimeBackend(id: string): void { + const normalized = normalizeBackendId(id); + if (!normalized) { + return; + } + EXTENSION_HOST_ACP_BACKENDS_BY_ID.delete(normalized); +} + +export function getExtensionHostAcpRuntimeBackend( + id?: string, +): ExtensionHostAcpRuntimeBackend | null { + const normalized = normalizeBackendId(id); + if (normalized) { + return EXTENSION_HOST_ACP_BACKENDS_BY_ID.get(normalized) ?? null; + } + if (EXTENSION_HOST_ACP_BACKENDS_BY_ID.size === 0) { + return null; + } + for (const backend of EXTENSION_HOST_ACP_BACKENDS_BY_ID.values()) { + if (isBackendHealthy(backend)) { + return backend; + } + } + return EXTENSION_HOST_ACP_BACKENDS_BY_ID.values().next().value ?? null; +} + +export function requireExtensionHostAcpRuntimeBackend(id?: string): ExtensionHostAcpRuntimeBackend { + const normalized = normalizeBackendId(id); + const backend = getExtensionHostAcpRuntimeBackend(normalized || undefined); + if (!backend) { + throw new AcpRuntimeError( + "ACP_BACKEND_MISSING", + "ACP runtime backend is not configured. Install and enable the acpx runtime plugin.", + ); + } + if (!isBackendHealthy(backend)) { + throw new AcpRuntimeError( + "ACP_BACKEND_UNAVAILABLE", + "ACP runtime backend is currently unavailable. Try again in a moment.", + ); + } + if (normalized && backend.id !== normalized) { + throw new AcpRuntimeError( + "ACP_BACKEND_MISSING", + `ACP runtime backend "${normalized}" is not registered.`, + ); + } + return backend; +} + +export const __testing = { + resetExtensionHostAcpRuntimeBackendsForTests() { + EXTENSION_HOST_ACP_BACKENDS_BY_ID.clear(); + }, + getExtensionHostAcpRuntimeRegistryGlobalStateForTests() { + return resolveExtensionHostAcpRuntimeRegistryGlobalState(); + }, +}; diff --git a/src/extension-host/contributions/cli-lifecycle.test.ts b/src/extension-host/contributions/cli-lifecycle.test.ts new file mode 100644 index 00000000000..504a9bce6bc --- /dev/null +++ b/src/extension-host/contributions/cli-lifecycle.test.ts @@ -0,0 +1,97 @@ +import { Command } from "commander"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { createEmptyPluginRegistry } from "../../plugins/registry.js"; +import type { PluginLogger } from "../../plugins/types.js"; +import { registerExtensionHostCliCommands } from "./cli-lifecycle.js"; + +function createLogger(): PluginLogger { + return { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }; +} + +describe("registerExtensionHostCliCommands", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("skips overlapping command registrations", () => { + const program = new Command(); + program.command("memory"); + const registry = createEmptyPluginRegistry(); + const memoryRegister = vi.fn(); + const otherRegister = vi.fn(); + registry.cliRegistrars.push( + { + pluginId: "memory-core", + register: memoryRegister, + commands: ["memory"], + source: "bundled", + }, + { + pluginId: "other", + register: otherRegister, + commands: ["other"], + source: "bundled", + }, + ); + const logger = createLogger(); + + registerExtensionHostCliCommands({ + program, + registry, + config: {} as never, + workspaceDir: "/tmp/workspace", + logger, + }); + + expect(memoryRegister).not.toHaveBeenCalled(); + expect(otherRegister).toHaveBeenCalledOnce(); + expect(logger.debug).toHaveBeenCalledWith( + "plugin CLI register skipped (memory-core): command already registered (memory)", + ); + }); + + it("warns on sync and async registration failures", async () => { + const program = new Command(); + const registry = createEmptyPluginRegistry(); + registry.cliRegistrars.push( + { + pluginId: "sync-fail", + register: () => { + throw new Error("sync fail"); + }, + commands: ["sync"], + source: "bundled", + }, + { + pluginId: "async-fail", + register: async () => { + throw new Error("async fail"); + }, + commands: ["async"], + source: "bundled", + }, + ); + const logger = createLogger(); + + registerExtensionHostCliCommands({ + program, + registry, + config: {} as never, + workspaceDir: "/tmp/workspace", + logger, + }); + await Promise.resolve(); + + expect(logger.warn).toHaveBeenCalledWith( + "plugin CLI register failed (sync-fail): Error: sync fail", + ); + expect(logger.warn).toHaveBeenCalledWith( + "plugin CLI register failed (async-fail): Error: async fail", + ); + }); +}); diff --git a/src/extension-host/contributions/cli-lifecycle.ts b/src/extension-host/contributions/cli-lifecycle.ts new file mode 100644 index 00000000000..665db1b2f2e --- /dev/null +++ b/src/extension-host/contributions/cli-lifecycle.ts @@ -0,0 +1,47 @@ +import type { Command } from "commander"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { PluginRegistry } from "../../plugins/registry.js"; +import type { PluginLogger } from "../../plugins/types.js"; +import { listExtensionHostCliRegistrations } from "./runtime-registry.js"; + +export function registerExtensionHostCliCommands(params: { + program: Command; + registry: PluginRegistry; + config: OpenClawConfig; + workspaceDir: string; + logger: PluginLogger; +}): void { + const existingCommands = new Set(params.program.commands.map((cmd) => cmd.name())); + + for (const entry of listExtensionHostCliRegistrations(params.registry)) { + if (entry.commands.length > 0) { + const overlaps = entry.commands.filter((command) => existingCommands.has(command)); + if (overlaps.length > 0) { + params.logger.debug( + `plugin CLI register skipped (${entry.pluginId}): command already registered (${overlaps.join( + ", ", + )})`, + ); + continue; + } + } + try { + const result = entry.register({ + program: params.program, + config: params.config, + workspaceDir: params.workspaceDir, + logger: params.logger, + }); + if (result && typeof result.then === "function") { + void result.catch((err) => { + params.logger.warn(`plugin CLI register failed (${entry.pluginId}): ${String(err)}`); + }); + } + for (const command of entry.commands) { + existingCommands.add(command); + } + } catch (err) { + params.logger.warn(`plugin CLI register failed (${entry.pluginId}): ${String(err)}`); + } + } +} diff --git a/src/extension-host/contributions/command-runtime.test.ts b/src/extension-host/contributions/command-runtime.test.ts new file mode 100644 index 00000000000..fe1cbd6eff2 --- /dev/null +++ b/src/extension-host/contributions/command-runtime.test.ts @@ -0,0 +1,93 @@ +import { afterEach, describe, expect, it } from "vitest"; +import { + clearExtensionHostPluginCommands, + getExtensionHostPluginCommandSpecs, + listExtensionHostPluginCommands, + registerExtensionHostPluginCommand, +} from "./command-runtime.js"; + +afterEach(() => { + clearExtensionHostPluginCommands(); +}); + +describe("extension host command runtime", () => { + it("rejects malformed runtime command shapes", () => { + const invalidName = registerExtensionHostPluginCommand("demo-plugin", { + name: undefined as unknown as string, + description: "Demo", + handler: async () => ({ text: "ok" }), + }); + expect(invalidName).toEqual({ + ok: false, + error: "Command name must be a string", + }); + + const invalidDescription = registerExtensionHostPluginCommand("demo-plugin", { + name: "demo", + description: undefined as unknown as string, + handler: async () => ({ text: "ok" }), + }); + expect(invalidDescription).toEqual({ + ok: false, + error: "Command description must be a string", + }); + }); + + it("normalizes command metadata for downstream consumers", () => { + const result = registerExtensionHostPluginCommand("demo-plugin", { + name: " demo_cmd ", + description: " Demo command ", + handler: async () => ({ text: "ok" }), + }); + expect(result).toEqual({ ok: true }); + expect(listExtensionHostPluginCommands()).toEqual([ + { + name: "demo_cmd", + description: "Demo command", + pluginId: "demo-plugin", + }, + ]); + expect(getExtensionHostPluginCommandSpecs()).toEqual([ + { + name: "demo_cmd", + description: "Demo command", + acceptsArgs: false, + }, + ]); + }); + + it("supports provider-specific native command aliases", () => { + const result = registerExtensionHostPluginCommand("demo-plugin", { + name: "voice", + nativeNames: { + default: "talkvoice", + discord: "discordvoice", + }, + description: "Demo command", + handler: async () => ({ text: "ok" }), + }); + + expect(result).toEqual({ ok: true }); + expect(getExtensionHostPluginCommandSpecs()).toEqual([ + { + name: "talkvoice", + description: "Demo command", + acceptsArgs: false, + }, + ]); + expect(getExtensionHostPluginCommandSpecs("discord")).toEqual([ + { + name: "discordvoice", + description: "Demo command", + acceptsArgs: false, + }, + ]); + expect(getExtensionHostPluginCommandSpecs("telegram")).toEqual([ + { + name: "talkvoice", + description: "Demo command", + acceptsArgs: false, + }, + ]); + }); +}); diff --git a/src/extension-host/contributions/command-runtime.ts b/src/extension-host/contributions/command-runtime.ts new file mode 100644 index 00000000000..afa666fc241 --- /dev/null +++ b/src/extension-host/contributions/command-runtime.ts @@ -0,0 +1,275 @@ +import type { OpenClawConfig } from "../../config/config.js"; +import { logVerbose } from "../../globals.js"; +import type { + OpenClawPluginCommandDefinition, + PluginCommandContext, + PluginCommandResult, +} from "../../plugins/types.js"; + +export type RegisteredExtensionHostPluginCommand = OpenClawPluginCommandDefinition & { + pluginId: string; +}; + +const extensionHostPluginCommands = new Map(); + +let extensionHostCommandRegistryLocked = false; + +const MAX_ARGS_LENGTH = 4096; + +const RESERVED_COMMANDS = new Set([ + "help", + "commands", + "status", + "whoami", + "context", + "btw", + "stop", + "restart", + "reset", + "new", + "compact", + "config", + "debug", + "allowlist", + "activation", + "skill", + "subagents", + "kill", + "steer", + "tell", + "model", + "models", + "queue", + "send", + "bash", + "exec", + "think", + "verbose", + "reasoning", + "elevated", + "usage", +]); + +export type CommandRegistrationResult = { + ok: boolean; + error?: string; +}; + +export function validateExtensionHostCommandName(name: string): string | null { + const trimmed = name.trim().toLowerCase(); + + if (!trimmed) { + return "Command name cannot be empty"; + } + + if (!/^[a-z][a-z0-9_-]*$/.test(trimmed)) { + return "Command name must start with a letter and contain only letters, numbers, hyphens, and underscores"; + } + + if (RESERVED_COMMANDS.has(trimmed)) { + return `Command name "${trimmed}" is reserved by a built-in command`; + } + + return null; +} + +export function registerExtensionHostPluginCommand( + pluginId: string, + command: OpenClawPluginCommandDefinition, +): CommandRegistrationResult { + if (extensionHostCommandRegistryLocked) { + return { ok: false, error: "Cannot register commands while processing is in progress" }; + } + + if (typeof command.handler !== "function") { + return { ok: false, error: "Command handler must be a function" }; + } + + if (typeof command.name !== "string") { + return { ok: false, error: "Command name must be a string" }; + } + + if (typeof command.description !== "string") { + return { ok: false, error: "Command description must be a string" }; + } + + const name = command.name.trim(); + const description = command.description.trim(); + if (!description) { + return { ok: false, error: "Command description cannot be empty" }; + } + + const validationError = validateExtensionHostCommandName(name); + if (validationError) { + return { ok: false, error: validationError }; + } + + const key = `/${name.toLowerCase()}`; + const existing = extensionHostPluginCommands.get(key); + if (existing) { + return { + ok: false, + error: `Command "${name}" already registered by plugin "${existing.pluginId}"`, + }; + } + + extensionHostPluginCommands.set(key, { ...command, name, description, pluginId }); + logVerbose(`Registered plugin command: ${key} (plugin: ${pluginId})`); + return { ok: true }; +} + +export function clearExtensionHostPluginCommands(): void { + extensionHostPluginCommands.clear(); +} + +export function clearExtensionHostPluginCommandsForPlugin(pluginId: string): void { + for (const [key, cmd] of extensionHostPluginCommands.entries()) { + if (cmd.pluginId === pluginId) { + extensionHostPluginCommands.delete(key); + } + } +} + +export function matchExtensionHostPluginCommand( + commandBody: string, +): { command: RegisteredExtensionHostPluginCommand; args?: string } | null { + const trimmed = commandBody.trim(); + if (!trimmed.startsWith("/")) { + return null; + } + + const spaceIndex = trimmed.indexOf(" "); + const commandName = spaceIndex === -1 ? trimmed : trimmed.slice(0, spaceIndex); + const args = spaceIndex === -1 ? undefined : trimmed.slice(spaceIndex + 1).trim(); + + const command = extensionHostPluginCommands.get(commandName.toLowerCase()); + if (!command) { + return null; + } + + if (args && !command.acceptsArgs) { + return null; + } + + return { command, args: args || undefined }; +} + +function sanitizeArgs(args: string | undefined): string | undefined { + if (!args) { + return undefined; + } + + if (args.length > MAX_ARGS_LENGTH) { + return args.slice(0, MAX_ARGS_LENGTH); + } + + let sanitized = ""; + for (const char of args) { + const code = char.charCodeAt(0); + const isControl = (code <= 0x1f && code !== 0x09 && code !== 0x0a) || code === 0x7f; + if (!isControl) { + sanitized += char; + } + } + return sanitized; +} + +export async function executeExtensionHostPluginCommand(params: { + command: RegisteredExtensionHostPluginCommand; + args?: string; + senderId?: string; + channel: string; + channelId?: PluginCommandContext["channelId"]; + isAuthorizedSender: boolean; + commandBody: string; + config: OpenClawConfig; + from?: PluginCommandContext["from"]; + to?: PluginCommandContext["to"]; + accountId?: PluginCommandContext["accountId"]; + messageThreadId?: PluginCommandContext["messageThreadId"]; +}): Promise { + const { command, args, senderId, channel, isAuthorizedSender, commandBody, config } = params; + + const requireAuth = command.requireAuth !== false; + if (requireAuth && !isAuthorizedSender) { + logVerbose( + `Plugin command /${command.name} blocked: unauthorized sender ${senderId || ""}`, + ); + return { text: "⚠️ This command requires authorization." }; + } + + const ctx: PluginCommandContext = { + senderId, + channel, + channelId: params.channelId, + isAuthorizedSender, + args: sanitizeArgs(args), + commandBody, + config, + from: params.from, + to: params.to, + accountId: params.accountId, + messageThreadId: params.messageThreadId, + requestConversationBinding: async () => ({ + status: "error" as const, + message: "Conversation binding is unavailable for this command surface.", + }), + detachConversationBinding: async () => ({ removed: false }), + getCurrentConversationBinding: async () => null, + }; + + extensionHostCommandRegistryLocked = true; + try { + const result = await command.handler(ctx); + logVerbose( + `Plugin command /${command.name} executed successfully for ${senderId || "unknown"}`, + ); + return result; + } catch (err) { + const error = err as Error; + logVerbose(`Plugin command /${command.name} error: ${error.message}`); + return { text: "⚠️ Command failed. Please try again later." }; + } finally { + extensionHostCommandRegistryLocked = false; + } +} + +function resolveExtensionHostPluginNativeName( + command: OpenClawPluginCommandDefinition, + provider?: string, +): string { + const providerName = provider?.trim().toLowerCase(); + const providerOverride = providerName ? command.nativeNames?.[providerName] : undefined; + if (typeof providerOverride === "string" && providerOverride.trim()) { + return providerOverride.trim(); + } + const defaultOverride = command.nativeNames?.default; + if (typeof defaultOverride === "string" && defaultOverride.trim()) { + return defaultOverride.trim(); + } + return command.name; +} + +export function listExtensionHostPluginCommands(): Array<{ + name: string; + description: string; + pluginId: string; +}> { + return Array.from(extensionHostPluginCommands.values()).map((cmd) => ({ + name: cmd.name, + description: cmd.description, + pluginId: cmd.pluginId, + })); +} + +export function getExtensionHostPluginCommandSpecs(provider?: string): Array<{ + name: string; + description: string; + acceptsArgs: boolean; +}> { + return Array.from(extensionHostPluginCommands.values()).map((cmd) => ({ + name: resolveExtensionHostPluginNativeName(cmd, provider), + description: cmd.description, + acceptsArgs: cmd.acceptsArgs ?? false, + })); +} diff --git a/src/extension-host/contributions/context-engine-runtime.test.ts b/src/extension-host/contributions/context-engine-runtime.test.ts new file mode 100644 index 00000000000..abbce28f77d --- /dev/null +++ b/src/extension-host/contributions/context-engine-runtime.test.ts @@ -0,0 +1,42 @@ +import { describe, expect, it } from "vitest"; +import type { ContextEngine } from "../../context-engine/types.js"; +import { + getExtensionHostContextEngineFactory, + listExtensionHostContextEngineIds, + registerExtensionHostContextEngine, +} from "./context-engine-runtime.js"; + +class TestContextEngine implements ContextEngine { + readonly info = { + id: "host-test", + name: "Host Test", + version: "1.0.0", + }; + + async ingest() { + return { ingested: false }; + } + + async assemble(params: { messages: [] }) { + return { messages: params.messages, estimatedTokens: 0 }; + } + + async afterTurn() {} + + async compact() { + return { ok: true, compacted: false, reason: "noop" }; + } +} + +describe("extension host context engine runtime", () => { + it("stores registered context-engine factories in the host-owned runtime", async () => { + const factory = () => new TestContextEngine(); + registerExtensionHostContextEngine("host-test", factory); + + expect(getExtensionHostContextEngineFactory("host-test")).toBe(factory); + expect(listExtensionHostContextEngineIds()).toContain("host-test"); + expect(await getExtensionHostContextEngineFactory("host-test")?.()).toBeInstanceOf( + TestContextEngine, + ); + }); +}); diff --git a/src/extension-host/contributions/context-engine-runtime.ts b/src/extension-host/contributions/context-engine-runtime.ts new file mode 100644 index 00000000000..736115c1c94 --- /dev/null +++ b/src/extension-host/contributions/context-engine-runtime.ts @@ -0,0 +1,60 @@ +import type { OpenClawConfig } from "../../config/config.js"; +import type { ContextEngine } from "../../context-engine/types.js"; +import { getExtensionHostDefaultSlotId } from "../policy/slot-arbitration.js"; + +export type ExtensionHostContextEngineFactory = () => ContextEngine | Promise; + +const CONTEXT_ENGINE_RUNTIME_STATE = Symbol.for("openclaw.contextEngineRegistryState"); + +type ExtensionHostContextEngineRuntimeState = { + engines: Map; +}; + +function getExtensionHostContextEngineRuntimeState(): ExtensionHostContextEngineRuntimeState { + const globalState = globalThis as typeof globalThis & { + [CONTEXT_ENGINE_RUNTIME_STATE]?: ExtensionHostContextEngineRuntimeState; + }; + if (!globalState[CONTEXT_ENGINE_RUNTIME_STATE]) { + globalState[CONTEXT_ENGINE_RUNTIME_STATE] = { + engines: new Map(), + }; + } + return globalState[CONTEXT_ENGINE_RUNTIME_STATE]; +} + +export function registerExtensionHostContextEngine( + id: string, + factory: ExtensionHostContextEngineFactory, +): void { + getExtensionHostContextEngineRuntimeState().engines.set(id, factory); +} + +export function getExtensionHostContextEngineFactory( + id: string, +): ExtensionHostContextEngineFactory | undefined { + return getExtensionHostContextEngineRuntimeState().engines.get(id); +} + +export function listExtensionHostContextEngineIds(): string[] { + return [...getExtensionHostContextEngineRuntimeState().engines.keys()]; +} + +export async function resolveExtensionHostContextEngine( + config?: OpenClawConfig, +): Promise { + const slotValue = config?.plugins?.slots?.contextEngine; + const engineId = + typeof slotValue === "string" && slotValue.trim() + ? slotValue.trim() + : getExtensionHostDefaultSlotId("contextEngine"); + + const factory = getExtensionHostContextEngineRuntimeState().engines.get(engineId); + if (!factory) { + throw new Error( + `Context engine "${engineId}" is not registered. ` + + `Available engines: ${listExtensionHostContextEngineIds().join(", ") || "(none)"}`, + ); + } + + return factory(); +} diff --git a/src/extension-host/contributions/embedding-manager-runtime.test.ts b/src/extension-host/contributions/embedding-manager-runtime.test.ts new file mode 100644 index 00000000000..95af8df8936 --- /dev/null +++ b/src/extension-host/contributions/embedding-manager-runtime.test.ts @@ -0,0 +1,65 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const createEmbeddingProvider = vi.hoisted(() => vi.fn()); +const resolveAgentDir = vi.hoisted(() => vi.fn(() => "/tmp/agent")); + +vi.mock("./embedding-runtime.js", () => ({ + createEmbeddingProvider, +})); + +vi.mock("../agents/agent-scope.js", () => ({ + resolveAgentDir, +})); + +describe("embedding-manager-runtime", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("uses the shared fallback policy for manager fallback activation", async () => { + createEmbeddingProvider.mockResolvedValue({ + provider: { + id: "ollama", + model: "nomic-embed-text", + embedQuery: vi.fn(), + embedBatch: vi.fn(), + }, + ollama: { kind: "ollama" }, + }); + + const { activateEmbeddingManagerFallbackProvider } = + await import("./embedding-manager-runtime.js"); + const result = await activateEmbeddingManagerFallbackProvider({ + cfg: {} as never, + agentId: "main", + settings: { + fallback: "ollama", + model: "text-embedding-3-small", + outputDimensionality: undefined, + remote: undefined, + local: undefined, + }, + state: { + provider: { + id: "openai", + model: "text-embedding-3-small", + embedQuery: vi.fn(), + embedBatch: vi.fn(), + }, + }, + reason: "forced fallback", + }); + + expect(createEmbeddingProvider).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "ollama", + model: "nomic-embed-text", + fallback: "none", + }), + ); + expect(result).toMatchObject({ + fallbackFrom: "openai", + fallbackReason: "forced fallback", + }); + }); +}); diff --git a/src/extension-host/contributions/embedding-manager-runtime.ts b/src/extension-host/contributions/embedding-manager-runtime.ts new file mode 100644 index 00000000000..6ea1acaab3c --- /dev/null +++ b/src/extension-host/contributions/embedding-manager-runtime.ts @@ -0,0 +1,105 @@ +import { resolveAgentDir } from "../../agents/agent-scope.js"; +import type { ResolvedMemorySearchConfig } from "../../agents/memory-search.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import { resolveExtensionHostEmbeddingFallbackPolicy } from "../policy/embedding-runtime-policy.js"; +import { + createEmbeddingProvider, + type EmbeddingProvider, + type EmbeddingProviderId, + type GeminiEmbeddingClient, + type MistralEmbeddingClient, + type OllamaEmbeddingClient, + type OpenAiEmbeddingClient, + type VoyageEmbeddingClient, +} from "./embedding-runtime.js"; + +export type EmbeddingManagerBatchConfig = { + enabled: boolean; + wait: boolean; + concurrency: number; + pollIntervalMs: number; + timeoutMs: number; +}; + +export type EmbeddingManagerRuntimeState = { + provider: EmbeddingProvider | null; + fallbackFrom?: EmbeddingProviderId; + openAi?: OpenAiEmbeddingClient; + gemini?: GeminiEmbeddingClient; + voyage?: VoyageEmbeddingClient; + mistral?: MistralEmbeddingClient; + ollama?: OllamaEmbeddingClient; +}; + +export type EmbeddingManagerFallbackActivation = EmbeddingManagerRuntimeState & { + fallbackFrom: EmbeddingProviderId; + fallbackReason: string; +}; + +export function resolveEmbeddingManagerBatchConfig(params: { + settings: Pick; + state: EmbeddingManagerRuntimeState; +}): EmbeddingManagerBatchConfig { + const batch = params.settings.remote?.batch; + const { provider } = params.state; + const enabled = Boolean( + batch?.enabled && + provider && + ((params.state.openAi && provider.id === "openai") || + (params.state.gemini && provider.id === "gemini") || + (params.state.voyage && provider.id === "voyage")), + ); + return { + enabled, + wait: batch?.wait ?? true, + concurrency: Math.max(1, batch?.concurrency ?? 2), + pollIntervalMs: batch?.pollIntervalMs ?? 2000, + timeoutMs: (batch?.timeoutMinutes ?? 60) * 60 * 1000, + }; +} + +export async function activateEmbeddingManagerFallbackProvider(params: { + cfg: OpenClawConfig; + agentId: string; + settings: Pick< + ResolvedMemorySearchConfig, + "fallback" | "local" | "model" | "outputDimensionality" | "remote" + >; + state: EmbeddingManagerRuntimeState; + reason: string; +}): Promise { + const { provider, fallbackFrom } = params.state; + if (!provider || fallbackFrom) { + return null; + } + const fallbackPolicy = resolveExtensionHostEmbeddingFallbackPolicy({ + requestedProvider: provider.id as EmbeddingProviderId, + fallback: params.settings.fallback, + configuredModel: params.settings.model, + }); + if (!fallbackPolicy) { + return null; + } + + const result = await createEmbeddingProvider({ + config: params.cfg, + agentDir: resolveAgentDir(params.cfg, params.agentId), + provider: fallbackPolicy.provider, + remote: params.settings.remote, + model: fallbackPolicy.model, + outputDimensionality: params.settings.outputDimensionality, + fallback: "none", + local: params.settings.local, + }); + + return { + provider: result.provider, + fallbackFrom: provider.id as EmbeddingProviderId, + fallbackReason: params.reason, + openAi: result.openAi, + gemini: result.gemini, + voyage: result.voyage, + mistral: result.mistral, + ollama: result.ollama, + }; +} diff --git a/src/extension-host/contributions/embedding-reindex-execution.test.ts b/src/extension-host/contributions/embedding-reindex-execution.test.ts new file mode 100644 index 00000000000..52ba9da9352 --- /dev/null +++ b/src/extension-host/contributions/embedding-reindex-execution.test.ts @@ -0,0 +1,112 @@ +import { describe, expect, it, vi } from "vitest"; +import { + resetExtensionHostEmbeddingIndexStore, + runExtensionHostEmbeddingReindexBody, +} from "./embedding-reindex-execution.js"; + +describe("embedding-reindex-execution", () => { + it("runs full reindex syncs, clears dirty flags, and writes metadata", async () => { + const syncMemoryFiles = vi.fn(async () => {}); + const syncSessionFiles = vi.fn(async () => {}); + const setDirty = vi.fn(); + const setSessionsDirty = vi.fn(); + const clearAllSessionDirtyFiles = vi.fn(); + const writeMeta = vi.fn(); + const pruneEmbeddingCacheIfNeeded = vi.fn(); + + const nextMeta = await runExtensionHostEmbeddingReindexBody({ + shouldSyncMemory: true, + shouldSyncSessions: true, + hasDirtySessionFiles: true, + syncMemoryFiles, + syncSessionFiles, + setDirty, + setSessionsDirty, + clearAllSessionDirtyFiles, + buildNextMeta: () => ({ + model: "model", + provider: "openai", + providerKey: "key", + sources: ["memory", "sessions"], + scopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + }), + vectorDims: 1536, + writeMeta, + pruneEmbeddingCacheIfNeeded, + }); + + expect(syncMemoryFiles).toHaveBeenCalledWith({ + needsFullReindex: true, + progress: undefined, + }); + expect(syncSessionFiles).toHaveBeenCalledWith({ + needsFullReindex: true, + progress: undefined, + }); + expect(setDirty).toHaveBeenCalledWith(false); + expect(setSessionsDirty).toHaveBeenCalledWith(false); + expect(clearAllSessionDirtyFiles).toHaveBeenCalled(); + expect(writeMeta).toHaveBeenCalledWith({ + model: "model", + provider: "openai", + providerKey: "key", + sources: ["memory", "sessions"], + scopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + vectorDims: 1536, + }); + expect(pruneEmbeddingCacheIfNeeded).toHaveBeenCalled(); + expect(nextMeta.vectorDims).toBe(1536); + }); + + it("preserves session dirty state when sessions are not reindexed", async () => { + const setSessionsDirty = vi.fn(); + + await runExtensionHostEmbeddingReindexBody({ + shouldSyncMemory: false, + shouldSyncSessions: false, + hasDirtySessionFiles: true, + syncMemoryFiles: vi.fn(async () => {}), + syncSessionFiles: vi.fn(async () => {}), + setDirty: vi.fn(), + setSessionsDirty, + clearAllSessionDirtyFiles: vi.fn(), + buildNextMeta: () => ({ + model: "model", + provider: "openai", + chunkTokens: 200, + chunkOverlap: 20, + }), + writeMeta: vi.fn(), + }); + + expect(setSessionsDirty).toHaveBeenCalledWith(true); + }); + + it("resets the index store and FTS rows when available", () => { + const execSql = vi.fn(); + const dropVectorTable = vi.fn(); + const clearVectorDims = vi.fn(); + const clearAllSessionDirtyFiles = vi.fn(); + + resetExtensionHostEmbeddingIndexStore({ + execSql, + ftsEnabled: true, + ftsAvailable: true, + ftsTable: "chunks_fts", + dropVectorTable, + clearVectorDims, + clearAllSessionDirtyFiles, + }); + + expect(execSql).toHaveBeenNthCalledWith(1, "DELETE FROM files"); + expect(execSql).toHaveBeenNthCalledWith(2, "DELETE FROM chunks"); + expect(execSql).toHaveBeenNthCalledWith(3, "DELETE FROM chunks_fts"); + expect(dropVectorTable).toHaveBeenCalled(); + expect(clearVectorDims).toHaveBeenCalled(); + expect(clearAllSessionDirtyFiles).toHaveBeenCalled(); + }); +}); diff --git a/src/extension-host/contributions/embedding-reindex-execution.ts b/src/extension-host/contributions/embedding-reindex-execution.ts new file mode 100644 index 00000000000..548d3397937 --- /dev/null +++ b/src/extension-host/contributions/embedding-reindex-execution.ts @@ -0,0 +1,80 @@ +import type { EmbeddingIndexMeta } from "./embedding-sync-planning.js"; + +type EmbeddingReindexProgress = unknown; + +type EmbeddingReindexMemoryFiles = (params: { + needsFullReindex: boolean; + progress?: TProgress; +}) => Promise; + +type EmbeddingReindexSessionFiles = (params: { + needsFullReindex: boolean; + progress?: TProgress; +}) => Promise; + +export async function runExtensionHostEmbeddingReindexBody< + TProgress = EmbeddingReindexProgress, +>(params: { + shouldSyncMemory: boolean; + shouldSyncSessions: boolean; + hasDirtySessionFiles: boolean; + progress?: TProgress; + syncMemoryFiles: EmbeddingReindexMemoryFiles; + syncSessionFiles: EmbeddingReindexSessionFiles; + setDirty: (value: boolean) => void; + setSessionsDirty: (value: boolean) => void; + clearAllSessionDirtyFiles: () => void; + buildNextMeta: () => EmbeddingIndexMeta; + vectorDims?: number; + writeMeta: (meta: EmbeddingIndexMeta) => void; + pruneEmbeddingCacheIfNeeded?: () => void; +}): Promise { + if (params.shouldSyncMemory) { + await params.syncMemoryFiles({ + needsFullReindex: true, + progress: params.progress, + }); + params.setDirty(false); + } + + if (params.shouldSyncSessions) { + await params.syncSessionFiles({ + needsFullReindex: true, + progress: params.progress, + }); + params.setSessionsDirty(false); + params.clearAllSessionDirtyFiles(); + } else { + params.setSessionsDirty(params.hasDirtySessionFiles); + } + + const nextMeta = params.buildNextMeta(); + if (params.vectorDims) { + nextMeta.vectorDims = params.vectorDims; + } + + params.writeMeta(nextMeta); + params.pruneEmbeddingCacheIfNeeded?.(); + return nextMeta; +} + +export function resetExtensionHostEmbeddingIndexStore(params: { + execSql: (sql: string) => void; + ftsEnabled: boolean; + ftsAvailable: boolean; + ftsTable: string; + dropVectorTable: () => void; + clearVectorDims: () => void; + clearAllSessionDirtyFiles: () => void; +}): void { + params.execSql("DELETE FROM files"); + params.execSql("DELETE FROM chunks"); + if (params.ftsEnabled && params.ftsAvailable) { + try { + params.execSql(`DELETE FROM ${params.ftsTable}`); + } catch {} + } + params.dropVectorTable(); + params.clearVectorDims(); + params.clearAllSessionDirtyFiles(); +} diff --git a/src/extension-host/contributions/embedding-runtime-registry.test.ts b/src/extension-host/contributions/embedding-runtime-registry.test.ts new file mode 100644 index 00000000000..89f1f4ad818 --- /dev/null +++ b/src/extension-host/contributions/embedding-runtime-registry.test.ts @@ -0,0 +1,115 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const listExtensionHostEmbeddingRemoteRuntimeBackendIds = vi.hoisted(() => + vi.fn(() => ["gemini", "openai"] as const), +); +const createGeminiEmbeddingProvider = vi.hoisted(() => vi.fn()); +const createOpenAiEmbeddingProvider = vi.hoisted(() => vi.fn()); + +vi.mock("../policy/embedding-runtime-policy.js", async () => ({ + ...(await vi.importActual( + "../policy/embedding-runtime-policy.js", + )), + listExtensionHostEmbeddingRemoteRuntimeBackendIds, +})); + +vi.mock("../../memory/embeddings-gemini.js", () => ({ + createGeminiEmbeddingProvider, +})); + +vi.mock("../../memory/embeddings-openai.js", () => ({ + createOpenAiEmbeddingProvider, +})); + +vi.mock("../../memory/embeddings-mistral.js", () => ({ + createMistralEmbeddingProvider: vi.fn(), +})); + +vi.mock("../../memory/embeddings-ollama.js", () => ({ + createOllamaEmbeddingProvider: vi.fn(), +})); + +vi.mock("../../memory/embeddings-voyage.js", () => ({ + createVoyageEmbeddingProvider: vi.fn(), +})); + +vi.mock("../../memory/node-llama.js", () => ({ + importNodeLlamaCpp: vi.fn(), +})); + +describe("extension host embedding runtime registry", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("uses the runtime-backend catalog for auto provider order", async () => { + createGeminiEmbeddingProvider.mockResolvedValue({ + provider: { + id: "gemini", + model: "gemini-embedding-001", + embedQuery: vi.fn(), + embedBatch: vi.fn(), + }, + client: { kind: "gemini" }, + }); + + const { createExtensionHostEmbeddingProvider } = + await import("./embedding-runtime-registry.js"); + const result = await createExtensionHostEmbeddingProvider({ + config: {} as never, + provider: "auto", + model: "gemini-embedding-001", + fallback: "none", + }); + + expect(listExtensionHostEmbeddingRemoteRuntimeBackendIds).toHaveBeenCalledTimes(1); + expect(createGeminiEmbeddingProvider).toHaveBeenCalledTimes(1); + expect(createOpenAiEmbeddingProvider).not.toHaveBeenCalled(); + expect(result.provider?.id).toBe("gemini"); + }); + + it("uses the same catalog order in local setup guidance", async () => { + const { formatExtensionHostLocalEmbeddingSetupError } = + await import("./embedding-runtime-registry.js"); + + const message = formatExtensionHostLocalEmbeddingSetupError( + new Error("Cannot find package 'node-llama-cpp'"), + ); + + expect(listExtensionHostEmbeddingRemoteRuntimeBackendIds).toHaveBeenCalledTimes(1); + expect(message).toContain('agents.defaults.memorySearch.provider = "gemini"'); + expect(message).toContain('agents.defaults.memorySearch.provider = "openai"'); + }); + + it("uses the shared fallback policy for explicit provider fallback requests", async () => { + createOpenAiEmbeddingProvider.mockRejectedValueOnce(new Error("openai failed")); + createGeminiEmbeddingProvider.mockResolvedValueOnce({ + provider: { + id: "gemini", + model: "gemini-embedding-001", + embedQuery: vi.fn(), + embedBatch: vi.fn(), + }, + client: { kind: "gemini" }, + }); + + const { createExtensionHostEmbeddingProvider } = + await import("./embedding-runtime-registry.js"); + const result = await createExtensionHostEmbeddingProvider({ + config: {} as never, + provider: "openai", + model: "text-embedding-3-small", + fallback: "gemini", + }); + + expect(createGeminiEmbeddingProvider).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "gemini", + model: "gemini-embedding-001", + fallback: "none", + }), + ); + expect(result.fallbackFrom).toBe("openai"); + expect(result.provider?.id).toBe("gemini"); + }); +}); diff --git a/src/extension-host/contributions/embedding-runtime-registry.ts b/src/extension-host/contributions/embedding-runtime-registry.ts new file mode 100644 index 00000000000..68581d31505 --- /dev/null +++ b/src/extension-host/contributions/embedding-runtime-registry.ts @@ -0,0 +1,315 @@ +import fsSync from "node:fs"; +import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp"; +import { formatErrorMessage } from "../../infra/errors.js"; +import { sanitizeAndNormalizeEmbedding } from "../../memory/embedding-vectors.js"; +import { + createGeminiEmbeddingProvider, + type GeminiEmbeddingClient, + type GeminiTaskType, +} from "../../memory/embeddings-gemini.js"; +import { + createMistralEmbeddingProvider, + type MistralEmbeddingClient, +} from "../../memory/embeddings-mistral.js"; +import { + createOllamaEmbeddingProvider, + type OllamaEmbeddingClient, +} from "../../memory/embeddings-ollama.js"; +import { + createOpenAiEmbeddingProvider, + type OpenAiEmbeddingClient, +} from "../../memory/embeddings-openai.js"; +import { + createVoyageEmbeddingProvider, + type VoyageEmbeddingClient, +} from "../../memory/embeddings-voyage.js"; +import { importNodeLlamaCpp } from "../../memory/node-llama.js"; +import { resolveUserPath } from "../../utils.js"; +import { + listExtensionHostEmbeddingRemoteRuntimeBackendIds, + resolveExtensionHostEmbeddingFallbackPolicy, +} from "../policy/embedding-runtime-policy.js"; +import { DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL } from "../static/embedding-runtime-backends.js"; +import type { + EmbeddingProvider, + EmbeddingProviderId, + EmbeddingProviderOptions, + EmbeddingProviderResult, +} from "./embedding-runtime-types.js"; + +export type { + GeminiEmbeddingClient, + GeminiTaskType, + MistralEmbeddingClient, + OllamaEmbeddingClient, + OpenAiEmbeddingClient, + VoyageEmbeddingClient, +}; + +export function canAutoSelectExtensionHostLocalEmbedding( + options: EmbeddingProviderOptions, +): boolean { + const modelPath = options.local?.modelPath?.trim(); + if (!modelPath) { + return false; + } + if (/^(hf:|https?:)/i.test(modelPath)) { + return false; + } + const resolved = resolveUserPath(modelPath); + try { + return fsSync.statSync(resolved).isFile(); + } catch { + return false; + } +} + +export function isMissingExtensionHostEmbeddingApiKeyError(err: unknown): boolean { + const message = formatErrorMessage(err); + return message.includes("No API key found for provider"); +} + +async function createExtensionHostLocalEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise { + const modelPath = + options.local?.modelPath?.trim() || DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL; + const modelCacheDir = options.local?.modelCacheDir?.trim(); + + // Lazy-load node-llama-cpp to keep startup light unless local is enabled. + const { getLlama, resolveModelFile, LlamaLogLevel } = await importNodeLlamaCpp(); + + let llama: Llama | null = null; + let embeddingModel: LlamaModel | null = null; + let embeddingContext: LlamaEmbeddingContext | null = null; + let initPromise: Promise | null = null; + + const ensureContext = async (): Promise => { + if (embeddingContext) { + return embeddingContext; + } + if (initPromise) { + return initPromise; + } + initPromise = (async () => { + try { + if (!llama) { + llama = await getLlama({ logLevel: LlamaLogLevel.error }); + } + if (!embeddingModel) { + const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined); + embeddingModel = await llama.loadModel({ modelPath: resolved }); + } + if (!embeddingContext) { + embeddingContext = await embeddingModel.createEmbeddingContext(); + } + return embeddingContext; + } catch (err) { + initPromise = null; + throw err; + } + })(); + return initPromise; + }; + + return { + id: "local", + model: modelPath, + embedQuery: async (text) => { + const ctx = await ensureContext(); + const embedding = await ctx.getEmbeddingFor(text); + return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); + }, + embedBatch: async (texts) => { + const ctx = await ensureContext(); + return Promise.all( + texts.map(async (text) => { + const embedding = await ctx.getEmbeddingFor(text); + return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); + }), + ); + }, + }; +} + +async function createExtensionHostEmbeddingProviderById( + id: EmbeddingProviderId, + options: EmbeddingProviderOptions, +): Promise< + Omit< + EmbeddingProviderResult, + "requestedProvider" | "fallbackFrom" | "fallbackReason" | "providerUnavailableReason" + > +> { + if (id === "local") { + const provider = await createExtensionHostLocalEmbeddingProvider(options); + return { provider }; + } + if (id === "ollama") { + const { provider, client } = await createOllamaEmbeddingProvider(options); + return { provider, ollama: client }; + } + if (id === "gemini") { + const { provider, client } = await createGeminiEmbeddingProvider(options); + return { provider, gemini: client }; + } + if (id === "voyage") { + const { provider, client } = await createVoyageEmbeddingProvider(options); + return { provider, voyage: client }; + } + if (id === "mistral") { + const { provider, client } = await createMistralEmbeddingProvider(options); + return { provider, mistral: client }; + } + const { provider, client } = await createOpenAiEmbeddingProvider(options); + return { provider, openAi: client }; +} + +function formatExtensionHostPrimaryEmbeddingError( + err: unknown, + provider: EmbeddingProviderId, +): string { + return provider === "local" + ? formatExtensionHostLocalEmbeddingSetupError(err) + : formatErrorMessage(err); +} + +export async function createExtensionHostEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise { + const requestedProvider = options.provider; + const fallback = options.fallback; + + if (requestedProvider === "auto") { + const missingKeyErrors: string[] = []; + let localError: string | null = null; + + if (canAutoSelectExtensionHostLocalEmbedding(options)) { + try { + const local = await createExtensionHostEmbeddingProviderById("local", options); + return { ...local, requestedProvider }; + } catch (err) { + localError = formatExtensionHostLocalEmbeddingSetupError(err); + } + } + + for (const provider of listExtensionHostEmbeddingRemoteRuntimeBackendIds()) { + try { + const result = await createExtensionHostEmbeddingProviderById(provider, options); + return { ...result, requestedProvider }; + } catch (err) { + const message = formatExtensionHostPrimaryEmbeddingError(err, provider); + if (isMissingExtensionHostEmbeddingApiKeyError(err)) { + missingKeyErrors.push(message); + continue; + } + const wrapped = new Error(message) as Error & { cause?: unknown }; + wrapped.cause = err; + throw wrapped; + } + } + + const details = [...missingKeyErrors, localError].filter(Boolean) as string[]; + const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available."; + return { + provider: null, + requestedProvider, + providerUnavailableReason: reason, + }; + } + + try { + const primary = await createExtensionHostEmbeddingProviderById(requestedProvider, options); + return { ...primary, requestedProvider }; + } catch (primaryErr) { + const reason = formatExtensionHostPrimaryEmbeddingError(primaryErr, requestedProvider); + const fallbackPolicy = resolveExtensionHostEmbeddingFallbackPolicy({ + requestedProvider, + fallback, + configuredModel: options.model, + }); + if (fallbackPolicy) { + try { + const fallbackResult = await createExtensionHostEmbeddingProviderById( + fallbackPolicy.provider, + { + ...options, + provider: fallbackPolicy.provider, + model: fallbackPolicy.model, + fallback: "none", + }, + ); + return { + ...fallbackResult, + requestedProvider, + fallbackFrom: requestedProvider, + fallbackReason: reason, + }; + } catch (fallbackErr) { + const fallbackReason = formatErrorMessage(fallbackErr); + const combinedReason = `${reason}\n\nFallback to ${fallbackPolicy.provider} failed: ${fallbackReason}`; + if ( + isMissingExtensionHostEmbeddingApiKeyError(primaryErr) && + isMissingExtensionHostEmbeddingApiKeyError(fallbackErr) + ) { + return { + provider: null, + requestedProvider, + fallbackFrom: requestedProvider, + fallbackReason: reason, + providerUnavailableReason: combinedReason, + }; + } + const wrapped = new Error(combinedReason) as Error & { cause?: unknown }; + wrapped.cause = fallbackErr; + throw wrapped; + } + } + if (isMissingExtensionHostEmbeddingApiKeyError(primaryErr)) { + return { + provider: null, + requestedProvider, + providerUnavailableReason: reason, + }; + } + const wrapped = new Error(reason) as Error & { cause?: unknown }; + wrapped.cause = primaryErr; + throw wrapped; + } +} + +function isNodeLlamaCppMissing(err: unknown): boolean { + if (!(err instanceof Error)) { + return false; + } + const code = (err as Error & { code?: unknown }).code; + if (code === "ERR_MODULE_NOT_FOUND") { + return err.message.includes("node-llama-cpp"); + } + return false; +} + +export function formatExtensionHostLocalEmbeddingSetupError(err: unknown): string { + const detail = formatErrorMessage(err); + const missing = isNodeLlamaCppMissing(err); + return [ + "Local embeddings unavailable.", + missing + ? "Reason: optional dependency node-llama-cpp is missing (or failed to install)." + : detail + ? `Reason: ${detail}` + : undefined, + missing && detail ? `Detail: ${detail}` : null, + "To enable local embeddings:", + "1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.16+, remains supported)", + missing + ? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest" + : null, + "3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp", + ...listExtensionHostEmbeddingRemoteRuntimeBackendIds().map( + (provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`, + ), + ] + .filter(Boolean) + .join("\n"); +} diff --git a/src/extension-host/contributions/embedding-runtime-types.ts b/src/extension-host/contributions/embedding-runtime-types.ts new file mode 100644 index 00000000000..12081afa639 --- /dev/null +++ b/src/extension-host/contributions/embedding-runtime-types.ts @@ -0,0 +1,61 @@ +import type { OpenClawConfig } from "../../config/config.js"; +import type { SecretInput } from "../../config/types.secrets.js"; +import type { EmbeddingInput } from "../../memory/embedding-inputs.js"; +import type { GeminiEmbeddingClient, GeminiTaskType } from "../../memory/embeddings-gemini.js"; +import type { MistralEmbeddingClient } from "../../memory/embeddings-mistral.js"; +import type { OllamaEmbeddingClient } from "../../memory/embeddings-ollama.js"; +import type { OpenAiEmbeddingClient } from "../../memory/embeddings-openai.js"; +import type { VoyageEmbeddingClient } from "../../memory/embeddings-voyage.js"; + +export type { GeminiEmbeddingClient } from "../../memory/embeddings-gemini.js"; +export type { MistralEmbeddingClient } from "../../memory/embeddings-mistral.js"; +export type { OpenAiEmbeddingClient } from "../../memory/embeddings-openai.js"; +export type { VoyageEmbeddingClient } from "../../memory/embeddings-voyage.js"; +export type { OllamaEmbeddingClient } from "../../memory/embeddings-ollama.js"; + +export type EmbeddingProvider = { + id: string; + model: string; + maxInputTokens?: number; + embedQuery: (text: string) => Promise; + embedBatch: (texts: string[]) => Promise; + embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise; +}; + +export type EmbeddingProviderId = "openai" | "local" | "gemini" | "voyage" | "mistral" | "ollama"; +export type EmbeddingProviderRequest = EmbeddingProviderId | "auto"; +export type EmbeddingProviderFallback = EmbeddingProviderId | "none"; + +export type EmbeddingProviderResult = { + provider: EmbeddingProvider | null; + requestedProvider: EmbeddingProviderRequest; + fallbackFrom?: EmbeddingProviderId; + fallbackReason?: string; + providerUnavailableReason?: string; + openAi?: OpenAiEmbeddingClient; + gemini?: GeminiEmbeddingClient; + voyage?: VoyageEmbeddingClient; + mistral?: MistralEmbeddingClient; + ollama?: OllamaEmbeddingClient; +}; + +export type EmbeddingProviderOptions = { + config: OpenClawConfig; + agentDir?: string; + provider: EmbeddingProviderRequest; + remote?: { + baseUrl?: string; + apiKey?: SecretInput; + headers?: Record; + }; + model: string; + fallback: EmbeddingProviderFallback; + local?: { + modelPath?: string; + modelCacheDir?: string; + }; + /** Gemini embedding-2: output vector dimensions (768, 1536, or 3072). */ + outputDimensionality?: number; + /** Gemini: override the default task type sent with embedding requests. */ + taskType?: GeminiTaskType; +}; diff --git a/src/extension-host/contributions/embedding-runtime.ts b/src/extension-host/contributions/embedding-runtime.ts new file mode 100644 index 00000000000..54be56e2948 --- /dev/null +++ b/src/extension-host/contributions/embedding-runtime.ts @@ -0,0 +1,28 @@ +import { DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL } from "../static/embedding-runtime-backends.js"; +import { createExtensionHostEmbeddingProvider } from "./embedding-runtime-registry.js"; +import type { + EmbeddingProviderOptions, + EmbeddingProviderResult, +} from "./embedding-runtime-types.js"; + +export type { + EmbeddingProvider, + EmbeddingProviderFallback, + EmbeddingProviderId, + EmbeddingProviderOptions, + EmbeddingProviderRequest, + EmbeddingProviderResult, + GeminiEmbeddingClient, + MistralEmbeddingClient, + OllamaEmbeddingClient, + OpenAiEmbeddingClient, + VoyageEmbeddingClient, +} from "./embedding-runtime-types.js"; + +export const DEFAULT_LOCAL_EMBEDDING_MODEL = DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL; + +export async function createEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise { + return createExtensionHostEmbeddingProvider(options); +} diff --git a/src/extension-host/contributions/embedding-safe-reindex.test.ts b/src/extension-host/contributions/embedding-safe-reindex.test.ts new file mode 100644 index 00000000000..0656d5e8857 --- /dev/null +++ b/src/extension-host/contributions/embedding-safe-reindex.test.ts @@ -0,0 +1,153 @@ +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { + moveExtensionHostIndexFiles, + removeExtensionHostIndexFiles, + runExtensionHostEmbeddingSafeReindex, + swapExtensionHostIndexFiles, +} from "./embedding-safe-reindex.js"; + +async function writeIndexFiles(basePath: string, value: string): Promise { + await fs.writeFile(basePath, `${value}-db`); + await fs.writeFile(`${basePath}-wal`, `${value}-wal`); + await fs.writeFile(`${basePath}-shm`, `${value}-shm`); +} + +async function readIndexFiles(basePath: string): Promise { + return await Promise.all([ + fs.readFile(basePath, "utf8"), + fs.readFile(`${basePath}-wal`, "utf8"), + fs.readFile(`${basePath}-shm`, "utf8"), + ]); +} + +describe("embedding-safe-reindex", () => { + const tempRoots: string[] = []; + + afterEach(async () => { + await Promise.all( + tempRoots.map(async (root) => await fs.rm(root, { recursive: true, force: true })), + ); + tempRoots.length = 0; + }); + + it("moves, swaps, and removes index sidecar files together", async () => { + const root = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-embed-safe-reindex-")); + tempRoots.push(root); + const sourcePath = path.join(root, "source.sqlite"); + const targetPath = path.join(root, "target.sqlite"); + + await writeIndexFiles(sourcePath, "source"); + await moveExtensionHostIndexFiles(sourcePath, targetPath); + await expect(readIndexFiles(targetPath)).resolves.toEqual([ + "source-db", + "source-wal", + "source-shm", + ]); + + await writeIndexFiles(sourcePath, "new-source"); + await swapExtensionHostIndexFiles(targetPath, sourcePath, "backup-id"); + await expect(readIndexFiles(targetPath)).resolves.toEqual([ + "new-source-db", + "new-source-wal", + "new-source-shm", + ]); + + await removeExtensionHostIndexFiles(targetPath); + await expect(fs.stat(targetPath)).rejects.toMatchObject({ code: "ENOENT" }); + }); + + it("runs the safe reindex flow, swaps files, and reopens the active database", async () => { + const root = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-embed-safe-reindex-")); + tempRoots.push(root); + const dbPath = path.join(root, "index.sqlite"); + await writeIndexFiles(dbPath, "active"); + + const closeDatabase = vi.fn(); + const captureOriginalState = vi.fn(() => ({ state: "original" })); + const restoreOriginalState = vi.fn(); + const prepareTempDb = vi.fn(); + const seedEmbeddingCache = vi.fn(); + const reopenAfterSwap = vi.fn(); + + const currentDb = { label: "current" }; + const openDatabaseAtPath = vi.fn((openedPath: string) => { + if (openedPath !== dbPath) { + void writeIndexFiles(openedPath, "temp"); + } + return { label: openedPath }; + }); + + const nextMeta = await runExtensionHostEmbeddingSafeReindex({ + dbPath, + currentDb, + openDatabaseAtPath, + closeDatabase, + captureOriginalState, + restoreOriginalState, + prepareTempDb, + seedEmbeddingCache, + runReindexBody: async () => ({ vectorDims: 1536 }), + reopenAfterSwap, + randomId: () => "temp-id", + }); + + expect(nextMeta).toEqual({ vectorDims: 1536 }); + expect(prepareTempDb).toHaveBeenCalledWith({ label: `${dbPath}.tmp-temp-id` }); + expect(seedEmbeddingCache).toHaveBeenCalledWith(currentDb); + expect(closeDatabase).toHaveBeenCalledTimes(2); + expect(reopenAfterSwap).toHaveBeenCalledWith(dbPath, { vectorDims: 1536 }); + expect(restoreOriginalState).not.toHaveBeenCalled(); + await expect(readIndexFiles(dbPath)).resolves.toEqual(["temp-db", "temp-wal", "temp-shm"]); + }); + + it("restores original state and removes temp files when reindex body fails", async () => { + const root = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-embed-safe-reindex-")); + tempRoots.push(root); + const dbPath = path.join(root, "index.sqlite"); + await writeIndexFiles(dbPath, "active"); + + const currentDb = { label: "current" }; + const restoreOriginalState = vi.fn(); + const closeDatabase = vi.fn(); + const openDatabaseAtPath = vi.fn((openedPath: string) => { + if (openedPath !== dbPath) { + void writeIndexFiles(openedPath, "temp"); + } + return { label: openedPath }; + }); + + await expect( + runExtensionHostEmbeddingSafeReindex({ + dbPath, + currentDb, + openDatabaseAtPath, + closeDatabase, + captureOriginalState: () => ({ state: "original" }), + restoreOriginalState, + prepareTempDb: vi.fn(), + seedEmbeddingCache: vi.fn(), + runReindexBody: async () => { + throw new Error("boom"); + }, + reopenAfterSwap: vi.fn(), + randomId: () => "temp-id", + }), + ).rejects.toThrow("boom"); + + expect(restoreOriginalState).toHaveBeenCalledWith({ + originalDb: currentDb, + originalState: { state: "original" }, + originalDbClosed: false, + dbPath, + }); + await expect(readIndexFiles(dbPath)).resolves.toEqual([ + "active-db", + "active-wal", + "active-shm", + ]); + await expect(fs.stat(`${dbPath}.tmp-temp-id`)).rejects.toMatchObject({ code: "ENOENT" }); + }); +}); diff --git a/src/extension-host/contributions/embedding-safe-reindex.ts b/src/extension-host/contributions/embedding-safe-reindex.ts new file mode 100644 index 00000000000..24c63098e03 --- /dev/null +++ b/src/extension-host/contributions/embedding-safe-reindex.ts @@ -0,0 +1,99 @@ +import { randomUUID } from "node:crypto"; +import fs from "node:fs/promises"; + +const INDEX_FILE_SUFFIXES = ["", "-wal", "-shm"]; + +export async function moveExtensionHostIndexFiles( + sourceBase: string, + targetBase: string, +): Promise { + for (const suffix of INDEX_FILE_SUFFIXES) { + const source = `${sourceBase}${suffix}`; + const target = `${targetBase}${suffix}`; + try { + await fs.rename(source, target); + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + throw err; + } + } + } +} + +export async function removeExtensionHostIndexFiles(basePath: string): Promise { + await Promise.all( + INDEX_FILE_SUFFIXES.map((suffix) => fs.rm(`${basePath}${suffix}`, { force: true })), + ); +} + +export async function swapExtensionHostIndexFiles( + targetPath: string, + tempPath: string, + backupId = randomUUID(), +): Promise { + const backupPath = `${targetPath}.backup-${backupId}`; + await moveExtensionHostIndexFiles(targetPath, backupPath); + try { + await moveExtensionHostIndexFiles(tempPath, targetPath); + } catch (err) { + await moveExtensionHostIndexFiles(backupPath, targetPath); + throw err; + } + await removeExtensionHostIndexFiles(backupPath); +} + +export async function runExtensionHostEmbeddingSafeReindex< + TDb, + TState, + TMeta extends { vectorDims?: number }, +>(params: { + dbPath: string; + currentDb: TDb; + openDatabaseAtPath: (dbPath: string) => TDb; + closeDatabase: (db: TDb) => void; + captureOriginalState: () => TState; + restoreOriginalState: (params: { + originalDb: TDb; + originalState: TState; + originalDbClosed: boolean; + dbPath: string; + }) => void; + prepareTempDb: (tempDb: TDb) => void; + seedEmbeddingCache: (sourceDb: TDb) => void; + runReindexBody: () => Promise; + reopenAfterSwap: (dbPath: string, nextMeta: TMeta) => void; + randomId?: () => string; +}): Promise { + const tempDbPath = `${params.dbPath}.tmp-${(params.randomId ?? randomUUID)()}`; + const tempDb = params.openDatabaseAtPath(tempDbPath); + const originalDb = params.currentDb; + const originalState = params.captureOriginalState(); + let originalDbClosed = false; + + params.prepareTempDb(tempDb); + + try { + params.seedEmbeddingCache(originalDb); + const nextMeta = await params.runReindexBody(); + + params.closeDatabase(tempDb); + params.closeDatabase(originalDb); + originalDbClosed = true; + + await swapExtensionHostIndexFiles(params.dbPath, tempDbPath); + params.reopenAfterSwap(params.dbPath, nextMeta); + return nextMeta; + } catch (err) { + try { + params.closeDatabase(tempDb); + } catch {} + await removeExtensionHostIndexFiles(tempDbPath); + params.restoreOriginalState({ + originalDb, + originalState, + originalDbClosed, + dbPath: params.dbPath, + }); + throw err; + } +} diff --git a/src/extension-host/contributions/embedding-sync-execution.test.ts b/src/extension-host/contributions/embedding-sync-execution.test.ts new file mode 100644 index 00000000000..700e6c06b8b --- /dev/null +++ b/src/extension-host/contributions/embedding-sync-execution.test.ts @@ -0,0 +1,234 @@ +import { describe, expect, it, vi } from "vitest"; +import { runExtensionHostEmbeddingSync } from "./embedding-sync-execution.js"; + +describe("embedding-sync-execution", () => { + it("prefers targeted session refreshes and clears only the targeted dirty files", async () => { + const syncSessionFiles = vi.fn(async () => {}); + const clearSyncedSessionFiles = vi.fn(); + + await runExtensionHostEmbeddingSync({ + reason: "post-compaction", + targetSessionFiles: new Set(["/tmp/a.jsonl"]), + vectorReady: false, + meta: null, + configuredSources: ["sessions"], + configuredScopeHash: "scope", + provider: null, + providerKey: null, + chunkTokens: 200, + chunkOverlap: 20, + sessionsEnabled: true, + dirty: true, + shouldSyncSessions: true, + useUnsafeReindex: false, + hasDirtySessionFiles: true, + syncMemoryFiles: vi.fn(async () => {}), + syncSessionFiles, + clearSyncedSessionFiles, + clearAllSessionDirtyFiles: vi.fn(), + setDirty: vi.fn(), + setSessionsDirty: vi.fn(), + shouldFallbackOnError: vi.fn(() => false), + activateFallbackProvider: vi.fn(async () => false), + runSafeReindex: vi.fn(async () => {}), + runUnsafeReindex: vi.fn(async () => {}), + }); + + expect(syncSessionFiles).toHaveBeenCalledWith({ + needsFullReindex: false, + targetSessionFiles: ["/tmp/a.jsonl"], + progress: undefined, + }); + expect(clearSyncedSessionFiles).toHaveBeenCalledWith(new Set(["/tmp/a.jsonl"])); + }); + + it("runs an unsafe reindex when fallback activates during a targeted refresh", async () => { + const runUnsafeReindex = vi.fn(async () => {}); + + await runExtensionHostEmbeddingSync({ + reason: "post-compaction", + targetSessionFiles: new Set(["/tmp/a.jsonl"]), + vectorReady: false, + meta: null, + configuredSources: ["sessions"], + configuredScopeHash: "scope", + provider: null, + providerKey: null, + chunkTokens: 200, + chunkOverlap: 20, + sessionsEnabled: true, + dirty: false, + shouldSyncSessions: true, + useUnsafeReindex: true, + hasDirtySessionFiles: false, + syncMemoryFiles: vi.fn(async () => {}), + syncSessionFiles: vi.fn(async () => { + throw new Error("embedding backend failed"); + }), + clearSyncedSessionFiles: vi.fn(), + clearAllSessionDirtyFiles: vi.fn(), + setDirty: vi.fn(), + setSessionsDirty: vi.fn(), + shouldFallbackOnError: vi.fn(() => true), + activateFallbackProvider: vi.fn(async () => true), + runSafeReindex: vi.fn(async () => {}), + runUnsafeReindex, + }); + + expect(runUnsafeReindex).toHaveBeenCalledWith({ + reason: "post-compaction", + force: true, + progress: undefined, + }); + }); + + it("runs a full safe reindex when planning detects metadata drift", async () => { + const runSafeReindex = vi.fn(async () => {}); + + await runExtensionHostEmbeddingSync({ + reason: "test", + force: false, + targetSessionFiles: null, + vectorReady: true, + meta: { + model: "old-model", + provider: "openai", + providerKey: "key", + sources: ["memory"], + scopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + }, + configuredSources: ["memory"], + configuredScopeHash: "scope", + provider: { + id: "openai", + model: "new-model", + embedQuery: async () => [1], + embedBatch: async () => [[1]], + }, + providerKey: "key", + chunkTokens: 200, + chunkOverlap: 20, + sessionsEnabled: false, + dirty: false, + shouldSyncSessions: false, + useUnsafeReindex: false, + hasDirtySessionFiles: false, + syncMemoryFiles: vi.fn(async () => {}), + syncSessionFiles: vi.fn(async () => {}), + clearSyncedSessionFiles: vi.fn(), + clearAllSessionDirtyFiles: vi.fn(), + setDirty: vi.fn(), + setSessionsDirty: vi.fn(), + shouldFallbackOnError: vi.fn(() => false), + activateFallbackProvider: vi.fn(async () => false), + runSafeReindex, + runUnsafeReindex: vi.fn(async () => {}), + }); + + expect(runSafeReindex).toHaveBeenCalledWith({ + reason: "test", + force: false, + progress: undefined, + }); + }); + + it("clears dirty flags after incremental syncs and preserves pending session dirtiness otherwise", async () => { + const setDirty = vi.fn(); + const setSessionsDirty = vi.fn(); + const clearAllSessionDirtyFiles = vi.fn(); + + await runExtensionHostEmbeddingSync({ + reason: "watch", + targetSessionFiles: null, + vectorReady: true, + meta: { + model: "model", + provider: "openai", + providerKey: "key", + sources: ["memory", "sessions"], + scopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + vectorDims: 1536, + }, + configuredSources: ["memory", "sessions"], + configuredScopeHash: "scope", + provider: { + id: "openai", + model: "model", + embedQuery: async () => [1], + embedBatch: async () => [[1]], + }, + providerKey: "key", + chunkTokens: 200, + chunkOverlap: 20, + sessionsEnabled: true, + dirty: true, + shouldSyncSessions: true, + useUnsafeReindex: false, + hasDirtySessionFiles: true, + syncMemoryFiles: vi.fn(async () => {}), + syncSessionFiles: vi.fn(async () => {}), + clearSyncedSessionFiles: vi.fn(), + clearAllSessionDirtyFiles, + setDirty, + setSessionsDirty, + shouldFallbackOnError: vi.fn(() => false), + activateFallbackProvider: vi.fn(async () => false), + runSafeReindex: vi.fn(async () => {}), + runUnsafeReindex: vi.fn(async () => {}), + }); + + expect(setDirty).toHaveBeenCalledWith(false); + expect(setSessionsDirty).toHaveBeenCalledWith(false); + expect(clearAllSessionDirtyFiles).toHaveBeenCalled(); + + setSessionsDirty.mockClear(); + + await runExtensionHostEmbeddingSync({ + reason: "watch", + targetSessionFiles: null, + vectorReady: true, + meta: { + model: "model", + provider: "openai", + providerKey: "key", + sources: ["memory", "sessions"], + scopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + vectorDims: 1536, + }, + configuredSources: ["memory", "sessions"], + configuredScopeHash: "scope", + provider: { + id: "openai", + model: "model", + embedQuery: async () => [1], + embedBatch: async () => [[1]], + }, + providerKey: "key", + chunkTokens: 200, + chunkOverlap: 20, + sessionsEnabled: true, + dirty: false, + shouldSyncSessions: false, + useUnsafeReindex: false, + hasDirtySessionFiles: true, + syncMemoryFiles: vi.fn(async () => {}), + syncSessionFiles: vi.fn(async () => {}), + clearSyncedSessionFiles: vi.fn(), + clearAllSessionDirtyFiles: vi.fn(), + setDirty: vi.fn(), + setSessionsDirty, + shouldFallbackOnError: vi.fn(() => false), + activateFallbackProvider: vi.fn(async () => false), + runSafeReindex: vi.fn(async () => {}), + runUnsafeReindex: vi.fn(async () => {}), + }); + + expect(setSessionsDirty).toHaveBeenCalledWith(true); + }); +}); diff --git a/src/extension-host/contributions/embedding-sync-execution.ts b/src/extension-host/contributions/embedding-sync-execution.ts new file mode 100644 index 00000000000..e39547f95bc --- /dev/null +++ b/src/extension-host/contributions/embedding-sync-execution.ts @@ -0,0 +1,153 @@ +import type { EmbeddingProvider } from "./embedding-runtime.js"; +import { + type EmbeddingIndexMeta, + type EmbeddingMemorySource, + resolveEmbeddingSyncPlan, +} from "./embedding-sync-planning.js"; + +type EmbeddingSyncProgress = unknown; + +type EmbeddingSyncMemoryFiles = (params: { + needsFullReindex: boolean; + progress?: TProgress; +}) => Promise; + +type EmbeddingSyncSessionFiles = (params: { + needsFullReindex: boolean; + targetSessionFiles?: string[]; + progress?: TProgress; +}) => Promise; + +type EmbeddingReindex = (params: { + reason?: string; + force?: boolean; + progress?: TProgress; +}) => Promise; + +export async function runExtensionHostEmbeddingSync(params: { + reason?: string; + force?: boolean; + targetSessionFiles: Set | null; + vectorReady: boolean; + meta: EmbeddingIndexMeta | null; + configuredSources: EmbeddingMemorySource[]; + configuredScopeHash: string; + provider: EmbeddingProvider | null; + providerKey: string | null; + chunkTokens: number; + chunkOverlap: number; + sessionsEnabled: boolean; + dirty: boolean; + shouldSyncSessions: boolean; + useUnsafeReindex: boolean; + hasDirtySessionFiles: boolean; + progress?: TProgress; + syncMemoryFiles: EmbeddingSyncMemoryFiles; + syncSessionFiles: EmbeddingSyncSessionFiles; + clearSyncedSessionFiles: (targetSessionFiles?: Iterable | null) => void; + clearAllSessionDirtyFiles: () => void; + setDirty: (value: boolean) => void; + setSessionsDirty: (value: boolean) => void; + shouldFallbackOnError: (message: string) => boolean; + activateFallbackProvider: (reason: string) => Promise; + runSafeReindex: EmbeddingReindex; + runUnsafeReindex: EmbeddingReindex; +}): Promise { + const hasTargetSessionFiles = params.targetSessionFiles !== null; + const syncPlan = resolveEmbeddingSyncPlan({ + force: params.force, + hasTargetSessionFiles, + targetSessionFiles: params.targetSessionFiles, + sessionsEnabled: params.sessionsEnabled, + dirty: params.dirty, + shouldSyncSessions: params.shouldSyncSessions, + useUnsafeReindex: params.useUnsafeReindex, + vectorReady: params.vectorReady, + meta: params.meta, + provider: params.provider, + providerKey: params.providerKey, + configuredSources: params.configuredSources, + configuredScopeHash: params.configuredScopeHash, + chunkTokens: params.chunkTokens, + chunkOverlap: params.chunkOverlap, + }); + + if (syncPlan.kind === "targeted-sessions") { + try { + await params.syncSessionFiles({ + needsFullReindex: false, + targetSessionFiles: syncPlan.targetSessionFiles, + progress: params.progress, + }); + params.clearSyncedSessionFiles(new Set(syncPlan.targetSessionFiles)); + } catch (err) { + const reason = err instanceof Error ? err.message : String(err); + const activated = + params.shouldFallbackOnError(reason) && (await params.activateFallbackProvider(reason)); + if (activated) { + const reindexParams = { + reason: params.reason, + force: true, + progress: params.progress, + }; + if (params.useUnsafeReindex) { + await params.runUnsafeReindex(reindexParams); + } else { + await params.runSafeReindex(reindexParams); + } + return; + } + throw err; + } + return; + } + + try { + if (syncPlan.kind === "full-reindex") { + const reindexParams = { + reason: params.reason, + force: params.force, + progress: params.progress, + }; + if (syncPlan.unsafe) { + await params.runUnsafeReindex(reindexParams); + } else { + await params.runSafeReindex(reindexParams); + } + return; + } + + if (syncPlan.shouldSyncMemory) { + await params.syncMemoryFiles({ + needsFullReindex: false, + progress: params.progress, + }); + params.setDirty(false); + } + + if (syncPlan.shouldSyncSessions) { + await params.syncSessionFiles({ + needsFullReindex: false, + targetSessionFiles: syncPlan.targetSessionFiles, + progress: params.progress, + }); + params.setSessionsDirty(false); + params.clearAllSessionDirtyFiles(); + } else { + params.setSessionsDirty(params.hasDirtySessionFiles); + } + } catch (err) { + const reason = err instanceof Error ? err.message : String(err); + const activated = + params.shouldFallbackOnError(reason) && (await params.activateFallbackProvider(reason)); + if (activated) { + await params.runSafeReindex({ + reason: params.reason ?? "fallback", + force: true, + progress: params.progress, + }); + return; + } + throw err; + } +} diff --git a/src/extension-host/contributions/embedding-sync-planning.test.ts b/src/extension-host/contributions/embedding-sync-planning.test.ts new file mode 100644 index 00000000000..3c73ba0f54d --- /dev/null +++ b/src/extension-host/contributions/embedding-sync-planning.test.ts @@ -0,0 +1,171 @@ +import { describe, expect, it } from "vitest"; +import { + buildEmbeddingIndexMeta, + metaSourcesDiffer, + normalizeEmbeddingMetaSources, + resolveEmbeddingSyncPlan, + shouldUseUnsafeEmbeddingReindex, +} from "./embedding-sync-planning.js"; + +describe("embedding-sync-planning", () => { + it("prefers targeted session refreshes over broader sync decisions", () => { + const plan = resolveEmbeddingSyncPlan({ + hasTargetSessionFiles: true, + targetSessionFiles: new Set(["/tmp/session.jsonl"]), + sessionsEnabled: true, + dirty: true, + shouldSyncSessions: true, + useUnsafeReindex: false, + vectorReady: false, + meta: null, + provider: null, + providerKey: null, + configuredSources: ["sessions"], + configuredScopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + }); + + expect(plan).toEqual({ + kind: "targeted-sessions", + targetSessionFiles: ["/tmp/session.jsonl"], + }); + }); + + it("plans a full reindex when metadata drift is detected", () => { + const plan = resolveEmbeddingSyncPlan({ + force: false, + hasTargetSessionFiles: false, + targetSessionFiles: null, + sessionsEnabled: true, + dirty: false, + shouldSyncSessions: false, + useUnsafeReindex: true, + vectorReady: true, + meta: { + model: "old-model", + provider: "openai", + providerKey: "key", + sources: ["memory"], + scopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + }, + provider: { + id: "openai", + model: "new-model", + embedQuery: async () => [1], + embedBatch: async () => [[1]], + }, + providerKey: "key", + configuredSources: ["memory"], + configuredScopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + }); + + expect(plan).toEqual({ + kind: "full-reindex", + unsafe: true, + }); + }); + + it("builds incremental sync plans from dirty/session state", () => { + const plan = resolveEmbeddingSyncPlan({ + force: false, + hasTargetSessionFiles: false, + targetSessionFiles: null, + sessionsEnabled: true, + dirty: true, + shouldSyncSessions: true, + useUnsafeReindex: false, + vectorReady: false, + meta: { + model: "model", + provider: "openai", + providerKey: "key", + sources: ["memory", "sessions"], + scopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + vectorDims: 1536, + }, + provider: { + id: "openai", + model: "model", + embedQuery: async () => [1], + embedBatch: async () => [[1]], + }, + providerKey: "key", + configuredSources: ["memory", "sessions"], + configuredScopeHash: "scope", + chunkTokens: 200, + chunkOverlap: 20, + }); + + expect(plan).toEqual({ + kind: "incremental", + shouldSyncMemory: true, + shouldSyncSessions: true, + targetSessionFiles: undefined, + }); + }); + + it("builds embedding metadata with provider and vector dimensions", () => { + expect( + buildEmbeddingIndexMeta({ + provider: { + id: "openai", + model: "text-embedding-3-small", + embedQuery: async () => [1], + embedBatch: async () => [[1]], + }, + providerKey: "provider-key", + configuredSources: ["memory", "sessions"], + configuredScopeHash: "scope", + chunkTokens: 256, + chunkOverlap: 32, + vectorDims: 1536, + }), + ).toEqual({ + model: "text-embedding-3-small", + provider: "openai", + providerKey: "provider-key", + sources: ["memory", "sessions"], + scopeHash: "scope", + chunkTokens: 256, + chunkOverlap: 32, + vectorDims: 1536, + }); + }); + + it("normalizes legacy meta sources and detects drift", () => { + expect(normalizeEmbeddingMetaSources(null)).toEqual(["memory"]); + expect(normalizeEmbeddingMetaSources({ sources: ["sessions", "memory", "sessions"] })).toEqual([ + "memory", + "sessions", + ]); + expect( + metaSourcesDiffer( + { + model: "model", + provider: "openai", + sources: ["memory"], + chunkTokens: 200, + chunkOverlap: 20, + }, + ["memory", "sessions"], + ), + ).toBe(true); + }); + + it("reads the unsafe test reindex gate from env vars", () => { + expect( + shouldUseUnsafeEmbeddingReindex({ + OPENCLAW_TEST_FAST: "1", + OPENCLAW_TEST_MEMORY_UNSAFE_REINDEX: "1", + } as NodeJS.ProcessEnv), + ).toBe(true); + expect(shouldUseUnsafeEmbeddingReindex({} as NodeJS.ProcessEnv)).toBe(false); + }); +}); diff --git a/src/extension-host/contributions/embedding-sync-planning.ts b/src/extension-host/contributions/embedding-sync-planning.ts new file mode 100644 index 00000000000..46f396599ad --- /dev/null +++ b/src/extension-host/contributions/embedding-sync-planning.ts @@ -0,0 +1,138 @@ +import type { EmbeddingProvider } from "./embedding-runtime.js"; + +export type EmbeddingMemorySource = "memory" | "sessions"; + +export type EmbeddingIndexMeta = { + model: string; + provider: string; + providerKey?: string; + sources?: EmbeddingMemorySource[]; + scopeHash?: string; + chunkTokens: number; + chunkOverlap: number; + vectorDims?: number; +}; + +export type EmbeddingSyncPlan = + | { + kind: "targeted-sessions"; + targetSessionFiles: string[]; + } + | { + kind: "full-reindex"; + unsafe: boolean; + } + | { + kind: "incremental"; + shouldSyncMemory: boolean; + shouldSyncSessions: boolean; + targetSessionFiles?: string[]; + }; + +export function resolveEmbeddingSyncPlan(params: { + force?: boolean; + hasTargetSessionFiles: boolean; + targetSessionFiles: Set | null; + sessionsEnabled: boolean; + dirty: boolean; + shouldSyncSessions: boolean; + useUnsafeReindex: boolean; + vectorReady: boolean; + meta: EmbeddingIndexMeta | null; + provider: EmbeddingProvider | null; + providerKey: string | null; + configuredSources: EmbeddingMemorySource[]; + configuredScopeHash: string; + chunkTokens: number; + chunkOverlap: number; +}): EmbeddingSyncPlan { + if (params.hasTargetSessionFiles && params.targetSessionFiles && params.sessionsEnabled) { + return { + kind: "targeted-sessions", + targetSessionFiles: Array.from(params.targetSessionFiles), + }; + } + + const needsFullReindex = + (params.force && !params.hasTargetSessionFiles) || + !params.meta || + (params.provider && params.meta.model !== params.provider.model) || + (params.provider && params.meta.provider !== params.provider.id) || + params.meta?.providerKey !== params.providerKey || + metaSourcesDiffer(params.meta, params.configuredSources) || + params.meta?.scopeHash !== params.configuredScopeHash || + params.meta?.chunkTokens !== params.chunkTokens || + params.meta?.chunkOverlap !== params.chunkOverlap || + (params.vectorReady && !params.meta?.vectorDims); + + if (needsFullReindex) { + return { + kind: "full-reindex", + unsafe: params.useUnsafeReindex, + }; + } + + return { + kind: "incremental", + shouldSyncMemory: !params.hasTargetSessionFiles && (Boolean(params.force) || params.dirty), + shouldSyncSessions: params.shouldSyncSessions, + targetSessionFiles: params.targetSessionFiles + ? Array.from(params.targetSessionFiles) + : undefined, + }; +} + +export function buildEmbeddingIndexMeta(params: { + provider: EmbeddingProvider | null; + providerKey: string | null; + configuredSources: EmbeddingMemorySource[]; + configuredScopeHash: string; + chunkTokens: number; + chunkOverlap: number; + vectorDims?: number; +}): EmbeddingIndexMeta { + const meta: EmbeddingIndexMeta = { + model: params.provider?.model ?? "fts-only", + provider: params.provider?.id ?? "none", + providerKey: params.providerKey ?? undefined, + sources: params.configuredSources, + scopeHash: params.configuredScopeHash, + chunkTokens: params.chunkTokens, + chunkOverlap: params.chunkOverlap, + }; + if (params.vectorDims) { + meta.vectorDims = params.vectorDims; + } + return meta; +} + +export function shouldUseUnsafeEmbeddingReindex(env = process.env): boolean { + return env.OPENCLAW_TEST_FAST === "1" && env.OPENCLAW_TEST_MEMORY_UNSAFE_REINDEX === "1"; +} + +export function metaSourcesDiffer( + meta: EmbeddingIndexMeta | null, + configuredSources: EmbeddingMemorySource[], +): boolean { + const metaSources = normalizeEmbeddingMetaSources(meta); + if (metaSources.length !== configuredSources.length) { + return true; + } + return metaSources.some((source, index) => source !== configuredSources[index]); +} + +export function normalizeEmbeddingMetaSources( + meta: Pick | null, +): EmbeddingMemorySource[] { + if (!Array.isArray(meta?.sources)) { + return ["memory"]; + } + const normalized = Array.from( + new Set( + meta.sources.filter( + (source): source is EmbeddingMemorySource => source === "memory" || source === "sessions", + ), + ), + ).toSorted(); + return normalized.length > 0 ? normalized : ["memory"]; +} diff --git a/src/extension-host/contributions/gateway-methods.test.ts b/src/extension-host/contributions/gateway-methods.test.ts new file mode 100644 index 00000000000..e1704749b17 --- /dev/null +++ b/src/extension-host/contributions/gateway-methods.test.ts @@ -0,0 +1,70 @@ +import { describe, expect, it, vi } from "vitest"; +import { createEmptyPluginRegistry } from "../../plugins/registry.js"; +import { + createExtensionHostGatewayExtraHandlers, + logExtensionHostPluginDiagnostics, + resolveExtensionHostGatewayMethods, +} from "./gateway-methods.js"; +import { setExtensionHostGatewayHandler } from "./runtime-registry.js"; + +describe("resolveExtensionHostGatewayMethods", () => { + it("adds plugin methods without duplicating base methods", () => { + const registry = createEmptyPluginRegistry(); + setExtensionHostGatewayHandler({ registry, method: "health", handler: vi.fn() }); + setExtensionHostGatewayHandler({ registry, method: "plugin.echo", handler: vi.fn() }); + + expect( + resolveExtensionHostGatewayMethods({ + registry, + baseMethods: ["health", "config.get"], + }), + ).toEqual(["health", "config.get", "plugin.echo"]); + }); +}); + +describe("createExtensionHostGatewayExtraHandlers", () => { + it("lets caller-provided handlers override plugin handlers", () => { + const pluginHandler = vi.fn(); + const callerHandler = vi.fn(); + const registry = createEmptyPluginRegistry(); + setExtensionHostGatewayHandler({ registry, method: "demo", handler: pluginHandler }); + + const handlers = createExtensionHostGatewayExtraHandlers({ + registry, + extraHandlers: { demo: callerHandler, health: vi.fn() }, + }); + + expect(handlers.demo).toBe(callerHandler); + expect(handlers.health).toBeTypeOf("function"); + }); +}); + +describe("logExtensionHostPluginDiagnostics", () => { + it("routes error diagnostics to error and others to info", () => { + const log = { + info: vi.fn(), + error: vi.fn(), + }; + + logExtensionHostPluginDiagnostics({ + diagnostics: [ + { + level: "warn", + pluginId: "demo", + source: "bundled", + message: "warned", + }, + { + level: "error", + pluginId: "demo", + source: "bundled", + message: "failed", + }, + ], + log, + }); + + expect(log.info).toHaveBeenCalledWith("[plugins] warned (plugin=demo, source=bundled)"); + expect(log.error).toHaveBeenCalledWith("[plugins] failed (plugin=demo, source=bundled)"); + }); +}); diff --git a/src/extension-host/contributions/gateway-methods.ts b/src/extension-host/contributions/gateway-methods.ts new file mode 100644 index 00000000000..5ab034e32c4 --- /dev/null +++ b/src/extension-host/contributions/gateway-methods.ts @@ -0,0 +1,48 @@ +import type { GatewayRequestHandlers } from "../../gateway/server-methods/types.js"; +import type { PluginRegistry } from "../../plugins/registry.js"; +import type { PluginDiagnostic } from "../../plugins/types.js"; +import { getExtensionHostGatewayHandlers } from "./runtime-registry.js"; + +export function resolveExtensionHostGatewayMethods(params: { + registry: PluginRegistry; + baseMethods: string[]; +}): string[] { + const pluginMethods = Object.keys(getExtensionHostGatewayHandlers(params.registry)); + return Array.from(new Set([...params.baseMethods, ...pluginMethods])); +} + +export function createExtensionHostGatewayExtraHandlers(params: { + registry: PluginRegistry; + extraHandlers?: GatewayRequestHandlers; +}): GatewayRequestHandlers { + const pluginHandlers = getExtensionHostGatewayHandlers(params.registry); + return { + ...pluginHandlers, + ...params.extraHandlers, + }; +} + +export function logExtensionHostPluginDiagnostics(params: { + diagnostics: PluginDiagnostic[]; + log: { + info: (msg: string) => void; + error: (msg: string) => void; + }; +}): void { + for (const diag of params.diagnostics) { + const details = [ + diag.pluginId ? `plugin=${diag.pluginId}` : null, + diag.source ? `source=${diag.source}` : null, + ] + .filter((entry): entry is string => Boolean(entry)) + .join(", "); + const message = details + ? `[plugins] ${diag.message} (${details})` + : `[plugins] ${diag.message}`; + if (diag.level === "error") { + params.log.error(message); + continue; + } + params.log.info(message); + } +} diff --git a/src/extension-host/contributions/media-runtime-api.test.ts b/src/extension-host/contributions/media-runtime-api.test.ts new file mode 100644 index 00000000000..04e4171ab74 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-api.test.ts @@ -0,0 +1,140 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { + buildExtensionHostMediaProviderRegistry, + normalizeExtensionHostMediaAttachments, + resolveExtensionHostAutoImageModel, + resolveExtensionHostMediaAttachmentLocalRoots, + runExtensionHostMediaApiCapability, +} from "./media-runtime-api.js"; + +vi.mock("./media-runtime-auto.js", () => ({ + clearMediaUnderstandingBinaryCacheForTests: vi.fn(), + resolveAutoImageModel: vi.fn(), +})); + +vi.mock("./media-runtime-orchestration.js", () => ({ + runCapability: vi.fn(), +})); + +vi.mock("./media-runtime-registry.js", () => ({ + buildExtensionHostMediaUnderstandingRegistry: vi.fn(), +})); + +vi.mock("../media/inbound-path-policy.js", () => ({ + mergeInboundPathRoots: vi.fn(), + resolveIMessageAttachmentRoots: vi.fn(), +})); + +vi.mock("../media/local-roots.js", () => ({ + getDefaultMediaLocalRoots: vi.fn(), +})); + +vi.mock("../media-understanding/attachments.js", () => ({ + MediaAttachmentCache: class MediaAttachmentCache { + constructor( + readonly attachments: unknown[], + readonly options?: unknown, + ) {} + }, + normalizeAttachments: vi.fn(), +})); + +describe("media-runtime-api", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("delegates provider-registry construction to the host-owned registry", async () => { + const registryModule = await import("./media-runtime-registry.js"); + const registry = new Map(); + vi.mocked(registryModule.buildExtensionHostMediaUnderstandingRegistry).mockReturnValue( + registry, + ); + + expect(buildExtensionHostMediaProviderRegistry({ openai: {} as never })).toBe(registry); + expect(registryModule.buildExtensionHostMediaUnderstandingRegistry).toHaveBeenCalledWith({ + openai: {} as never, + }); + }); + + it("resolves local roots through the host-owned inbound-path policy", async () => { + const localRootsModule = await import("../media/local-roots.js"); + const inboundPolicyModule = await import("../media/inbound-path-policy.js"); + + vi.mocked(localRootsModule.getDefaultMediaLocalRoots).mockReturnValue(["/tmp/openclaw"]); + vi.mocked(inboundPolicyModule.resolveIMessageAttachmentRoots).mockReturnValue(["/messages"]); + vi.mocked(inboundPolicyModule.mergeInboundPathRoots).mockReturnValue([ + "/tmp/openclaw", + "/messages", + ]); + + const roots = resolveExtensionHostMediaAttachmentLocalRoots({ + cfg: { channels: { imessage: {} } } as never, + ctx: { AccountId: "primary" } as never, + }); + + expect(roots).toEqual(["/tmp/openclaw", "/messages"]); + expect(inboundPolicyModule.resolveIMessageAttachmentRoots).toHaveBeenCalledWith({ + cfg: { channels: { imessage: {} } }, + accountId: "primary", + }); + }); + + it("injects the default registry when resolving the auto image model", async () => { + const registryModule = await import("./media-runtime-registry.js"); + const autoModule = await import("./media-runtime-auto.js"); + const registry = new Map(); + + vi.mocked(registryModule.buildExtensionHostMediaUnderstandingRegistry).mockReturnValue( + registry, + ); + vi.mocked(autoModule.resolveAutoImageModel).mockResolvedValue({ + provider: "openai", + model: "gpt-4.1", + }); + + await expect( + resolveExtensionHostAutoImageModel({ + cfg: {} as never, + }), + ).resolves.toEqual({ + provider: "openai", + model: "gpt-4.1", + }); + + expect(autoModule.resolveAutoImageModel).toHaveBeenCalledWith({ + cfg: {}, + providerRegistry: registry, + }); + }); + + it("delegates top-level capability execution to the host-owned orchestration", async () => { + const orchestrationModule = await import("./media-runtime-orchestration.js"); + const attachments = { cleanup: vi.fn() } as never; + const media = [{ kind: "image" }] as never; + const providerRegistry = new Map() as never; + const result = { outputs: [], decision: { capability: "image", outcome: "skipped" } } as never; + + vi.mocked(orchestrationModule.runCapability).mockResolvedValue(result); + + await expect( + runExtensionHostMediaApiCapability({ + capability: "image", + cfg: {} as never, + ctx: {} as never, + attachments, + media, + providerRegistry, + }), + ).resolves.toBe(result); + }); + + it("delegates attachment normalization to the shared media attachment helper", async () => { + const attachmentsModule = await import("../media-understanding/attachments.js"); + vi.mocked(attachmentsModule.normalizeAttachments).mockReturnValue([{ kind: "audio" }] as never); + + expect(normalizeExtensionHostMediaAttachments({ MediaPath: "/tmp/test.wav" } as never)).toEqual( + [{ kind: "audio" }], + ); + }); +}); diff --git a/src/extension-host/contributions/media-runtime-api.ts b/src/extension-host/contributions/media-runtime-api.ts new file mode 100644 index 00000000000..ec431447d66 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-api.ts @@ -0,0 +1,95 @@ +import type { MsgContext } from "../../auto-reply/templating.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { MediaUnderstandingConfig } from "../../config/types.tools.js"; +import { + MediaAttachmentCache, + type MediaAttachmentCacheOptions, + normalizeAttachments, +} from "../../media-understanding/attachments.js"; +import type { + MediaAttachment, + MediaUnderstandingCapability, + MediaUnderstandingProvider, +} from "../../media-understanding/types.js"; +import { + mergeInboundPathRoots, + resolveIMessageAttachmentRoots, +} from "../../media/inbound-path-policy.js"; +import { getDefaultMediaLocalRoots } from "../../media/local-roots.js"; +import { + clearMediaUnderstandingBinaryCacheForTests as clearExtensionHostMediaUnderstandingBinaryCacheForTests, + resolveAutoImageModel as resolveExtensionHostMediaRuntimeAutoImageModel, + type ActiveMediaModel, +} from "./media-runtime-auto.js"; +import { + runCapability as runExtensionHostMediaCapability, + type RunCapabilityResult, +} from "./media-runtime-orchestration.js"; +import { + buildExtensionHostMediaUnderstandingRegistry, + type ExtensionHostMediaUnderstandingProviderRegistry, +} from "./media-runtime-registry.js"; + +type ProviderRegistry = ExtensionHostMediaUnderstandingProviderRegistry; + +export type { ActiveMediaModel, RunCapabilityResult }; +export type ExtensionHostMediaProviderRegistry = ProviderRegistry; + +export function buildExtensionHostMediaProviderRegistry( + overrides?: Record, +): ProviderRegistry { + return buildExtensionHostMediaUnderstandingRegistry(overrides); +} + +export function normalizeExtensionHostMediaAttachments(ctx: MsgContext): MediaAttachment[] { + return normalizeAttachments(ctx); +} + +export function resolveExtensionHostMediaAttachmentLocalRoots(params: { + cfg: OpenClawConfig; + ctx: MsgContext; +}): readonly string[] { + return mergeInboundPathRoots( + getDefaultMediaLocalRoots(), + resolveIMessageAttachmentRoots({ + cfg: params.cfg, + accountId: params.ctx.AccountId, + }), + ); +} + +export function createExtensionHostMediaAttachmentCache( + attachments: MediaAttachment[], + options?: MediaAttachmentCacheOptions, +): MediaAttachmentCache { + return new MediaAttachmentCache(attachments, options); +} + +export function clearExtensionHostMediaBinaryCacheForTests(): void { + clearExtensionHostMediaUnderstandingBinaryCacheForTests(); +} + +export async function resolveExtensionHostAutoImageModel(params: { + cfg: OpenClawConfig; + agentDir?: string; + activeModel?: ActiveMediaModel; +}): Promise { + return await resolveExtensionHostMediaRuntimeAutoImageModel({ + ...params, + providerRegistry: buildExtensionHostMediaProviderRegistry(), + }); +} + +export async function runExtensionHostMediaApiCapability(params: { + capability: MediaUnderstandingCapability; + cfg: OpenClawConfig; + ctx: MsgContext; + attachments: MediaAttachmentCache; + media: MediaAttachment[]; + agentDir?: string; + providerRegistry: ProviderRegistry; + config?: MediaUnderstandingConfig; + activeModel?: ActiveMediaModel; +}): Promise { + return await runExtensionHostMediaCapability(params); +} diff --git a/src/extension-host/contributions/media-runtime-auto.test.ts b/src/extension-host/contributions/media-runtime-auto.test.ts new file mode 100644 index 00000000000..01eeaf53280 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-auto.test.ts @@ -0,0 +1,110 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { OpenClawConfig } from "../../config/config.js"; +import { resolveExtensionHostMediaRuntimeDefaultModel } from "../static/runtime-backend-catalog.js"; +import { buildExtensionHostMediaUnderstandingRegistry } from "./media-runtime-registry.js"; + +const resolveApiKeyForProvider = vi.hoisted(() => vi.fn()); + +vi.mock("../agents/model-auth.js", () => ({ + resolveApiKeyForProvider, +})); + +import { resolveAutoImageModel } from "./media-runtime-auto.js"; + +function createImageCfg(): OpenClawConfig { + return { + models: { + providers: { + openai: { + apiKey: "test-key", + models: [], + }, + }, + }, + } as unknown as OpenClawConfig; +} + +describe("media runtime auto image model", () => { + beforeEach(() => { + resolveApiKeyForProvider.mockReset(); + resolveApiKeyForProvider.mockImplementation( + async ({ provider, cfg }: { provider: string; cfg: OpenClawConfig }) => { + if (cfg.models?.providers?.[provider]) { + return { + apiKey: "test-key", + source: "config", + mode: "api-key", + }; + } + throw new Error("missing key"); + }, + ); + }); + + it("keeps a valid active image model", async () => { + const result = await resolveAutoImageModel({ + cfg: createImageCfg(), + providerRegistry: buildExtensionHostMediaUnderstandingRegistry(), + activeModel: { + provider: "openai", + model: "gpt-4.1-mini", + }, + }); + + expect(result).toEqual({ + provider: "openai", + model: "gpt-4.1-mini", + }); + }); + + it("falls back to the default keyed image model when the active model cannot be used", async () => { + const result = await resolveAutoImageModel({ + cfg: createImageCfg(), + providerRegistry: buildExtensionHostMediaUnderstandingRegistry(), + activeModel: { + provider: "missing-provider", + model: "ignored", + }, + }); + + expect(result).toEqual({ + provider: "openai", + model: resolveExtensionHostMediaRuntimeDefaultModel({ + capability: "image", + backendId: "openai", + }), + }); + }); + + it("keeps catalog image provider ordering when multiple keyed providers are available", async () => { + const result = await resolveAutoImageModel({ + cfg: { + models: { + providers: { + anthropic: { + apiKey: "anthropic-test-key", + models: [], + }, + google: { + apiKey: "google-test-key", + models: [], + }, + }, + }, + } as unknown as OpenClawConfig, + providerRegistry: buildExtensionHostMediaUnderstandingRegistry(), + activeModel: { + provider: "missing-provider", + model: "ignored", + }, + }); + + expect(result).toEqual({ + provider: "anthropic", + model: resolveExtensionHostMediaRuntimeDefaultModel({ + capability: "image", + backendId: "anthropic", + }), + }); + }); +}); diff --git a/src/extension-host/contributions/media-runtime-auto.ts b/src/extension-host/contributions/media-runtime-auto.ts new file mode 100644 index 00000000000..0d73d2bac33 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-auto.ts @@ -0,0 +1,458 @@ +import { constants as fsConstants } from "node:fs"; +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import { resolveApiKeyForProvider } from "../../agents/model-auth.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import { + resolveAgentModelFallbackValues, + resolveAgentModelPrimaryValue, +} from "../../config/model-input.js"; +import type { MediaUnderstandingModelConfig } from "../../config/types.tools.js"; +import { fileExists } from "../../media-understanding/fs.js"; +import { extractGeminiResponse } from "../../media-understanding/output-extract.js"; +import type { MediaUnderstandingCapability } from "../../media-understanding/types.js"; +import { runExec } from "../../process/exec.js"; +import { + resolveExtensionHostMediaProviderCandidates, + type ExtensionHostMediaActiveModel, +} from "../policy/media-runtime-policy.js"; +import { + getExtensionHostMediaUnderstandingProvider, + normalizeExtensionHostMediaProviderId, + type ExtensionHostMediaUnderstandingProviderRegistry, +} from "./media-runtime-registry.js"; + +export type ActiveMediaModel = { + provider: string; + model?: string; +}; + +type ProviderRegistry = ExtensionHostMediaUnderstandingProviderRegistry; + +const binaryCache = new Map>(); +const geminiProbeCache = new Map>(); + +export function clearMediaUnderstandingBinaryCacheForTests(): void { + binaryCache.clear(); + geminiProbeCache.clear(); +} + +function expandHomeDir(value: string): string { + if (!value.startsWith("~")) { + return value; + } + const home = os.homedir(); + if (value === "~") { + return home; + } + if (value.startsWith("~/")) { + return path.join(home, value.slice(2)); + } + return value; +} + +function hasPathSeparator(value: string): boolean { + return value.includes("/") || value.includes("\\"); +} + +function candidateBinaryNames(name: string): string[] { + if (process.platform !== "win32") { + return [name]; + } + const ext = path.extname(name); + if (ext) { + return [name]; + } + const pathext = (process.env.PATHEXT ?? ".EXE;.CMD;.BAT;.COM") + .split(";") + .map((item) => item.trim()) + .filter(Boolean) + .map((item) => (item.startsWith(".") ? item : `.${item}`)); + const unique = Array.from(new Set(pathext)); + return [name, ...unique.map((item) => `${name}${item}`)]; +} + +async function isExecutable(filePath: string): Promise { + try { + const stat = await fs.stat(filePath); + if (!stat.isFile()) { + return false; + } + if (process.platform === "win32") { + return true; + } + await fs.access(filePath, fsConstants.X_OK); + return true; + } catch { + return false; + } +} + +async function findBinary(name: string): Promise { + const cached = binaryCache.get(name); + if (cached) { + return cached; + } + const resolved = (async () => { + const direct = expandHomeDir(name.trim()); + if (direct && hasPathSeparator(direct)) { + for (const candidate of candidateBinaryNames(direct)) { + if (await isExecutable(candidate)) { + return candidate; + } + } + } + + const searchName = name.trim(); + if (!searchName) { + return null; + } + const pathEntries = (process.env.PATH ?? "").split(path.delimiter); + const candidates = candidateBinaryNames(searchName); + for (const entryRaw of pathEntries) { + const entry = expandHomeDir(entryRaw.trim().replace(/^"(.*)"$/, "$1")); + if (!entry) { + continue; + } + for (const candidate of candidates) { + const fullPath = path.join(entry, candidate); + if (await isExecutable(fullPath)) { + return fullPath; + } + } + } + + return null; + })(); + binaryCache.set(name, resolved); + return resolved; +} + +async function hasBinary(name: string): Promise { + return Boolean(await findBinary(name)); +} + +async function probeGeminiCli(): Promise { + const cached = geminiProbeCache.get("gemini"); + if (cached) { + return cached; + } + const resolved = (async () => { + if (!(await hasBinary("gemini"))) { + return false; + } + try { + const { stdout } = await runExec("gemini", ["--output-format", "json", "ok"], { + timeoutMs: 8000, + }); + return Boolean(extractGeminiResponse(stdout) ?? stdout.toLowerCase().includes("ok")); + } catch { + return false; + } + })(); + geminiProbeCache.set("gemini", resolved); + return resolved; +} + +async function resolveLocalWhisperCppEntry(): Promise { + if (!(await hasBinary("whisper-cli"))) { + return null; + } + const envModel = process.env.WHISPER_CPP_MODEL?.trim(); + const defaultModel = "/opt/homebrew/share/whisper-cpp/for-tests-ggml-tiny.bin"; + const modelPath = envModel && (await fileExists(envModel)) ? envModel : defaultModel; + if (!(await fileExists(modelPath))) { + return null; + } + return { + type: "cli", + command: "whisper-cli", + args: ["-m", modelPath, "-otxt", "-of", "{{OutputBase}}", "-np", "-nt", "{{MediaPath}}"], + }; +} + +async function resolveLocalWhisperEntry(): Promise { + if (!(await hasBinary("whisper"))) { + return null; + } + return { + type: "cli", + command: "whisper", + args: [ + "--model", + "turbo", + "--output_format", + "txt", + "--output_dir", + "{{OutputDir}}", + "--verbose", + "False", + "{{MediaPath}}", + ], + }; +} + +async function resolveSherpaOnnxEntry(): Promise { + if (!(await hasBinary("sherpa-onnx-offline"))) { + return null; + } + const modelDir = process.env.SHERPA_ONNX_MODEL_DIR?.trim(); + if (!modelDir) { + return null; + } + const tokens = path.join(modelDir, "tokens.txt"); + const encoder = path.join(modelDir, "encoder.onnx"); + const decoder = path.join(modelDir, "decoder.onnx"); + const joiner = path.join(modelDir, "joiner.onnx"); + if (!(await fileExists(tokens))) { + return null; + } + if (!(await fileExists(encoder))) { + return null; + } + if (!(await fileExists(decoder))) { + return null; + } + if (!(await fileExists(joiner))) { + return null; + } + return { + type: "cli", + command: "sherpa-onnx-offline", + args: [ + `--tokens=${tokens}`, + `--encoder=${encoder}`, + `--decoder=${decoder}`, + `--joiner=${joiner}`, + "{{MediaPath}}", + ], + }; +} + +async function resolveLocalAudioEntry(): Promise { + const sherpa = await resolveSherpaOnnxEntry(); + if (sherpa) { + return sherpa; + } + const whisperCpp = await resolveLocalWhisperCppEntry(); + if (whisperCpp) { + return whisperCpp; + } + return await resolveLocalWhisperEntry(); +} + +async function resolveGeminiCliEntry( + _capability: MediaUnderstandingCapability, +): Promise { + if (!(await probeGeminiCli())) { + return null; + } + return { + type: "cli", + command: "gemini", + args: [ + "--output-format", + "json", + "--allowed-tools", + "read_many_files", + "--include-directories", + "{{MediaDir}}", + "{{Prompt}}", + "Use read_many_files to read {{MediaPath}} and respond with only the text output.", + ], + }; +} + +async function resolveActiveModelEntry(params: { + cfg: OpenClawConfig; + agentDir?: string; + providerRegistry: ProviderRegistry; + capability: MediaUnderstandingCapability; + activeModel?: ActiveMediaModel; +}): Promise { + const activeProviderRaw = params.activeModel?.provider?.trim(); + if (!activeProviderRaw) { + return null; + } + const providerId = normalizeExtensionHostMediaProviderId(activeProviderRaw); + if (!providerId) { + return null; + } + const provider = getExtensionHostMediaUnderstandingProvider(providerId, params.providerRegistry); + if (!provider) { + return null; + } + if (params.capability === "audio" && !provider.transcribeAudio) { + return null; + } + if (params.capability === "image" && !provider.describeImage) { + return null; + } + if (params.capability === "video" && !provider.describeVideo) { + return null; + } + try { + await resolveApiKeyForProvider({ + provider: providerId, + cfg: params.cfg, + agentDir: params.agentDir, + }); + } catch { + return null; + } + return { + type: "provider", + provider: providerId, + model: params.activeModel?.model, + }; +} + +async function resolveKeyEntry(params: { + cfg: OpenClawConfig; + agentDir?: string; + providerRegistry: ProviderRegistry; + capability: MediaUnderstandingCapability; + activeModel?: ExtensionHostMediaActiveModel; +}): Promise { + const { cfg, agentDir, providerRegistry, capability } = params; + const checkProvider = async ( + providerId: string, + model?: string, + ): Promise => { + const provider = getExtensionHostMediaUnderstandingProvider(providerId, providerRegistry); + if (!provider) { + return null; + } + if (capability === "audio" && !provider.transcribeAudio) { + return null; + } + if (capability === "image" && !provider.describeImage) { + return null; + } + if (capability === "video" && !provider.describeVideo) { + return null; + } + try { + await resolveApiKeyForProvider({ provider: providerId, cfg, agentDir }); + return { type: "provider", provider: providerId, model }; + } catch { + return null; + } + }; + + for (const candidate of resolveExtensionHostMediaProviderCandidates({ + capability, + activeModel: params.activeModel, + })) { + const entry = await checkProvider(candidate.provider, candidate.model); + if (entry) { + return entry; + } + } + return null; +} + +function resolveImageModelFromAgentDefaults(cfg: OpenClawConfig): MediaUnderstandingModelConfig[] { + const refs: string[] = []; + const primary = resolveAgentModelPrimaryValue(cfg.agents?.defaults?.imageModel); + if (primary?.trim()) { + refs.push(primary.trim()); + } + for (const fb of resolveAgentModelFallbackValues(cfg.agents?.defaults?.imageModel)) { + if (fb?.trim()) { + refs.push(fb.trim()); + } + } + if (refs.length === 0) { + return []; + } + const entries: MediaUnderstandingModelConfig[] = []; + for (const ref of refs) { + const slashIdx = ref.indexOf("/"); + if (slashIdx <= 0 || slashIdx >= ref.length - 1) { + continue; + } + entries.push({ + type: "provider", + provider: ref.slice(0, slashIdx), + model: ref.slice(slashIdx + 1), + }); + } + return entries; +} + +export async function resolveAutoEntries(params: { + cfg: OpenClawConfig; + agentDir?: string; + providerRegistry: ProviderRegistry; + capability: MediaUnderstandingCapability; + activeModel?: ActiveMediaModel; +}): Promise { + const activeEntry = await resolveActiveModelEntry(params); + if (activeEntry) { + return [activeEntry]; + } + if (params.capability === "audio") { + const localAudio = await resolveLocalAudioEntry(); + if (localAudio) { + return [localAudio]; + } + } + if (params.capability === "image") { + const imageModelEntries = resolveImageModelFromAgentDefaults(params.cfg); + if (imageModelEntries.length > 0) { + return imageModelEntries; + } + } + const gemini = await resolveGeminiCliEntry(params.capability); + if (gemini) { + return [gemini]; + } + const keys = await resolveKeyEntry(params); + if (keys) { + return [keys]; + } + return []; +} + +export async function resolveAutoImageModel(params: { + cfg: OpenClawConfig; + agentDir?: string; + activeModel?: ActiveMediaModel; + providerRegistry: ProviderRegistry; +}): Promise { + const toActive = (entry: MediaUnderstandingModelConfig | null): ActiveMediaModel | null => { + if (!entry || entry.type === "cli") { + return null; + } + const provider = entry.provider; + if (!provider) { + return null; + } + const model = entry.model; + if (!model) { + return null; + } + return { provider, model }; + }; + const activeEntry = await resolveActiveModelEntry({ + cfg: params.cfg, + agentDir: params.agentDir, + providerRegistry: params.providerRegistry, + capability: "image", + activeModel: params.activeModel, + }); + const resolvedActive = toActive(activeEntry); + if (resolvedActive) { + return resolvedActive; + } + const keyEntry = await resolveKeyEntry({ + cfg: params.cfg, + agentDir: params.agentDir, + providerRegistry: params.providerRegistry, + capability: "image", + activeModel: params.activeModel, + }); + return toActive(keyEntry); +} diff --git a/src/extension-host/contributions/media-runtime-config.ts b/src/extension-host/contributions/media-runtime-config.ts new file mode 100644 index 00000000000..9c52adc4df5 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-config.ts @@ -0,0 +1,190 @@ +import type { MsgContext } from "../../auto-reply/templating.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { + MediaUnderstandingConfig, + MediaUnderstandingModelConfig, + MediaUnderstandingScopeConfig, +} from "../../config/types.tools.js"; +import { logVerbose, shouldLogVerbose } from "../../globals.js"; +import { + DEFAULT_MAX_BYTES, + DEFAULT_MAX_CHARS_BY_CAPABILITY, + DEFAULT_MEDIA_CONCURRENCY, + DEFAULT_PROMPT, +} from "../../media-understanding/defaults.js"; +import { + normalizeMediaUnderstandingChatType, + resolveMediaUnderstandingScope, +} from "../../media-understanding/scope.js"; +import type { MediaUnderstandingCapability } from "../../media-understanding/types.js"; +import { normalizeExtensionHostMediaProviderId } from "./media-runtime-registry.js"; + +export function resolveTimeoutMs(seconds: number | undefined, fallbackSeconds: number): number { + const value = typeof seconds === "number" && Number.isFinite(seconds) ? seconds : fallbackSeconds; + return Math.max(1000, Math.floor(value * 1000)); +} + +export function resolvePrompt( + capability: MediaUnderstandingCapability, + prompt?: string, + maxChars?: number, +): string { + const base = prompt?.trim() || DEFAULT_PROMPT[capability]; + if (!maxChars || capability === "audio") { + return base; + } + return `${base} Respond in at most ${maxChars} characters.`; +} + +export function resolveMaxChars(params: { + capability: MediaUnderstandingCapability; + entry: MediaUnderstandingModelConfig; + cfg: OpenClawConfig; + config?: MediaUnderstandingConfig; +}): number | undefined { + const { capability, entry, cfg } = params; + const configured = + entry.maxChars ?? params.config?.maxChars ?? cfg.tools?.media?.[capability]?.maxChars; + if (typeof configured === "number") { + return configured; + } + return DEFAULT_MAX_CHARS_BY_CAPABILITY[capability]; +} + +export function resolveMaxBytes(params: { + capability: MediaUnderstandingCapability; + entry: MediaUnderstandingModelConfig; + cfg: OpenClawConfig; + config?: MediaUnderstandingConfig; +}): number { + const configured = + params.entry.maxBytes ?? + params.config?.maxBytes ?? + params.cfg.tools?.media?.[params.capability]?.maxBytes; + if (typeof configured === "number") { + return configured; + } + return DEFAULT_MAX_BYTES[params.capability]; +} + +export function resolveCapabilityConfig( + cfg: OpenClawConfig, + capability: MediaUnderstandingCapability, +): MediaUnderstandingConfig | undefined { + return cfg.tools?.media?.[capability]; +} + +export function resolveScopeDecision(params: { + scope?: MediaUnderstandingScopeConfig; + ctx: MsgContext; +}): "allow" | "deny" { + return resolveMediaUnderstandingScope({ + scope: params.scope, + sessionKey: params.ctx.SessionKey, + channel: params.ctx.Surface ?? params.ctx.Provider, + chatType: normalizeMediaUnderstandingChatType(params.ctx.ChatType), + }); +} + +function resolveEntryCapabilities(params: { + entry: MediaUnderstandingModelConfig; + providerRegistry: Map; +}): MediaUnderstandingCapability[] | undefined { + const entryType = params.entry.type ?? (params.entry.command ? "cli" : "provider"); + if (entryType === "cli") { + return undefined; + } + const providerId = normalizeExtensionHostMediaProviderId(params.entry.provider ?? ""); + if (!providerId) { + return undefined; + } + return params.providerRegistry.get(providerId)?.capabilities; +} + +export function resolveModelEntries(params: { + cfg: OpenClawConfig; + capability: MediaUnderstandingCapability; + config?: MediaUnderstandingConfig; + providerRegistry: Map; +}): MediaUnderstandingModelConfig[] { + const { cfg, capability, config } = params; + const sharedModels = cfg.tools?.media?.models ?? []; + const entries = [ + ...(config?.models ?? []).map((entry) => ({ entry, source: "capability" as const })), + ...sharedModels.map((entry) => ({ entry, source: "shared" as const })), + ]; + if (entries.length === 0) { + return []; + } + + return entries + .filter(({ entry, source }) => { + const caps = + entry.capabilities && entry.capabilities.length > 0 + ? entry.capabilities + : source === "shared" + ? resolveEntryCapabilities({ entry, providerRegistry: params.providerRegistry }) + : undefined; + if (!caps || caps.length === 0) { + if (source === "shared") { + if (shouldLogVerbose()) { + logVerbose( + `Skipping shared media model without capabilities: ${entry.provider ?? entry.command ?? "unknown"}`, + ); + } + return false; + } + return true; + } + return caps.includes(capability); + }) + .map(({ entry }) => entry); +} + +export function resolveConcurrency(cfg: OpenClawConfig): number { + const configured = cfg.tools?.media?.concurrency; + if (typeof configured === "number" && Number.isFinite(configured) && configured > 0) { + return Math.floor(configured); + } + return DEFAULT_MEDIA_CONCURRENCY; +} + +export function resolveEntriesWithActiveFallback(params: { + cfg: OpenClawConfig; + capability: MediaUnderstandingCapability; + config?: MediaUnderstandingConfig; + providerRegistry: Map; + activeModel?: { provider: string; model?: string }; +}): MediaUnderstandingModelConfig[] { + const entries = resolveModelEntries({ + cfg: params.cfg, + capability: params.capability, + config: params.config, + providerRegistry: params.providerRegistry, + }); + if (entries.length > 0) { + return entries; + } + if (params.config?.enabled !== true) { + return entries; + } + const activeProviderRaw = params.activeModel?.provider?.trim(); + if (!activeProviderRaw) { + return entries; + } + const activeProvider = normalizeExtensionHostMediaProviderId(activeProviderRaw); + if (!activeProvider) { + return entries; + } + const capabilities = params.providerRegistry.get(activeProvider)?.capabilities; + if (!capabilities || !capabilities.includes(params.capability)) { + return entries; + } + return [ + { + type: "provider", + provider: activeProvider, + model: params.activeModel?.model, + }, + ]; +} diff --git a/src/extension-host/contributions/media-runtime-decision.ts b/src/extension-host/contributions/media-runtime-decision.ts new file mode 100644 index 00000000000..93b9524ed75 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-decision.ts @@ -0,0 +1,58 @@ +import type { MediaUnderstandingModelConfig } from "../../config/types.tools.js"; +import type { + MediaUnderstandingDecision, + MediaUnderstandingModelDecision, +} from "../../media-understanding/types.js"; +import { normalizeExtensionHostMediaProviderId } from "./media-runtime-registry.js"; + +export function buildModelDecision(params: { + entry: MediaUnderstandingModelConfig; + entryType: "provider" | "cli"; + outcome: MediaUnderstandingModelDecision["outcome"]; + reason?: string; +}): MediaUnderstandingModelDecision { + if (params.entryType === "cli") { + const command = params.entry.command?.trim(); + return { + type: "cli", + provider: command ?? "cli", + model: params.entry.model ?? command, + outcome: params.outcome, + reason: params.reason, + }; + } + const providerIdRaw = params.entry.provider?.trim(); + const providerId = providerIdRaw + ? normalizeExtensionHostMediaProviderId(providerIdRaw) + : undefined; + return { + type: "provider", + provider: providerId ?? providerIdRaw, + model: params.entry.model, + outcome: params.outcome, + reason: params.reason, + }; +} + +export function formatDecisionSummary(decision: MediaUnderstandingDecision): string { + const attachments = Array.isArray(decision.attachments) ? decision.attachments : []; + const total = attachments.length; + const success = attachments.filter((entry) => entry?.chosen?.outcome === "success").length; + const chosen = attachments.find((entry) => entry?.chosen)?.chosen; + const provider = typeof chosen?.provider === "string" ? chosen.provider.trim() : undefined; + const model = typeof chosen?.model === "string" ? chosen.model.trim() : undefined; + const modelLabel = provider ? (model ? `${provider}/${model}` : provider) : undefined; + const reason = attachments + .flatMap((entry) => { + const attempts = Array.isArray(entry?.attempts) ? entry.attempts : []; + return attempts + .map((attempt) => (typeof attempt?.reason === "string" ? attempt.reason : undefined)) + .filter((value): value is string => Boolean(value)); + }) + .find((value) => value.trim().length > 0); + const shortReason = reason ? reason.split(":")[0]?.trim() : undefined; + const countLabel = total > 0 ? ` (${success}/${total})` : ""; + const viaLabel = modelLabel ? ` via ${modelLabel}` : ""; + const reasonLabel = shortReason ? ` reason=${shortReason}` : ""; + return `${decision.capability}: ${decision.outcome}${countLabel}${viaLabel}${reasonLabel}`; +} diff --git a/src/extension-host/contributions/media-runtime-entrypoints.ts b/src/extension-host/contributions/media-runtime-entrypoints.ts new file mode 100644 index 00000000000..0da8b9d4304 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-entrypoints.ts @@ -0,0 +1,42 @@ +import type { MsgContext } from "../../auto-reply/templating.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { + MediaUnderstandingConfig, + MediaUnderstandingModelConfig, +} from "../../config/types.tools.js"; +import type { MediaAttachmentCache } from "../../media-understanding/attachments.js"; +import type { + MediaUnderstandingCapability, + MediaUnderstandingOutput, + MediaUnderstandingProvider, +} from "../../media-understanding/types.js"; + +export type ExtensionHostMediaProviderRegistry = Map; + +export async function runExtensionHostMediaProviderEntry(params: { + capability: MediaUnderstandingCapability; + entry: MediaUnderstandingModelConfig; + cfg: OpenClawConfig; + ctx: MsgContext; + attachmentIndex: number; + cache: MediaAttachmentCache; + agentDir?: string; + providerRegistry: ExtensionHostMediaProviderRegistry; + config?: MediaUnderstandingConfig; +}): Promise { + const runtime = await import("./media-runtime-execution.js"); + return runtime.runProviderEntry(params); +} + +export async function runExtensionHostMediaCliEntry(params: { + capability: MediaUnderstandingCapability; + entry: MediaUnderstandingModelConfig; + cfg: OpenClawConfig; + ctx: MsgContext; + attachmentIndex: number; + cache: MediaAttachmentCache; + config?: MediaUnderstandingConfig; +}): Promise { + const runtime = await import("./media-runtime-execution.js"); + return runtime.runCliEntry(params); +} diff --git a/src/extension-host/contributions/media-runtime-execution.ts b/src/extension-host/contributions/media-runtime-execution.ts new file mode 100644 index 00000000000..6108653e1ff --- /dev/null +++ b/src/extension-host/contributions/media-runtime-execution.ts @@ -0,0 +1,630 @@ +import fs from "node:fs/promises"; +import path from "node:path"; +import { + collectProviderApiKeysForExecution, + executeWithApiKeyRotation, +} from "../../agents/api-key-rotation.js"; +import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js"; +import type { MsgContext } from "../../auto-reply/templating.js"; +import { applyTemplate } from "../../auto-reply/templating.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { + MediaUnderstandingConfig, + MediaUnderstandingModelConfig, +} from "../../config/types.tools.js"; +import { logVerbose, shouldLogVerbose } from "../../globals.js"; +import { resolveProxyFetchFromEnv } from "../../infra/net/proxy-fetch.js"; +import { resolvePreferredOpenClawTmpDir } from "../../infra/tmp-openclaw-dir.js"; +import { MediaAttachmentCache } from "../../media-understanding/attachments.js"; +import { + CLI_OUTPUT_MAX_BUFFER, + DEFAULT_TIMEOUT_SECONDS, + MIN_AUDIO_FILE_BYTES, +} from "../../media-understanding/defaults.js"; +import { MediaUnderstandingSkipError } from "../../media-understanding/errors.js"; +import { fileExists } from "../../media-understanding/fs.js"; +import { extractGeminiResponse } from "../../media-understanding/output-extract.js"; +import type { + MediaUnderstandingCapability, + MediaUnderstandingOutput, + MediaUnderstandingProvider, +} from "../../media-understanding/types.js"; +import { estimateBase64Size, resolveVideoMaxBase64Bytes } from "../../media-understanding/video.js"; +import { runExec } from "../../process/exec.js"; +import { resolveExtensionHostMediaRuntimeDefaultModel } from "../static/runtime-backend-catalog.js"; +import { + resolveMaxBytes, + resolveMaxChars, + resolvePrompt, + resolveTimeoutMs, +} from "./media-runtime-config.js"; +import { + getExtensionHostMediaUnderstandingProvider, + normalizeExtensionHostMediaProviderId, +} from "./media-runtime-registry.js"; + +export type ProviderRegistry = Map; + +function sanitizeProviderHeaders( + headers: Record | undefined, +): Record | undefined { + if (!headers) { + return undefined; + } + const next: Record = {}; + for (const [key, value] of Object.entries(headers)) { + if (typeof value !== "string") { + continue; + } + // Intentionally preserve marker-shaped values here. This path handles + // explicit config/runtime provider headers, where literal values may + // legitimately match marker patterns; discovered models.json entries are + // sanitized separately in the model registry path. + next[key] = value; + } + return Object.keys(next).length > 0 ? next : undefined; +} + +function trimOutput(text: string, maxChars?: number): string { + const trimmed = text.trim(); + if (!maxChars || trimmed.length <= maxChars) { + return trimmed; + } + return trimmed.slice(0, maxChars).trim(); +} + +function extractSherpaOnnxText(raw: string): string | null { + const tryParse = (value: string): string | null => { + const trimmed = value.trim(); + if (!trimmed) { + return null; + } + const head = trimmed[0]; + if (head !== "{" && head !== '"') { + return null; + } + try { + const parsed = JSON.parse(trimmed) as unknown; + if (typeof parsed === "string") { + return tryParse(parsed); + } + if (parsed && typeof parsed === "object") { + const text = (parsed as { text?: unknown }).text; + if (typeof text === "string" && text.trim()) { + return text.trim(); + } + } + } catch {} + return null; + }; + + const direct = tryParse(raw); + if (direct) { + return direct; + } + + const lines = raw + .split("\n") + .map((line) => line.trim()) + .filter(Boolean); + for (let i = lines.length - 1; i >= 0; i -= 1) { + const parsed = tryParse(lines[i] ?? ""); + if (parsed) { + return parsed; + } + } + return null; +} + +function commandBase(command: string): string { + return path.parse(command).name; +} + +function findArgValue(args: string[], keys: string[]): string | undefined { + for (let i = 0; i < args.length; i += 1) { + if (keys.includes(args[i] ?? "")) { + const value = args[i + 1]; + if (value) { + return value; + } + } + } + return undefined; +} + +function hasArg(args: string[], keys: string[]): boolean { + return args.some((arg) => keys.includes(arg)); +} + +function resolveWhisperOutputPath(args: string[], mediaPath: string): string | null { + const outputDir = findArgValue(args, ["--output_dir", "-o"]); + const outputFormat = findArgValue(args, ["--output_format"]); + if (!outputDir || !outputFormat) { + return null; + } + const formats = outputFormat.split(",").map((value) => value.trim()); + if (!formats.includes("txt")) { + return null; + } + const base = path.parse(mediaPath).name; + return path.join(outputDir, `${base}.txt`); +} + +function resolveWhisperCppOutputPath(args: string[]): string | null { + if (!hasArg(args, ["-otxt", "--output-txt"])) { + return null; + } + const outputBase = findArgValue(args, ["-of", "--output-file"]); + if (!outputBase) { + return null; + } + return `${outputBase}.txt`; +} + +function resolveParakeetOutputPath(args: string[], mediaPath: string): string | null { + const outputDir = findArgValue(args, ["--output-dir"]); + const outputFormat = findArgValue(args, ["--output-format"]); + if (!outputDir) { + return null; + } + if (outputFormat && outputFormat !== "txt") { + return null; + } + const base = path.parse(mediaPath).name; + return path.join(outputDir, `${base}.txt`); +} + +async function resolveCliOutput(params: { + command: string; + args: string[]; + stdout: string; + mediaPath: string; +}): Promise { + const commandId = commandBase(params.command); + const fileOutput = + commandId === "whisper-cli" + ? resolveWhisperCppOutputPath(params.args) + : commandId === "whisper" + ? resolveWhisperOutputPath(params.args, params.mediaPath) + : commandId === "parakeet-mlx" + ? resolveParakeetOutputPath(params.args, params.mediaPath) + : null; + if (fileOutput && (await fileExists(fileOutput))) { + try { + const content = await fs.readFile(fileOutput, "utf8"); + if (content.trim()) { + return content.trim(); + } + } catch {} + } + + if (commandId === "gemini") { + const response = extractGeminiResponse(params.stdout); + if (response) { + return response; + } + } + + if (commandId === "sherpa-onnx-offline") { + const response = extractSherpaOnnxText(params.stdout); + if (response) { + return response; + } + } + + return params.stdout.trim(); +} + +type ProviderQuery = Record; + +function normalizeProviderQuery( + options?: Record, +): ProviderQuery | undefined { + if (!options) { + return undefined; + } + const query: ProviderQuery = {}; + for (const [key, value] of Object.entries(options)) { + if (value === undefined) { + continue; + } + query[key] = value; + } + return Object.keys(query).length > 0 ? query : undefined; +} + +function buildDeepgramCompatQuery(options?: { + detectLanguage?: boolean; + punctuate?: boolean; + smartFormat?: boolean; +}): ProviderQuery | undefined { + if (!options) { + return undefined; + } + const query: ProviderQuery = {}; + if (typeof options.detectLanguage === "boolean") { + query.detect_language = options.detectLanguage; + } + if (typeof options.punctuate === "boolean") { + query.punctuate = options.punctuate; + } + if (typeof options.smartFormat === "boolean") { + query.smart_format = options.smartFormat; + } + return Object.keys(query).length > 0 ? query : undefined; +} + +function normalizeDeepgramQueryKeys(query: ProviderQuery): ProviderQuery { + const normalized = { ...query }; + if ("detectLanguage" in normalized) { + normalized.detect_language = normalized.detectLanguage as boolean; + delete normalized.detectLanguage; + } + if ("smartFormat" in normalized) { + normalized.smart_format = normalized.smartFormat as boolean; + delete normalized.smartFormat; + } + return normalized; +} + +function resolveProviderQuery(params: { + providerId: string; + config?: MediaUnderstandingConfig; + entry: MediaUnderstandingModelConfig; +}): ProviderQuery | undefined { + const { providerId, config, entry } = params; + const mergedOptions = normalizeProviderQuery({ + ...config?.providerOptions?.[providerId], + ...entry.providerOptions?.[providerId], + }); + if (providerId !== "deepgram") { + return mergedOptions; + } + const query = normalizeDeepgramQueryKeys(mergedOptions ?? {}); + const compat = buildDeepgramCompatQuery({ ...config?.deepgram, ...entry.deepgram }); + for (const [key, value] of Object.entries(compat ?? {})) { + if (query[key] === undefined) { + query[key] = value; + } + } + return Object.keys(query).length > 0 ? query : undefined; +} + +function resolveEntryRunOptions(params: { + capability: MediaUnderstandingCapability; + entry: MediaUnderstandingModelConfig; + cfg: OpenClawConfig; + config?: MediaUnderstandingConfig; +}): { maxBytes: number; maxChars?: number; timeoutMs: number; prompt: string } { + const { capability, entry, cfg } = params; + const maxBytes = resolveMaxBytes({ capability, entry, cfg, config: params.config }); + const maxChars = resolveMaxChars({ capability, entry, cfg, config: params.config }); + const timeoutMs = resolveTimeoutMs( + entry.timeoutSeconds ?? + params.config?.timeoutSeconds ?? + cfg.tools?.media?.[capability]?.timeoutSeconds, + DEFAULT_TIMEOUT_SECONDS[capability], + ); + const prompt = resolvePrompt( + capability, + entry.prompt ?? params.config?.prompt ?? cfg.tools?.media?.[capability]?.prompt, + maxChars, + ); + return { maxBytes, maxChars, timeoutMs, prompt }; +} + +async function resolveProviderExecutionAuth(params: { + providerId: string; + cfg: OpenClawConfig; + entry: MediaUnderstandingModelConfig; + agentDir?: string; +}) { + const auth = await resolveApiKeyForProvider({ + provider: params.providerId, + cfg: params.cfg, + profileId: params.entry.profile, + preferredProfile: params.entry.preferredProfile, + agentDir: params.agentDir, + }); + return { + apiKeys: collectProviderApiKeysForExecution({ + provider: params.providerId, + primaryApiKey: requireApiKey(auth, params.providerId), + }), + providerConfig: params.cfg.models?.providers?.[params.providerId], + }; +} + +async function resolveProviderExecutionContext(params: { + providerId: string; + cfg: OpenClawConfig; + entry: MediaUnderstandingModelConfig; + config?: MediaUnderstandingConfig; + agentDir?: string; +}) { + const { apiKeys, providerConfig } = await resolveProviderExecutionAuth({ + providerId: params.providerId, + cfg: params.cfg, + entry: params.entry, + agentDir: params.agentDir, + }); + const baseUrl = params.entry.baseUrl ?? params.config?.baseUrl ?? providerConfig?.baseUrl; + const mergedHeaders = { + ...sanitizeProviderHeaders(providerConfig?.headers as Record | undefined), + ...sanitizeProviderHeaders(params.config?.headers as Record | undefined), + ...sanitizeProviderHeaders(params.entry.headers as Record | undefined), + }; + const headers = Object.keys(mergedHeaders).length > 0 ? mergedHeaders : undefined; + return { apiKeys, baseUrl, headers }; +} + +function assertMinAudioSize(params: { size: number; attachmentIndex: number }): void { + if (params.size >= MIN_AUDIO_FILE_BYTES) { + return; + } + throw new MediaUnderstandingSkipError( + "tooSmall", + `Audio attachment ${params.attachmentIndex + 1} is too small (${params.size} bytes, minimum ${MIN_AUDIO_FILE_BYTES})`, + ); +} + +export async function runProviderEntry(params: { + capability: MediaUnderstandingCapability; + entry: MediaUnderstandingModelConfig; + cfg: OpenClawConfig; + ctx: MsgContext; + attachmentIndex: number; + cache: MediaAttachmentCache; + agentDir?: string; + providerRegistry: ProviderRegistry; + config?: MediaUnderstandingConfig; +}): Promise { + const { entry, capability, cfg } = params; + const providerIdRaw = entry.provider?.trim(); + if (!providerIdRaw) { + throw new Error(`Provider entry missing provider for ${capability}`); + } + const providerId = normalizeExtensionHostMediaProviderId(providerIdRaw); + const { maxBytes, maxChars, timeoutMs, prompt } = resolveEntryRunOptions({ + capability, + entry, + cfg, + config: params.config, + }); + + if (capability === "image") { + if (!params.agentDir) { + throw new Error("Image understanding requires agentDir"); + } + const modelId = entry.model?.trim(); + if (!modelId) { + throw new Error("Image understanding requires model id"); + } + const media = await params.cache.getBuffer({ + attachmentIndex: params.attachmentIndex, + maxBytes, + timeoutMs, + }); + const provider = getExtensionHostMediaUnderstandingProvider( + providerId, + params.providerRegistry, + ); + const imageInput = { + buffer: media.buffer, + fileName: media.fileName, + mime: media.mime, + model: modelId, + provider: providerId, + prompt, + timeoutMs, + profile: entry.profile, + preferredProfile: entry.preferredProfile, + agentDir: params.agentDir, + cfg: params.cfg, + }; + const { describeImageWithModel } = await import("../../media-understanding/providers/image.js"); + const describeImage = provider?.describeImage ?? describeImageWithModel; + const result = await describeImage(imageInput); + return { + kind: "image.description", + attachmentIndex: params.attachmentIndex, + text: trimOutput(result.text, maxChars), + provider: providerId, + model: result.model ?? modelId, + }; + } + + const provider = getExtensionHostMediaUnderstandingProvider(providerId, params.providerRegistry); + if (!provider) { + throw new Error(`Media provider not available: ${providerId}`); + } + + // Resolve proxy-aware fetch from env vars (HTTPS_PROXY, HTTP_PROXY, etc.) + // so provider HTTP calls are routed through the proxy when configured. + const fetchFn = resolveProxyFetchFromEnv(); + + if (capability === "audio") { + if (!provider.transcribeAudio) { + throw new Error(`Audio transcription provider "${providerId}" not available.`); + } + const transcribeAudio = provider.transcribeAudio; + const media = await params.cache.getBuffer({ + attachmentIndex: params.attachmentIndex, + maxBytes, + timeoutMs, + }); + assertMinAudioSize({ size: media.size, attachmentIndex: params.attachmentIndex }); + const { apiKeys, baseUrl, headers } = await resolveProviderExecutionContext({ + providerId, + cfg, + entry, + config: params.config, + agentDir: params.agentDir, + }); + const providerQuery = resolveProviderQuery({ + providerId, + config: params.config, + entry, + }); + const model = + entry.model?.trim() || + resolveExtensionHostMediaRuntimeDefaultModel({ + capability: "audio", + backendId: providerId, + }) || + entry.model; + const result = await executeWithApiKeyRotation({ + provider: providerId, + apiKeys, + execute: async (apiKey) => + transcribeAudio({ + buffer: media.buffer, + fileName: media.fileName, + mime: media.mime, + apiKey, + baseUrl, + headers, + model, + language: entry.language ?? params.config?.language ?? cfg.tools?.media?.audio?.language, + prompt, + query: providerQuery, + timeoutMs, + fetchFn, + }), + }); + return { + kind: "audio.transcription", + attachmentIndex: params.attachmentIndex, + text: trimOutput(result.text, maxChars), + provider: providerId, + model: result.model ?? model, + }; + } + + if (!provider.describeVideo) { + throw new Error(`Video understanding provider "${providerId}" not available.`); + } + const describeVideo = provider.describeVideo; + const media = await params.cache.getBuffer({ + attachmentIndex: params.attachmentIndex, + maxBytes, + timeoutMs, + }); + const estimatedBase64Bytes = estimateBase64Size(media.size); + const maxBase64Bytes = resolveVideoMaxBase64Bytes(maxBytes); + if (estimatedBase64Bytes > maxBase64Bytes) { + throw new MediaUnderstandingSkipError( + "maxBytes", + `Video attachment ${params.attachmentIndex + 1} base64 payload ${estimatedBase64Bytes} exceeds ${maxBase64Bytes}`, + ); + } + const { apiKeys, baseUrl, headers } = await resolveProviderExecutionContext({ + providerId, + cfg, + entry, + config: params.config, + agentDir: params.agentDir, + }); + const result = await executeWithApiKeyRotation({ + provider: providerId, + apiKeys, + execute: (apiKey) => + describeVideo({ + buffer: media.buffer, + fileName: media.fileName, + mime: media.mime, + apiKey, + baseUrl, + headers, + model: entry.model, + prompt, + timeoutMs, + fetchFn, + }), + }); + return { + kind: "video.description", + attachmentIndex: params.attachmentIndex, + text: trimOutput(result.text, maxChars), + provider: providerId, + model: result.model ?? entry.model, + }; +} + +export async function runCliEntry(params: { + capability: MediaUnderstandingCapability; + entry: MediaUnderstandingModelConfig; + cfg: OpenClawConfig; + ctx: MsgContext; + attachmentIndex: number; + cache: MediaAttachmentCache; + config?: MediaUnderstandingConfig; +}): Promise { + const { entry, capability, cfg, ctx } = params; + const command = entry.command?.trim(); + const args = entry.args ?? []; + if (!command) { + throw new Error(`CLI entry missing command for ${capability}`); + } + const { maxBytes, maxChars, timeoutMs, prompt } = resolveEntryRunOptions({ + capability, + entry, + cfg, + config: params.config, + }); + const pathResult = await params.cache.getPath({ + attachmentIndex: params.attachmentIndex, + maxBytes, + timeoutMs, + }); + if (capability === "audio") { + const stat = await fs.stat(pathResult.path); + assertMinAudioSize({ size: stat.size, attachmentIndex: params.attachmentIndex }); + } + const outputDir = await fs.mkdtemp( + path.join(resolvePreferredOpenClawTmpDir(), "openclaw-media-cli-"), + ); + const mediaPath = pathResult.path; + const outputBase = path.join(outputDir, path.parse(mediaPath).name); + + const templCtx: MsgContext = { + ...ctx, + MediaPath: mediaPath, + MediaDir: path.dirname(mediaPath), + OutputDir: outputDir, + OutputBase: outputBase, + Prompt: prompt, + MaxChars: maxChars, + }; + const argv = [command, ...args].map((part, index) => + index === 0 ? part : applyTemplate(part, templCtx), + ); + try { + if (shouldLogVerbose()) { + logVerbose(`Media understanding via CLI: ${argv.join(" ")}`); + } + const { stdout } = await runExec(argv[0], argv.slice(1), { + timeoutMs, + maxBuffer: CLI_OUTPUT_MAX_BUFFER, + }); + const resolved = await resolveCliOutput({ + command, + args: argv.slice(1), + stdout, + mediaPath, + }); + const text = trimOutput(resolved, maxChars); + if (!text) { + return null; + } + return { + kind: capability === "audio" ? "audio.transcription" : `${capability}.description`, + attachmentIndex: params.attachmentIndex, + text, + provider: "cli", + model: command, + }; + } finally { + await fs.rm(outputDir, { recursive: true, force: true }).catch(() => {}); + } +} diff --git a/src/extension-host/contributions/media-runtime-orchestration.test.ts b/src/extension-host/contributions/media-runtime-orchestration.test.ts new file mode 100644 index 00000000000..b0d1ef647be --- /dev/null +++ b/src/extension-host/contributions/media-runtime-orchestration.test.ts @@ -0,0 +1,58 @@ +import { describe, expect, it, vi } from "vitest"; +import type { MsgContext } from "../../auto-reply/templating.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import { + createMediaAttachmentCache, + normalizeMediaAttachments, +} from "../../media-understanding/runner.js"; +import { runCapability } from "./media-runtime-orchestration.js"; +import { buildExtensionHostMediaUnderstandingRegistry } from "./media-runtime-registry.js"; + +const catalog = [ + { + id: "gpt-4.1", + name: "GPT-4.1", + provider: "openai", + input: ["text", "image"] as const, + }, +]; + +vi.mock("../agents/model-catalog.js", async () => { + const actual = await vi.importActual( + "../agents/model-catalog.js", + ); + return { + ...actual, + loadModelCatalog: vi.fn(async () => catalog), + }; +}); + +describe("media runtime orchestration", () => { + it("skips image understanding when the active model already supports vision", async () => { + const ctx: MsgContext = { MediaPath: "/tmp/image.png", MediaType: "image/png" }; + const media = normalizeMediaAttachments(ctx); + const cache = createMediaAttachmentCache(media); + const cfg = {} as OpenClawConfig; + + try { + const result = await runCapability({ + capability: "image", + cfg, + ctx, + attachments: cache, + media, + providerRegistry: buildExtensionHostMediaUnderstandingRegistry(), + activeModel: { provider: "openai", model: "gpt-4.1" }, + }); + + expect(result.outputs).toHaveLength(0); + expect(result.decision.outcome).toBe("skipped"); + expect(result.decision.attachments).toHaveLength(1); + expect(result.decision.attachments[0]?.attempts[0]?.reason).toBe( + "primary model supports vision natively", + ); + } finally { + await cache.cleanup(); + } + }); +}); diff --git a/src/extension-host/contributions/media-runtime-orchestration.ts b/src/extension-host/contributions/media-runtime-orchestration.ts new file mode 100644 index 00000000000..efb2a6318d6 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-orchestration.ts @@ -0,0 +1,271 @@ +import { + findModelInCatalog, + loadModelCatalog, + modelSupportsVision, +} from "../../agents/model-catalog.js"; +import type { MsgContext } from "../../auto-reply/templating.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { + MediaUnderstandingConfig, + MediaUnderstandingModelConfig, +} from "../../config/types.tools.js"; +import { logVerbose, shouldLogVerbose } from "../../globals.js"; +import { MediaAttachmentCache, selectAttachments } from "../../media-understanding/attachments.js"; +import { isMediaUnderstandingSkipError } from "../../media-understanding/errors.js"; +import type { + MediaAttachment, + MediaUnderstandingCapability, + MediaUnderstandingDecision, + MediaUnderstandingModelDecision, + MediaUnderstandingOutput, + MediaUnderstandingProvider, +} from "../../media-understanding/types.js"; +import { resolveAutoEntries, type ActiveMediaModel } from "./media-runtime-auto.js"; +import { resolveModelEntries, resolveScopeDecision } from "./media-runtime-config.js"; +import { buildModelDecision, formatDecisionSummary } from "./media-runtime-decision.js"; +import { + runExtensionHostMediaCliEntry, + runExtensionHostMediaProviderEntry, +} from "./media-runtime-entrypoints.js"; + +type ProviderRegistry = Map; + +export type RunCapabilityResult = { + outputs: MediaUnderstandingOutput[]; + decision: MediaUnderstandingDecision; +}; + +async function runAttachmentEntries(params: { + capability: MediaUnderstandingCapability; + cfg: OpenClawConfig; + ctx: MsgContext; + attachmentIndex: number; + agentDir?: string; + providerRegistry: ProviderRegistry; + cache: MediaAttachmentCache; + entries: MediaUnderstandingModelConfig[]; + config?: MediaUnderstandingConfig; +}): Promise<{ + output: MediaUnderstandingOutput | null; + attempts: MediaUnderstandingModelDecision[]; +}> { + const { entries, capability } = params; + const attempts: MediaUnderstandingModelDecision[] = []; + for (const entry of entries) { + const entryType = entry.type ?? (entry.command ? "cli" : "provider"); + try { + const result = + entryType === "cli" + ? await runExtensionHostMediaCliEntry({ + capability, + entry, + cfg: params.cfg, + ctx: params.ctx, + attachmentIndex: params.attachmentIndex, + cache: params.cache, + config: params.config, + }) + : await runExtensionHostMediaProviderEntry({ + capability, + entry, + cfg: params.cfg, + ctx: params.ctx, + attachmentIndex: params.attachmentIndex, + cache: params.cache, + agentDir: params.agentDir, + providerRegistry: params.providerRegistry, + config: params.config, + }); + if (result) { + const decision = buildModelDecision({ entry, entryType, outcome: "success" }); + if (result.provider) { + decision.provider = result.provider; + } + if (result.model) { + decision.model = result.model; + } + attempts.push(decision); + return { output: result, attempts }; + } + attempts.push( + buildModelDecision({ entry, entryType, outcome: "skipped", reason: "empty output" }), + ); + } catch (err) { + if (isMediaUnderstandingSkipError(err)) { + attempts.push( + buildModelDecision({ + entry, + entryType, + outcome: "skipped", + reason: `${err.reason}: ${err.message}`, + }), + ); + if (shouldLogVerbose()) { + logVerbose(`Skipping ${capability} model due to ${err.reason}: ${err.message}`); + } + continue; + } + attempts.push( + buildModelDecision({ + entry, + entryType, + outcome: "failed", + reason: String(err), + }), + ); + if (shouldLogVerbose()) { + logVerbose(`${capability} understanding failed: ${String(err)}`); + } + } + } + + return { output: null, attempts }; +} + +export async function runCapability(params: { + capability: MediaUnderstandingCapability; + cfg: OpenClawConfig; + ctx: MsgContext; + attachments: MediaAttachmentCache; + media: MediaAttachment[]; + agentDir?: string; + providerRegistry: ProviderRegistry; + config?: MediaUnderstandingConfig; + activeModel?: ActiveMediaModel; +}): Promise { + const { capability, cfg, ctx } = params; + const config = params.config ?? cfg.tools?.media?.[capability]; + if (config?.enabled === false) { + return { + outputs: [], + decision: { capability, outcome: "disabled", attachments: [] }, + }; + } + + const attachmentPolicy = config?.attachments; + const selected = selectAttachments({ + capability, + attachments: params.media, + policy: attachmentPolicy, + }); + if (selected.length === 0) { + return { + outputs: [], + decision: { capability, outcome: "no-attachment", attachments: [] }, + }; + } + + const scopeDecision = resolveScopeDecision({ scope: config?.scope, ctx }); + if (scopeDecision === "deny") { + if (shouldLogVerbose()) { + logVerbose(`${capability} understanding disabled by scope policy.`); + } + return { + outputs: [], + decision: { + capability, + outcome: "scope-deny", + attachments: selected.map((item) => ({ attachmentIndex: item.index, attempts: [] })), + }, + }; + } + + // Skip image understanding when the primary model supports vision natively. + // The image will be injected directly into the model context instead. + const activeProvider = params.activeModel?.provider?.trim(); + if (capability === "image" && activeProvider) { + const catalog = await loadModelCatalog({ config: cfg }); + const entry = findModelInCatalog(catalog, activeProvider, params.activeModel?.model ?? ""); + if (modelSupportsVision(entry)) { + if (shouldLogVerbose()) { + logVerbose("Skipping image understanding: primary model supports vision natively"); + } + const model = params.activeModel?.model?.trim(); + const reason = "primary model supports vision natively"; + return { + outputs: [], + decision: { + capability, + outcome: "skipped", + attachments: selected.map((item) => { + const attempt = { + type: "provider" as const, + provider: activeProvider, + model: model || undefined, + outcome: "skipped" as const, + reason, + }; + return { + attachmentIndex: item.index, + attempts: [attempt], + chosen: attempt, + }; + }), + }, + }; + } + } + + const entries = resolveModelEntries({ + cfg, + capability, + config, + providerRegistry: params.providerRegistry, + }); + let resolvedEntries = entries; + if (resolvedEntries.length === 0) { + resolvedEntries = await resolveAutoEntries({ + cfg, + agentDir: params.agentDir, + providerRegistry: params.providerRegistry, + capability, + activeModel: params.activeModel, + }); + } + if (resolvedEntries.length === 0) { + return { + outputs: [], + decision: { + capability, + outcome: "skipped", + attachments: selected.map((item) => ({ attachmentIndex: item.index, attempts: [] })), + }, + }; + } + + const outputs: MediaUnderstandingOutput[] = []; + const attachmentDecisions: MediaUnderstandingDecision["attachments"] = []; + for (const attachment of selected) { + const { output, attempts } = await runAttachmentEntries({ + capability, + cfg, + ctx, + attachmentIndex: attachment.index, + agentDir: params.agentDir, + providerRegistry: params.providerRegistry, + cache: params.attachments, + entries: resolvedEntries, + config, + }); + if (output) { + outputs.push(output); + } + attachmentDecisions.push({ + attachmentIndex: attachment.index, + attempts, + chosen: attempts.find((attempt) => attempt.outcome === "success"), + }); + } + const decision: MediaUnderstandingDecision = { + capability, + outcome: outputs.length > 0 ? "success" : "skipped", + attachments: attachmentDecisions, + }; + if (shouldLogVerbose()) { + logVerbose(`Media understanding ${formatDecisionSummary(decision)}`); + } + return { + outputs, + decision, + }; +} diff --git a/src/extension-host/contributions/media-runtime-registry.test.ts b/src/extension-host/contributions/media-runtime-registry.test.ts new file mode 100644 index 00000000000..b99867873b9 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-registry.test.ts @@ -0,0 +1,47 @@ +import { describe, expect, it } from "vitest"; +import { + buildExtensionHostMediaUnderstandingRegistry, + getExtensionHostMediaUnderstandingProvider, + normalizeExtensionHostMediaProviderId, +} from "./media-runtime-registry.js"; + +describe("extension host media runtime registry", () => { + it("registers built-in providers", () => { + const registry = buildExtensionHostMediaUnderstandingRegistry(); + const provider = getExtensionHostMediaUnderstandingProvider("mistral", registry); + + expect(provider?.id).toBe("mistral"); + expect(provider?.capabilities).toEqual(["audio"]); + }); + + it("keeps media-specific provider normalization", () => { + expect(normalizeExtensionHostMediaProviderId("gemini")).toBe("google"); + }); + + it("merges overrides onto built-in providers", () => { + const registry = buildExtensionHostMediaUnderstandingRegistry({ + openai: { + id: "openai", + capabilities: ["image"], + }, + }); + + const provider = getExtensionHostMediaUnderstandingProvider("openai", registry); + expect(provider?.id).toBe("openai"); + expect(provider?.capabilities).toEqual(["image"]); + expect(provider?.describeImage).toBeTypeOf("function"); + }); + + it("adds brand new providers", () => { + const registry = buildExtensionHostMediaUnderstandingRegistry({ + custom: { + id: "custom", + capabilities: ["audio"], + }, + }); + + const provider = getExtensionHostMediaUnderstandingProvider("custom", registry); + expect(provider?.id).toBe("custom"); + expect(provider?.capabilities).toEqual(["audio"]); + }); +}); diff --git a/src/extension-host/contributions/media-runtime-registry.ts b/src/extension-host/contributions/media-runtime-registry.ts new file mode 100644 index 00000000000..21998d9acc2 --- /dev/null +++ b/src/extension-host/contributions/media-runtime-registry.ts @@ -0,0 +1,45 @@ +import type { MediaUnderstandingProvider } from "../../media-understanding/types.js"; +import { + listExtensionHostMediaUnderstandingProviders, + normalizeExtensionHostMediaProviderId, +} from "../static/media-runtime-backends.js"; + +export type ExtensionHostMediaUnderstandingProviderRegistry = Map< + string, + MediaUnderstandingProvider +>; + +export { normalizeExtensionHostMediaProviderId } from "../static/media-runtime-backends.js"; + +export function buildExtensionHostMediaUnderstandingRegistry( + overrides?: Record, +): ExtensionHostMediaUnderstandingProviderRegistry { + const registry: ExtensionHostMediaUnderstandingProviderRegistry = new Map(); + for (const provider of listExtensionHostMediaUnderstandingProviders()) { + registry.set(normalizeExtensionHostMediaProviderId(provider.id), provider); + } + if (!overrides) { + return registry; + } + + for (const [key, provider] of Object.entries(overrides)) { + const normalizedKey = normalizeExtensionHostMediaProviderId(key); + const existing = registry.get(normalizedKey); + const merged = existing + ? { + ...existing, + ...provider, + capabilities: provider.capabilities ?? existing.capabilities, + } + : provider; + registry.set(normalizedKey, merged); + } + return registry; +} + +export function getExtensionHostMediaUnderstandingProvider( + id: string, + registry: ExtensionHostMediaUnderstandingProviderRegistry, +): MediaUnderstandingProvider | undefined { + return registry.get(normalizeExtensionHostMediaProviderId(id)); +} diff --git a/src/extension-host/contributions/provider-auth-flow.ts b/src/extension-host/contributions/provider-auth-flow.ts new file mode 100644 index 00000000000..4ad5bf3a768 --- /dev/null +++ b/src/extension-host/contributions/provider-auth-flow.ts @@ -0,0 +1,233 @@ +import { resolveOpenClawAgentDir } from "../../agents/agent-paths.js"; +import { + resolveDefaultAgentId, + resolveAgentDir, + resolveAgentWorkspaceDir, +} from "../../agents/agent-scope.js"; +import { upsertAuthProfile } from "../../agents/auth-profiles.js"; +import { resolveDefaultAgentWorkspaceDir } from "../../agents/workspace.js"; +import type { + ApplyAuthChoiceParams, + ApplyAuthChoiceResult, +} from "../../commands/auth-choice.apply.js"; +import { isRemoteEnvironment } from "../../commands/oauth-env.js"; +import { createVpsAwareOAuthHandlers } from "../../commands/oauth-flow.js"; +import { applyAuthProfileConfig } from "../../commands/onboard-auth.js"; +import { openUrl } from "../../commands/onboard-helpers.js"; +import { enablePluginInConfig } from "../../plugins/enable.js"; +import { resolvePluginProviders } from "../../plugins/providers.js"; +import type { ProviderAuthMethod } from "../../plugins/types.js"; +import { + applyExtensionHostDefaultModel, + mergeExtensionHostConfigPatch, + pickExtensionHostAuthMethod, + resolveExtensionHostProviderMatch, +} from "./provider-auth.js"; +import { runExtensionHostProviderModelSelectedHook } from "./provider-model-selection.js"; +import { resolveExtensionHostProviderChoice } from "./provider-wizard.js"; + +export type ExtensionHostPluginProviderAuthChoiceOptions = { + authChoice: string; + pluginId: string; + providerId: string; + methodId?: string; + label: string; +}; + +export async function runExtensionHostProviderAuthMethod(params: { + config: ApplyAuthChoiceParams["config"]; + runtime: ApplyAuthChoiceParams["runtime"]; + prompter: ApplyAuthChoiceParams["prompter"]; + method: ProviderAuthMethod; + agentDir?: string; + agentId?: string; + workspaceDir?: string; + emitNotes?: boolean; +}): Promise<{ config: ApplyAuthChoiceParams["config"]; defaultModel?: string }> { + const agentId = params.agentId ?? resolveDefaultAgentId(params.config); + const defaultAgentId = resolveDefaultAgentId(params.config); + const agentDir = + params.agentDir ?? + (agentId === defaultAgentId + ? resolveOpenClawAgentDir() + : resolveAgentDir(params.config, agentId)); + const workspaceDir = + params.workspaceDir ?? + resolveAgentWorkspaceDir(params.config, agentId) ?? + resolveDefaultAgentWorkspaceDir(); + + const isRemote = isRemoteEnvironment(); + const result = await params.method.run({ + config: params.config, + agentDir, + workspaceDir, + prompter: params.prompter, + runtime: params.runtime, + isRemote, + openUrl: async (url) => { + await openUrl(url); + }, + oauth: { + createVpsAwareHandlers: (opts) => createVpsAwareOAuthHandlers(opts), + }, + }); + + let nextConfig = params.config; + if (result.configPatch) { + nextConfig = mergeExtensionHostConfigPatch(nextConfig, result.configPatch); + } + + for (const profile of result.profiles) { + upsertAuthProfile({ + profileId: profile.profileId, + credential: profile.credential, + agentDir, + }); + + nextConfig = applyAuthProfileConfig(nextConfig, { + profileId: profile.profileId, + provider: profile.credential.provider, + mode: profile.credential.type === "token" ? "token" : profile.credential.type, + ...("email" in profile.credential && profile.credential.email + ? { email: profile.credential.email } + : {}), + }); + } + + if (params.emitNotes !== false && result.notes && result.notes.length > 0) { + await params.prompter.note(result.notes.join("\n"), "Provider notes"); + } + + return { + config: nextConfig, + defaultModel: result.defaultModel, + }; +} + +export async function applyExtensionHostLoadedPluginProvider( + params: ApplyAuthChoiceParams, +): Promise { + const agentId = params.agentId ?? resolveDefaultAgentId(params.config); + const workspaceDir = + resolveAgentWorkspaceDir(params.config, agentId) ?? resolveDefaultAgentWorkspaceDir(); + const providers = resolvePluginProviders({ config: params.config, workspaceDir }); + const resolved = resolveExtensionHostProviderChoice({ + providers, + choice: params.authChoice, + }); + if (!resolved) { + return null; + } + + const applied = await runExtensionHostProviderAuthMethod({ + config: params.config, + runtime: params.runtime, + prompter: params.prompter, + method: resolved.method, + agentDir: params.agentDir, + agentId: params.agentId, + workspaceDir, + }); + + let agentModelOverride: string | undefined; + if (applied.defaultModel) { + if (params.setDefaultModel) { + const nextConfig = applyExtensionHostDefaultModel(applied.config, applied.defaultModel); + await runExtensionHostProviderModelSelectedHook({ + config: nextConfig, + model: applied.defaultModel, + prompter: params.prompter, + agentDir: params.agentDir, + workspaceDir, + }); + await params.prompter.note( + `Default model set to ${applied.defaultModel}`, + "Model configured", + ); + return { config: nextConfig }; + } + agentModelOverride = applied.defaultModel; + } + + return { config: applied.config, agentModelOverride }; +} + +export async function applyExtensionHostPluginProvider( + params: ApplyAuthChoiceParams, + options: ExtensionHostPluginProviderAuthChoiceOptions, +): Promise { + if (params.authChoice !== options.authChoice) { + return null; + } + + const enableResult = enablePluginInConfig(params.config, options.pluginId); + let nextConfig = enableResult.config; + if (!enableResult.enabled) { + await params.prompter.note( + `${options.label} plugin is disabled (${enableResult.reason ?? "blocked"}).`, + options.label, + ); + return { config: nextConfig }; + } + + const agentId = params.agentId ?? resolveDefaultAgentId(nextConfig); + const defaultAgentId = resolveDefaultAgentId(nextConfig); + const agentDir = + params.agentDir ?? + (agentId === defaultAgentId ? resolveOpenClawAgentDir() : resolveAgentDir(nextConfig, agentId)); + const workspaceDir = + resolveAgentWorkspaceDir(nextConfig, agentId) ?? resolveDefaultAgentWorkspaceDir(); + + const providers = resolvePluginProviders({ config: nextConfig, workspaceDir }); + const provider = resolveExtensionHostProviderMatch(providers, options.providerId); + if (!provider) { + await params.prompter.note( + `${options.label} auth plugin is not available. Enable it and re-run the wizard.`, + options.label, + ); + return { config: nextConfig }; + } + + const method = pickExtensionHostAuthMethod(provider, options.methodId) ?? provider.auth[0]; + if (!method) { + await params.prompter.note(`${options.label} auth method missing.`, options.label); + return { config: nextConfig }; + } + + const applied = await runExtensionHostProviderAuthMethod({ + config: nextConfig, + runtime: params.runtime, + prompter: params.prompter, + method, + agentDir, + agentId, + workspaceDir, + }); + nextConfig = applied.config; + + let agentModelOverride: string | undefined; + if (applied.defaultModel) { + if (params.setDefaultModel) { + nextConfig = applyExtensionHostDefaultModel(nextConfig, applied.defaultModel); + await runExtensionHostProviderModelSelectedHook({ + config: nextConfig, + model: applied.defaultModel, + prompter: params.prompter, + agentDir, + workspaceDir, + }); + await params.prompter.note( + `Default model set to ${applied.defaultModel}`, + "Model configured", + ); + } else if (params.agentId) { + agentModelOverride = applied.defaultModel; + await params.prompter.note( + `Default model set to ${applied.defaultModel} for agent "${params.agentId}".`, + "Model configured", + ); + } + } + + return { config: nextConfig, agentModelOverride }; +} diff --git a/src/extension-host/contributions/provider-auth.test.ts b/src/extension-host/contributions/provider-auth.test.ts new file mode 100644 index 00000000000..e049d877812 --- /dev/null +++ b/src/extension-host/contributions/provider-auth.test.ts @@ -0,0 +1,106 @@ +import { describe, expect, it, vi } from "vitest"; +import type { ProviderPlugin } from "../../plugins/types.js"; +import { + applyExtensionHostDefaultModel, + mergeExtensionHostConfigPatch, + pickExtensionHostAuthMethod, + resolveExtensionHostProviderMatch, +} from "./provider-auth.js"; + +function makeProvider(overrides: Partial & Pick) { + return { + auth: [], + ...overrides, + } satisfies ProviderPlugin; +} + +describe("resolveExtensionHostProviderMatch", () => { + it("matches providers by normalized id and aliases", () => { + const providers = [ + makeProvider({ + id: "openrouter", + label: "OpenRouter", + aliases: ["Open Router"], + }), + ]; + + expect(resolveExtensionHostProviderMatch(providers, "openrouter")?.id).toBe("openrouter"); + expect(resolveExtensionHostProviderMatch(providers, " Open Router ")?.id).toBe("openrouter"); + expect(resolveExtensionHostProviderMatch(providers, "missing")).toBeNull(); + }); +}); + +describe("pickExtensionHostAuthMethod", () => { + it("matches auth methods by id or label", () => { + const provider = makeProvider({ + id: "ollama", + label: "Ollama", + auth: [ + { id: "local", label: "Local", kind: "custom", run: vi.fn() }, + { id: "cloud", label: "Cloud", kind: "custom", run: vi.fn() }, + ], + }); + + expect(pickExtensionHostAuthMethod(provider, "local")?.id).toBe("local"); + expect(pickExtensionHostAuthMethod(provider, "cloud")?.id).toBe("cloud"); + expect(pickExtensionHostAuthMethod(provider, "Cloud")?.id).toBe("cloud"); + expect(pickExtensionHostAuthMethod(provider, "missing")).toBeNull(); + }); +}); + +describe("mergeExtensionHostConfigPatch", () => { + it("deep-merges plain record config patches", () => { + expect( + mergeExtensionHostConfigPatch( + { + models: { providers: { ollama: { baseUrl: "http://127.0.0.1:11434" } } }, + auth: { profiles: { existing: { provider: "anthropic" } } }, + }, + { + models: { providers: { ollama: { api: "ollama" } } }, + auth: { profiles: { fresh: { provider: "ollama" } } }, + }, + ), + ).toEqual({ + models: { providers: { ollama: { baseUrl: "http://127.0.0.1:11434", api: "ollama" } } }, + auth: { + profiles: { + existing: { provider: "anthropic" }, + fresh: { provider: "ollama" }, + }, + }, + }); + }); +}); + +describe("applyExtensionHostDefaultModel", () => { + it("sets the primary model while preserving fallback config", () => { + expect( + applyExtensionHostDefaultModel( + { + agents: { + defaults: { + model: { + primary: "anthropic/claude-sonnet-4-5", + fallbacks: ["openai/gpt-5"], + }, + }, + }, + }, + "ollama/qwen3:4b", + ), + ).toEqual({ + agents: { + defaults: { + models: { + "ollama/qwen3:4b": {}, + }, + model: { + primary: "ollama/qwen3:4b", + fallbacks: ["openai/gpt-5"], + }, + }, + }, + }); + }); +}); diff --git a/src/extension-host/contributions/provider-auth.ts b/src/extension-host/contributions/provider-auth.ts new file mode 100644 index 00000000000..9ac2947f8de --- /dev/null +++ b/src/extension-host/contributions/provider-auth.ts @@ -0,0 +1,82 @@ +import { normalizeProviderId } from "../../agents/provider-id.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { ProviderAuthMethod, ProviderPlugin } from "../../plugins/types.js"; + +export function resolveExtensionHostProviderMatch( + providers: ProviderPlugin[], + rawProvider?: string, +): ProviderPlugin | null { + const raw = rawProvider?.trim(); + if (!raw) { + return null; + } + const normalized = normalizeProviderId(raw); + return ( + providers.find((provider) => normalizeProviderId(provider.id) === normalized) ?? + providers.find( + (provider) => + provider.aliases?.some((alias) => normalizeProviderId(alias) === normalized) ?? false, + ) ?? + null + ); +} + +export function pickExtensionHostAuthMethod( + provider: ProviderPlugin, + rawMethod?: string, +): ProviderAuthMethod | null { + const raw = rawMethod?.trim(); + if (!raw) { + return null; + } + const normalized = raw.toLowerCase(); + return ( + provider.auth.find((method) => method.id.toLowerCase() === normalized) ?? + provider.auth.find((method) => method.label.toLowerCase() === normalized) ?? + null + ); +} + +function isPlainRecord(value: unknown): value is Record { + return Boolean(value && typeof value === "object" && !Array.isArray(value)); +} + +export function mergeExtensionHostConfigPatch(base: T, patch: unknown): T { + if (!isPlainRecord(base) || !isPlainRecord(patch)) { + return patch as T; + } + + const next: Record = { ...base }; + for (const [key, value] of Object.entries(patch)) { + const existing = next[key]; + if (isPlainRecord(existing) && isPlainRecord(value)) { + next[key] = mergeExtensionHostConfigPatch(existing, value); + } else { + next[key] = value; + } + } + return next as T; +} + +export function applyExtensionHostDefaultModel(cfg: OpenClawConfig, model: string): OpenClawConfig { + const models = { ...cfg.agents?.defaults?.models }; + models[model] = models[model] ?? {}; + + const existingModel = cfg.agents?.defaults?.model; + return { + ...cfg, + agents: { + ...cfg.agents, + defaults: { + ...cfg.agents?.defaults, + models, + model: { + ...(existingModel && typeof existingModel === "object" && "fallbacks" in existingModel + ? { fallbacks: (existingModel as { fallbacks?: string[] }).fallbacks } + : undefined), + primary: model, + }, + }, + }, + }; +} diff --git a/src/extension-host/contributions/provider-discovery.test.ts b/src/extension-host/contributions/provider-discovery.test.ts new file mode 100644 index 00000000000..cde7c52b822 --- /dev/null +++ b/src/extension-host/contributions/provider-discovery.test.ts @@ -0,0 +1,107 @@ +import { describe, expect, it } from "vitest"; +import type { ModelProviderConfig } from "../../config/types.js"; +import type { ProviderDiscoveryOrder, ProviderPlugin } from "../../plugins/types.js"; +import { + groupExtensionHostDiscoveryProvidersByOrder, + normalizeExtensionHostDiscoveryResult, + resolveExtensionHostDiscoveryProviders, +} from "./provider-discovery.js"; + +function makeProvider(params: { + id: string; + label?: string; + order?: ProviderDiscoveryOrder; + discovery?: boolean; +}): ProviderPlugin { + return { + id: params.id, + label: params.label ?? params.id, + auth: [], + ...(params.discovery === false + ? {} + : { + discovery: { + ...(params.order ? { order: params.order } : {}), + run: async () => null, + }, + }), + }; +} + +function makeModelProviderConfig(overrides?: Partial): ModelProviderConfig { + return { + baseUrl: "http://127.0.0.1:8000/v1", + models: [], + ...overrides, + }; +} + +describe("resolveExtensionHostDiscoveryProviders", () => { + it("keeps only providers with discovery handlers", () => { + expect( + resolveExtensionHostDiscoveryProviders([ + makeProvider({ id: "simple" }), + makeProvider({ id: "hidden", discovery: false }), + ]).map((provider) => provider.id), + ).toEqual(["simple"]); + }); +}); + +describe("groupExtensionHostDiscoveryProvidersByOrder", () => { + it("groups providers by declared order and sorts labels within each group", () => { + const grouped = groupExtensionHostDiscoveryProvidersByOrder([ + makeProvider({ id: "late-b", label: "Zulu" }), + makeProvider({ id: "late-a", label: "Alpha" }), + makeProvider({ id: "paired", label: "Paired", order: "paired" }), + makeProvider({ id: "profile", label: "Profile", order: "profile" }), + makeProvider({ id: "simple", label: "Simple", order: "simple" }), + ]); + + expect(grouped.simple.map((provider) => provider.id)).toEqual(["simple"]); + expect(grouped.profile.map((provider) => provider.id)).toEqual(["profile"]); + expect(grouped.paired.map((provider) => provider.id)).toEqual(["paired"]); + expect(grouped.late.map((provider) => provider.id)).toEqual(["late-a", "late-b"]); + }); +}); + +describe("normalizeExtensionHostDiscoveryResult", () => { + it("maps a single provider result to the provider id", () => { + const provider = makeProvider({ id: "Ollama" }); + const normalized = normalizeExtensionHostDiscoveryResult({ + provider, + result: { + provider: makeModelProviderConfig({ + baseUrl: "http://127.0.0.1:11434", + api: "ollama", + }), + }, + }); + + expect(normalized).toEqual({ + ollama: { + baseUrl: "http://127.0.0.1:11434", + api: "ollama", + models: [], + }, + }); + }); + + it("normalizes keys for multi-provider discovery results", () => { + const normalized = normalizeExtensionHostDiscoveryResult({ + provider: makeProvider({ id: "ignored" }), + result: { + providers: { + " VLLM ": makeModelProviderConfig(), + "": makeModelProviderConfig({ baseUrl: "http://ignored" }), + }, + }, + }); + + expect(normalized).toEqual({ + vllm: { + baseUrl: "http://127.0.0.1:8000/v1", + models: [], + }, + }); + }); +}); diff --git a/src/extension-host/contributions/provider-discovery.ts b/src/extension-host/contributions/provider-discovery.ts new file mode 100644 index 00000000000..9846e304015 --- /dev/null +++ b/src/extension-host/contributions/provider-discovery.ts @@ -0,0 +1,61 @@ +import { normalizeProviderId } from "../../agents/provider-id.js"; +import type { ModelProviderConfig } from "../../config/types.js"; +import type { ProviderDiscoveryOrder, ProviderPlugin } from "../../plugins/types.js"; + +const DISCOVERY_ORDER: readonly ProviderDiscoveryOrder[] = ["simple", "profile", "paired", "late"]; + +export function resolveExtensionHostDiscoveryProviders( + providers: ProviderPlugin[], +): ProviderPlugin[] { + return providers.filter((provider) => provider.discovery); +} + +export function groupExtensionHostDiscoveryProvidersByOrder( + providers: ProviderPlugin[], +): Record { + const grouped = { + simple: [], + profile: [], + paired: [], + late: [], + } as Record; + + for (const provider of providers) { + const order = provider.discovery?.order ?? "late"; + grouped[order].push(provider); + } + + for (const order of DISCOVERY_ORDER) { + grouped[order].sort((a, b) => a.label.localeCompare(b.label)); + } + + return grouped; +} + +export function normalizeExtensionHostDiscoveryResult(params: { + provider: ProviderPlugin; + result: + | { provider: ModelProviderConfig } + | { providers: Record } + | null + | undefined; +}): Record { + const result = params.result; + if (!result) { + return {}; + } + + if ("provider" in result) { + return { [normalizeProviderId(params.provider.id)]: result.provider }; + } + + const normalized: Record = {}; + for (const [key, value] of Object.entries(result.providers)) { + const normalizedKey = normalizeProviderId(key); + if (!normalizedKey || !value) { + continue; + } + normalized[normalizedKey] = value; + } + return normalized; +} diff --git a/src/extension-host/contributions/provider-model-selection.ts b/src/extension-host/contributions/provider-model-selection.ts new file mode 100644 index 00000000000..cf78b4ab545 --- /dev/null +++ b/src/extension-host/contributions/provider-model-selection.ts @@ -0,0 +1,40 @@ +import { DEFAULT_PROVIDER } from "../../agents/defaults.js"; +import { parseModelRef } from "../../agents/model-ref.js"; +import { normalizeProviderId } from "../../agents/provider-id.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import { resolvePluginProviders } from "../../plugins/providers.js"; +import type { WizardPrompter } from "../../wizard/prompts.js"; + +export async function runExtensionHostProviderModelSelectedHook(params: { + config: OpenClawConfig; + model: string; + prompter: WizardPrompter; + agentDir?: string; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; +}): Promise { + const parsed = parseModelRef(params.model, DEFAULT_PROVIDER); + if (!parsed) { + return; + } + + const providers = resolvePluginProviders({ + config: params.config, + workspaceDir: params.workspaceDir, + env: params.env, + }); + const provider = providers.find( + (entry) => normalizeProviderId(entry.id) === normalizeProviderId(parsed.provider), + ); + if (!provider?.onModelSelected) { + return; + } + + await provider.onModelSelected({ + config: params.config, + model: params.model, + prompter: params.prompter, + agentDir: params.agentDir, + workspaceDir: params.workspaceDir, + }); +} diff --git a/src/extension-host/contributions/provider-runtime.test.ts b/src/extension-host/contributions/provider-runtime.test.ts new file mode 100644 index 00000000000..722375219e6 --- /dev/null +++ b/src/extension-host/contributions/provider-runtime.test.ts @@ -0,0 +1,28 @@ +import { describe, expect, it } from "vitest"; +import { createEmptyPluginRegistry } from "../../plugins/registry.js"; +import { resolveExtensionHostProviders } from "./provider-runtime.js"; +import { addExtensionHostProviderRegistration } from "./runtime-registry.js"; + +describe("resolveExtensionHostProviders", () => { + it("projects provider registrations into provider plugins with plugin ids", () => { + const registry = createEmptyPluginRegistry(); + addExtensionHostProviderRegistration(registry, { + pluginId: "demo-plugin", + source: "bundled", + provider: { + id: "demo-provider", + label: "Demo Provider", + auth: [], + }, + }); + + expect(resolveExtensionHostProviders({ registry })).toEqual([ + { + id: "demo-provider", + label: "Demo Provider", + auth: [], + pluginId: "demo-plugin", + }, + ]); + }); +}); diff --git a/src/extension-host/contributions/provider-runtime.ts b/src/extension-host/contributions/provider-runtime.ts new file mode 100644 index 00000000000..ab45a16e05e --- /dev/null +++ b/src/extension-host/contributions/provider-runtime.ts @@ -0,0 +1,22 @@ +import type { PluginRegistry } from "../../plugins/registry.js"; +import type { ProviderPlugin } from "../../plugins/types.js"; +import { listExtensionHostProviderRegistrations } from "./runtime-registry.js"; + +export function resolveExtensionHostProviders(params: { + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >; +}): ProviderPlugin[] { + return listExtensionHostProviderRegistrations(params.registry).map((entry) => ({ + ...entry.provider, + pluginId: entry.pluginId, + })); +} diff --git a/src/extension-host/contributions/provider-wizard.test.ts b/src/extension-host/contributions/provider-wizard.test.ts new file mode 100644 index 00000000000..aa6e499122d --- /dev/null +++ b/src/extension-host/contributions/provider-wizard.test.ts @@ -0,0 +1,83 @@ +import { describe, expect, it, vi } from "vitest"; +import type { ProviderPlugin } from "../../plugins/types.js"; +import { + buildExtensionHostProviderMethodChoice, + resolveExtensionHostProviderChoice, + resolveExtensionHostProviderModelPickerEntries, + resolveExtensionHostProviderWizardOptions, +} from "./provider-wizard.js"; + +function makeProvider(overrides: Partial & Pick) { + return { + auth: [], + ...overrides, + } satisfies ProviderPlugin; +} + +describe("resolveExtensionHostProviderWizardOptions", () => { + it("uses explicit onboarding choice ids and bound method ids", () => { + const provider = makeProvider({ + id: "vllm", + label: "vLLM", + auth: [ + { id: "local", label: "Local", kind: "custom", run: vi.fn() }, + { id: "cloud", label: "Cloud", kind: "custom", run: vi.fn() }, + ], + wizard: { + onboarding: { + choiceId: "self-hosted-vllm", + methodId: "local", + choiceLabel: "vLLM local", + groupId: "local-runtimes", + groupLabel: "Local runtimes", + }, + }, + }); + + expect(resolveExtensionHostProviderWizardOptions([provider])).toEqual([ + { + value: "self-hosted-vllm", + label: "vLLM local", + groupId: "local-runtimes", + groupLabel: "Local runtimes", + }, + ]); + expect( + resolveExtensionHostProviderChoice({ + providers: [provider], + choice: "self-hosted-vllm", + }), + ).toEqual({ + provider, + method: provider.auth[0], + }); + }); +}); + +describe("resolveExtensionHostProviderModelPickerEntries", () => { + it("builds model-picker entries from provider metadata", () => { + const provider = makeProvider({ + id: "sglang", + label: "SGLang", + auth: [ + { id: "server", label: "Server", kind: "custom", run: vi.fn() }, + { id: "cloud", label: "Cloud", kind: "custom", run: vi.fn() }, + ], + wizard: { + modelPicker: { + label: "SGLang server", + hint: "OpenAI-compatible local runtime", + methodId: "server", + }, + }, + }); + + expect(resolveExtensionHostProviderModelPickerEntries([provider])).toEqual([ + { + value: buildExtensionHostProviderMethodChoice("sglang", "server"), + label: "SGLang server", + hint: "OpenAI-compatible local runtime", + }, + ]); + }); +}); diff --git a/src/extension-host/contributions/provider-wizard.ts b/src/extension-host/contributions/provider-wizard.ts new file mode 100644 index 00000000000..afb5fd9055c --- /dev/null +++ b/src/extension-host/contributions/provider-wizard.ts @@ -0,0 +1,201 @@ +import { normalizeProviderId } from "../../agents/provider-id.js"; +import type { + ProviderAuthMethod, + ProviderPlugin, + ProviderPluginWizardModelPicker, + ProviderPluginWizardOnboarding, +} from "../../plugins/types.js"; + +export const EXTENSION_HOST_PROVIDER_CHOICE_PREFIX = "provider-plugin:"; + +export type ExtensionHostProviderWizardOption = { + value: string; + label: string; + hint?: string; + groupId: string; + groupLabel: string; + groupHint?: string; +}; + +export type ExtensionHostProviderModelPickerEntry = { + value: string; + label: string; + hint?: string; +}; + +function normalizeChoiceId(choiceId: string): string { + return choiceId.trim(); +} + +function resolveWizardOnboardingChoiceId( + provider: ProviderPlugin, + wizard: ProviderPluginWizardOnboarding, +): string { + const explicit = wizard.choiceId?.trim(); + if (explicit) { + return explicit; + } + const explicitMethodId = wizard.methodId?.trim(); + if (explicitMethodId) { + return buildExtensionHostProviderMethodChoice(provider.id, explicitMethodId); + } + if (provider.auth.length === 1) { + return provider.id; + } + return buildExtensionHostProviderMethodChoice(provider.id, provider.auth[0]?.id ?? "default"); +} + +function resolveMethodById( + provider: ProviderPlugin, + methodId?: string, +): ProviderAuthMethod | undefined { + const normalizedMethodId = methodId?.trim().toLowerCase(); + if (!normalizedMethodId) { + return provider.auth[0]; + } + return provider.auth.find((method) => method.id.trim().toLowerCase() === normalizedMethodId); +} + +function buildOnboardingOptionForMethod(params: { + provider: ProviderPlugin; + wizard: ProviderPluginWizardOnboarding; + method: ProviderAuthMethod; + value: string; +}): ExtensionHostProviderWizardOption { + const normalizedGroupId = params.wizard.groupId?.trim() || params.provider.id; + return { + value: normalizeChoiceId(params.value), + label: + params.wizard.choiceLabel?.trim() || + (params.provider.auth.length === 1 ? params.provider.label : params.method.label), + hint: params.wizard.choiceHint?.trim() || params.method.hint, + groupId: normalizedGroupId, + groupLabel: params.wizard.groupLabel?.trim() || params.provider.label, + groupHint: params.wizard.groupHint?.trim(), + }; +} + +function resolveModelPickerChoiceValue( + provider: ProviderPlugin, + modelPicker: ProviderPluginWizardModelPicker, +): string { + const explicitMethodId = modelPicker.methodId?.trim(); + if (explicitMethodId) { + return buildExtensionHostProviderMethodChoice(provider.id, explicitMethodId); + } + if (provider.auth.length === 1) { + return provider.id; + } + return buildExtensionHostProviderMethodChoice(provider.id, provider.auth[0]?.id ?? "default"); +} + +export function buildExtensionHostProviderMethodChoice( + providerId: string, + methodId: string, +): string { + return `${EXTENSION_HOST_PROVIDER_CHOICE_PREFIX}${providerId.trim()}:${methodId.trim()}`; +} + +export function resolveExtensionHostProviderWizardOptions( + providers: ProviderPlugin[], +): ExtensionHostProviderWizardOption[] { + const options: ExtensionHostProviderWizardOption[] = []; + + for (const provider of providers) { + const wizard = provider.wizard?.onboarding; + if (!wizard) { + continue; + } + const explicitMethod = resolveMethodById(provider, wizard.methodId); + if (explicitMethod) { + options.push( + buildOnboardingOptionForMethod({ + provider, + wizard, + method: explicitMethod, + value: resolveWizardOnboardingChoiceId(provider, wizard), + }), + ); + continue; + } + + for (const method of provider.auth) { + options.push( + buildOnboardingOptionForMethod({ + provider, + wizard, + method, + value: buildExtensionHostProviderMethodChoice(provider.id, method.id), + }), + ); + } + } + + return options; +} + +export function resolveExtensionHostProviderModelPickerEntries( + providers: ProviderPlugin[], +): ExtensionHostProviderModelPickerEntry[] { + const entries: ExtensionHostProviderModelPickerEntry[] = []; + + for (const provider of providers) { + const modelPicker = provider.wizard?.modelPicker; + if (!modelPicker) { + continue; + } + entries.push({ + value: resolveModelPickerChoiceValue(provider, modelPicker), + label: modelPicker.label?.trim() || `${provider.label} (custom)`, + hint: modelPicker.hint?.trim(), + }); + } + + return entries; +} + +export function resolveExtensionHostProviderChoice(params: { + providers: ProviderPlugin[]; + choice: string; +}): { provider: ProviderPlugin; method: ProviderAuthMethod } | null { + const choice = params.choice.trim(); + if (!choice) { + return null; + } + + if (choice.startsWith(EXTENSION_HOST_PROVIDER_CHOICE_PREFIX)) { + const payload = choice.slice(EXTENSION_HOST_PROVIDER_CHOICE_PREFIX.length); + const separator = payload.indexOf(":"); + const providerId = separator >= 0 ? payload.slice(0, separator) : payload; + const methodId = separator >= 0 ? payload.slice(separator + 1) : undefined; + const provider = params.providers.find( + (entry) => normalizeProviderId(entry.id) === normalizeProviderId(providerId), + ); + if (!provider) { + return null; + } + const method = resolveMethodById(provider, methodId); + return method ? { provider, method } : null; + } + + for (const provider of params.providers) { + const onboarding = provider.wizard?.onboarding; + if (onboarding) { + const onboardingChoiceId = resolveWizardOnboardingChoiceId(provider, onboarding); + if (normalizeChoiceId(onboardingChoiceId) === choice) { + const method = resolveMethodById(provider, onboarding.methodId); + if (method) { + return { provider, method }; + } + } + } + if ( + normalizeProviderId(provider.id) === normalizeProviderId(choice) && + provider.auth.length > 0 + ) { + return { provider, method: provider.auth[0] }; + } + } + + return null; +} diff --git a/src/extension-host/contributions/registry-writes.test.ts b/src/extension-host/contributions/registry-writes.test.ts new file mode 100644 index 00000000000..7d96daa361a --- /dev/null +++ b/src/extension-host/contributions/registry-writes.test.ts @@ -0,0 +1,203 @@ +import { describe, expect, it, vi } from "vitest"; +import { createEmptyPluginRegistry, type PluginRecord } from "../../plugins/registry.js"; +import { + addExtensionChannelRegistration, + addExtensionCliRegistration, + addExtensionCommandRegistration, + addExtensionContextEngineRegistration, + addExtensionGatewayMethodRegistration, + addExtensionLegacyHookRegistration, + addExtensionHttpRouteRegistration, + addExtensionProviderRegistration, + addExtensionServiceRegistration, + addExtensionToolRegistration, + addExtensionTypedHookRegistration, +} from "./registry-writes.js"; + +function createRecord(): PluginRecord { + return { + id: "demo", + name: "Demo", + source: "/plugins/demo.ts", + origin: "workspace", + enabled: true, + status: "loaded", + toolNames: [], + hookNames: [], + channelIds: [], + providerIds: [], + gatewayMethods: [], + cliCommands: [], + services: [], + commands: [], + httpRoutes: 0, + hookCount: 0, + configSchema: false, + }; +} + +describe("extension host registry writes", () => { + it("writes tool registrations through the host helper", () => { + const registry = createEmptyPluginRegistry(); + const record = createRecord(); + + addExtensionToolRegistration({ + registry, + record, + names: ["tool-a"], + entry: { + pluginId: record.id, + factory: (() => ({}) as never) as never, + names: ["tool-a"], + optional: false, + source: record.source, + }, + }); + + expect(record.toolNames).toEqual(["tool-a"]); + expect(registry.tools).toHaveLength(1); + }); + + it("writes cli, service, and command registrations through host helpers", () => { + const registry = createEmptyPluginRegistry(); + const record = createRecord(); + + addExtensionCliRegistration({ + registry, + record, + commands: ["demo"], + entry: { + pluginId: record.id, + register: (() => {}) as never, + commands: ["demo"], + source: record.source, + }, + }); + addExtensionServiceRegistration({ + registry, + record, + serviceId: "svc", + entry: { + pluginId: record.id, + service: { id: "svc", start: async () => {}, stop: async () => {} } as never, + source: record.source, + }, + }); + addExtensionCommandRegistration({ + registry, + record, + commandName: "cmd", + entry: { + pluginId: record.id, + command: { name: "cmd", description: "demo", run: async () => {} } as never, + source: record.source, + }, + }); + + expect(record.cliCommands).toEqual(["demo"]); + expect(record.services).toEqual(["svc"]); + expect(record.commands).toEqual(["cmd"]); + expect(registry.cliRegistrars).toHaveLength(1); + expect(registry.services).toHaveLength(1); + expect(registry.commands).toHaveLength(1); + }); + + it("writes gateway, http, channel, and provider registrations through host helpers", () => { + const registry = createEmptyPluginRegistry(); + const record = createRecord(); + + addExtensionGatewayMethodRegistration({ + registry, + record, + method: "demo.method", + handler: (() => {}) as never, + }); + addExtensionHttpRouteRegistration({ + registry, + record, + action: "append", + entry: { + pluginId: record.id, + path: "/demo", + handler: (() => {}) as never, + auth: "optional", + match: "exact", + source: record.source, + }, + }); + addExtensionChannelRegistration({ + registry, + record, + channelId: "demo-channel", + entry: { + pluginId: record.id, + plugin: {} as never, + source: record.source, + }, + }); + addExtensionProviderRegistration({ + registry, + record, + providerId: "demo-provider", + entry: { + pluginId: record.id, + provider: {} as never, + source: record.source, + }, + }); + + expect(record.gatewayMethods).toEqual(["demo.method"]); + expect(record.httpRoutes).toBe(1); + expect(record.channelIds).toEqual(["demo-channel"]); + expect(record.providerIds).toEqual(["demo-provider"]); + expect(registry.gatewayHandlers["demo.method"]).toBeTypeOf("function"); + expect(registry.httpRoutes).toHaveLength(1); + expect(registry.channels).toHaveLength(1); + expect(registry.providers).toHaveLength(1); + expect(registry.providers[0]?.pluginId).toBe("demo"); + }); + + it("writes legacy hooks, typed hooks, and context engines through host helpers", () => { + const registry = createEmptyPluginRegistry(); + const record = createRecord(); + const registerEngine = vi.fn(); + + addExtensionLegacyHookRegistration({ + registry, + record, + hookName: "before_send", + events: ["before_send"], + entry: { + pluginId: record.id, + entry: {} as never, + events: ["before_send"], + source: record.source, + handler: (() => {}) as never, + }, + }); + addExtensionTypedHookRegistration({ + registry, + record, + entry: { + pluginId: record.id, + hookName: "before_send" as never, + handler: (() => {}) as never, + priority: 0, + source: record.source, + } as never, + }); + addExtensionContextEngineRegistration({ + entry: { + engineId: "context-demo", + factory: (() => ({}) as never) as never, + }, + registerEngine, + }); + + expect(record.hookNames).toEqual(["before_send"]); + expect(record.hookCount).toBe(1); + expect(registry.hooks).toHaveLength(1); + expect(registry.typedHooks).toHaveLength(1); + expect(registerEngine).toHaveBeenCalledWith("context-demo", expect.any(Function)); + }); +}); diff --git a/src/extension-host/contributions/registry-writes.ts b/src/extension-host/contributions/registry-writes.ts new file mode 100644 index 00000000000..92dfb4a2791 --- /dev/null +++ b/src/extension-host/contributions/registry-writes.ts @@ -0,0 +1,172 @@ +import type { GatewayRequestHandler } from "../../gateway/server-methods/types.js"; +import type { + PluginChannelRegistration, + PluginCliRegistration, + PluginCommandRegistration, + PluginHookRegistration, + PluginHttpRouteRegistration, + PluginRecord, + PluginRegistry, + PluginProviderRegistration, + PluginServiceRegistration, + PluginToolRegistration, +} from "../../plugins/registry.js"; +import type { PluginHookRegistration as TypedPluginHookRegistration } from "../../plugins/types.js"; +import { + registerExtensionHostContextEngine, + type ExtensionHostContextEngineFactory, +} from "./context-engine-runtime.js"; +import type { + ExtensionHostChannelRegistration, + ExtensionHostCliRegistration, + ExtensionHostCommandRegistration, + ExtensionHostContextEngineRegistration, + ExtensionHostLegacyHookRegistration, + ExtensionHostHttpRouteRegistration, + ExtensionHostProviderRegistration, + ExtensionHostServiceRegistration, + ExtensionHostToolRegistration, +} from "./runtime-registrations.js"; +import { + addExtensionHostChannelRegistration, + addExtensionHostCliRegistration, + addExtensionHostCommandRegistration, + addExtensionHostHttpRoute, + addExtensionHostProviderRegistration, + addExtensionHostServiceRegistration, + addExtensionHostToolRegistration, + replaceExtensionHostHttpRoute, + setExtensionHostGatewayHandler, +} from "./runtime-registry.js"; + +export function addExtensionGatewayMethodRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + method: string; + handler: GatewayRequestHandler; +}): void { + setExtensionHostGatewayHandler({ + registry: params.registry, + method: params.method, + handler: params.handler, + }); + params.record.gatewayMethods.push(params.method); +} + +export function addExtensionHttpRouteRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + entry: ExtensionHostHttpRouteRegistration; + action: "replace" | "append"; + existingIndex?: number; +}): void { + if (params.action === "replace") { + if (params.existingIndex === undefined) { + return; + } + replaceExtensionHostHttpRoute({ + registry: params.registry, + index: params.existingIndex, + entry: params.entry as PluginHttpRouteRegistration, + }); + return; + } + + params.record.httpRoutes += 1; + addExtensionHostHttpRoute(params.registry, params.entry as PluginHttpRouteRegistration); +} + +export function addExtensionChannelRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + channelId: string; + entry: ExtensionHostChannelRegistration; +}): void { + params.record.channelIds.push(params.channelId); + addExtensionHostChannelRegistration(params.registry, params.entry as PluginChannelRegistration); +} + +export function addExtensionProviderRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + providerId: string; + entry: ExtensionHostProviderRegistration; +}): void { + params.record.providerIds.push(params.providerId); + addExtensionHostProviderRegistration(params.registry, params.entry as PluginProviderRegistration); +} + +export function addExtensionLegacyHookRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + hookName: string; + entry: ExtensionHostLegacyHookRegistration; + events: string[]; +}): void { + params.record.hookNames.push(params.hookName); + params.registry.hooks.push({ + pluginId: params.entry.pluginId, + entry: params.entry.entry, + events: params.events, + source: params.entry.source, + } as PluginHookRegistration); +} + +export function addExtensionTypedHookRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + entry: TypedPluginHookRegistration; +}): void { + params.record.hookCount += 1; + params.registry.typedHooks.push(params.entry); +} + +export function addExtensionToolRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + names: string[]; + entry: ExtensionHostToolRegistration; +}): void { + if (params.names.length > 0) { + params.record.toolNames.push(...params.names); + } + addExtensionHostToolRegistration(params.registry, params.entry as PluginToolRegistration); +} + +export function addExtensionCliRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + commands: string[]; + entry: ExtensionHostCliRegistration; +}): void { + params.record.cliCommands.push(...params.commands); + addExtensionHostCliRegistration(params.registry, params.entry as PluginCliRegistration); +} + +export function addExtensionServiceRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + serviceId: string; + entry: ExtensionHostServiceRegistration; +}): void { + params.record.services.push(params.serviceId); + addExtensionHostServiceRegistration(params.registry, params.entry as PluginServiceRegistration); +} + +export function addExtensionCommandRegistration(params: { + registry: PluginRegistry; + record: PluginRecord; + commandName: string; + entry: ExtensionHostCommandRegistration; +}): void { + params.record.commands.push(params.commandName); + addExtensionHostCommandRegistration(params.registry, params.entry as PluginCommandRegistration); +} + +export function addExtensionContextEngineRegistration(params: { + entry: ExtensionHostContextEngineRegistration; + registerEngine?: (engineId: string, factory: ExtensionHostContextEngineFactory) => void; +}): void { + const registerEngine = params.registerEngine ?? registerExtensionHostContextEngine; + registerEngine(params.entry.engineId, params.entry.factory); +} diff --git a/src/extension-host/contributions/runtime-registrations.test.ts b/src/extension-host/contributions/runtime-registrations.test.ts new file mode 100644 index 00000000000..6a81f95980c --- /dev/null +++ b/src/extension-host/contributions/runtime-registrations.test.ts @@ -0,0 +1,524 @@ +import { describe, expect, it, vi } from "vitest"; +import type { AnyAgentTool } from "../../agents/tools/common.js"; +import type { ChannelPlugin } from "../../channels/plugins/types.js"; +import type { ContextEngineFactory } from "../../context-engine/registry.js"; +import type { InternalHookHandler } from "../../hooks/internal-hooks.js"; +import type { HookEntry } from "../../hooks/types.js"; +import type { + OpenClawPluginCliContext, + OpenClawPluginCommandDefinition, + OpenClawPluginHookOptions, + OpenClawPluginService, + PluginHookRegistration, + ProviderPlugin, +} from "../../plugins/types.js"; +import { + resolveExtensionChannelRegistration, + resolveExtensionCliRegistration, + resolveExtensionCommandRegistration, + resolveExtensionContextEngineRegistration, + resolveExtensionGatewayMethodRegistration, + resolveExtensionLegacyHookRegistration, + resolveExtensionHttpRouteRegistration, + resolveExtensionProviderRegistration, + resolveExtensionServiceRegistration, + resolveExtensionToolRegistration, + resolveExtensionTypedHookRegistration, + type ExtensionHostChannelRegistration, + type ExtensionHostHttpRouteRegistration, + type ExtensionHostProviderRegistration, +} from "./runtime-registrations.js"; + +function createChannelPlugin(id: string): ChannelPlugin { + return { + id, + meta: { + id, + label: id, + selectionLabel: id, + docsPath: `/channels/${id}`, + blurb: "test", + }, + capabilities: { chatTypes: ["direct"] }, + config: { + listAccountIds: () => [], + resolveAccount: () => ({}), + }, + }; +} + +function createProviderPlugin(id: string): ProviderPlugin { + return { + id, + label: id, + auth: [], + }; +} + +function createService(id: string): OpenClawPluginService { + return { + id, + start: vi.fn(), + }; +} + +function createCommand(name: string): OpenClawPluginCommandDefinition { + return { + name, + description: "demo command", + handler: vi.fn(), + }; +} + +function createLegacyHookEntry(name: string): HookEntry { + return { + hook: { + name, + description: "hook description", + source: "openclaw-plugin", + pluginId: "demo-plugin", + filePath: "/demo/plugin.ts", + baseDir: "/demo", + handlerPath: "/demo/plugin.ts", + }, + frontmatter: {}, + metadata: { events: ["message:received"] }, + invocation: { enabled: true }, + }; +} + +describe("runtime registration helpers", () => { + it("normalizes tool registration metadata", () => { + const tool = { name: "demo-tool" } as AnyAgentTool; + const result = resolveExtensionToolRegistration({ + ownerPluginId: "tool-plugin", + ownerSource: "tool-source", + tool, + opts: { + names: [" demo-tool ", "alias"], + optional: true, + }, + }); + + expect(result).toMatchObject({ + names: ["demo-tool", "alias"], + entry: { + pluginId: "tool-plugin", + names: ["demo-tool", "alias"], + optional: true, + source: "tool-source", + }, + }); + expect(result.entry.factory({} as never)).toBe(tool); + }); + + it("normalizes cli registration metadata", () => { + const registrar = (_ctx: OpenClawPluginCliContext) => {}; + const result = resolveExtensionCliRegistration({ + ownerPluginId: "cli-plugin", + ownerSource: "cli-source", + registrar, + opts: { commands: [" foo ", "bar", "foo"] }, + }); + + expect(result).toEqual({ + commands: ["foo", "bar"], + entry: { + pluginId: "cli-plugin", + register: registrar, + commands: ["foo", "bar"], + source: "cli-source", + }, + }); + }); + + it("normalizes service registrations", () => { + const result = resolveExtensionServiceRegistration({ + ownerPluginId: "service-plugin", + ownerSource: "service-source", + service: createService(" demo-service "), + }); + + expect(result).toMatchObject({ + ok: true, + serviceId: "demo-service", + entry: { + pluginId: "service-plugin", + source: "service-source", + service: { id: "demo-service" }, + }, + }); + }); + + it("rejects service registrations without ids", () => { + const result = resolveExtensionServiceRegistration({ + ownerPluginId: "service-plugin", + ownerSource: "service-source", + service: createService(" "), + }); + + expect(result).toEqual({ + ok: false, + message: "service registration missing id", + }); + }); + + it("normalizes command registrations", () => { + const result = resolveExtensionCommandRegistration({ + ownerPluginId: "command-plugin", + ownerSource: "command-source", + command: createCommand(" demo "), + }); + + expect(result).toMatchObject({ + ok: true, + commandName: "demo", + entry: { + pluginId: "command-plugin", + source: "command-source", + command: { name: "demo" }, + }, + }); + }); + + it("rejects command registrations without names", () => { + const result = resolveExtensionCommandRegistration({ + ownerPluginId: "command-plugin", + ownerSource: "command-source", + command: createCommand(" "), + }); + + expect(result).toEqual({ + ok: false, + message: "command registration missing name", + }); + }); + + it("normalizes context-engine registrations", () => { + const factory = vi.fn() as unknown as ContextEngineFactory; + const result = resolveExtensionContextEngineRegistration({ + engineId: " demo-engine ", + factory, + }); + + expect(result).toEqual({ + ok: true, + entry: { + engineId: "demo-engine", + factory, + }, + }); + }); + + it("rejects context-engine registrations without ids", () => { + const result = resolveExtensionContextEngineRegistration({ + engineId: " ", + factory: vi.fn() as unknown as ContextEngineFactory, + }); + + expect(result).toEqual({ + ok: false, + message: "context engine registration missing id", + }); + }); + + it("normalizes legacy hook registrations", () => { + const handler = vi.fn() as unknown as InternalHookHandler; + const result = resolveExtensionLegacyHookRegistration({ + ownerPluginId: "hook-plugin", + ownerSource: "/plugins/hook.ts", + events: [" message:received ", "message:received", "message:sent"], + handler, + opts: { + name: "demo-hook", + description: "hook description", + } satisfies OpenClawPluginHookOptions, + }); + + expect(result).toMatchObject({ + ok: true, + hookName: "demo-hook", + events: ["message:received", "message:sent"], + entry: { + pluginId: "hook-plugin", + source: "/plugins/hook.ts", + }, + }); + }); + + it("preserves explicit legacy hook entries while normalizing events", () => { + const result = resolveExtensionLegacyHookRegistration({ + ownerPluginId: "hook-plugin", + ownerSource: "/plugins/hook.ts", + events: " message:received ", + handler: vi.fn() as unknown as InternalHookHandler, + opts: { + entry: createLegacyHookEntry("demo-hook"), + }, + }); + + expect(result).toMatchObject({ + ok: true, + hookName: "demo-hook", + events: ["message:received"], + }); + if (result.ok) { + expect(result.entry.entry.hook.pluginId).toBe("hook-plugin"); + expect(result.entry.entry.metadata?.events).toEqual(["message:received"]); + } + }); + + it("rejects legacy hook registrations without names", () => { + const result = resolveExtensionLegacyHookRegistration({ + ownerPluginId: "hook-plugin", + ownerSource: "/plugins/hook.ts", + events: "message:received", + handler: vi.fn() as unknown as InternalHookHandler, + opts: {}, + }); + + expect(result).toEqual({ + ok: false, + message: "hook registration missing name", + }); + }); + + it("normalizes typed hook registrations", () => { + const handler = vi.fn() as PluginHookRegistration<"before_prompt_build">["handler"]; + const result = resolveExtensionTypedHookRegistration({ + ownerPluginId: "typed-hook-plugin", + ownerSource: "/plugins/typed-hook.ts", + hookName: "before_prompt_build", + handler, + priority: 10, + }); + + expect(result).toEqual({ + ok: true, + hookName: "before_prompt_build", + entry: { + pluginId: "typed-hook-plugin", + hookName: "before_prompt_build", + handler, + priority: 10, + source: "/plugins/typed-hook.ts", + }, + }); + }); + + it("rejects unknown typed hook registrations", () => { + const result = resolveExtensionTypedHookRegistration({ + ownerPluginId: "typed-hook-plugin", + ownerSource: "/plugins/typed-hook.ts", + hookName: "totally_unknown_hook_name", + handler: vi.fn() as never, + priority: 10, + }); + + expect(result).toEqual({ + ok: false, + message: 'unknown typed hook "totally_unknown_hook_name" ignored', + }); + }); + + it("normalizes and accepts a unique channel registration", () => { + const result = resolveExtensionChannelRegistration({ + existing: [], + ownerPluginId: "demo-plugin", + ownerSource: "demo-source", + registration: createChannelPlugin("demo-channel"), + }); + + expect(result).toMatchObject({ + ok: true, + channelId: "demo-channel", + entry: { + pluginId: "demo-plugin", + source: "demo-source", + }, + }); + }); + + it("rejects duplicate channel registrations", () => { + const existing: ExtensionHostChannelRegistration[] = [ + { + pluginId: "demo-a", + plugin: createChannelPlugin("demo-channel"), + source: "demo-a-source", + }, + ]; + + const result = resolveExtensionChannelRegistration({ + existing, + ownerPluginId: "demo-b", + ownerSource: "demo-b-source", + registration: createChannelPlugin("demo-channel"), + }); + + expect(result).toEqual({ + ok: false, + message: "channel already registered: demo-channel (demo-a)", + }); + }); + + it("accepts a unique provider registration", () => { + const result = resolveExtensionProviderRegistration({ + existing: [], + ownerPluginId: "provider-plugin", + ownerSource: "provider-source", + provider: createProviderPlugin("demo-provider"), + }); + + expect(result).toMatchObject({ + ok: true, + providerId: "demo-provider", + entry: { + pluginId: "provider-plugin", + source: "provider-source", + }, + }); + }); + + it("rejects duplicate provider registrations", () => { + const existing: ExtensionHostProviderRegistration[] = [ + { + pluginId: "provider-a", + provider: createProviderPlugin("demo-provider"), + source: "provider-a-source", + }, + ]; + + const result = resolveExtensionProviderRegistration({ + existing, + ownerPluginId: "provider-b", + ownerSource: "provider-b-source", + provider: createProviderPlugin("demo-provider"), + }); + + expect(result).toEqual({ + ok: false, + message: "provider already registered: demo-provider (provider-a)", + }); + }); + + it("accepts a unique http route registration", () => { + const result = resolveExtensionHttpRouteRegistration({ + existing: [], + ownerPluginId: "route-plugin", + ownerSource: "route-source", + route: { + path: "/demo", + auth: "plugin", + handler: vi.fn(), + }, + }); + + expect(result).toMatchObject({ + ok: true, + action: "append", + entry: { + pluginId: "route-plugin", + path: "/demo", + auth: "plugin", + match: "exact", + source: "route-source", + }, + }); + }); + + it("rejects conflicting http routes owned by another plugin", () => { + const existing: ExtensionHostHttpRouteRegistration[] = [ + { + pluginId: "route-a", + path: "/demo", + auth: "plugin", + match: "exact", + handler: vi.fn(), + source: "route-a-source", + }, + ]; + + const result = resolveExtensionHttpRouteRegistration({ + existing, + ownerPluginId: "route-b", + ownerSource: "route-b-source", + route: { + path: "/demo", + auth: "plugin", + handler: vi.fn(), + }, + }); + + expect(result).toEqual({ + ok: false, + message: "http route already registered: /demo (exact) by route-a (route-a-source)", + }); + }); + + it("supports same-owner http route replacement", () => { + const existing: ExtensionHostHttpRouteRegistration[] = [ + { + pluginId: "route-plugin", + path: "/demo", + auth: "plugin", + match: "exact", + handler: vi.fn(), + source: "route-source", + }, + ]; + + const result = resolveExtensionHttpRouteRegistration({ + existing, + ownerPluginId: "route-plugin", + ownerSource: "route-source", + route: { + path: "/demo", + auth: "plugin", + replaceExisting: true, + handler: vi.fn(), + }, + }); + + expect(result).toMatchObject({ + ok: true, + action: "replace", + existingIndex: 0, + entry: { + pluginId: "route-plugin", + path: "/demo", + }, + }); + }); + + it("accepts a unique gateway method registration", () => { + const handler = vi.fn(); + const result = resolveExtensionGatewayMethodRegistration({ + existing: {}, + coreGatewayMethods: new Set(["core.method"]), + method: "plugin.method", + handler, + }); + + expect(result).toEqual({ + ok: true, + method: "plugin.method", + handler, + }); + }); + + it("rejects duplicate gateway method registrations", () => { + const result = resolveExtensionGatewayMethodRegistration({ + existing: { + "plugin.method": vi.fn(), + }, + coreGatewayMethods: new Set(["core.method"]), + method: "plugin.method", + handler: vi.fn(), + }); + + expect(result).toEqual({ + ok: false, + message: "gateway method already registered: plugin.method", + }); + }); +}); diff --git a/src/extension-host/contributions/runtime-registrations.ts b/src/extension-host/contributions/runtime-registrations.ts new file mode 100644 index 00000000000..a1107ac3cb6 --- /dev/null +++ b/src/extension-host/contributions/runtime-registrations.ts @@ -0,0 +1,556 @@ +import path from "node:path"; +import type { AnyAgentTool } from "../../agents/tools/common.js"; +import type { ChannelDock } from "../../channels/dock.js"; +import type { ChannelPlugin } from "../../channels/plugins/types.js"; +import type { ContextEngineFactory } from "../../context-engine/registry.js"; +import type { + GatewayRequestHandler, + GatewayRequestHandlers, +} from "../../gateway/server-methods/types.js"; +import type { InternalHookHandler } from "../../hooks/internal-hooks.js"; +import type { HookEntry } from "../../hooks/types.js"; +import { normalizePluginHttpPath } from "../../plugins/http-path.js"; +import { findOverlappingPluginHttpRoute } from "../../plugins/http-route-overlap.js"; +import type { + OpenClawPluginCliRegistrar, + OpenClawPluginCommandDefinition, + OpenClawPluginChannelRegistration, + OpenClawPluginHookOptions, + OpenClawPluginHttpRouteAuth, + OpenClawPluginHttpRouteHandler, + OpenClawPluginHttpRouteMatch, + OpenClawPluginHttpRouteParams, + OpenClawPluginService, + OpenClawPluginToolContext, + OpenClawPluginToolFactory, + PluginHookHandlerMap, + PluginHookName, + PluginHookRegistration, + ProviderPlugin, +} from "../../plugins/types.js"; +import { isPluginHookName } from "../../plugins/types.js"; + +export type ExtensionHostChannelRegistration = { + pluginId: string; + plugin: ChannelPlugin; + dock?: ChannelDock; + source: string; +}; + +export type ExtensionHostProviderRegistration = { + pluginId: string; + provider: ProviderPlugin; + source: string; +}; + +export type ExtensionHostToolRegistration = { + pluginId: string; + factory: OpenClawPluginToolFactory; + names: string[]; + optional: boolean; + source: string; +}; + +export type ExtensionHostCliRegistration = { + pluginId: string; + register: OpenClawPluginCliRegistrar; + commands: string[]; + source: string; +}; + +export type ExtensionHostServiceRegistration = { + pluginId: string; + service: OpenClawPluginService; + source: string; +}; + +export type ExtensionHostCommandRegistration = { + pluginId: string; + command: OpenClawPluginCommandDefinition; + source: string; +}; + +export type ExtensionHostContextEngineRegistration = { + engineId: string; + factory: ContextEngineFactory; +}; + +export type ExtensionHostLegacyHookRegistration = { + pluginId: string; + entry: HookEntry; + events: string[]; + source: string; + handler: InternalHookHandler; +}; + +export type ExtensionHostHttpRouteRegistration = { + pluginId?: string; + path: string; + handler: OpenClawPluginHttpRouteHandler; + auth: OpenClawPluginHttpRouteAuth; + match: OpenClawPluginHttpRouteMatch; + source?: string; +}; + +function normalizeNameList(names: string[]): string[] { + return Array.from(new Set(names.map((name) => name.trim()).filter(Boolean))); +} + +export function resolveExtensionToolRegistration(params: { + ownerPluginId: string; + ownerSource: string; + tool: AnyAgentTool | OpenClawPluginToolFactory; + opts?: { name?: string; names?: string[]; optional?: boolean }; +}): { + names: string[]; + entry: ExtensionHostToolRegistration; +} { + const names = [...(params.opts?.names ?? []), ...(params.opts?.name ? [params.opts.name] : [])]; + if (typeof params.tool !== "function") { + names.push(params.tool.name); + } + const normalizedNames = normalizeNameList(names); + let factory: OpenClawPluginToolFactory; + if (typeof params.tool === "function") { + factory = params.tool; + } else { + const tool = params.tool; + factory = (_ctx: OpenClawPluginToolContext) => tool; + } + + return { + names: normalizedNames, + entry: { + pluginId: params.ownerPluginId, + factory, + names: normalizedNames, + optional: params.opts?.optional === true, + source: params.ownerSource, + }, + }; +} + +export function resolveExtensionCliRegistration(params: { + ownerPluginId: string; + ownerSource: string; + registrar: OpenClawPluginCliRegistrar; + opts?: { commands?: string[] }; +}): { + commands: string[]; + entry: ExtensionHostCliRegistration; +} { + const commands = normalizeNameList(params.opts?.commands ?? []); + return { + commands, + entry: { + pluginId: params.ownerPluginId, + register: params.registrar, + commands, + source: params.ownerSource, + }, + }; +} + +export function resolveExtensionServiceRegistration(params: { + ownerPluginId: string; + ownerSource: string; + service: OpenClawPluginService; +}): + | { + ok: true; + serviceId: string; + entry: ExtensionHostServiceRegistration; + } + | { + ok: false; + message: string; + } { + const serviceId = params.service.id.trim(); + if (!serviceId) { + return { ok: false, message: "service registration missing id" }; + } + return { + ok: true, + serviceId, + entry: { + pluginId: params.ownerPluginId, + service: { + ...params.service, + id: serviceId, + }, + source: params.ownerSource, + }, + }; +} + +export function resolveExtensionCommandRegistration(params: { + ownerPluginId: string; + ownerSource: string; + command: OpenClawPluginCommandDefinition; +}): + | { + ok: true; + commandName: string; + entry: ExtensionHostCommandRegistration; + } + | { + ok: false; + message: string; + } { + const commandName = params.command.name.trim(); + if (!commandName) { + return { ok: false, message: "command registration missing name" }; + } + return { + ok: true, + commandName, + entry: { + pluginId: params.ownerPluginId, + command: { + ...params.command, + name: commandName, + }, + source: params.ownerSource, + }, + }; +} + +export function resolveExtensionContextEngineRegistration(params: { + engineId: string; + factory: ContextEngineFactory; +}): + | { + ok: true; + entry: ExtensionHostContextEngineRegistration; + } + | { + ok: false; + message: string; + } { + const engineId = params.engineId.trim(); + if (!engineId) { + return { ok: false, message: "context engine registration missing id" }; + } + return { + ok: true, + entry: { + engineId, + factory: params.factory, + }, + }; +} + +export function resolveExtensionLegacyHookRegistration(params: { + ownerPluginId: string; + ownerSource: string; + events: string | string[]; + handler: InternalHookHandler; + opts?: OpenClawPluginHookOptions; +}): + | { + ok: true; + hookName: string; + events: string[]; + entry: ExtensionHostLegacyHookRegistration; + } + | { + ok: false; + message: string; + } { + const eventList = Array.isArray(params.events) ? params.events : [params.events]; + const normalizedEvents = normalizeNameList(eventList); + const entry = params.opts?.entry ?? null; + const hookName = entry?.hook.name ?? params.opts?.name?.trim(); + if (!hookName) { + return { ok: false, message: "hook registration missing name" }; + } + + const description = entry?.hook.description ?? params.opts?.description ?? ""; + const hookEntry: HookEntry = entry + ? { + ...entry, + hook: { + ...entry.hook, + name: hookName, + description, + source: "openclaw-plugin", + pluginId: params.ownerPluginId, + }, + metadata: { + ...entry.metadata, + events: normalizedEvents, + }, + } + : { + hook: { + name: hookName, + description, + source: "openclaw-plugin", + pluginId: params.ownerPluginId, + filePath: params.ownerSource, + baseDir: path.dirname(params.ownerSource), + handlerPath: params.ownerSource, + }, + frontmatter: {}, + metadata: { events: normalizedEvents }, + invocation: { enabled: true }, + }; + + return { + ok: true, + hookName, + events: normalizedEvents, + entry: { + pluginId: params.ownerPluginId, + entry: hookEntry, + events: normalizedEvents, + source: params.ownerSource, + handler: params.handler, + }, + }; +} + +export function resolveExtensionTypedHookRegistration(params: { + ownerPluginId: string; + ownerSource: string; + hookName: unknown; + handler: PluginHookHandlerMap[K]; + priority?: number; +}): + | { + ok: true; + hookName: K; + entry: PluginHookRegistration; + } + | { + ok: false; + message: string; + } { + if (!isPluginHookName(params.hookName)) { + return { + ok: false, + message: `unknown typed hook "${String(params.hookName)}" ignored`, + }; + } + return { + ok: true, + hookName: params.hookName as K, + entry: { + pluginId: params.ownerPluginId, + hookName: params.hookName as K, + handler: params.handler, + priority: params.priority, + source: params.ownerSource, + }, + }; +} + +export function resolveExtensionGatewayMethodRegistration(params: { + existing: GatewayRequestHandlers; + coreGatewayMethods: ReadonlySet; + method: string; + handler: GatewayRequestHandler; +}): + | { + ok: true; + method: string; + handler: GatewayRequestHandler; + } + | { + ok: false; + message: string; + } { + const method = params.method.trim(); + if (!method) { + return { ok: false, message: "gateway method registration missing name" }; + } + if (params.coreGatewayMethods.has(method) || params.existing[method]) { + return { + ok: false, + message: `gateway method already registered: ${method}`, + }; + } + return { + ok: true, + method, + handler: params.handler, + }; +} + +function normalizeChannelRegistration( + registration: OpenClawPluginChannelRegistration | ChannelPlugin, +): { plugin: ChannelPlugin; dock?: ChannelDock } { + return typeof (registration as OpenClawPluginChannelRegistration).plugin === "object" + ? (registration as OpenClawPluginChannelRegistration) + : { plugin: registration as ChannelPlugin }; +} + +export function resolveExtensionChannelRegistration(params: { + existing: ExtensionHostChannelRegistration[]; + ownerPluginId: string; + ownerSource: string; + registration: OpenClawPluginChannelRegistration | ChannelPlugin; +}): + | { + ok: true; + channelId: string; + entry: ExtensionHostChannelRegistration; + } + | { + ok: false; + message: string; + } { + const normalized = normalizeChannelRegistration(params.registration); + const plugin = normalized.plugin; + const channelId = + typeof plugin?.id === "string" ? plugin.id.trim() : String(plugin?.id ?? "").trim(); + if (!channelId) { + return { ok: false, message: "channel registration missing id" }; + } + const existing = params.existing.find((entry) => entry.plugin.id === channelId); + if (existing) { + return { + ok: false, + message: `channel already registered: ${channelId} (${existing.pluginId})`, + }; + } + return { + ok: true, + channelId, + entry: { + pluginId: params.ownerPluginId, + plugin, + dock: normalized.dock, + source: params.ownerSource, + }, + }; +} + +export function resolveExtensionProviderRegistration(params: { + existing: ExtensionHostProviderRegistration[]; + ownerPluginId: string; + ownerSource: string; + provider: ProviderPlugin; +}): + | { + ok: true; + providerId: string; + entry: ExtensionHostProviderRegistration; + } + | { + ok: false; + message: string; + } { + const providerId = params.provider.id; + const existing = params.existing.find((entry) => entry.provider.id === providerId); + if (existing) { + return { + ok: false, + message: `provider already registered: ${providerId} (${existing.pluginId})`, + }; + } + return { + ok: true, + providerId, + entry: { + pluginId: params.ownerPluginId, + provider: params.provider, + source: params.ownerSource, + }, + }; +} + +function describeHttpRouteOwner(entry: ExtensionHostHttpRouteRegistration): string { + const plugin = entry.pluginId?.trim() || "unknown-plugin"; + const source = entry.source?.trim() || "unknown-source"; + return `${plugin} (${source})`; +} + +export function resolveExtensionHttpRouteRegistration(params: { + existing: ExtensionHostHttpRouteRegistration[]; + ownerPluginId: string; + ownerSource: string; + route: OpenClawPluginHttpRouteParams; +}): + | { + ok: true; + action: "append" | "replace"; + entry: ExtensionHostHttpRouteRegistration; + existingIndex?: number; + } + | { + ok: false; + message: string; + } { + const normalizedPath = normalizePluginHttpPath(params.route.path); + if (!normalizedPath) { + return { ok: false, message: "http route registration missing path" }; + } + if (params.route.auth !== "gateway" && params.route.auth !== "plugin") { + return { + ok: false, + message: `http route registration missing or invalid auth: ${normalizedPath}`, + }; + } + + const match = params.route.match ?? "exact"; + const overlappingRoute = findOverlappingPluginHttpRoute(params.existing, { + path: normalizedPath, + match, + }); + if (overlappingRoute && overlappingRoute.auth !== params.route.auth) { + return { + ok: false, + message: + `http route overlap rejected: ${normalizedPath} (${match}, ${params.route.auth}) ` + + `overlaps ${overlappingRoute.path} (${overlappingRoute.match}, ${overlappingRoute.auth}) ` + + `owned by ${describeHttpRouteOwner(overlappingRoute)}`, + }; + } + + const existingIndex = params.existing.findIndex( + (entry) => entry.path === normalizedPath && entry.match === match, + ); + const nextEntry: ExtensionHostHttpRouteRegistration = { + pluginId: params.ownerPluginId, + path: normalizedPath, + handler: params.route.handler, + auth: params.route.auth, + match, + source: params.ownerSource, + }; + + if (existingIndex >= 0) { + const existing = params.existing[existingIndex]; + if (!existing) { + return { + ok: false, + message: `http route registration missing existing route: ${normalizedPath}`, + }; + } + if (!params.route.replaceExisting) { + return { + ok: false, + message: `http route already registered: ${normalizedPath} (${match}) by ${describeHttpRouteOwner(existing)}`, + }; + } + if (existing.pluginId && existing.pluginId !== params.ownerPluginId) { + return { + ok: false, + message: `http route replacement rejected: ${normalizedPath} (${match}) owned by ${describeHttpRouteOwner(existing)}`, + }; + } + return { + ok: true, + action: "replace", + existingIndex, + entry: nextEntry, + }; + } + + return { + ok: true, + action: "append", + entry: nextEntry, + }; +} diff --git a/src/extension-host/contributions/runtime-registry.test.ts b/src/extension-host/contributions/runtime-registry.test.ts new file mode 100644 index 00000000000..2d1e982f7fe --- /dev/null +++ b/src/extension-host/contributions/runtime-registry.test.ts @@ -0,0 +1,343 @@ +import { describe, expect, it, vi } from "vitest"; +import { createEmptyPluginRegistry } from "../../plugins/registry.js"; +import { + addExtensionHostChannelRegistration, + addExtensionHostCliRegistration, + addExtensionHostCommandRegistration, + addExtensionHostHttpRoute, + addExtensionHostProviderRegistration, + addExtensionHostServiceRegistration, + addExtensionHostToolRegistration, + getExtensionHostGatewayHandlers, + hasExtensionHostRuntimeEntries, + listExtensionHostChannelRegistrations, + listExtensionHostCliRegistrations, + listExtensionHostCommandRegistrations, + listExtensionHostHttpRoutes, + listExtensionHostProviderRegistrations, + listExtensionHostServiceRegistrations, + listExtensionHostToolRegistrations, + removeExtensionHostHttpRoute, + replaceExtensionHostHttpRoute, + setExtensionHostGatewayHandler, +} from "./runtime-registry.js"; + +describe("extension host runtime registry accessors", () => { + it("detects runtime entries across non-tool surfaces", () => { + const providerRegistry = createEmptyPluginRegistry(); + addExtensionHostProviderRegistration(providerRegistry, { + pluginId: "provider-demo", + source: "test", + provider: { + id: "provider-demo", + label: "Provider Demo", + auth: [], + }, + }); + expect(hasExtensionHostRuntimeEntries(providerRegistry)).toBe(true); + + const routeRegistry = createEmptyPluginRegistry(); + addExtensionHostHttpRoute(routeRegistry, { + path: "/plugins/demo", + handler: vi.fn(), + auth: "plugin", + match: "exact", + pluginId: "route-demo", + source: "test", + }); + expect(hasExtensionHostRuntimeEntries(routeRegistry)).toBe(true); + + const channelRegistry = createEmptyPluginRegistry(); + addExtensionHostChannelRegistration(channelRegistry, { + pluginId: "channel-demo", + source: "test", + plugin: { + id: "channel-demo", + meta: { + id: "channel-demo", + label: "Channel Demo", + selectionLabel: "Channel Demo", + docsPath: "/channels/channel-demo", + blurb: "demo", + }, + capabilities: { chatTypes: ["direct"] }, + config: { + listAccountIds: () => [], + resolveAccount: () => ({}), + }, + }, + }); + expect(hasExtensionHostRuntimeEntries(channelRegistry)).toBe(true); + + const gatewayRegistry = createEmptyPluginRegistry(); + setExtensionHostGatewayHandler({ + registry: gatewayRegistry, + method: "demo.echo", + handler: vi.fn(), + }); + expect(hasExtensionHostRuntimeEntries(gatewayRegistry)).toBe(true); + + const cliRegistry = createEmptyPluginRegistry(); + addExtensionHostCliRegistration(cliRegistry, { + pluginId: "cli-demo", + source: "test", + commands: ["demo"], + register: () => undefined, + }); + expect(hasExtensionHostRuntimeEntries(cliRegistry)).toBe(true); + + const commandRegistry = createEmptyPluginRegistry(); + addExtensionHostCommandRegistration(commandRegistry, { + pluginId: "cmd-demo", + source: "test", + command: { + name: "demo", + description: "Demo command", + handler: async () => ({ text: "ok" }), + }, + }); + expect(hasExtensionHostRuntimeEntries(commandRegistry)).toBe(true); + + const serviceRegistry = createEmptyPluginRegistry(); + addExtensionHostServiceRegistration(serviceRegistry, { + pluginId: "svc-demo", + source: "test", + service: { + id: "svc-demo", + start: () => undefined, + }, + }); + expect(hasExtensionHostRuntimeEntries(serviceRegistry)).toBe(true); + }); + + it("returns stable empty views for missing registries", () => { + expect(hasExtensionHostRuntimeEntries(null)).toBe(false); + expect(listExtensionHostProviderRegistrations(null)).toEqual([]); + expect(listExtensionHostChannelRegistrations(null)).toEqual([]); + expect(listExtensionHostToolRegistrations(null)).toEqual([]); + expect(listExtensionHostServiceRegistrations(null)).toEqual([]); + expect(listExtensionHostCliRegistrations(null)).toEqual([]); + expect(listExtensionHostCommandRegistrations(null)).toEqual([]); + expect(listExtensionHostHttpRoutes(null)).toEqual([]); + expect(getExtensionHostGatewayHandlers(null)).toEqual({}); + }); + + it("projects existing registry collections without copying them", () => { + const registry = createEmptyPluginRegistry(); + addExtensionHostToolRegistration(registry, { + pluginId: "tool-demo", + optional: false, + source: "test", + names: ["tool_demo"], + factory: () => ({ + name: "tool_demo", + description: "tool demo", + parameters: { type: "object", properties: {} }, + async execute() { + return { content: [{ type: "text", text: "ok" }] }; + }, + }), + }); + addExtensionHostProviderRegistration(registry, { + pluginId: "provider-demo", + source: "test", + provider: { + id: "provider-demo", + label: "Provider Demo", + auth: [], + }, + }); + addExtensionHostServiceRegistration(registry, { + pluginId: "svc-demo", + source: "test", + service: { + id: "svc-demo", + start: () => undefined, + }, + }); + addExtensionHostCliRegistration(registry, { + pluginId: "cli-demo", + source: "test", + commands: ["demo"], + register: () => undefined, + }); + addExtensionHostCommandRegistration(registry, { + pluginId: "cmd-demo", + source: "test", + command: { + name: "demo", + description: "Demo command", + handler: async () => ({ text: "ok" }), + }, + }); + addExtensionHostHttpRoute(registry, { + path: "/plugins/demo", + handler: vi.fn(), + auth: "plugin", + match: "exact", + pluginId: "route-demo", + source: "test", + }); + const handler = vi.fn(); + setExtensionHostGatewayHandler({ + registry, + method: "demo.echo", + handler, + }); + + addExtensionHostChannelRegistration(registry, { + pluginId: "channel-demo", + source: "test", + plugin: { + id: "channel-demo", + meta: { + id: "channel-demo", + label: "Channel Demo", + selectionLabel: "Channel Demo", + docsPath: "/channels/channel-demo", + blurb: "demo", + }, + capabilities: { chatTypes: ["direct"] }, + config: { + listAccountIds: () => [], + resolveAccount: () => ({}), + }, + }, + }); + + expect(listExtensionHostChannelRegistrations(registry)).toEqual(registry.channels); + expect(listExtensionHostToolRegistrations(registry)).toEqual(registry.tools); + expect(listExtensionHostProviderRegistrations(registry)).toEqual(registry.providers); + expect(listExtensionHostServiceRegistrations(registry)).toEqual(registry.services); + expect(listExtensionHostCliRegistrations(registry)).toEqual(registry.cliRegistrars); + expect(listExtensionHostCommandRegistrations(registry)).toEqual(registry.commands); + expect(listExtensionHostHttpRoutes(registry)).toEqual(registry.httpRoutes); + expect(getExtensionHostGatewayHandlers(registry)).toEqual(registry.gatewayHandlers); + expect(getExtensionHostGatewayHandlers(registry)["demo.echo"]).toBe(handler); + }); + + it("keeps legacy route and gateway mirrors synchronized with host-owned state", () => { + const registry = createEmptyPluginRegistry(); + const firstHandler = vi.fn(); + const secondHandler = vi.fn(); + const entry = { + path: "/plugins/demo", + handler: firstHandler, + auth: "plugin" as const, + match: "exact" as const, + pluginId: "route-demo", + source: "test", + }; + + addExtensionHostHttpRoute(registry, entry); + setExtensionHostGatewayHandler({ + registry, + method: "demo.echo", + handler: firstHandler, + }); + replaceExtensionHostHttpRoute({ + registry, + index: 0, + entry: { ...entry, handler: secondHandler }, + }); + removeExtensionHostHttpRoute(registry, entry); + + expect(registry.httpRoutes).toHaveLength(1); + expect(registry.httpRoutes[0]?.handler).toBe(secondHandler); + expect(getExtensionHostGatewayHandlers(registry)).toEqual(registry.gatewayHandlers); + }); + + it("keeps legacy CLI and service mirrors synchronized with host-owned state", () => { + const registry = createEmptyPluginRegistry(); + const service = { + id: "svc-demo", + start: () => undefined, + }; + const register = () => undefined; + const command = { + name: "demo", + description: "Demo command", + handler: async () => ({ text: "ok" }), + }; + + addExtensionHostServiceRegistration(registry, { + pluginId: "svc-demo", + source: "test", + service, + }); + addExtensionHostCliRegistration(registry, { + pluginId: "cli-demo", + source: "test", + commands: ["demo"], + register, + }); + addExtensionHostCommandRegistration(registry, { + pluginId: "cmd-demo", + source: "test", + command, + }); + + expect(listExtensionHostServiceRegistrations(registry)).toEqual(registry.services); + expect(listExtensionHostCliRegistrations(registry)).toEqual(registry.cliRegistrars); + expect(listExtensionHostCommandRegistrations(registry)).toEqual(registry.commands); + expect(registry.services[0]?.service).toBe(service); + expect(registry.cliRegistrars[0]?.register).toBe(register); + expect(registry.commands[0]?.command).toBe(command); + }); + + it("keeps legacy tool and provider mirrors synchronized with host-owned state", () => { + const registry = createEmptyPluginRegistry(); + const factory = (() => ({}) as never) as never; + const provider = { + id: "provider-demo", + label: "Provider Demo", + auth: [], + }; + + addExtensionHostToolRegistration(registry, { + pluginId: "tool-demo", + optional: false, + source: "test", + names: ["tool_demo"], + factory, + }); + addExtensionHostProviderRegistration(registry, { + pluginId: "provider-demo", + source: "test", + provider, + }); + + expect(listExtensionHostToolRegistrations(registry)).toEqual(registry.tools); + expect(listExtensionHostProviderRegistrations(registry)).toEqual(registry.providers); + expect(registry.tools[0]?.factory).toBe(factory); + expect(registry.providers[0]?.provider).toBe(provider); + }); + + it("keeps legacy channel mirrors synchronized with host-owned state", () => { + const registry = createEmptyPluginRegistry(); + const plugin = { + id: "channel-demo", + meta: { + id: "channel-demo", + label: "Channel Demo", + selectionLabel: "Channel Demo", + docsPath: "/channels/channel-demo", + blurb: "demo", + }, + capabilities: { chatTypes: ["direct"] as const }, + config: { + listAccountIds: () => [], + resolveAccount: () => ({}), + }, + }; + + addExtensionHostChannelRegistration(registry, { + pluginId: "channel-demo", + source: "test", + plugin, + }); + + expect(listExtensionHostChannelRegistrations(registry)).toEqual(registry.channels); + expect(registry.channels[0]?.plugin).toBe(plugin); + }); +}); diff --git a/src/extension-host/contributions/runtime-registry.ts b/src/extension-host/contributions/runtime-registry.ts new file mode 100644 index 00000000000..919ab3823d9 --- /dev/null +++ b/src/extension-host/contributions/runtime-registry.ts @@ -0,0 +1,604 @@ +import type { GatewayRequestHandlers } from "../../gateway/server-methods/types.js"; +import type { + PluginChannelRegistration, + PluginCliRegistration, + PluginCommandRegistration, + PluginHttpRouteRegistration, + PluginProviderRegistration, + PluginRegistry, + PluginServiceRegistration, + PluginToolRegistration, +} from "../../plugins/registry.js"; + +const EMPTY_PROVIDERS: readonly PluginProviderRegistration[] = []; +const EMPTY_TOOLS: readonly PluginToolRegistration[] = []; +const EMPTY_CHANNELS: readonly PluginChannelRegistration[] = []; +const EMPTY_SERVICES: readonly PluginServiceRegistration[] = []; +const EMPTY_CLI_REGISTRARS: readonly PluginCliRegistration[] = []; +const EMPTY_COMMANDS: readonly PluginCommandRegistration[] = []; +const EMPTY_HTTP_ROUTES: readonly PluginHttpRouteRegistration[] = []; +const EMPTY_GATEWAY_HANDLERS: Readonly = Object.freeze({}); +const EXTENSION_HOST_RUNTIME_REGISTRY_STATE = Symbol.for("openclaw.extensionHostRuntimeRegistry"); + +type ExtensionHostRuntimeRegistryState = { + channels: PluginChannelRegistration[]; + legacyChannels: PluginChannelRegistration[]; + tools: PluginToolRegistration[]; + legacyTools: PluginToolRegistration[]; + providers: PluginProviderRegistration[]; + legacyProviders: PluginProviderRegistration[]; + cliRegistrars: PluginCliRegistration[]; + legacyCliRegistrars: PluginCliRegistration[]; + commands: PluginCommandRegistration[]; + legacyCommands: PluginCommandRegistration[]; + services: PluginServiceRegistration[]; + legacyServices: PluginServiceRegistration[]; + httpRoutes: PluginHttpRouteRegistration[]; + legacyHttpRoutes: PluginHttpRouteRegistration[]; + gatewayHandlers: GatewayRequestHandlers; + legacyGatewayHandlers: GatewayRequestHandlers; +}; + +type RuntimeRegistryBackedPluginRegistry = Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" +> & { + [EXTENSION_HOST_RUNTIME_REGISTRY_STATE]?: ExtensionHostRuntimeRegistryState; +}; + +function ensureExtensionHostRuntimeRegistryState( + registry: RuntimeRegistryBackedPluginRegistry, +): ExtensionHostRuntimeRegistryState { + const existing = registry[EXTENSION_HOST_RUNTIME_REGISTRY_STATE]; + if (existing) { + if (registry.channels !== existing.legacyChannels) { + existing.legacyChannels = registry.channels ?? []; + existing.channels = [...existing.legacyChannels]; + } + if (registry.tools !== existing.legacyTools) { + existing.legacyTools = registry.tools ?? []; + existing.tools = [...existing.legacyTools]; + } + if (registry.providers !== existing.legacyProviders) { + existing.legacyProviders = registry.providers ?? []; + existing.providers = [...existing.legacyProviders]; + } + if (registry.cliRegistrars !== existing.legacyCliRegistrars) { + existing.legacyCliRegistrars = registry.cliRegistrars ?? []; + existing.cliRegistrars = [...existing.legacyCliRegistrars]; + } + if (registry.commands !== existing.legacyCommands) { + existing.legacyCommands = registry.commands ?? []; + existing.commands = [...existing.legacyCommands]; + } + if (registry.services !== existing.legacyServices) { + existing.legacyServices = registry.services ?? []; + existing.services = [...existing.legacyServices]; + } + if (registry.httpRoutes !== existing.legacyHttpRoutes) { + existing.legacyHttpRoutes = registry.httpRoutes ?? []; + existing.httpRoutes = [...existing.legacyHttpRoutes]; + } + if (registry.gatewayHandlers !== existing.legacyGatewayHandlers) { + existing.legacyGatewayHandlers = registry.gatewayHandlers ?? {}; + existing.gatewayHandlers = { ...existing.legacyGatewayHandlers }; + } + return existing; + } + + const legacyHttpRoutes = registry.httpRoutes ?? []; + registry.httpRoutes = legacyHttpRoutes; + const legacyGatewayHandlers = registry.gatewayHandlers ?? {}; + registry.gatewayHandlers = legacyGatewayHandlers; + const legacyCliRegistrars = registry.cliRegistrars ?? []; + registry.cliRegistrars = legacyCliRegistrars; + const legacyCommands = registry.commands ?? []; + registry.commands = legacyCommands; + const legacyServices = registry.services ?? []; + registry.services = legacyServices; + const legacyChannels = registry.channels ?? []; + registry.channels = legacyChannels; + const legacyTools = registry.tools ?? []; + registry.tools = legacyTools; + const legacyProviders = registry.providers ?? []; + registry.providers = legacyProviders; + + const state: ExtensionHostRuntimeRegistryState = { + channels: [...legacyChannels], + legacyChannels, + tools: [...legacyTools], + legacyTools, + providers: [...legacyProviders], + legacyProviders, + cliRegistrars: [...legacyCliRegistrars], + legacyCliRegistrars, + commands: [...legacyCommands], + legacyCommands, + services: [...legacyServices], + legacyServices, + httpRoutes: [...legacyHttpRoutes], + legacyHttpRoutes, + gatewayHandlers: { ...legacyGatewayHandlers }, + legacyGatewayHandlers, + }; + registry[EXTENSION_HOST_RUNTIME_REGISTRY_STATE] = state; + return state; +} + +function syncLegacyChannels(state: ExtensionHostRuntimeRegistryState): void { + state.legacyChannels.splice(0, state.legacyChannels.length, ...state.channels); +} + +function syncLegacyTools(state: ExtensionHostRuntimeRegistryState): void { + state.legacyTools.splice(0, state.legacyTools.length, ...state.tools); +} + +function syncLegacyProviders(state: ExtensionHostRuntimeRegistryState): void { + state.legacyProviders.splice(0, state.legacyProviders.length, ...state.providers); +} + +function syncLegacyCliRegistrars(state: ExtensionHostRuntimeRegistryState): void { + state.legacyCliRegistrars.splice(0, state.legacyCliRegistrars.length, ...state.cliRegistrars); +} + +function syncLegacyCommands(state: ExtensionHostRuntimeRegistryState): void { + state.legacyCommands.splice(0, state.legacyCommands.length, ...state.commands); +} + +function syncLegacyServices(state: ExtensionHostRuntimeRegistryState): void { + state.legacyServices.splice(0, state.legacyServices.length, ...state.services); +} + +function syncLegacyHttpRoutes(state: ExtensionHostRuntimeRegistryState): void { + state.legacyHttpRoutes.splice(0, state.legacyHttpRoutes.length, ...state.httpRoutes); +} + +function syncLegacyGatewayHandlers(state: ExtensionHostRuntimeRegistryState): void { + for (const key of Object.keys(state.legacyGatewayHandlers)) { + if (!(key in state.gatewayHandlers)) { + delete state.legacyGatewayHandlers[key]; + } + } + Object.assign(state.legacyGatewayHandlers, state.gatewayHandlers); +} + +export function hasExtensionHostRuntimeEntries( + registry: + | Pick< + PluginRegistry, + | "plugins" + | "channels" + | "tools" + | "providers" + | "gatewayHandlers" + | "httpRoutes" + | "cliRegistrars" + | "services" + | "commands" + | "hooks" + | "typedHooks" + > + | null + | undefined, +): boolean { + if (!registry) { + return false; + } + return ( + registry.plugins.length > 0 || + listExtensionHostChannelRegistrations(registry).length > 0 || + listExtensionHostToolRegistrations(registry).length > 0 || + listExtensionHostProviderRegistrations(registry).length > 0 || + Object.keys(getExtensionHostGatewayHandlers(registry)).length > 0 || + listExtensionHostHttpRoutes(registry).length > 0 || + listExtensionHostCliRegistrations(registry).length > 0 || + listExtensionHostCommandRegistrations(registry).length > 0 || + listExtensionHostServiceRegistrations(registry).length > 0 || + registry.hooks.length > 0 || + registry.typedHooks.length > 0 + ); +} + +export function listExtensionHostProviderRegistrations( + registry: + | Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + > + | null + | undefined, +): readonly PluginProviderRegistration[] { + if (!registry) { + return EMPTY_PROVIDERS; + } + return ensureExtensionHostRuntimeRegistryState(registry as RuntimeRegistryBackedPluginRegistry) + .providers; +} + +export function listExtensionHostToolRegistrations( + registry: + | Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + > + | null + | undefined, +): readonly PluginToolRegistration[] { + if (!registry) { + return EMPTY_TOOLS; + } + return ensureExtensionHostRuntimeRegistryState(registry as RuntimeRegistryBackedPluginRegistry) + .tools; +} + +export function listExtensionHostChannelRegistrations( + registry: + | Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "services" + | "httpRoutes" + | "gatewayHandlers" + > + | null + | undefined, +): readonly PluginChannelRegistration[] { + if (!registry) { + return EMPTY_CHANNELS; + } + return ensureExtensionHostRuntimeRegistryState(registry as RuntimeRegistryBackedPluginRegistry) + .channels; +} + +export function listExtensionHostServiceRegistrations( + registry: + | Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "services" + | "httpRoutes" + | "gatewayHandlers" + > + | null + | undefined, +): readonly PluginServiceRegistration[] { + if (!registry) { + return EMPTY_SERVICES; + } + return ensureExtensionHostRuntimeRegistryState(registry as RuntimeRegistryBackedPluginRegistry) + .services; +} + +export function listExtensionHostCliRegistrations( + registry: + | Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + > + | null + | undefined, +): readonly PluginCliRegistration[] { + if (!registry) { + return EMPTY_CLI_REGISTRARS; + } + return ensureExtensionHostRuntimeRegistryState(registry as RuntimeRegistryBackedPluginRegistry) + .cliRegistrars; +} + +export function listExtensionHostCommandRegistrations( + registry: + | Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + > + | null + | undefined, +): readonly PluginCommandRegistration[] { + if (!registry) { + return EMPTY_COMMANDS; + } + return ensureExtensionHostRuntimeRegistryState(registry as RuntimeRegistryBackedPluginRegistry) + .commands; +} + +export function listExtensionHostHttpRoutes( + registry: + | Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + > + | null + | undefined, +): readonly PluginHttpRouteRegistration[] { + if (!registry) { + return EMPTY_HTTP_ROUTES; + } + return ensureExtensionHostRuntimeRegistryState(registry as RuntimeRegistryBackedPluginRegistry) + .httpRoutes; +} + +export function getExtensionHostGatewayHandlers( + registry: + | Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + > + | null + | undefined, +): Readonly { + if (!registry) { + return EMPTY_GATEWAY_HANDLERS; + } + return ensureExtensionHostRuntimeRegistryState(registry as RuntimeRegistryBackedPluginRegistry) + .gatewayHandlers; +} + +export function addExtensionHostHttpRoute( + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >, + entry: PluginHttpRouteRegistration, +): void { + const state = ensureExtensionHostRuntimeRegistryState( + registry as RuntimeRegistryBackedPluginRegistry, + ); + state.httpRoutes.push(entry); + syncLegacyHttpRoutes(state); +} + +export function replaceExtensionHostHttpRoute(params: { + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >; + index: number; + entry: PluginHttpRouteRegistration; +}): void { + const state = ensureExtensionHostRuntimeRegistryState( + params.registry as RuntimeRegistryBackedPluginRegistry, + ); + state.httpRoutes[params.index] = params.entry; + syncLegacyHttpRoutes(state); +} + +export function removeExtensionHostHttpRoute( + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >, + entry: PluginHttpRouteRegistration, +): void { + const state = ensureExtensionHostRuntimeRegistryState( + registry as RuntimeRegistryBackedPluginRegistry, + ); + const index = state.httpRoutes.indexOf(entry); + if (index < 0) { + return; + } + state.httpRoutes.splice(index, 1); + syncLegacyHttpRoutes(state); +} + +export function setExtensionHostGatewayHandler(params: { + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >; + method: string; + handler: GatewayRequestHandlers[string]; +}): void { + const state = ensureExtensionHostRuntimeRegistryState( + params.registry as RuntimeRegistryBackedPluginRegistry, + ); + state.gatewayHandlers[params.method] = params.handler; + syncLegacyGatewayHandlers(state); +} + +export function addExtensionHostCliRegistration( + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >, + entry: PluginCliRegistration, +): void { + const state = ensureExtensionHostRuntimeRegistryState( + registry as RuntimeRegistryBackedPluginRegistry, + ); + state.cliRegistrars.push(entry); + syncLegacyCliRegistrars(state); +} + +export function addExtensionHostCommandRegistration( + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >, + entry: PluginCommandRegistration, +): void { + const state = ensureExtensionHostRuntimeRegistryState( + registry as RuntimeRegistryBackedPluginRegistry, + ); + state.commands.push(entry); + syncLegacyCommands(state); +} + +export function addExtensionHostServiceRegistration( + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >, + entry: PluginServiceRegistration, +): void { + const state = ensureExtensionHostRuntimeRegistryState( + registry as RuntimeRegistryBackedPluginRegistry, + ); + state.services.push(entry); + syncLegacyServices(state); +} + +export function addExtensionHostToolRegistration( + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >, + entry: PluginToolRegistration, +): void { + const state = ensureExtensionHostRuntimeRegistryState( + registry as RuntimeRegistryBackedPluginRegistry, + ); + state.tools.push(entry); + syncLegacyTools(state); +} + +export function addExtensionHostProviderRegistration( + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >, + entry: PluginProviderRegistration, +): void { + const state = ensureExtensionHostRuntimeRegistryState( + registry as RuntimeRegistryBackedPluginRegistry, + ); + state.providers.push(entry); + syncLegacyProviders(state); +} + +export function addExtensionHostChannelRegistration( + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "services" + | "httpRoutes" + | "gatewayHandlers" + >, + entry: PluginChannelRegistration, +): void { + const state = ensureExtensionHostRuntimeRegistryState( + registry as RuntimeRegistryBackedPluginRegistry, + ); + state.channels.push(entry); + syncLegacyChannels(state); +} diff --git a/src/extension-host/contributions/service-lifecycle.test.ts b/src/extension-host/contributions/service-lifecycle.test.ts new file mode 100644 index 00000000000..c7908d9bf33 --- /dev/null +++ b/src/extension-host/contributions/service-lifecycle.test.ts @@ -0,0 +1,127 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { createEmptyPluginRegistry } from "../../plugins/registry.js"; +import type { OpenClawPluginService, OpenClawPluginServiceContext } from "../../plugins/types.js"; + +const mockedLogger = vi.hoisted(() => ({ + info: vi.fn<(msg: string) => void>(), + warn: vi.fn<(msg: string) => void>(), + error: vi.fn<(msg: string) => void>(), + debug: vi.fn<(msg: string) => void>(), +})); + +vi.mock("../logging/subsystem.js", () => ({ + createSubsystemLogger: () => mockedLogger, +})); + +import { STATE_DIR } from "../../config/paths.js"; +import { startExtensionHostServices } from "./service-lifecycle.js"; + +function createRegistry(services: OpenClawPluginService[]) { + const registry = createEmptyPluginRegistry(); + for (const service of services) { + registry.services.push({ pluginId: "plugin:test", service, source: "test" }); + } + return registry; +} + +describe("startExtensionHostServices", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("starts services and stops them in reverse order", async () => { + const starts: string[] = []; + const stops: string[] = []; + const contexts: OpenClawPluginServiceContext[] = []; + + const serviceA: OpenClawPluginService = { + id: "service-a", + start: (ctx) => { + starts.push("a"); + contexts.push(ctx); + }, + stop: () => { + stops.push("a"); + }, + }; + const serviceB: OpenClawPluginService = { + id: "service-b", + start: (ctx) => { + starts.push("b"); + contexts.push(ctx); + }, + }; + const serviceC: OpenClawPluginService = { + id: "service-c", + start: (ctx) => { + starts.push("c"); + contexts.push(ctx); + }, + stop: () => { + stops.push("c"); + }, + }; + + const config = {} as Parameters[0]["config"]; + const handle = await startExtensionHostServices({ + registry: createRegistry([serviceA, serviceB, serviceC]), + config, + workspaceDir: "/tmp/workspace", + }); + await handle.stop(); + + expect(starts).toEqual(["a", "b", "c"]); + expect(stops).toEqual(["c", "a"]); + expect(contexts).toHaveLength(3); + for (const ctx of contexts) { + expect(ctx.config).toBe(config); + expect(ctx.workspaceDir).toBe("/tmp/workspace"); + expect(ctx.stateDir).toBe(STATE_DIR); + expect(ctx.logger).toBeDefined(); + expect(typeof ctx.logger.info).toBe("function"); + expect(typeof ctx.logger.warn).toBe("function"); + expect(typeof ctx.logger.error).toBe("function"); + } + }); + + it("logs start and stop failures and continues", async () => { + const stopOk = vi.fn(); + const stopThrows = vi.fn(() => { + throw new Error("stop failed"); + }); + + const handle = await startExtensionHostServices({ + registry: createRegistry([ + { + id: "service-start-fail", + start: () => { + throw new Error("start failed"); + }, + stop: vi.fn(), + }, + { + id: "service-ok", + start: () => undefined, + stop: stopOk, + }, + { + id: "service-stop-fail", + start: () => undefined, + stop: stopThrows, + }, + ]), + config: {} as Parameters[0]["config"], + }); + + await handle.stop(); + + expect(mockedLogger.error).toHaveBeenCalledWith( + expect.stringContaining("plugin service failed (service-start-fail):"), + ); + expect(mockedLogger.warn).toHaveBeenCalledWith( + expect.stringContaining("plugin service stop failed (service-stop-fail):"), + ); + expect(stopOk).toHaveBeenCalledOnce(); + expect(stopThrows).toHaveBeenCalledOnce(); + }); +}); diff --git a/src/extension-host/contributions/service-lifecycle.ts b/src/extension-host/contributions/service-lifecycle.ts new file mode 100644 index 00000000000..3c6d1ccf101 --- /dev/null +++ b/src/extension-host/contributions/service-lifecycle.ts @@ -0,0 +1,76 @@ +import type { OpenClawConfig } from "../../config/config.js"; +import { STATE_DIR } from "../../config/paths.js"; +import { createSubsystemLogger } from "../../logging/subsystem.js"; +import type { PluginRegistry } from "../../plugins/registry.js"; +import type { OpenClawPluginServiceContext, PluginLogger } from "../../plugins/types.js"; +import { listExtensionHostServiceRegistrations } from "./runtime-registry.js"; + +const log = createSubsystemLogger("plugins"); + +function createExtensionHostServiceLogger(): PluginLogger { + return { + info: (msg) => log.info(msg), + warn: (msg) => log.warn(msg), + error: (msg) => log.error(msg), + debug: (msg) => log.debug(msg), + }; +} + +function createExtensionHostServiceContext(params: { + config: OpenClawConfig; + workspaceDir?: string; +}): OpenClawPluginServiceContext { + return { + config: params.config, + workspaceDir: params.workspaceDir, + stateDir: STATE_DIR, + logger: createExtensionHostServiceLogger(), + }; +} + +export type ExtensionHostServicesHandle = { + stop: () => Promise; +}; + +export async function startExtensionHostServices(params: { + registry: PluginRegistry; + config: OpenClawConfig; + workspaceDir?: string; +}): Promise { + const running: Array<{ + id: string; + stop?: () => void | Promise; + }> = []; + const serviceContext = createExtensionHostServiceContext({ + config: params.config, + workspaceDir: params.workspaceDir, + }); + + for (const entry of listExtensionHostServiceRegistrations(params.registry)) { + const service = entry.service; + try { + await service.start(serviceContext); + running.push({ + id: service.id, + stop: service.stop ? () => service.stop?.(serviceContext) : undefined, + }); + } catch (err) { + log.error(`plugin service failed (${service.id}): ${String(err)}`); + } + } + + return { + stop: async () => { + for (const entry of running.toReversed()) { + if (!entry.stop) { + continue; + } + try { + await entry.stop(); + } catch (err) { + log.warn(`plugin service stop failed (${entry.id}): ${String(err)}`); + } + } + }, + }; +} diff --git a/src/extension-host/contributions/tool-runtime.test.ts b/src/extension-host/contributions/tool-runtime.test.ts new file mode 100644 index 00000000000..d16020a47f8 --- /dev/null +++ b/src/extension-host/contributions/tool-runtime.test.ts @@ -0,0 +1,124 @@ +import { describe, expect, it, vi } from "vitest"; +import type { AnyAgentTool } from "../../agents/tools/common.js"; +import { createEmptyPluginRegistry } from "../../plugins/registry.js"; +import { addExtensionHostToolRegistration } from "./runtime-registry.js"; +import { getExtensionHostPluginToolMeta, resolveExtensionHostPluginTools } from "./tool-runtime.js"; + +function makeTool(name: string): AnyAgentTool { + return { + name, + description: `${name} tool`, + parameters: { type: "object", properties: {} }, + async execute() { + return { content: [{ type: "text", text: "ok" }] }; + }, + }; +} + +function createContext() { + return { + config: { + plugins: { + enabled: true, + }, + }, + workspaceDir: "/tmp", + }; +} + +describe("resolveExtensionHostPluginTools", () => { + it("allows optional tools through tool, plugin, and plugin-group allowlists", () => { + const registry = createEmptyPluginRegistry(); + addExtensionHostToolRegistration(registry, { + pluginId: "optional-demo", + optional: true, + source: "/tmp/optional-demo.js", + factory: () => makeTool("optional_tool"), + names: ["optional_tool"], + }); + + expect( + resolveExtensionHostPluginTools({ + registry, + context: createContext() as never, + }), + ).toEqual([]); + expect( + resolveExtensionHostPluginTools({ + registry, + context: createContext() as never, + toolAllowlist: ["optional_tool"], + }).map((tool) => tool.name), + ).toEqual(["optional_tool"]); + expect( + resolveExtensionHostPluginTools({ + registry, + context: createContext() as never, + toolAllowlist: ["optional-demo"], + }).map((tool) => tool.name), + ).toEqual(["optional_tool"]); + expect( + resolveExtensionHostPluginTools({ + registry, + context: createContext() as never, + toolAllowlist: ["group:plugins"], + }).map((tool) => tool.name), + ).toEqual(["optional_tool"]); + }); + + it("records conflict diagnostics and preserves tool metadata", () => { + const registry = createEmptyPluginRegistry(); + const extraTool = makeTool("other_tool"); + addExtensionHostToolRegistration(registry, { + pluginId: "message", + optional: false, + source: "/tmp/message.js", + factory: () => makeTool("optional_tool"), + names: ["optional_tool"], + }); + addExtensionHostToolRegistration(registry, { + pluginId: "multi", + optional: false, + source: "/tmp/multi.js", + factory: () => [makeTool("message"), extraTool], + names: ["message", "other_tool"], + }); + + const tools = resolveExtensionHostPluginTools({ + registry, + context: createContext() as never, + existingToolNames: new Set(["message"]), + }); + + expect(tools.map((tool) => tool.name)).toEqual(["other_tool"]); + expect(registry.diagnostics).toHaveLength(2); + expect(registry.diagnostics[0]?.message).toContain("plugin id conflicts with core tool name"); + expect(registry.diagnostics[1]?.message).toContain("plugin tool name conflict"); + expect(getExtensionHostPluginToolMeta(extraTool)).toEqual({ + pluginId: "multi", + optional: false, + }); + }); + + it("skips tool factories that throw", () => { + const registry = createEmptyPluginRegistry(); + const factory = vi.fn(() => { + throw new Error("boom"); + }); + addExtensionHostToolRegistration(registry, { + pluginId: "broken", + optional: false, + source: "/tmp/broken.js", + factory, + names: ["broken_tool"], + }); + + expect( + resolveExtensionHostPluginTools({ + registry, + context: createContext() as never, + }), + ).toEqual([]); + expect(factory).toHaveBeenCalledOnce(); + }); +}); diff --git a/src/extension-host/contributions/tool-runtime.ts b/src/extension-host/contributions/tool-runtime.ts new file mode 100644 index 00000000000..2018be901f5 --- /dev/null +++ b/src/extension-host/contributions/tool-runtime.ts @@ -0,0 +1,138 @@ +import { normalizeToolName } from "../../agents/tool-policy.js"; +import type { AnyAgentTool } from "../../agents/tools/common.js"; +import { createSubsystemLogger } from "../../logging/subsystem.js"; +import type { PluginRegistry } from "../../plugins/registry.js"; +import type { OpenClawPluginToolContext } from "../../plugins/types.js"; +import { listExtensionHostToolRegistrations } from "./runtime-registry.js"; + +const log = createSubsystemLogger("plugins"); + +export type ExtensionHostPluginToolMeta = { + pluginId: string; + optional: boolean; +}; + +const extensionHostPluginToolMeta = new WeakMap(); + +export function getExtensionHostPluginToolMeta( + tool: AnyAgentTool, +): ExtensionHostPluginToolMeta | undefined { + return extensionHostPluginToolMeta.get(tool); +} + +function normalizeAllowlist(list?: string[]) { + return new Set((list ?? []).map(normalizeToolName).filter(Boolean)); +} + +function isOptionalToolAllowed(params: { + toolName: string; + pluginId: string; + allowlist: Set; +}): boolean { + if (params.allowlist.size === 0) { + return false; + } + const toolName = normalizeToolName(params.toolName); + if (params.allowlist.has(toolName)) { + return true; + } + const pluginKey = normalizeToolName(params.pluginId); + if (params.allowlist.has(pluginKey)) { + return true; + } + return params.allowlist.has("group:plugins"); +} + +export function resolveExtensionHostPluginTools(params: { + registry: Pick< + PluginRegistry, + | "channels" + | "tools" + | "providers" + | "cliRegistrars" + | "commands" + | "services" + | "httpRoutes" + | "gatewayHandlers" + | "diagnostics" + >; + context: OpenClawPluginToolContext; + existingToolNames?: Set; + toolAllowlist?: string[]; + suppressNameConflicts?: boolean; +}): AnyAgentTool[] { + const tools: AnyAgentTool[] = []; + const existing = params.existingToolNames ?? new Set(); + const existingNormalized = new Set(Array.from(existing, (tool) => normalizeToolName(tool))); + const allowlist = normalizeAllowlist(params.toolAllowlist); + const blockedPlugins = new Set(); + + for (const entry of listExtensionHostToolRegistrations(params.registry)) { + if (blockedPlugins.has(entry.pluginId)) { + continue; + } + const pluginIdKey = normalizeToolName(entry.pluginId); + if (existingNormalized.has(pluginIdKey)) { + const message = `plugin id conflicts with core tool name (${entry.pluginId})`; + if (!params.suppressNameConflicts) { + log.error(message); + params.registry.diagnostics.push({ + level: "error", + pluginId: entry.pluginId, + source: entry.source, + message, + }); + } + blockedPlugins.add(entry.pluginId); + continue; + } + let resolved: AnyAgentTool | AnyAgentTool[] | null | undefined = null; + try { + resolved = entry.factory(params.context); + } catch (err) { + log.error(`plugin tool failed (${entry.pluginId}): ${String(err)}`); + continue; + } + if (!resolved) { + continue; + } + const listRaw = Array.isArray(resolved) ? resolved : [resolved]; + const list = entry.optional + ? listRaw.filter((tool) => + isOptionalToolAllowed({ + toolName: tool.name, + pluginId: entry.pluginId, + allowlist, + }), + ) + : listRaw; + if (list.length === 0) { + continue; + } + const nameSet = new Set(); + for (const tool of list) { + if (nameSet.has(tool.name) || existing.has(tool.name)) { + const message = `plugin tool name conflict (${entry.pluginId}): ${tool.name}`; + if (!params.suppressNameConflicts) { + log.error(message); + params.registry.diagnostics.push({ + level: "error", + pluginId: entry.pluginId, + source: entry.source, + message, + }); + } + continue; + } + nameSet.add(tool.name); + existing.add(tool.name); + extensionHostPluginToolMeta.set(tool, { + pluginId: entry.pluginId, + optional: entry.optional, + }); + tools.push(tool); + } + } + + return tools; +} diff --git a/src/extension-host/contributions/tts-api.test.ts b/src/extension-host/contributions/tts-api.test.ts new file mode 100644 index 00000000000..97306d3633d --- /dev/null +++ b/src/extension-host/contributions/tts-api.test.ts @@ -0,0 +1,135 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { + applyExtensionHostTtsToPayload, + buildExtensionHostTtsSystemPromptHint, + runExtensionHostTextToSpeech, +} from "./tts-api.js"; + +vi.mock("./tts-config.js", () => ({ + normalizeExtensionHostTtsConfigAutoMode: vi.fn(), + resolveExtensionHostTtsConfig: vi.fn(), + resolveExtensionHostTtsModelOverridePolicy: vi.fn(), +})); + +vi.mock("./tts-preferences.js", () => ({ + getExtensionHostTtsMaxLength: vi.fn(), + isExtensionHostTtsSummarizationEnabled: vi.fn(), + resolveExtensionHostTtsAutoMode: vi.fn(), + resolveExtensionHostTtsPrefsPath: vi.fn(), +})); + +vi.mock("./tts-payload.js", () => ({ + resolveExtensionHostTtsPayloadPlan: vi.fn(), +})); + +vi.mock("./tts-runtime-setup.js", () => ({ + resolveExtensionHostTtsRequestSetup: vi.fn(), +})); + +vi.mock("./tts-runtime-execution.js", () => ({ + executeExtensionHostTextToSpeech: vi.fn(), + executeExtensionHostTextToSpeechTelephony: vi.fn(), + isExtensionHostTtsVoiceBubbleChannel: vi.fn(() => false), + resolveExtensionHostEdgeOutputFormat: vi.fn(() => "audio-24khz-48kbitrate-mono-mp3"), + resolveExtensionHostTtsOutputFormat: vi.fn(() => ({ + openai: "mp3", + elevenlabs: "mp3_44100_128", + extension: ".mp3", + voiceCompatible: false, + })), +})); + +vi.mock("./tts-status.js", () => ({ + getExtensionHostLastTtsAttempt: vi.fn(), + setExtensionHostLastTtsAttempt: vi.fn(), +})); + +describe("tts-api", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("builds the remaining system prompt hint through host-owned preferences", async () => { + const configModule = await import("./tts-config.js"); + const prefsModule = await import("./tts-preferences.js"); + + vi.mocked(configModule.resolveExtensionHostTtsConfig).mockReturnValue({} as never); + vi.mocked(prefsModule.resolveExtensionHostTtsPrefsPath).mockReturnValue("/tmp/tts.json"); + vi.mocked(prefsModule.resolveExtensionHostTtsAutoMode).mockReturnValue("inbound"); + vi.mocked(prefsModule.getExtensionHostTtsMaxLength).mockReturnValue(900); + vi.mocked(prefsModule.isExtensionHostTtsSummarizationEnabled).mockReturnValue(false); + + const hint = buildExtensionHostTtsSystemPromptHint({} as never); + + expect(hint).toContain("Voice (TTS) is enabled."); + expect(hint).toContain("Only use TTS when the user's last message includes audio/voice."); + expect(hint).toContain("Keep spoken text ≤900 chars"); + expect(hint).toContain("summary off"); + }); + + it("returns setup validation errors through the host-owned TTS API", async () => { + const configModule = await import("./tts-config.js"); + const prefsModule = await import("./tts-preferences.js"); + const setupModule = await import("./tts-runtime-setup.js"); + + vi.mocked(configModule.resolveExtensionHostTtsConfig).mockReturnValue({} as never); + vi.mocked(prefsModule.resolveExtensionHostTtsPrefsPath).mockReturnValue("/tmp/tts.json"); + vi.mocked(setupModule.resolveExtensionHostTtsRequestSetup).mockReturnValue({ + error: "Text too long (5000 chars, max 4096)", + }); + + await expect( + runExtensionHostTextToSpeech({ + text: "x".repeat(5000), + cfg: {} as never, + }), + ).resolves.toEqual({ + success: false, + error: "Text too long (5000 chars, max 4096)", + }); + }); + + it("returns the planned payload when TTS conversion fails", async () => { + const configModule = await import("./tts-config.js"); + const prefsModule = await import("./tts-preferences.js"); + const payloadModule = await import("./tts-payload.js"); + const setupModule = await import("./tts-runtime-setup.js"); + const executionModule = await import("./tts-runtime-execution.js"); + const statusModule = await import("./tts-status.js"); + + vi.mocked(configModule.resolveExtensionHostTtsConfig).mockReturnValue({} as never); + vi.mocked(prefsModule.resolveExtensionHostTtsPrefsPath).mockReturnValue("/tmp/tts.json"); + vi.mocked(payloadModule.resolveExtensionHostTtsPayloadPlan).mockResolvedValue({ + kind: "ready", + nextPayload: { text: "cleaned" }, + textForAudio: "speak this", + wasSummarized: true, + overrides: {}, + }); + vi.mocked(setupModule.resolveExtensionHostTtsRequestSetup).mockReturnValue({ + config: {} as never, + providers: ["openai"], + }); + vi.mocked(executionModule.executeExtensionHostTextToSpeech).mockResolvedValue({ + success: false, + error: "provider failed", + }); + + const result = await applyExtensionHostTtsToPayload({ + payload: { text: "original" }, + cfg: {} as never, + channel: "telegram", + kind: "final", + }); + + expect(result).toEqual({ text: "cleaned" }); + expect(statusModule.setExtensionHostLastTtsAttempt).toHaveBeenCalledWith( + expect.objectContaining({ + success: false, + textLength: "original".length, + summarized: true, + error: "provider failed", + }), + ); + }); +}); diff --git a/src/extension-host/contributions/tts-api.ts b/src/extension-host/contributions/tts-api.ts new file mode 100644 index 00000000000..25edd726a0e --- /dev/null +++ b/src/extension-host/contributions/tts-api.ts @@ -0,0 +1,169 @@ +import type { ReplyPayload } from "../../auto-reply/types.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import { logVerbose } from "../../globals.js"; +import type { TtsDirectiveOverrides, TtsResult, TtsTelephonyResult } from "../../tts/tts.js"; +import { + resolveExtensionHostTtsConfig, + resolveExtensionHostTtsModelOverridePolicy, +} from "./tts-config.js"; +import { resolveExtensionHostTtsPayloadPlan } from "./tts-payload.js"; +import { + getExtensionHostTtsMaxLength, + isExtensionHostTtsSummarizationEnabled, + resolveExtensionHostTtsAutoMode, + resolveExtensionHostTtsPrefsPath, +} from "./tts-preferences.js"; +import { + executeExtensionHostTextToSpeech, + executeExtensionHostTextToSpeechTelephony, + isExtensionHostTtsVoiceBubbleChannel, + resolveExtensionHostEdgeOutputFormat, + resolveExtensionHostTtsOutputFormat, +} from "./tts-runtime-execution.js"; +import { resolveExtensionHostTtsRequestSetup } from "./tts-runtime-setup.js"; +import { setExtensionHostLastTtsAttempt, type ExtensionHostTtsStatusEntry } from "./tts-status.js"; + +export type { ExtensionHostTtsStatusEntry }; + +export { resolveExtensionHostTtsModelOverridePolicy }; +export { resolveExtensionHostTtsOutputFormat, resolveExtensionHostEdgeOutputFormat }; + +export function buildExtensionHostTtsSystemPromptHint(cfg: OpenClawConfig): string | undefined { + const config = resolveExtensionHostTtsConfig(cfg); + const prefsPath = resolveExtensionHostTtsPrefsPath(config); + const autoMode = resolveExtensionHostTtsAutoMode({ config, prefsPath }); + if (autoMode === "off") { + return undefined; + } + const maxLength = getExtensionHostTtsMaxLength(prefsPath); + const summarize = isExtensionHostTtsSummarizationEnabled(prefsPath) ? "on" : "off"; + const autoHint = + autoMode === "inbound" + ? "Only use TTS when the user's last message includes audio/voice." + : autoMode === "tagged" + ? "Only use TTS when you include [[tts]] or [[tts:text]] tags." + : undefined; + return [ + "Voice (TTS) is enabled.", + autoHint, + `Keep spoken text ≤${maxLength} chars to avoid auto-summary (summary ${summarize}).`, + "Use [[tts:...]] and optional [[tts:text]]...[[/tts:text]] to control voice/expressiveness.", + ] + .filter(Boolean) + .join("\n"); +} + +export async function runExtensionHostTextToSpeech(params: { + text: string; + cfg: OpenClawConfig; + prefsPath?: string; + channel?: string; + overrides?: TtsDirectiveOverrides; +}): Promise { + const config = resolveExtensionHostTtsConfig(params.cfg); + const prefsPath = params.prefsPath ?? resolveExtensionHostTtsPrefsPath(config); + const setup = resolveExtensionHostTtsRequestSetup({ + text: params.text, + config, + prefsPath, + providerOverride: params.overrides?.provider, + }); + if ("error" in setup) { + return { success: false, error: setup.error }; + } + + return executeExtensionHostTextToSpeech({ + text: params.text, + config: setup.config, + providers: setup.providers, + channel: params.channel, + overrides: params.overrides, + }); +} + +export async function runExtensionHostTextToSpeechTelephony(params: { + text: string; + cfg: OpenClawConfig; + prefsPath?: string; +}): Promise { + const config = resolveExtensionHostTtsConfig(params.cfg); + const prefsPath = params.prefsPath ?? resolveExtensionHostTtsPrefsPath(config); + const setup = resolveExtensionHostTtsRequestSetup({ + text: params.text, + config, + prefsPath, + }); + if ("error" in setup) { + return { success: false, error: setup.error }; + } + + return executeExtensionHostTextToSpeechTelephony({ + text: params.text, + config: setup.config, + providers: setup.providers, + }); +} + +export async function applyExtensionHostTtsToPayload(params: { + payload: ReplyPayload; + cfg: OpenClawConfig; + channel?: string; + kind?: "tool" | "block" | "final"; + inboundAudio?: boolean; + ttsAuto?: string; +}): Promise { + const config = resolveExtensionHostTtsConfig(params.cfg); + const prefsPath = resolveExtensionHostTtsPrefsPath(config); + const plan = await resolveExtensionHostTtsPayloadPlan({ + payload: params.payload, + cfg: params.cfg, + config, + prefsPath, + kind: params.kind, + inboundAudio: params.inboundAudio, + ttsAuto: params.ttsAuto, + }); + if (plan.kind === "skip") { + return plan.payload; + } + + const ttsStart = Date.now(); + const result = await runExtensionHostTextToSpeech({ + text: plan.textForAudio, + cfg: params.cfg, + prefsPath, + channel: params.channel, + overrides: plan.overrides, + }); + + if (result.success && result.audioPath) { + setExtensionHostLastTtsAttempt({ + timestamp: Date.now(), + success: true, + textLength: (params.payload.text ?? "").length, + summarized: plan.wasSummarized, + provider: result.provider, + latencyMs: result.latencyMs, + }); + + const shouldVoice = + isExtensionHostTtsVoiceBubbleChannel(params.channel) && result.voiceCompatible === true; + return { + ...plan.nextPayload, + mediaUrl: result.audioPath, + audioAsVoice: shouldVoice || params.payload.audioAsVoice, + }; + } + + setExtensionHostLastTtsAttempt({ + timestamp: Date.now(), + success: false, + textLength: (params.payload.text ?? "").length, + summarized: plan.wasSummarized, + error: result.error, + }); + + const latency = Date.now() - ttsStart; + logVerbose(`TTS: conversion failed after ${latency}ms (${result.error ?? "unknown"}).`); + return plan.nextPayload; +} diff --git a/src/extension-host/contributions/tts-config.ts b/src/extension-host/contributions/tts-config.ts new file mode 100644 index 00000000000..f51efc5a710 --- /dev/null +++ b/src/extension-host/contributions/tts-config.ts @@ -0,0 +1,193 @@ +import type { OpenClawConfig } from "../../config/config.js"; +import { normalizeResolvedSecretInputString } from "../../config/types.secrets.js"; +import type { + TtsAutoMode, + TtsConfig, + TtsMode, + TtsModelOverrideConfig, + TtsProvider, +} from "../../config/types.tts.js"; +import { normalizeExtensionHostTtsAutoMode } from "./tts-preferences.js"; + +export const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"; + +const DEFAULT_TIMEOUT_MS = 30_000; +const DEFAULT_MAX_TEXT_LENGTH = 4096; +const DEFAULT_ELEVENLABS_BASE_URL = "https://api.elevenlabs.io"; +const DEFAULT_ELEVENLABS_VOICE_ID = "pMsXgVXv3BLzUgSXRplE"; +const DEFAULT_ELEVENLABS_MODEL_ID = "eleven_multilingual_v2"; +const DEFAULT_OPENAI_MODEL = "gpt-4o-mini-tts"; +const DEFAULT_OPENAI_VOICE = "alloy"; +const DEFAULT_EDGE_VOICE = "en-US-MichelleNeural"; +const DEFAULT_EDGE_LANG = "en-US"; +const DEFAULT_EDGE_OUTPUT_FORMAT = "audio-24khz-48kbitrate-mono-mp3"; + +const DEFAULT_ELEVENLABS_VOICE_SETTINGS = { + stability: 0.5, + similarityBoost: 0.75, + style: 0.0, + useSpeakerBoost: true, + speed: 1.0, +}; + +export type ResolvedTtsConfig = { + auto: TtsAutoMode; + mode: TtsMode; + provider: TtsProvider; + providerSource: "config" | "default"; + summaryModel?: string; + modelOverrides: ResolvedTtsModelOverrides; + elevenlabs: { + apiKey?: string; + baseUrl: string; + voiceId: string; + modelId: string; + seed?: number; + applyTextNormalization?: "auto" | "on" | "off"; + languageCode?: string; + voiceSettings: { + stability: number; + similarityBoost: number; + style: number; + useSpeakerBoost: boolean; + speed: number; + }; + }; + openai: { + apiKey?: string; + baseUrl: string; + model: string; + voice: string; + speed?: number; + instructions?: string; + }; + edge: { + enabled: boolean; + voice: string; + lang: string; + outputFormat: string; + outputFormatConfigured: boolean; + pitch?: string; + rate?: string; + volume?: string; + saveSubtitles: boolean; + proxy?: string; + timeoutMs?: number; + }; + prefsPath?: string; + maxTextLength: number; + timeoutMs: number; +}; + +export type ResolvedTtsModelOverrides = { + enabled: boolean; + allowText: boolean; + allowProvider: boolean; + allowVoice: boolean; + allowModelId: boolean; + allowVoiceSettings: boolean; + allowNormalization: boolean; + allowSeed: boolean; +}; + +export const normalizeExtensionHostTtsConfigAutoMode = normalizeExtensionHostTtsAutoMode; + +export function resolveExtensionHostTtsModelOverridePolicy( + overrides: TtsModelOverrideConfig | undefined, +): ResolvedTtsModelOverrides { + const enabled = overrides?.enabled ?? true; + if (!enabled) { + return { + enabled: false, + allowText: false, + allowProvider: false, + allowVoice: false, + allowModelId: false, + allowVoiceSettings: false, + allowNormalization: false, + allowSeed: false, + }; + } + const allow = (value: boolean | undefined, defaultValue = true) => value ?? defaultValue; + return { + enabled: true, + allowText: allow(overrides?.allowText), + allowProvider: allow(overrides?.allowProvider, false), + allowVoice: allow(overrides?.allowVoice), + allowModelId: allow(overrides?.allowModelId), + allowVoiceSettings: allow(overrides?.allowVoiceSettings), + allowNormalization: allow(overrides?.allowNormalization), + allowSeed: allow(overrides?.allowSeed), + }; +} + +export function resolveExtensionHostTtsConfig(cfg: OpenClawConfig): ResolvedTtsConfig { + const raw: TtsConfig = cfg.messages?.tts ?? {}; + const providerSource = raw.provider ? "config" : "default"; + const edgeOutputFormat = raw.edge?.outputFormat?.trim(); + const auto = + normalizeExtensionHostTtsConfigAutoMode(raw.auto) ?? (raw.enabled ? "always" : "off"); + return { + auto, + mode: raw.mode ?? "final", + provider: raw.provider ?? "edge", + providerSource, + summaryModel: raw.summaryModel?.trim() || undefined, + modelOverrides: resolveExtensionHostTtsModelOverridePolicy(raw.modelOverrides), + elevenlabs: { + apiKey: normalizeResolvedSecretInputString({ + value: raw.elevenlabs?.apiKey, + path: "messages.tts.elevenlabs.apiKey", + }), + baseUrl: raw.elevenlabs?.baseUrl?.trim() || DEFAULT_ELEVENLABS_BASE_URL, + voiceId: raw.elevenlabs?.voiceId ?? DEFAULT_ELEVENLABS_VOICE_ID, + modelId: raw.elevenlabs?.modelId ?? DEFAULT_ELEVENLABS_MODEL_ID, + seed: raw.elevenlabs?.seed, + applyTextNormalization: raw.elevenlabs?.applyTextNormalization, + languageCode: raw.elevenlabs?.languageCode, + voiceSettings: { + stability: + raw.elevenlabs?.voiceSettings?.stability ?? DEFAULT_ELEVENLABS_VOICE_SETTINGS.stability, + similarityBoost: + raw.elevenlabs?.voiceSettings?.similarityBoost ?? + DEFAULT_ELEVENLABS_VOICE_SETTINGS.similarityBoost, + style: raw.elevenlabs?.voiceSettings?.style ?? DEFAULT_ELEVENLABS_VOICE_SETTINGS.style, + useSpeakerBoost: + raw.elevenlabs?.voiceSettings?.useSpeakerBoost ?? + DEFAULT_ELEVENLABS_VOICE_SETTINGS.useSpeakerBoost, + speed: raw.elevenlabs?.voiceSettings?.speed ?? DEFAULT_ELEVENLABS_VOICE_SETTINGS.speed, + }, + }, + openai: { + apiKey: normalizeResolvedSecretInputString({ + value: raw.openai?.apiKey, + path: "messages.tts.openai.apiKey", + }), + baseUrl: ( + raw.openai?.baseUrl?.trim() || + process.env.OPENAI_TTS_BASE_URL?.trim() || + DEFAULT_OPENAI_BASE_URL + ).replace(/\/+$/, ""), + model: raw.openai?.model ?? DEFAULT_OPENAI_MODEL, + voice: raw.openai?.voice ?? DEFAULT_OPENAI_VOICE, + speed: raw.openai?.speed, + instructions: raw.openai?.instructions?.trim() || undefined, + }, + edge: { + enabled: raw.edge?.enabled ?? true, + voice: raw.edge?.voice?.trim() || DEFAULT_EDGE_VOICE, + lang: raw.edge?.lang?.trim() || DEFAULT_EDGE_LANG, + outputFormat: edgeOutputFormat || DEFAULT_EDGE_OUTPUT_FORMAT, + outputFormatConfigured: Boolean(edgeOutputFormat), + pitch: raw.edge?.pitch?.trim() || undefined, + rate: raw.edge?.rate?.trim() || undefined, + volume: raw.edge?.volume?.trim() || undefined, + saveSubtitles: raw.edge?.saveSubtitles ?? false, + proxy: raw.edge?.proxy?.trim() || undefined, + timeoutMs: raw.edge?.timeoutMs, + }, + prefsPath: raw.prefsPath, + maxTextLength: raw.maxTextLength ?? DEFAULT_MAX_TEXT_LENGTH, + timeoutMs: raw.timeoutMs ?? DEFAULT_TIMEOUT_MS, + }; +} diff --git a/src/extension-host/contributions/tts-payload.ts b/src/extension-host/contributions/tts-payload.ts new file mode 100644 index 00000000000..d02ce484f53 --- /dev/null +++ b/src/extension-host/contributions/tts-payload.ts @@ -0,0 +1,140 @@ +import type { ReplyPayload } from "../../auto-reply/types.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import { logVerbose } from "../../globals.js"; +import { stripMarkdown } from "../../line/markdown-to-line.js"; +import { parseTtsDirectives, summarizeText } from "../../tts/tts-core.js"; +import type { TtsDirectiveOverrides } from "../../tts/tts.js"; +import type { ResolvedTtsConfig } from "./tts-config.js"; +import { + getExtensionHostTtsMaxLength, + isExtensionHostTtsSummarizationEnabled, + resolveExtensionHostTtsAutoMode, +} from "./tts-preferences.js"; + +export type ExtensionHostTtsPayloadPlan = + | { + kind: "skip"; + payload: ReplyPayload; + } + | { + kind: "ready"; + nextPayload: ReplyPayload; + textForAudio: string; + wasSummarized: boolean; + overrides: TtsDirectiveOverrides; + }; + +export async function resolveExtensionHostTtsPayloadPlan(params: { + payload: ReplyPayload; + cfg: OpenClawConfig; + config: ResolvedTtsConfig; + prefsPath: string; + kind?: "tool" | "block" | "final"; + inboundAudio?: boolean; + ttsAuto?: string; +}): Promise { + const autoMode = resolveExtensionHostTtsAutoMode({ + config: params.config, + prefsPath: params.prefsPath, + sessionAuto: params.ttsAuto, + }); + if (autoMode === "off") { + return { kind: "skip", payload: params.payload }; + } + + const text = params.payload.text ?? ""; + const directives = parseTtsDirectives( + text, + params.config.modelOverrides, + params.config.openai.baseUrl, + ); + if (directives.warnings.length > 0) { + logVerbose(`TTS: ignored directive overrides (${directives.warnings.join("; ")})`); + } + + const cleanedText = directives.cleanedText; + const trimmedCleaned = cleanedText.trim(); + const visibleText = trimmedCleaned.length > 0 ? trimmedCleaned : ""; + const ttsText = directives.ttsText?.trim() || visibleText; + + const nextPayload = + visibleText === text.trim() + ? params.payload + : { + ...params.payload, + text: visibleText.length > 0 ? visibleText : undefined, + }; + + if (autoMode === "tagged" && !directives.hasDirective) { + return { kind: "skip", payload: nextPayload }; + } + if (autoMode === "inbound" && params.inboundAudio !== true) { + return { kind: "skip", payload: nextPayload }; + } + + const mode = params.config.mode ?? "final"; + if (mode === "final" && params.kind && params.kind !== "final") { + return { kind: "skip", payload: nextPayload }; + } + + if (!ttsText.trim()) { + return { kind: "skip", payload: nextPayload }; + } + if (params.payload.mediaUrl || (params.payload.mediaUrls?.length ?? 0) > 0) { + return { kind: "skip", payload: nextPayload }; + } + if (text.includes("MEDIA:")) { + return { kind: "skip", payload: nextPayload }; + } + if (ttsText.trim().length < 10) { + return { kind: "skip", payload: nextPayload }; + } + + const maxLength = getExtensionHostTtsMaxLength(params.prefsPath); + let textForAudio = ttsText.trim(); + let wasSummarized = false; + + if (textForAudio.length > maxLength) { + if (!isExtensionHostTtsSummarizationEnabled(params.prefsPath)) { + logVerbose( + `TTS: truncating long text (${textForAudio.length} > ${maxLength}), summarization disabled.`, + ); + textForAudio = `${textForAudio.slice(0, maxLength - 3)}...`; + } else { + try { + const summary = await summarizeText({ + text: textForAudio, + targetLength: maxLength, + cfg: params.cfg, + config: params.config, + timeoutMs: params.config.timeoutMs, + }); + textForAudio = summary.summary; + wasSummarized = true; + if (textForAudio.length > params.config.maxTextLength) { + logVerbose( + `TTS: summary exceeded hard limit (${textForAudio.length} > ${params.config.maxTextLength}); truncating.`, + ); + textForAudio = `${textForAudio.slice(0, params.config.maxTextLength - 3)}...`; + } + } catch (err) { + const error = err as Error; + logVerbose(`TTS: summarization failed, truncating instead: ${error.message}`); + textForAudio = `${textForAudio.slice(0, maxLength - 3)}...`; + } + } + } + + textForAudio = stripMarkdown(textForAudio).trim(); + if (textForAudio.length < 10) { + return { kind: "skip", payload: nextPayload }; + } + + return { + kind: "ready", + nextPayload, + textForAudio, + wasSummarized, + overrides: directives.overrides, + }; +} diff --git a/src/extension-host/contributions/tts-preferences.test.ts b/src/extension-host/contributions/tts-preferences.test.ts new file mode 100644 index 00000000000..835c9da012d --- /dev/null +++ b/src/extension-host/contributions/tts-preferences.test.ts @@ -0,0 +1,121 @@ +import { mkdtempSync, readFileSync, rmSync } from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, describe, expect, it } from "vitest"; +import { withEnv } from "../../test-utils/env.js"; +import type { ResolvedTtsConfig } from "./tts-config.js"; +import { + getExtensionHostTtsMaxLength, + isExtensionHostTtsEnabled, + isExtensionHostTtsSummarizationEnabled, + resolveExtensionHostTtsAutoMode, + resolveExtensionHostTtsPrefsPath, + setExtensionHostTtsAutoMode, + setExtensionHostTtsMaxLength, + setExtensionHostTtsSummarizationEnabled, +} from "./tts-preferences.js"; + +const tempDirs: string[] = []; + +function createPrefsPath(): string { + const tempDir = mkdtempSync(path.join(os.tmpdir(), "openclaw-tts-prefs-")); + tempDirs.push(tempDir); + return path.join(tempDir, "tts.json"); +} + +function createResolvedConfig(overrides?: Partial): ResolvedTtsConfig { + return { + auto: "off", + mode: "final", + provider: "edge", + providerSource: "default", + modelOverrides: { + enabled: true, + allowText: true, + allowProvider: false, + allowVoice: true, + allowModelId: true, + allowVoiceSettings: true, + allowNormalization: true, + allowSeed: true, + }, + elevenlabs: { + baseUrl: "https://api.elevenlabs.io", + voiceId: "voice-id", + modelId: "eleven_multilingual_v2", + voiceSettings: { + stability: 0.5, + similarityBoost: 0.75, + style: 0, + useSpeakerBoost: true, + speed: 1, + }, + }, + openai: { + baseUrl: "https://api.openai.com/v1", + model: "gpt-4o-mini-tts", + voice: "alloy", + }, + edge: { + enabled: true, + voice: "en-US-MichelleNeural", + lang: "en-US", + outputFormat: "audio-24khz-48kbitrate-mono-mp3", + outputFormatConfigured: false, + saveSubtitles: false, + }, + maxTextLength: 4096, + timeoutMs: 30_000, + ...overrides, + }; +} + +afterEach(() => { + for (const tempDir of tempDirs.splice(0)) { + rmSync(tempDir, { recursive: true, force: true }); + } +}); + +describe("tts-preferences", () => { + it("prefers config prefsPath over env and default locations", () => { + const config = createResolvedConfig({ prefsPath: "~/custom-tts.json" }); + + withEnv({ OPENCLAW_TTS_PREFS: "/tmp/ignored-tts.json" }, () => { + expect(resolveExtensionHostTtsPrefsPath(config)).toContain("custom-tts.json"); + }); + }); + + it("resolves session, persisted, and config auto modes in precedence order", () => { + const prefsPath = createPrefsPath(); + const config = createResolvedConfig({ auto: "inbound" }); + + setExtensionHostTtsAutoMode(prefsPath, "tagged"); + + expect( + resolveExtensionHostTtsAutoMode({ + config, + prefsPath, + sessionAuto: "always", + }), + ).toBe("always"); + expect(resolveExtensionHostTtsAutoMode({ config, prefsPath })).toBe("tagged"); + + const persisted = JSON.parse(readFileSync(prefsPath, "utf8")) as { + tts?: { auto?: string; enabled?: boolean }; + }; + expect(persisted.tts?.auto).toBe("tagged"); + expect("enabled" in (persisted.tts ?? {})).toBe(false); + }); + + it("persists max-length and summarization preferences through the host helper", () => { + const prefsPath = createPrefsPath(); + const config = createResolvedConfig({ auto: "always" }); + + setExtensionHostTtsMaxLength(prefsPath, 900); + setExtensionHostTtsSummarizationEnabled(prefsPath, false); + + expect(getExtensionHostTtsMaxLength(prefsPath)).toBe(900); + expect(isExtensionHostTtsSummarizationEnabled(prefsPath)).toBe(false); + expect(isExtensionHostTtsEnabled(config, prefsPath)).toBe(true); + }); +}); diff --git a/src/extension-host/contributions/tts-preferences.ts b/src/extension-host/contributions/tts-preferences.ts new file mode 100644 index 00000000000..c3baac206b6 --- /dev/null +++ b/src/extension-host/contributions/tts-preferences.ts @@ -0,0 +1,162 @@ +import { randomBytes } from "node:crypto"; +import { + existsSync, + mkdirSync, + readFileSync, + renameSync, + unlinkSync, + writeFileSync, +} from "node:fs"; +import path from "node:path"; +import type { TtsAutoMode, TtsProvider } from "../../config/types.tts.js"; +import { CONFIG_DIR, resolveUserPath } from "../../utils.js"; +import type { ResolvedTtsConfig } from "./tts-config.js"; + +export const DEFAULT_EXTENSION_HOST_TTS_MAX_LENGTH = 1500; +export const DEFAULT_EXTENSION_HOST_TTS_SUMMARIZE = true; + +type TtsUserPrefs = { + tts?: { + auto?: TtsAutoMode; + enabled?: boolean; + provider?: TtsProvider; + maxLength?: number; + summarize?: boolean; + }; +}; + +function readExtensionHostTtsPrefs(prefsPath: string): TtsUserPrefs { + try { + if (!existsSync(prefsPath)) { + return {}; + } + return JSON.parse(readFileSync(prefsPath, "utf8")) as TtsUserPrefs; + } catch { + return {}; + } +} + +function atomicWriteExtensionHostTtsPrefs(filePath: string, content: string): void { + const tmpPath = `${filePath}.tmp.${Date.now()}.${randomBytes(8).toString("hex")}`; + writeFileSync(tmpPath, content, { mode: 0o600 }); + try { + renameSync(tmpPath, filePath); + } catch (err) { + try { + unlinkSync(tmpPath); + } catch {} + throw err; + } +} + +function updateExtensionHostTtsPrefs( + prefsPath: string, + update: (prefs: TtsUserPrefs) => void, +): void { + const prefs = readExtensionHostTtsPrefs(prefsPath); + update(prefs); + mkdirSync(path.dirname(prefsPath), { recursive: true }); + atomicWriteExtensionHostTtsPrefs(prefsPath, JSON.stringify(prefs, null, 2)); +} + +export function normalizeExtensionHostTtsAutoMode(value: unknown): TtsAutoMode | undefined { + if (typeof value !== "string") { + return undefined; + } + const normalized = value.trim().toLowerCase(); + return normalized === "off" || + normalized === "always" || + normalized === "inbound" || + normalized === "tagged" + ? normalized + : undefined; +} + +export function resolveExtensionHostTtsPrefsPath(config: ResolvedTtsConfig): string { + if (config.prefsPath?.trim()) { + return resolveUserPath(config.prefsPath.trim()); + } + const envPath = process.env.OPENCLAW_TTS_PREFS?.trim(); + if (envPath) { + return resolveUserPath(envPath); + } + return path.join(CONFIG_DIR, "settings", "tts.json"); +} + +function resolveExtensionHostTtsAutoModeFromPrefs(prefs: TtsUserPrefs): TtsAutoMode | undefined { + const auto = normalizeExtensionHostTtsAutoMode(prefs.tts?.auto); + if (auto) { + return auto; + } + if (typeof prefs.tts?.enabled === "boolean") { + return prefs.tts.enabled ? "always" : "off"; + } + return undefined; +} + +export function resolveExtensionHostTtsAutoMode(params: { + config: ResolvedTtsConfig; + prefsPath: string; + sessionAuto?: string; +}): TtsAutoMode { + const sessionAuto = normalizeExtensionHostTtsAutoMode(params.sessionAuto); + if (sessionAuto) { + return sessionAuto; + } + const prefsAuto = resolveExtensionHostTtsAutoModeFromPrefs( + readExtensionHostTtsPrefs(params.prefsPath), + ); + if (prefsAuto) { + return prefsAuto; + } + return params.config.auto; +} + +export function isExtensionHostTtsEnabled( + config: ResolvedTtsConfig, + prefsPath: string, + sessionAuto?: string, +): boolean { + return resolveExtensionHostTtsAutoMode({ config, prefsPath, sessionAuto }) !== "off"; +} + +export function setExtensionHostTtsAutoMode(prefsPath: string, mode: TtsAutoMode): void { + updateExtensionHostTtsPrefs(prefsPath, (prefs) => { + const next = { ...prefs.tts }; + delete next.enabled; + next.auto = mode; + prefs.tts = next; + }); +} + +export function setExtensionHostTtsEnabled(prefsPath: string, enabled: boolean): void { + setExtensionHostTtsAutoMode(prefsPath, enabled ? "always" : "off"); +} + +export function setExtensionHostTtsProvider(prefsPath: string, provider: TtsProvider): void { + updateExtensionHostTtsPrefs(prefsPath, (prefs) => { + prefs.tts = { ...prefs.tts, provider }; + }); +} + +export function getExtensionHostTtsMaxLength(prefsPath: string): number { + const prefs = readExtensionHostTtsPrefs(prefsPath); + return prefs.tts?.maxLength ?? DEFAULT_EXTENSION_HOST_TTS_MAX_LENGTH; +} + +export function setExtensionHostTtsMaxLength(prefsPath: string, maxLength: number): void { + updateExtensionHostTtsPrefs(prefsPath, (prefs) => { + prefs.tts = { ...prefs.tts, maxLength }; + }); +} + +export function isExtensionHostTtsSummarizationEnabled(prefsPath: string): boolean { + const prefs = readExtensionHostTtsPrefs(prefsPath); + return prefs.tts?.summarize ?? DEFAULT_EXTENSION_HOST_TTS_SUMMARIZE; +} + +export function setExtensionHostTtsSummarizationEnabled(prefsPath: string, enabled: boolean): void { + updateExtensionHostTtsPrefs(prefsPath, (prefs) => { + prefs.tts = { ...prefs.tts, summarize: enabled }; + }); +} diff --git a/src/extension-host/contributions/tts-runtime-execution.ts b/src/extension-host/contributions/tts-runtime-execution.ts new file mode 100644 index 00000000000..f45718808d0 --- /dev/null +++ b/src/extension-host/contributions/tts-runtime-execution.ts @@ -0,0 +1,313 @@ +import { mkdirSync, mkdtempSync, rmSync, writeFileSync } from "node:fs"; +import path from "node:path"; +import type { TtsProvider } from "../../config/types.tts.js"; +import { logVerbose } from "../../globals.js"; +import { resolvePreferredOpenClawTmpDir } from "../../infra/tmp-openclaw-dir.js"; +import { isVoiceCompatibleAudio } from "../../media/audio.js"; +import { + edgeTTS, + elevenLabsTTS, + inferEdgeExtension, + openaiTTS, + scheduleCleanup, +} from "../../tts/tts-core.js"; +import type { TtsDirectiveOverrides, TtsResult, TtsTelephonyResult } from "../../tts/tts.js"; +import type { ResolvedTtsConfig } from "./tts-config.js"; +import { + resolveExtensionHostTtsApiKey, + supportsExtensionHostTtsTelephony, +} from "./tts-runtime-registry.js"; + +const TELEGRAM_OUTPUT: ExtensionHostTtsOutputFormat = { + openai: "opus" as const, + // ElevenLabs output formats use codec_sample_rate_bitrate naming. + // Opus @ 48kHz/64kbps is a good voice-note tradeoff for Telegram. + elevenlabs: "opus_48000_64", + extension: ".opus", + voiceCompatible: true, +}; + +const DEFAULT_OUTPUT: ExtensionHostTtsOutputFormat = { + openai: "mp3" as const, + elevenlabs: "mp3_44100_128", + extension: ".mp3", + voiceCompatible: false, +}; + +const TELEPHONY_OUTPUT = { + openai: { format: "pcm" as const, sampleRate: 24000 }, + elevenlabs: { format: "pcm_22050", sampleRate: 22050 }, +}; + +const DEFAULT_EDGE_OUTPUT_FORMAT = "audio-24khz-48kbitrate-mono-mp3"; + +const VOICE_BUBBLE_CHANNELS = new Set(["telegram", "feishu", "whatsapp"]); + +type ExtensionHostTtsOutputFormat = { + openai: "opus" | "mp3"; + elevenlabs: string; + extension: ".opus" | ".mp3"; + voiceCompatible: boolean; +}; + +export function isExtensionHostTtsVoiceBubbleChannel(channel?: string | null): boolean { + const channelId = channel?.trim().toLowerCase(); + return typeof channelId === "string" && VOICE_BUBBLE_CHANNELS.has(channelId); +} + +export function resolveExtensionHostTtsOutputFormat( + channel?: string | null, +): ExtensionHostTtsOutputFormat { + if (isExtensionHostTtsVoiceBubbleChannel(channel)) { + return TELEGRAM_OUTPUT; + } + return DEFAULT_OUTPUT; +} + +export function resolveExtensionHostEdgeOutputFormat(config: ResolvedTtsConfig): string { + return config.edge.outputFormat || DEFAULT_EDGE_OUTPUT_FORMAT; +} + +export function formatExtensionHostTtsProviderError(provider: TtsProvider, err: unknown): string { + const error = err instanceof Error ? err : new Error(String(err)); + if (error.name === "AbortError") { + return `${provider}: request timed out`; + } + return `${provider}: ${error.message}`; +} + +export function buildExtensionHostTtsFailureResult(errors: string[]): { + success: false; + error: string; +} { + return { + success: false, + error: `TTS conversion failed: ${errors.join("; ") || "no providers available"}`, + }; +} + +export async function executeExtensionHostTextToSpeech(params: { + text: string; + config: ResolvedTtsConfig; + providers: TtsProvider[]; + channel?: string; + overrides?: TtsDirectiveOverrides; +}): Promise { + const { config, providers } = params; + const output = resolveExtensionHostTtsOutputFormat(params.channel); + const errors: string[] = []; + + for (const provider of providers) { + const providerStart = Date.now(); + try { + if (provider === "edge") { + if (!config.edge.enabled) { + errors.push("edge: disabled"); + continue; + } + + const tempRoot = resolvePreferredOpenClawTmpDir(); + mkdirSync(tempRoot, { recursive: true, mode: 0o700 }); + const tempDir = mkdtempSync(path.join(tempRoot, "tts-")); + let edgeOutputFormat = resolveExtensionHostEdgeOutputFormat(config); + const fallbackEdgeOutputFormat = + edgeOutputFormat !== DEFAULT_EDGE_OUTPUT_FORMAT ? DEFAULT_EDGE_OUTPUT_FORMAT : undefined; + + const attemptEdgeTts = async (outputFormat: string) => { + const extension = inferEdgeExtension(outputFormat); + const audioPath = path.join(tempDir, `voice-${Date.now()}${extension}`); + await edgeTTS({ + text: params.text, + outputPath: audioPath, + config: { + ...config.edge, + outputFormat, + }, + timeoutMs: config.timeoutMs, + }); + return { audioPath, outputFormat }; + }; + + let edgeResult: { audioPath: string; outputFormat: string }; + try { + edgeResult = await attemptEdgeTts(edgeOutputFormat); + } catch (err) { + if (fallbackEdgeOutputFormat && fallbackEdgeOutputFormat !== edgeOutputFormat) { + logVerbose( + `TTS: Edge output ${edgeOutputFormat} failed; retrying with ${fallbackEdgeOutputFormat}.`, + ); + edgeOutputFormat = fallbackEdgeOutputFormat; + try { + edgeResult = await attemptEdgeTts(edgeOutputFormat); + } catch (fallbackErr) { + try { + rmSync(tempDir, { recursive: true, force: true }); + } catch {} + throw fallbackErr; + } + } else { + try { + rmSync(tempDir, { recursive: true, force: true }); + } catch {} + throw err; + } + } + + scheduleCleanup(tempDir); + const voiceCompatible = isVoiceCompatibleAudio({ fileName: edgeResult.audioPath }); + + return { + success: true, + audioPath: edgeResult.audioPath, + latencyMs: Date.now() - providerStart, + provider, + outputFormat: edgeResult.outputFormat, + voiceCompatible, + }; + } + + const apiKey = resolveExtensionHostTtsApiKey(config, provider); + if (!apiKey) { + errors.push(`${provider}: no API key`); + continue; + } + + let audioBuffer: Buffer; + if (provider === "elevenlabs") { + const voiceIdOverride = params.overrides?.elevenlabs?.voiceId; + const modelIdOverride = params.overrides?.elevenlabs?.modelId; + const voiceSettings = { + ...config.elevenlabs.voiceSettings, + ...params.overrides?.elevenlabs?.voiceSettings, + }; + const seedOverride = params.overrides?.elevenlabs?.seed; + const normalizationOverride = params.overrides?.elevenlabs?.applyTextNormalization; + const languageOverride = params.overrides?.elevenlabs?.languageCode; + audioBuffer = await elevenLabsTTS({ + text: params.text, + apiKey, + baseUrl: config.elevenlabs.baseUrl, + voiceId: voiceIdOverride ?? config.elevenlabs.voiceId, + modelId: modelIdOverride ?? config.elevenlabs.modelId, + outputFormat: output.elevenlabs, + seed: seedOverride ?? config.elevenlabs.seed, + applyTextNormalization: normalizationOverride ?? config.elevenlabs.applyTextNormalization, + languageCode: languageOverride ?? config.elevenlabs.languageCode, + voiceSettings, + timeoutMs: config.timeoutMs, + }); + } else { + const openaiModelOverride = params.overrides?.openai?.model; + const openaiVoiceOverride = params.overrides?.openai?.voice; + audioBuffer = await openaiTTS({ + text: params.text, + apiKey, + baseUrl: config.openai.baseUrl, + model: openaiModelOverride ?? config.openai.model, + voice: openaiVoiceOverride ?? config.openai.voice, + speed: config.openai.speed, + instructions: config.openai.instructions, + responseFormat: output.openai, + timeoutMs: config.timeoutMs, + }); + } + + const tempRoot = resolvePreferredOpenClawTmpDir(); + mkdirSync(tempRoot, { recursive: true, mode: 0o700 }); + const tempDir = mkdtempSync(path.join(tempRoot, "tts-")); + const audioPath = path.join(tempDir, `voice-${Date.now()}${output.extension}`); + writeFileSync(audioPath, audioBuffer); + scheduleCleanup(tempDir); + + return { + success: true, + audioPath, + latencyMs: Date.now() - providerStart, + provider, + outputFormat: provider === "openai" ? output.openai : output.elevenlabs, + voiceCompatible: output.voiceCompatible, + }; + } catch (err) { + errors.push(formatExtensionHostTtsProviderError(provider, err)); + } + } + + return buildExtensionHostTtsFailureResult(errors); +} + +export async function executeExtensionHostTextToSpeechTelephony(params: { + text: string; + config: ResolvedTtsConfig; + providers: TtsProvider[]; +}): Promise { + const { config, providers } = params; + const errors: string[] = []; + + for (const provider of providers) { + const providerStart = Date.now(); + try { + if (!supportsExtensionHostTtsTelephony(provider)) { + errors.push("edge: unsupported for telephony"); + continue; + } + + const apiKey = resolveExtensionHostTtsApiKey(config, provider); + if (!apiKey) { + errors.push(`${provider}: no API key`); + continue; + } + + if (provider === "elevenlabs") { + const output = TELEPHONY_OUTPUT.elevenlabs; + const audioBuffer = await elevenLabsTTS({ + text: params.text, + apiKey, + baseUrl: config.elevenlabs.baseUrl, + voiceId: config.elevenlabs.voiceId, + modelId: config.elevenlabs.modelId, + outputFormat: output.format, + seed: config.elevenlabs.seed, + applyTextNormalization: config.elevenlabs.applyTextNormalization, + languageCode: config.elevenlabs.languageCode, + voiceSettings: config.elevenlabs.voiceSettings, + timeoutMs: config.timeoutMs, + }); + + return { + success: true, + audioBuffer, + latencyMs: Date.now() - providerStart, + provider, + outputFormat: output.format, + sampleRate: output.sampleRate, + }; + } + + const output = TELEPHONY_OUTPUT.openai; + const audioBuffer = await openaiTTS({ + text: params.text, + apiKey, + baseUrl: config.openai.baseUrl, + model: config.openai.model, + voice: config.openai.voice, + speed: config.openai.speed, + instructions: config.openai.instructions, + responseFormat: output.format, + timeoutMs: config.timeoutMs, + }); + + return { + success: true, + audioBuffer, + latencyMs: Date.now() - providerStart, + provider, + outputFormat: output.format, + sampleRate: output.sampleRate, + }; + } catch (err) { + errors.push(formatExtensionHostTtsProviderError(provider, err)); + } + } + + return buildExtensionHostTtsFailureResult(errors); +} diff --git a/src/extension-host/contributions/tts-runtime-registry.test.ts b/src/extension-host/contributions/tts-runtime-registry.test.ts new file mode 100644 index 00000000000..3445227ef07 --- /dev/null +++ b/src/extension-host/contributions/tts-runtime-registry.test.ts @@ -0,0 +1,52 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { + EXTENSION_HOST_TTS_PROVIDER_IDS, + isExtensionHostTtsProviderConfigured, + resolveExtensionHostTtsApiKey, + resolveExtensionHostTtsProviderOrder, + supportsExtensionHostTtsTelephony, +} from "./tts-runtime-registry.js"; + +describe("extension host TTS runtime registry", () => { + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it("keeps the built-in provider order stable", () => { + expect(EXTENSION_HOST_TTS_PROVIDER_IDS).toEqual(["openai", "elevenlabs", "edge"]); + expect(resolveExtensionHostTtsProviderOrder("edge")).toEqual(["edge", "openai", "elevenlabs"]); + }); + + it("resolves API keys for remote providers", () => { + const config = { + openai: { apiKey: "openai-key" }, + elevenlabs: { apiKey: "xi-key" }, + edge: { enabled: false }, + } as const; + + expect(resolveExtensionHostTtsApiKey(config, "openai")).toBe("openai-key"); + expect(resolveExtensionHostTtsApiKey(config, "elevenlabs")).toBe("xi-key"); + expect(resolveExtensionHostTtsApiKey(config, "edge")).toBeUndefined(); + }); + + it("checks provider configuration through the host-owned definitions", () => { + vi.stubEnv("ELEVENLABS_API_KEY", ""); + vi.stubEnv("XI_API_KEY", ""); + + const config = { + openai: { apiKey: "openai-key" }, + elevenlabs: { apiKey: "" }, + edge: { enabled: true }, + } as const; + + expect(isExtensionHostTtsProviderConfigured(config, "openai")).toBe(true); + expect(isExtensionHostTtsProviderConfigured(config, "elevenlabs")).toBe(false); + expect(isExtensionHostTtsProviderConfigured(config, "edge")).toBe(true); + }); + + it("tracks telephony support per provider", () => { + expect(supportsExtensionHostTtsTelephony("openai")).toBe(true); + expect(supportsExtensionHostTtsTelephony("elevenlabs")).toBe(true); + expect(supportsExtensionHostTtsTelephony("edge")).toBe(false); + }); +}); diff --git a/src/extension-host/contributions/tts-runtime-registry.ts b/src/extension-host/contributions/tts-runtime-registry.ts new file mode 100644 index 00000000000..82a521b1426 --- /dev/null +++ b/src/extension-host/contributions/tts-runtime-registry.ts @@ -0,0 +1,45 @@ +import type { TtsProvider } from "../../config/types.tts.js"; +import { resolveExtensionHostTtsRuntimeBackendOrder } from "../static/runtime-backend-catalog.js"; +import { + EXTENSION_HOST_TTS_RUNTIME_BACKEND_IDS, + getExtensionHostTtsRuntimeBackend, + listExtensionHostTtsRuntimeBackends, + type ExtensionHostTtsRuntimeBackend, +} from "../static/tts-runtime-backends.js"; +import type { ResolvedTtsConfig } from "./tts-config.js"; + +export type ExtensionHostTtsRuntimeProvider = ExtensionHostTtsRuntimeBackend; + +export const EXTENSION_HOST_TTS_PROVIDER_IDS = EXTENSION_HOST_TTS_RUNTIME_BACKEND_IDS; + +export function listExtensionHostTtsRuntimeProviders(): readonly ExtensionHostTtsRuntimeProvider[] { + return listExtensionHostTtsRuntimeBackends(); +} + +export function getExtensionHostTtsRuntimeProvider( + id: TtsProvider, +): ExtensionHostTtsRuntimeProvider | undefined { + return getExtensionHostTtsRuntimeBackend(id); +} + +export function resolveExtensionHostTtsApiKey( + config: ResolvedTtsConfig, + provider: TtsProvider, +): string | undefined { + return getExtensionHostTtsRuntimeProvider(provider)?.resolveApiKey(config); +} + +export function isExtensionHostTtsProviderConfigured( + config: ResolvedTtsConfig, + provider: TtsProvider, +): boolean { + return getExtensionHostTtsRuntimeProvider(provider)?.isConfigured(config) ?? false; +} + +export function resolveExtensionHostTtsProviderOrder(primary: TtsProvider): TtsProvider[] { + return [...resolveExtensionHostTtsRuntimeBackendOrder(primary)]; +} + +export function supportsExtensionHostTtsTelephony(provider: TtsProvider): boolean { + return getExtensionHostTtsRuntimeProvider(provider)?.supportsTelephony ?? false; +} diff --git a/src/extension-host/contributions/tts-runtime-setup.test.ts b/src/extension-host/contributions/tts-runtime-setup.test.ts new file mode 100644 index 00000000000..24669fcce1d --- /dev/null +++ b/src/extension-host/contributions/tts-runtime-setup.test.ts @@ -0,0 +1,168 @@ +import { mkdtempSync, rmSync, writeFileSync } from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { withEnv } from "../../test-utils/env.js"; +import type { ResolvedTtsConfig } from "./tts-config.js"; +import { + resolveExtensionHostTtsProvider, + resolveExtensionHostTtsRequestSetup, +} from "./tts-runtime-setup.js"; + +vi.mock("./runtime-backend-catalog.js", () => ({ + resolveExtensionHostTtsRuntimeBackendOrder: vi.fn((provider: string) => + [provider, "openai", "elevenlabs", "edge"].filter( + (candidate, index, items) => items.indexOf(candidate) === index, + ), + ), + listExtensionHostTtsRuntimeBackendCatalogEntries: vi.fn(() => [ + { + id: "capability.runtime-backend:tts:openai", + family: "capability.runtime-backend", + subsystemId: "tts", + backendId: "openai", + source: "builtin", + defaultRank: 0, + selectorKeys: ["openai"], + capabilities: ["tts.synthesis", "tts.telephony"], + }, + { + id: "capability.runtime-backend:tts:elevenlabs", + family: "capability.runtime-backend", + subsystemId: "tts", + backendId: "elevenlabs", + source: "builtin", + defaultRank: 1, + selectorKeys: ["elevenlabs"], + capabilities: ["tts.synthesis", "tts.telephony"], + }, + { + id: "capability.runtime-backend:tts:edge", + family: "capability.runtime-backend", + subsystemId: "tts", + backendId: "edge", + source: "builtin", + defaultRank: 2, + selectorKeys: ["edge"], + capabilities: ["tts.synthesis"], + }, + ]), +})); + +const tempDirs: string[] = []; + +function createPrefsPath(contents: object): string { + const tempDir = mkdtempSync(path.join(os.tmpdir(), "openclaw-tts-setup-")); + tempDirs.push(tempDir); + const prefsPath = path.join(tempDir, "tts.json"); + writeFileSync(prefsPath, JSON.stringify(contents), "utf8"); + return prefsPath; +} + +function createResolvedConfig(overrides?: Partial): ResolvedTtsConfig { + return { + auto: "off", + mode: "final", + provider: "edge", + providerSource: "default", + modelOverrides: { + enabled: true, + allowText: true, + allowProvider: false, + allowVoice: true, + allowModelId: true, + allowVoiceSettings: true, + allowNormalization: true, + allowSeed: true, + }, + elevenlabs: { + baseUrl: "https://api.elevenlabs.io", + voiceId: "voice-id", + modelId: "eleven_multilingual_v2", + voiceSettings: { + stability: 0.5, + similarityBoost: 0.75, + style: 0, + useSpeakerBoost: true, + speed: 1, + }, + }, + openai: { + baseUrl: "https://api.openai.com/v1", + model: "gpt-4o-mini-tts", + voice: "alloy", + }, + edge: { + enabled: true, + voice: "en-US-MichelleNeural", + lang: "en-US", + outputFormat: "audio-24khz-48kbitrate-mono-mp3", + outputFormatConfigured: false, + saveSubtitles: false, + }, + maxTextLength: 4096, + timeoutMs: 30_000, + ...overrides, + }; +} + +afterEach(() => { + for (const tempDir of tempDirs.splice(0)) { + rmSync(tempDir, { recursive: true, force: true }); + } +}); + +describe("tts-runtime-setup", () => { + it("prefers the stored provider over config and environment", () => { + const prefsPath = createPrefsPath({ tts: { provider: "elevenlabs" } }); + const config = createResolvedConfig({ + provider: "openai", + providerSource: "config", + openai: { + baseUrl: "https://api.openai.com/v1", + model: "gpt-4o-mini-tts", + voice: "alloy", + apiKey: "config-openai-key", + }, + }); + + withEnv({ OPENAI_API_KEY: "env-openai-key", ELEVENLABS_API_KEY: undefined }, () => { + expect(resolveExtensionHostTtsProvider(config, prefsPath)).toBe("elevenlabs"); + }); + }); + + it("returns a validation error when text exceeds the configured hard limit", () => { + const config = createResolvedConfig({ maxTextLength: 5 }); + const prefsPath = createPrefsPath({}); + + expect( + resolveExtensionHostTtsRequestSetup({ + text: "too-long", + config, + prefsPath, + }), + ).toEqual({ + error: "Text too long (8 chars, max 5)", + }); + }); + + it("uses the override provider to build the host-owned configured fallback order", () => { + const config = createResolvedConfig({ + provider: "edge", + providerSource: "config", + }); + const prefsPath = createPrefsPath({}); + + expect( + resolveExtensionHostTtsRequestSetup({ + text: "hello world", + config, + prefsPath, + providerOverride: "elevenlabs", + }), + ).toEqual({ + config, + providers: ["elevenlabs", "edge"], + }); + }); +}); diff --git a/src/extension-host/contributions/tts-runtime-setup.ts b/src/extension-host/contributions/tts-runtime-setup.ts new file mode 100644 index 00000000000..4cb951f5813 --- /dev/null +++ b/src/extension-host/contributions/tts-runtime-setup.ts @@ -0,0 +1,73 @@ +import { existsSync, readFileSync } from "node:fs"; +import type { TtsProvider } from "../../config/types.tts.js"; +import { + resolveExtensionHostDefaultTtsProvider, + resolveExtensionHostTtsFallbackProviders, +} from "../policy/tts-runtime-policy.js"; +import type { ResolvedTtsConfig } from "./tts-config.js"; + +type TtsUserPrefs = { + tts?: { + provider?: TtsProvider; + }; +}; + +function readExtensionHostTtsPrefs(prefsPath: string): TtsUserPrefs { + try { + if (!existsSync(prefsPath)) { + return {}; + } + const raw = readFileSync(prefsPath, "utf8"); + const parsed = JSON.parse(raw) as TtsUserPrefs; + return parsed && typeof parsed === "object" ? parsed : {}; + } catch { + return {}; + } +} + +export function resolveExtensionHostTtsProvider( + config: ResolvedTtsConfig, + prefsPath: string, +): TtsProvider { + const prefs = readExtensionHostTtsPrefs(prefsPath); + if (prefs.tts?.provider) { + return prefs.tts.provider; + } + if (config.providerSource === "config") { + return config.provider; + } + + return resolveExtensionHostDefaultTtsProvider(config); +} + +export function resolveExtensionHostTtsRequestSetup(params: { + text: string; + config: ResolvedTtsConfig; + prefsPath: string; + providerOverride?: TtsProvider; +}): + | { + config: ResolvedTtsConfig; + providers: TtsProvider[]; + } + | { + error: string; + } { + if (params.text.length > params.config.maxTextLength) { + return { + error: `Text too long (${params.text.length} chars, max ${params.config.maxTextLength})`, + }; + } + + const provider = + params.providerOverride ?? resolveExtensionHostTtsProvider(params.config, params.prefsPath); + return { + config: params.config, + providers: [ + ...resolveExtensionHostTtsFallbackProviders({ + config: params.config, + preferredProvider: provider, + }), + ], + }; +} diff --git a/src/extension-host/contributions/tts-status.test.ts b/src/extension-host/contributions/tts-status.test.ts new file mode 100644 index 00000000000..16644c13e6a --- /dev/null +++ b/src/extension-host/contributions/tts-status.test.ts @@ -0,0 +1,157 @@ +import { describe, expect, it, vi } from "vitest"; +import { + formatExtensionHostTtsStatusText, + resolveExtensionHostTtsStatusSnapshot, + setExtensionHostLastTtsAttempt, +} from "./tts-status.js"; + +vi.mock("./runtime-backend-catalog.js", () => ({ + resolveExtensionHostTtsRuntimeBackendOrder: vi.fn((provider: string) => + [provider, "openai", "elevenlabs", "edge"].filter( + (candidate, index, items) => items.indexOf(candidate) === index, + ), + ), + listExtensionHostTtsRuntimeBackendCatalogEntries: vi.fn(() => [ + { + id: "capability.runtime-backend:tts:openai", + family: "capability.runtime-backend", + subsystemId: "tts", + backendId: "openai", + source: "builtin", + defaultRank: 0, + selectorKeys: ["openai"], + capabilities: ["tts.synthesis", "tts.telephony"], + }, + { + id: "capability.runtime-backend:tts:elevenlabs", + family: "capability.runtime-backend", + subsystemId: "tts", + backendId: "elevenlabs", + source: "builtin", + defaultRank: 1, + selectorKeys: ["elevenlabs"], + capabilities: ["tts.synthesis", "tts.telephony"], + }, + { + id: "capability.runtime-backend:tts:edge", + family: "capability.runtime-backend", + subsystemId: "tts", + backendId: "edge", + source: "builtin", + defaultRank: 2, + selectorKeys: ["edge"], + capabilities: ["tts.synthesis"], + }, + ]), +})); + +describe("tts-status", () => { + it("builds a status snapshot from host-owned preferences and runtime state", () => { + const config = { + auto: "always", + provider: "openai", + providerSource: "config", + prefsPath: "/tmp/tts-status.json", + modelOverrides: { + enabled: true, + allowText: true, + allowProvider: false, + allowVoice: true, + allowModelId: true, + allowVoiceSettings: true, + allowNormalization: true, + allowSeed: true, + }, + elevenlabs: { + apiKey: undefined, + baseUrl: "https://api.elevenlabs.io", + voiceId: "voice-id", + modelId: "eleven_multilingual_v2", + voiceSettings: { + stability: 0.5, + similarityBoost: 0.75, + style: 0, + useSpeakerBoost: true, + speed: 1, + }, + }, + openai: { + apiKey: "openai-key", + baseUrl: "https://api.openai.com/v1", + model: "gpt-4o-mini-tts", + voice: "alloy", + }, + edge: { + enabled: true, + voice: "en-US-MichelleNeural", + lang: "en-US", + outputFormat: "audio-24khz-48kbitrate-mono-mp3", + outputFormatConfigured: false, + saveSubtitles: false, + }, + mode: "final", + maxTextLength: 4096, + timeoutMs: 30000, + }; + + const status = resolveExtensionHostTtsStatusSnapshot({ + config, + prefsPath: "/tmp/tts-status.json", + }); + + expect(status).toMatchObject({ + enabled: true, + auto: "always", + provider: "openai", + providerConfigured: true, + hasOpenAIKey: true, + edgeEnabled: true, + maxLength: 1500, + summarize: true, + }); + expect(status.fallbackProviders.length).toBeGreaterThan(0); + expect(status.fallbackProviders).toContain(status.fallbackProvider); + }); + + it("formats the last attempt details in the host-owned status text", () => { + setExtensionHostLastTtsAttempt({ + timestamp: 1000, + success: false, + textLength: 42, + summarized: true, + error: "provider failed", + }); + + const text = formatExtensionHostTtsStatusText( + { + enabled: true, + auto: "always", + provider: "openai", + providerConfigured: true, + fallbackProvider: "edge", + fallbackProviders: ["edge"], + prefsPath: "/tmp/tts-status.json", + maxLength: 1500, + summarize: true, + hasOpenAIKey: true, + hasElevenLabsKey: false, + edgeEnabled: true, + lastAttempt: { + timestamp: 1000, + success: false, + textLength: 42, + summarized: true, + error: "provider failed", + }, + }, + 6000, + ); + + expect(text).toContain("📊 TTS status"); + expect(text).toContain("Last attempt (5s ago): ❌"); + expect(text).toContain("Text: 42 chars (summarized)"); + expect(text).toContain("Error: provider failed"); + + setExtensionHostLastTtsAttempt(undefined); + }); +}); diff --git a/src/extension-host/contributions/tts-status.ts b/src/extension-host/contributions/tts-status.ts new file mode 100644 index 00000000000..1617c58fd6b --- /dev/null +++ b/src/extension-host/contributions/tts-status.ts @@ -0,0 +1,109 @@ +import type { TtsProvider } from "../../config/types.tts.js"; +import { resolveExtensionHostTtsFallbackProviders } from "../policy/tts-runtime-policy.js"; +import type { ResolvedTtsConfig } from "./tts-config.js"; +import { + getExtensionHostTtsMaxLength, + isExtensionHostTtsEnabled, + isExtensionHostTtsSummarizationEnabled, + resolveExtensionHostTtsAutoMode, +} from "./tts-preferences.js"; +import { + isExtensionHostTtsProviderConfigured, + resolveExtensionHostTtsApiKey, +} from "./tts-runtime-registry.js"; +import { resolveExtensionHostTtsProvider } from "./tts-runtime-setup.js"; + +export type ExtensionHostTtsStatusEntry = { + timestamp: number; + success: boolean; + textLength: number; + summarized: boolean; + provider?: string; + latencyMs?: number; + error?: string; +}; + +export type ExtensionHostTtsStatusSnapshot = { + enabled: boolean; + auto: ReturnType; + provider: TtsProvider; + providerConfigured: boolean; + fallbackProvider: TtsProvider | null; + fallbackProviders: TtsProvider[]; + prefsPath: string; + maxLength: number; + summarize: boolean; + hasOpenAIKey: boolean; + hasElevenLabsKey: boolean; + edgeEnabled: boolean; + lastAttempt?: ExtensionHostTtsStatusEntry; +}; + +let lastExtensionHostTtsAttempt: ExtensionHostTtsStatusEntry | undefined; + +export function getExtensionHostLastTtsAttempt(): ExtensionHostTtsStatusEntry | undefined { + return lastExtensionHostTtsAttempt; +} + +export function setExtensionHostLastTtsAttempt( + entry: ExtensionHostTtsStatusEntry | undefined, +): void { + lastExtensionHostTtsAttempt = entry; +} + +export function resolveExtensionHostTtsStatusSnapshot(params: { + config: ResolvedTtsConfig; + prefsPath: string; +}): ExtensionHostTtsStatusSnapshot { + const { config, prefsPath } = params; + const provider = resolveExtensionHostTtsProvider(config, prefsPath); + const fallbackProviders = resolveExtensionHostTtsFallbackProviders({ + config, + preferredProvider: provider, + }).slice(1); + return { + enabled: isExtensionHostTtsEnabled(config, prefsPath), + auto: resolveExtensionHostTtsAutoMode({ config, prefsPath }), + provider, + providerConfigured: isExtensionHostTtsProviderConfigured(config, provider), + fallbackProvider: fallbackProviders[0] ?? null, + fallbackProviders, + prefsPath, + maxLength: getExtensionHostTtsMaxLength(prefsPath), + summarize: isExtensionHostTtsSummarizationEnabled(prefsPath), + hasOpenAIKey: Boolean(resolveExtensionHostTtsApiKey(config, "openai")), + hasElevenLabsKey: Boolean(resolveExtensionHostTtsApiKey(config, "elevenlabs")), + edgeEnabled: isExtensionHostTtsProviderConfigured(config, "edge"), + lastAttempt: getExtensionHostLastTtsAttempt(), + }; +} + +export function formatExtensionHostTtsStatusText( + status: ExtensionHostTtsStatusSnapshot, + now = Date.now(), +): string { + const lines = [ + "📊 TTS status", + `State: ${status.enabled ? "✅ enabled" : "❌ disabled"}`, + `Provider: ${status.provider} (${status.providerConfigured ? "✅ configured" : "❌ not configured"})`, + `Text limit: ${status.maxLength} chars`, + `Auto-summary: ${status.summarize ? "on" : "off"}`, + ]; + if (!status.lastAttempt) { + return lines.join("\n"); + } + + const timeAgo = Math.round((now - status.lastAttempt.timestamp) / 1000); + lines.push(""); + lines.push(`Last attempt (${timeAgo}s ago): ${status.lastAttempt.success ? "✅" : "❌"}`); + lines.push( + `Text: ${status.lastAttempt.textLength} chars${status.lastAttempt.summarized ? " (summarized)" : ""}`, + ); + if (status.lastAttempt.success) { + lines.push(`Provider: ${status.lastAttempt.provider ?? "unknown"}`); + lines.push(`Latency: ${status.lastAttempt.latencyMs ?? 0}ms`); + } else if (status.lastAttempt.error) { + lines.push(`Error: ${status.lastAttempt.error}`); + } + return lines.join("\n"); +} diff --git a/src/extension-host/manifests/manifest-registry.ts b/src/extension-host/manifests/manifest-registry.ts new file mode 100644 index 00000000000..2e242fdf5c3 --- /dev/null +++ b/src/extension-host/manifests/manifest-registry.ts @@ -0,0 +1,52 @@ +import type { PluginCandidate } from "../../plugins/discovery.js"; +import { + loadPackageManifest, + type PackageManifest, + type PluginManifest, +} from "../../plugins/manifest.js"; +import { resolveLegacyExtensionDescriptor, type ResolvedExtension } from "./schema.js"; + +export type ResolvedExtensionRecord = { + extension: ResolvedExtension; + manifestPath: string; + schemaCacheKey?: string; +}; + +export function buildResolvedExtensionRecord(params: { + manifest: PluginManifest; + candidate: PluginCandidate; + manifestPath: string; + schemaCacheKey?: string; + configSchema?: Record; +}): ResolvedExtensionRecord { + const packageDir = params.candidate.packageDir ?? params.candidate.rootDir; + const packageManifest = + params.candidate.packageManifest || + params.candidate.packageName || + params.candidate.packageVersion + ? ({ + openclaw: params.candidate.packageManifest, + name: params.candidate.packageName, + version: params.candidate.packageVersion, + description: params.candidate.packageDescription, + } as PackageManifest) + : (loadPackageManifest(packageDir, params.candidate.origin !== "bundled") ?? undefined); + + const extension = resolveLegacyExtensionDescriptor({ + manifest: { + ...params.manifest, + configSchema: params.configSchema ?? params.manifest.configSchema, + }, + packageManifest, + origin: params.candidate.origin, + rootDir: params.candidate.rootDir, + source: params.candidate.source, + workspaceDir: params.candidate.workspaceDir, + }); + + return { + extension, + manifestPath: params.manifestPath, + schemaCacheKey: params.schemaCacheKey, + }; +} diff --git a/src/extension-host/manifests/resolved-registry.ts b/src/extension-host/manifests/resolved-registry.ts new file mode 100644 index 00000000000..ccaaf92a557 --- /dev/null +++ b/src/extension-host/manifests/resolved-registry.ts @@ -0,0 +1,70 @@ +import type { OpenClawConfig } from "../../config/config.js"; +import { + loadPluginManifestRegistry, + type PluginManifestRegistry, +} from "../../plugins/manifest-registry.js"; +import type { PluginDiagnostic } from "../../plugins/types.js"; +import type { ResolvedExtension } from "./schema.js"; + +export type ResolvedExtensionRegistryEntry = { + extension: ResolvedExtension; + manifestPath: string; + schemaCacheKey?: string; +}; + +export type ResolvedExtensionRegistry = { + extensions: ResolvedExtensionRegistryEntry[]; + diagnostics: PluginDiagnostic[]; +}; + +export function resolvedExtensionRegistryFromPluginManifestRegistry( + registry: PluginManifestRegistry, +): ResolvedExtensionRegistry { + return { + diagnostics: registry.diagnostics, + extensions: registry.plugins.map((plugin) => ({ + extension: + plugin.resolvedExtension ?? + ({ + id: plugin.id, + name: plugin.name, + description: plugin.description, + version: plugin.version, + kind: plugin.kind, + origin: plugin.origin, + rootDir: plugin.rootDir, + source: plugin.source, + workspaceDir: plugin.workspaceDir, + manifest: { + id: plugin.id, + name: plugin.name, + description: plugin.description, + version: plugin.version, + kind: plugin.kind, + channels: plugin.channels, + providers: plugin.providers, + skills: plugin.skills, + configSchema: plugin.configSchema ?? {}, + uiHints: plugin.configUiHints, + }, + staticMetadata: { + configSchema: plugin.configSchema ?? {}, + configUiHints: plugin.configUiHints, + package: { entries: [] }, + }, + contributions: [], + } satisfies ResolvedExtension), + manifestPath: plugin.manifestPath, + schemaCacheKey: plugin.schemaCacheKey, + })), + }; +} + +export function loadResolvedExtensionRegistry(params: { + config?: OpenClawConfig; + workspaceDir?: string; + cache?: boolean; + env?: NodeJS.ProcessEnv; +}): ResolvedExtensionRegistry { + return resolvedExtensionRegistryFromPluginManifestRegistry(loadPluginManifestRegistry(params)); +} diff --git a/src/extension-host/manifests/schema.test.ts b/src/extension-host/manifests/schema.test.ts new file mode 100644 index 00000000000..89cf344c7d8 --- /dev/null +++ b/src/extension-host/manifests/schema.test.ts @@ -0,0 +1,112 @@ +import { describe, expect, it } from "vitest"; +import { + DEFAULT_EXTENSION_ENTRY_CANDIDATES, + getExtensionPackageMetadata, + resolveExtensionEntryCandidates, + resolveLegacyExtensionDescriptor, +} from "./schema.js"; + +describe("extension host schema helpers", () => { + it("normalizes package metadata through the host boundary", () => { + const metadata = getExtensionPackageMetadata({ + openclaw: { + channel: { + id: "telegram", + label: "Telegram", + }, + install: { + npmSpec: "@openclaw/telegram", + defaultChoice: "npm", + }, + }, + }); + + expect(metadata).toEqual({ + channel: { + id: "telegram", + label: "Telegram", + }, + install: { + npmSpec: "@openclaw/telegram", + defaultChoice: "npm", + }, + }); + }); + + it("preserves current extension entry resolution semantics", () => { + expect(resolveExtensionEntryCandidates(undefined)).toEqual({ + status: "missing", + entries: [], + }); + expect(DEFAULT_EXTENSION_ENTRY_CANDIDATES).toContain("index.ts"); + expect( + resolveExtensionEntryCandidates({ + openclaw: { + extensions: ["./dist/index.js"], + }, + }), + ).toEqual({ + status: "ok", + entries: ["./dist/index.js"], + }); + }); + + it("builds a normalized legacy extension descriptor", () => { + const resolved = resolveLegacyExtensionDescriptor({ + manifest: { + id: "telegram", + name: "Telegram", + configSchema: { type: "object" }, + channels: ["telegram"], + providers: ["telegram-provider"], + }, + packageManifest: { + openclaw: { + channel: { + id: "telegram", + label: "Telegram", + }, + install: { + npmSpec: "@openclaw/telegram", + defaultChoice: "npm", + }, + }, + }, + origin: "bundled", + rootDir: "/tmp/telegram", + source: "/tmp/telegram/index.ts", + }); + + expect(resolved.id).toBe("telegram"); + expect(resolved.staticMetadata.package.entries).toEqual([ + "index.ts", + "index.js", + "index.mjs", + "index.cjs", + ]); + expect(resolved.contributions).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "telegram/config", + kind: "surface.config", + }), + expect.objectContaining({ + id: "telegram/channel/telegram", + kind: "adapter.runtime", + }), + expect.objectContaining({ + id: "telegram/provider/telegram-provider", + kind: "capability.provider-integration", + }), + expect.objectContaining({ + id: "telegram/channel-catalog", + kind: "surface.channel-catalog", + }), + expect.objectContaining({ + id: "telegram/install", + kind: "surface.install", + }), + ]), + ); + }); +}); diff --git a/src/extension-host/manifests/schema.ts b/src/extension-host/manifests/schema.ts new file mode 100644 index 00000000000..92986c04a63 --- /dev/null +++ b/src/extension-host/manifests/schema.ts @@ -0,0 +1,182 @@ +import { + DEFAULT_PLUGIN_ENTRY_CANDIDATES, + getPackageManifestMetadata, + resolvePackageExtensionEntries, + type OpenClawPackageManifest, + type PackageExtensionResolution, + type PackageManifest, + type PluginManifest, +} from "../../plugins/manifest.js"; +import type { PluginConfigUiHint, PluginKind, PluginOrigin } from "../../plugins/types.js"; + +export type { OpenClawPackageManifest, PackageExtensionResolution, PackageManifest }; + +export const DEFAULT_EXTENSION_ENTRY_CANDIDATES = DEFAULT_PLUGIN_ENTRY_CANDIDATES; + +export type ContributionPolicy = { + promptMutation?: "none" | "append-only" | "replace-allowed"; + routeEffect?: "observe-only" | "augment" | "veto" | "resolve"; + executionMode?: "sync-hot-path" | "sequential" | "parallel"; +}; + +export type ResolvedContributionKind = + | "adapter.runtime" + | "capability.context-engine" + | "capability.memory" + | "capability.provider-integration" + | "capability.runtime-backend" + | "surface.channel-catalog" + | "surface.config" + | "surface.install"; + +export type ResolvedContribution = { + id: string; + kind: ResolvedContributionKind; + source: "manifest" | "package"; + policy?: ContributionPolicy; + metadata?: Record; +}; + +export type ResolvedExtensionPackageMetadata = { + entries: string[]; + manifest?: OpenClawPackageManifest; +}; + +export type ResolvedExtensionStaticMetadata = { + configSchema: Record; + configUiHints?: Record; + package: ResolvedExtensionPackageMetadata; +}; + +export type ResolvedExtension = { + id: string; + name?: string; + description?: string; + version?: string; + kind?: PluginKind; + origin?: PluginOrigin; + rootDir?: string; + source?: string; + workspaceDir?: string; + manifest: PluginManifest; + staticMetadata: ResolvedExtensionStaticMetadata; + contributions: ResolvedContribution[]; +}; + +export function getExtensionPackageMetadata( + manifest: PackageManifest | undefined, +): OpenClawPackageManifest | undefined { + return getPackageManifestMetadata(manifest); +} + +export function resolveExtensionEntryCandidates( + manifest: PackageManifest | undefined, +): PackageExtensionResolution { + return resolvePackageExtensionEntries(manifest); +} + +function normalizeResolvedEntries( + packageManifest: PackageManifest | undefined, +): ResolvedExtensionPackageMetadata { + const manifest = getExtensionPackageMetadata(packageManifest); + const entries = resolveExtensionEntryCandidates(packageManifest); + return { + entries: + entries.status === "ok" ? entries.entries : Array.from(DEFAULT_EXTENSION_ENTRY_CANDIDATES), + manifest, + }; +} + +export function resolveLegacyExtensionDescriptor(params: { + manifest: PluginManifest; + packageManifest?: PackageManifest; + origin?: PluginOrigin; + rootDir?: string; + source?: string; + workspaceDir?: string; +}): ResolvedExtension { + const packageMetadata = normalizeResolvedEntries(params.packageManifest); + const contributions: ResolvedContribution[] = [ + { + id: `${params.manifest.id}/config`, + kind: "surface.config", + source: "manifest", + }, + ]; + + for (const channelId of params.manifest.channels ?? []) { + contributions.push({ + id: `${params.manifest.id}/channel/${channelId}`, + kind: "adapter.runtime", + source: "manifest", + metadata: { channelId }, + }); + } + + for (const providerId of params.manifest.providers ?? []) { + contributions.push({ + id: `${params.manifest.id}/provider/${providerId}`, + kind: "capability.provider-integration", + source: "manifest", + metadata: { providerId }, + }); + } + + if (params.manifest.kind === "memory") { + contributions.push({ + id: `${params.manifest.id}/memory`, + kind: "capability.memory", + source: "manifest", + }); + } + + if (params.manifest.kind === "context-engine") { + contributions.push({ + id: `${params.manifest.id}/context-engine`, + kind: "capability.context-engine", + source: "manifest", + }); + } + + if (packageMetadata.manifest?.channel) { + contributions.push({ + id: `${params.manifest.id}/channel-catalog`, + kind: "surface.channel-catalog", + source: "package", + metadata: { + channelId: packageMetadata.manifest.channel.id, + }, + }); + } + + if (packageMetadata.manifest?.install) { + contributions.push({ + id: `${params.manifest.id}/install`, + kind: "surface.install", + source: "package", + metadata: { + defaultChoice: packageMetadata.manifest.install.defaultChoice, + npmSpec: packageMetadata.manifest.install.npmSpec, + }, + }); + } + + return { + id: params.manifest.id, + name: params.manifest.name, + description: params.manifest.description, + version: params.manifest.version, + kind: params.manifest.kind, + origin: params.origin, + rootDir: params.rootDir, + source: params.source, + workspaceDir: params.workspaceDir, + manifest: params.manifest, + staticMetadata: { + configSchema: params.manifest.configSchema, + configUiHints: params.manifest.uiHints, + package: packageMetadata, + }, + contributions, + }; +} diff --git a/src/extension-host/static/active-registry.test.ts b/src/extension-host/static/active-registry.test.ts new file mode 100644 index 00000000000..679d5b3d7da --- /dev/null +++ b/src/extension-host/static/active-registry.test.ts @@ -0,0 +1,58 @@ +import { describe, expect, it } from "vitest"; +import { createEmptyPluginRegistry } from "../../plugins/registry.js"; +import { + createEmptyExtensionHostRegistry, + getActiveExtensionHostRegistry, + getActiveExtensionHostRegistryKey, + getActiveExtensionHostRegistryVersion, + requireActiveExtensionHostRegistry, + setActiveExtensionHostRegistry, +} from "./active-registry.js"; + +describe("extension host active registry", () => { + it("initializes with an empty registry", () => { + const emptyRegistry = createEmptyExtensionHostRegistry(); + setActiveExtensionHostRegistry(emptyRegistry, "empty"); + const registry = requireActiveExtensionHostRegistry(); + expect(registry).toBeDefined(); + expect(registry).toBe(emptyRegistry); + expect(registry.channels).toEqual([]); + expect(registry.plugins).toEqual([]); + }); + + it("tracks registry replacement and cache keys", () => { + const baseVersion = getActiveExtensionHostRegistryVersion(); + const registry = createEmptyPluginRegistry(); + registry.plugins.push({ + id: "host-test", + name: "host-test", + source: "test", + origin: "workspace", + enabled: true, + status: "loaded", + toolNames: [], + hookNames: [], + channelIds: [], + providerIds: [], + gatewayMethods: [], + cliCommands: [], + services: [], + commands: [], + httpRoutes: 0, + hookCount: 0, + configSchema: false, + }); + + setActiveExtensionHostRegistry(registry, "host-registry"); + + expect(getActiveExtensionHostRegistry()).toBe(registry); + expect(getActiveExtensionHostRegistryKey()).toBe("host-registry"); + expect(getActiveExtensionHostRegistryVersion()).toBe(baseVersion + 1); + }); + + it("can create a fresh empty registry", () => { + const registry = createEmptyExtensionHostRegistry(); + expect(registry).not.toBe(getActiveExtensionHostRegistry()); + expect(registry).toEqual(createEmptyPluginRegistry()); + }); +}); diff --git a/src/extension-host/static/active-registry.ts b/src/extension-host/static/active-registry.ts new file mode 100644 index 00000000000..777ae393a78 --- /dev/null +++ b/src/extension-host/static/active-registry.ts @@ -0,0 +1,58 @@ +import { createEmptyPluginRegistry, type PluginRegistry } from "../../plugins/registry.js"; + +const EXTENSION_HOST_REGISTRY_STATE = Symbol.for("openclaw.extensionHostRegistryState"); + +export type ExtensionHostRegistry = PluginRegistry; + +type ExtensionHostRegistryState = { + registry: ExtensionHostRegistry | null; + key: string | null; + version: number; +}; + +const state: ExtensionHostRegistryState = (() => { + const globalState = globalThis as typeof globalThis & { + [EXTENSION_HOST_REGISTRY_STATE]?: ExtensionHostRegistryState; + }; + if (!globalState[EXTENSION_HOST_REGISTRY_STATE]) { + globalState[EXTENSION_HOST_REGISTRY_STATE] = { + registry: createEmptyExtensionHostRegistry(), + key: null, + version: 0, + }; + } + return globalState[EXTENSION_HOST_REGISTRY_STATE]; +})(); + +export function createEmptyExtensionHostRegistry(): ExtensionHostRegistry { + return createEmptyPluginRegistry(); +} + +export function setActiveExtensionHostRegistry( + registry: ExtensionHostRegistry, + cacheKey?: string, +): void { + state.registry = registry; + state.key = cacheKey ?? null; + state.version += 1; +} + +export function getActiveExtensionHostRegistry(): ExtensionHostRegistry | null { + return state.registry; +} + +export function requireActiveExtensionHostRegistry(): ExtensionHostRegistry { + if (!state.registry) { + state.registry = createEmptyExtensionHostRegistry(); + state.version += 1; + } + return state.registry; +} + +export function getActiveExtensionHostRegistryKey(): string | null { + return state.key; +} + +export function getActiveExtensionHostRegistryVersion(): number { + return state.version; +} diff --git a/src/extension-host/static/embedding-runtime-backends.test.ts b/src/extension-host/static/embedding-runtime-backends.test.ts new file mode 100644 index 00000000000..5441ff96fbb --- /dev/null +++ b/src/extension-host/static/embedding-runtime-backends.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, it } from "vitest"; +import { + DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL, + EXTENSION_HOST_EMBEDDING_RUNTIME_BACKEND_IDS, + EXTENSION_HOST_REMOTE_EMBEDDING_PROVIDER_IDS, + isExtensionHostEmbeddingRuntimeBackendAutoSelectable, +} from "./embedding-runtime-backends.js"; + +describe("embedding-runtime-backends", () => { + it("keeps the built-in embedding backend order stable", () => { + expect(DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL).toContain("embeddinggemma"); + expect(EXTENSION_HOST_REMOTE_EMBEDDING_PROVIDER_IDS).toEqual([ + "openai", + "gemini", + "voyage", + "mistral", + ]); + expect(EXTENSION_HOST_EMBEDDING_RUNTIME_BACKEND_IDS).toEqual([ + "local", + "openai", + "gemini", + "voyage", + "mistral", + "ollama", + ]); + }); + + it("marks only local and remote embedding backends as auto-selectable", () => { + expect(isExtensionHostEmbeddingRuntimeBackendAutoSelectable("local")).toBe(true); + expect(isExtensionHostEmbeddingRuntimeBackendAutoSelectable("openai")).toBe(true); + expect(isExtensionHostEmbeddingRuntimeBackendAutoSelectable("mistral")).toBe(true); + expect(isExtensionHostEmbeddingRuntimeBackendAutoSelectable("ollama")).toBe(false); + }); +}); diff --git a/src/extension-host/static/embedding-runtime-backends.ts b/src/extension-host/static/embedding-runtime-backends.ts new file mode 100644 index 00000000000..ec057d138bb --- /dev/null +++ b/src/extension-host/static/embedding-runtime-backends.ts @@ -0,0 +1,53 @@ +import type { EmbeddingProviderId } from "../contributions/embedding-runtime-types.js"; + +export const DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL = + "hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf"; +export const DEFAULT_EXTENSION_HOST_OPENAI_EMBEDDING_MODEL = "text-embedding-3-small"; +export const DEFAULT_EXTENSION_HOST_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001"; +export const DEFAULT_EXTENSION_HOST_VOYAGE_EMBEDDING_MODEL = "voyage-4-large"; +export const DEFAULT_EXTENSION_HOST_MISTRAL_EMBEDDING_MODEL = "mistral-embed"; +export const DEFAULT_EXTENSION_HOST_OLLAMA_EMBEDDING_MODEL = "nomic-embed-text"; + +export const EXTENSION_HOST_REMOTE_EMBEDDING_PROVIDER_IDS = [ + "openai", + "gemini", + "voyage", + "mistral", +] as const satisfies readonly EmbeddingProviderId[]; + +export const EXTENSION_HOST_EMBEDDING_RUNTIME_BACKEND_IDS = [ + "local", + ...EXTENSION_HOST_REMOTE_EMBEDDING_PROVIDER_IDS, + "ollama", +] as const satisfies readonly EmbeddingProviderId[]; + +export function isExtensionHostEmbeddingRuntimeBackendAutoSelectable( + backendId: EmbeddingProviderId, +): boolean { + return ( + backendId === "local" || + backendId === "openai" || + backendId === "gemini" || + backendId === "voyage" || + backendId === "mistral" + ); +} + +export function resolveExtensionHostEmbeddingRuntimeDefaultModel( + backendId: EmbeddingProviderId, +): string { + switch (backendId) { + case "openai": + return DEFAULT_EXTENSION_HOST_OPENAI_EMBEDDING_MODEL; + case "gemini": + return DEFAULT_EXTENSION_HOST_GEMINI_EMBEDDING_MODEL; + case "voyage": + return DEFAULT_EXTENSION_HOST_VOYAGE_EMBEDDING_MODEL; + case "mistral": + return DEFAULT_EXTENSION_HOST_MISTRAL_EMBEDDING_MODEL; + case "ollama": + return DEFAULT_EXTENSION_HOST_OLLAMA_EMBEDDING_MODEL; + case "local": + return DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL; + } +} diff --git a/src/extension-host/static/media-runtime-backends.test.ts b/src/extension-host/static/media-runtime-backends.test.ts new file mode 100644 index 00000000000..d5acac9661b --- /dev/null +++ b/src/extension-host/static/media-runtime-backends.test.ts @@ -0,0 +1,60 @@ +import { describe, expect, it } from "vitest"; +import { + buildExtensionHostMediaRuntimeSelectorKeys, + listExtensionHostMediaAutoRuntimeBackendSeedIds, + listExtensionHostMediaRuntimeBackendIds, + listExtensionHostMediaUnderstandingProviders, + normalizeExtensionHostMediaProviderId, + resolveExtensionHostMediaRuntimeDefaultModelMetadata, +} from "./media-runtime-backends.js"; + +describe("extension host media runtime backends", () => { + it("publishes the built-in media providers once", () => { + const providers = listExtensionHostMediaUnderstandingProviders(); + + expect(providers.some((provider) => provider.id === "openai")).toBe(true); + expect(providers.some((provider) => provider.id === "deepgram")).toBe(true); + }); + + it("keeps media-specific provider normalization and selector aliases", () => { + expect(normalizeExtensionHostMediaProviderId("gemini")).toBe("google"); + expect(buildExtensionHostMediaRuntimeSelectorKeys("google")).toEqual(["google", "gemini"]); + }); + + it("keeps auto-seeded runtime backends ordered ahead of the rest", () => { + expect(listExtensionHostMediaAutoRuntimeBackendSeedIds("image")).toEqual([ + "openai", + "anthropic", + "google", + "minimax", + "minimax-portal", + "zai", + ]); + expect(listExtensionHostMediaRuntimeBackendIds("audio").slice(0, 3)).toEqual([ + "openai", + "groq", + "deepgram", + ]); + expect(listExtensionHostMediaRuntimeBackendIds("image").slice(0, 4)).toEqual([ + "openai", + "anthropic", + "google", + "minimax", + ]); + }); + + it("keeps default-model metadata with the shared backend definitions", () => { + expect( + resolveExtensionHostMediaRuntimeDefaultModelMetadata({ + capability: "image", + backendId: "openai", + }), + ).toBe("gpt-5-mini"); + expect( + resolveExtensionHostMediaRuntimeDefaultModelMetadata({ + capability: "video", + backendId: "openai", + }), + ).toBeUndefined(); + }); +}); diff --git a/src/extension-host/static/media-runtime-backends.ts b/src/extension-host/static/media-runtime-backends.ts new file mode 100644 index 00000000000..471c3349beb --- /dev/null +++ b/src/extension-host/static/media-runtime-backends.ts @@ -0,0 +1,118 @@ +import { normalizeProviderId } from "../../agents/provider-id.js"; +import { + AUTO_AUDIO_KEY_PROVIDERS, + AUTO_IMAGE_KEY_PROVIDERS, + AUTO_VIDEO_KEY_PROVIDERS, + DEFAULT_AUDIO_MODELS, + DEFAULT_IMAGE_MODELS, +} from "../../media-understanding/defaults.js"; +import { anthropicProvider } from "../../media-understanding/providers/anthropic/index.js"; +import { deepgramProvider } from "../../media-understanding/providers/deepgram/index.js"; +import { googleProvider } from "../../media-understanding/providers/google/index.js"; +import { groqProvider } from "../../media-understanding/providers/groq/index.js"; +import { + minimaxPortalProvider, + minimaxProvider, +} from "../../media-understanding/providers/minimax/index.js"; +import { mistralProvider } from "../../media-understanding/providers/mistral/index.js"; +import { moonshotProvider } from "../../media-understanding/providers/moonshot/index.js"; +import { openaiProvider } from "../../media-understanding/providers/openai/index.js"; +import { zaiProvider } from "../../media-understanding/providers/zai/index.js"; +import type { + MediaUnderstandingCapability, + MediaUnderstandingProvider, +} from "../../media-understanding/types.js"; + +const EXTENSION_HOST_MEDIA_UNDERSTANDING_PROVIDERS: readonly MediaUnderstandingProvider[] = [ + groqProvider, + openaiProvider, + googleProvider, + anthropicProvider, + minimaxProvider, + minimaxPortalProvider, + moonshotProvider, + mistralProvider, + zaiProvider, + deepgramProvider, +]; + +const EXTENSION_HOST_MEDIA_AUTO_RUNTIME_BACKEND_IDS: Record< + MediaUnderstandingCapability, + readonly string[] +> = { + audio: AUTO_AUDIO_KEY_PROVIDERS, + image: AUTO_IMAGE_KEY_PROVIDERS, + video: AUTO_VIDEO_KEY_PROVIDERS, +}; + +export function listExtensionHostMediaUnderstandingProviders(): readonly MediaUnderstandingProvider[] { + return EXTENSION_HOST_MEDIA_UNDERSTANDING_PROVIDERS; +} + +export function normalizeExtensionHostMediaProviderId(id: string): string { + const normalized = normalizeProviderId(id); + if (normalized === "gemini") { + return "google"; + } + return normalized; +} + +export function buildExtensionHostMediaRuntimeSelectorKeys(providerId: string): readonly string[] { + const normalized = normalizeExtensionHostMediaProviderId(providerId); + if (normalized === "google") { + return [providerId, "gemini"]; + } + return normalized === providerId ? [providerId] : [providerId, normalized]; +} + +export function listExtensionHostMediaAutoRuntimeBackendSeedIds( + capability: MediaUnderstandingCapability, +): readonly string[] { + return EXTENSION_HOST_MEDIA_AUTO_RUNTIME_BACKEND_IDS[capability]; +} + +export function listExtensionHostMediaRuntimeBackendIds( + capability: MediaUnderstandingCapability, +): readonly string[] { + const ordered: string[] = []; + const seen = new Set(); + const pushProvider = (provider: MediaUnderstandingProvider | undefined) => { + if (!provider || !(provider.capabilities ?? []).includes(capability)) { + return; + } + const normalized = normalizeExtensionHostMediaProviderId(provider.id); + if (seen.has(normalized)) { + return; + } + seen.add(normalized); + ordered.push(normalized); + }; + + const providersById = new Map( + listExtensionHostMediaUnderstandingProviders().map((provider) => [ + normalizeExtensionHostMediaProviderId(provider.id), + provider, + ]), + ); + + for (const providerId of listExtensionHostMediaAutoRuntimeBackendSeedIds(capability)) { + pushProvider(providersById.get(normalizeExtensionHostMediaProviderId(providerId))); + } + for (const provider of providersById.values()) { + pushProvider(provider); + } + return ordered; +} + +export function resolveExtensionHostMediaRuntimeDefaultModelMetadata(params: { + capability: MediaUnderstandingCapability; + backendId: string; +}): string | undefined { + if (params.capability === "audio") { + return DEFAULT_AUDIO_MODELS[params.backendId]; + } + if (params.capability === "image") { + return DEFAULT_IMAGE_MODELS[params.backendId]; + } + return undefined; +} diff --git a/src/extension-host/static/runtime-backend-catalog.test.ts b/src/extension-host/static/runtime-backend-catalog.test.ts new file mode 100644 index 00000000000..d70a60c8393 --- /dev/null +++ b/src/extension-host/static/runtime-backend-catalog.test.ts @@ -0,0 +1,165 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { listExtensionHostEmbeddingRemoteRuntimeBackendIds } from "../policy/embedding-runtime-policy.js"; + +vi.mock("./embedding-runtime-backends.js", () => ({ + EXTENSION_HOST_EMBEDDING_RUNTIME_BACKEND_IDS: [ + "local", + "openai", + "gemini", + "voyage", + "mistral", + "ollama", + ], + EXTENSION_HOST_REMOTE_EMBEDDING_PROVIDER_IDS: ["openai", "gemini", "voyage", "mistral"], + isExtensionHostEmbeddingRuntimeBackendAutoSelectable: vi.fn( + (backendId: string) => backendId !== "ollama", + ), + resolveExtensionHostEmbeddingRuntimeDefaultModel: vi.fn((backendId: string) => + backendId === "local" ? "local-model.gguf" : `${backendId}-default-model`, + ), +})); + +vi.mock("./media-runtime-backends.js", () => ({ + buildExtensionHostMediaRuntimeSelectorKeys: vi.fn((id: string) => + id === "google" ? ["google", "gemini"] : [id], + ), + listExtensionHostMediaAutoRuntimeBackendSeedIds: vi.fn( + (capability: "audio" | "image" | "video") => + ({ + audio: ["deepgram"], + image: ["openai", "google"], + video: ["openai"], + })[capability], + ), + listExtensionHostMediaRuntimeBackendIds: vi.fn( + (capability: "audio" | "image" | "video") => + ({ + audio: ["deepgram"], + image: ["openai", "google"], + video: ["openai"], + })[capability], + ), + normalizeExtensionHostMediaProviderId: vi.fn((id: string) => + id.trim().toLowerCase() === "gemini" ? "google" : id.trim().toLowerCase(), + ), + resolveExtensionHostMediaRuntimeDefaultModelMetadata: vi.fn( + (params: { capability: "audio" | "image" | "video"; backendId: string }) => + params.capability === "image" && params.backendId === "openai" ? "gpt-5-mini" : undefined, + ), +})); + +vi.mock("./tts-runtime-backends.js", () => ({ + listExtensionHostTtsRuntimeBackends: vi.fn(() => [ + { id: "openai", supportsTelephony: true }, + { id: "elevenlabs", supportsTelephony: true }, + { id: "edge", supportsTelephony: false }, + ]), +})); + +describe("runtime-backend-catalog", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("publishes embedding backends as host-owned runtime-backend catalog entries", async () => { + const catalog = await import("./runtime-backend-catalog.js"); + const entries = catalog.listExtensionHostEmbeddingRuntimeBackendCatalogEntries(); + + expect(entries.map((entry) => entry.backendId)).toEqual([ + "local", + "openai", + "gemini", + "voyage", + "mistral", + "ollama", + ]); + expect( + entries.every((entry) => entry.family === catalog.EXTENSION_HOST_RUNTIME_BACKEND_FAMILY), + ).toBe(true); + expect(entries.every((entry) => entry.subsystemId === "embedding")).toBe(true); + expect(entries[0]?.capabilities).toContain("embed.query"); + expect(entries[0]?.metadata).toMatchObject({ + autoSelectable: true, + defaultModel: "local-model.gguf", + }); + expect(entries.at(-1)?.metadata).toMatchObject({ + autoSelectable: false, + defaultModel: "ollama-default-model", + }); + }); + + it("splits media providers into subsystem-specific runtime-backend catalog entries", async () => { + const catalog = await import("./runtime-backend-catalog.js"); + const entries = catalog.listExtensionHostMediaRuntimeBackendCatalogEntries(); + + expect(entries).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + subsystemId: "media.image", + backendId: "openai", + capabilities: ["image"], + }), + expect.objectContaining({ + subsystemId: "media.audio", + backendId: "deepgram", + capabilities: ["audio"], + }), + ]), + ); + expect(entries.find((entry) => entry.backendId === "google")?.selectorKeys).toContain("gemini"); + expect(catalog.listExtensionHostMediaAutoRuntimeBackendIds("image")).toEqual([ + "openai", + "google", + ]); + expect( + catalog.resolveExtensionHostMediaRuntimeDefaultModel({ + capability: "image", + backendId: "openai", + }), + ).toBe("gpt-5-mini"); + }); + + it("publishes TTS backends with telephony capability metadata", async () => { + const catalog = await import("./runtime-backend-catalog.js"); + const entries = catalog.listExtensionHostTtsRuntimeBackendCatalogEntries(); + + expect(entries.map((entry) => entry.backendId)).toEqual(["openai", "elevenlabs", "edge"]); + expect(entries.find((entry) => entry.backendId === "openai")?.capabilities).toContain( + "tts.telephony", + ); + expect(entries.find((entry) => entry.backendId === "edge")?.capabilities).toEqual([ + "tts.synthesis", + ]); + expect(catalog.listExtensionHostTtsRuntimeBackendIds()).toEqual([ + "openai", + "elevenlabs", + "edge", + ]); + expect(catalog.resolveExtensionHostTtsRuntimeBackendOrder("edge")).toEqual([ + "edge", + "openai", + "elevenlabs", + ]); + }); + + it("aggregates runtime-backend catalog entries across subsystem families", async () => { + const catalog = await import("./runtime-backend-catalog.js"); + const entries = catalog.listExtensionHostRuntimeBackendCatalogEntries(); + const ids = new Set(entries.map((entry) => entry.id)); + + expect(ids.size).toBe(entries.length); + expect( + catalog.getExtensionHostRuntimeBackendCatalogEntry({ subsystemId: "tts", backendId: "edge" }), + ).toMatchObject({ + id: `${catalog.EXTENSION_HOST_RUNTIME_BACKEND_FAMILY}:tts:edge`, + subsystemId: "tts", + backendId: "edge", + }); + expect(listExtensionHostEmbeddingRemoteRuntimeBackendIds()).toEqual([ + "openai", + "gemini", + "voyage", + "mistral", + ]); + }); +}); diff --git a/src/extension-host/static/runtime-backend-catalog.ts b/src/extension-host/static/runtime-backend-catalog.ts new file mode 100644 index 00000000000..996f086ff14 --- /dev/null +++ b/src/extension-host/static/runtime-backend-catalog.ts @@ -0,0 +1,209 @@ +import type { TtsProvider } from "../../config/types.tts.js"; +import type { MediaUnderstandingCapability } from "../../media-understanding/types.js"; +import { resolveExtensionHostRuntimeBackendIdsByPolicy } from "../policy/runtime-backend-policy.js"; +import { + resolveExtensionHostEmbeddingRuntimeDefaultModel, + EXTENSION_HOST_EMBEDDING_RUNTIME_BACKEND_IDS, + isExtensionHostEmbeddingRuntimeBackendAutoSelectable, +} from "./embedding-runtime-backends.js"; +import { + buildExtensionHostMediaRuntimeSelectorKeys, + listExtensionHostMediaAutoRuntimeBackendSeedIds, + listExtensionHostMediaRuntimeBackendIds as listExtensionHostMediaRuntimeBackendIdsFromDefinitions, + normalizeExtensionHostMediaProviderId, + resolveExtensionHostMediaRuntimeDefaultModelMetadata, +} from "./media-runtime-backends.js"; +import { listExtensionHostTtsRuntimeBackends } from "./tts-runtime-backends.js"; + +export const EXTENSION_HOST_RUNTIME_BACKEND_FAMILY = "capability.runtime-backend"; + +export type ExtensionHostRuntimeBackendFamily = typeof EXTENSION_HOST_RUNTIME_BACKEND_FAMILY; + +export type ExtensionHostRuntimeBackendSubsystemId = + | "embedding" + | "media.audio" + | "media.image" + | "media.video" + | "tts"; + +export type ExtensionHostRuntimeBackendCatalogEntry = { + id: string; + family: ExtensionHostRuntimeBackendFamily; + subsystemId: ExtensionHostRuntimeBackendSubsystemId; + backendId: string; + source: "builtin"; + defaultRank: number; + selectorKeys: readonly string[]; + capabilities: readonly string[]; + metadata?: Record; +}; + +type ExtensionHostMediaRuntimeSubsystemId = Extract< + ExtensionHostRuntimeBackendSubsystemId, + "media.audio" | "media.image" | "media.video" +>; + +function buildRuntimeBackendCatalogId( + subsystemId: ExtensionHostRuntimeBackendSubsystemId, + backendId: string, +): string { + return `${EXTENSION_HOST_RUNTIME_BACKEND_FAMILY}:${subsystemId}:${backendId}`; +} + +function mapMediaCapabilityToSubsystem( + capability: MediaUnderstandingCapability, +): ExtensionHostRuntimeBackendSubsystemId { + if (capability === "audio") { + return "media.audio"; + } + if (capability === "video") { + return "media.video"; + } + return "media.image"; +} + +export function listExtensionHostEmbeddingRuntimeBackendCatalogEntries(): readonly ExtensionHostRuntimeBackendCatalogEntry[] { + return EXTENSION_HOST_EMBEDDING_RUNTIME_BACKEND_IDS.map((backendId, defaultRank) => ({ + id: buildRuntimeBackendCatalogId("embedding", backendId), + family: EXTENSION_HOST_RUNTIME_BACKEND_FAMILY, + subsystemId: "embedding", + backendId, + source: "builtin", + defaultRank, + selectorKeys: [backendId], + capabilities: ["embed.query", "embed.batch"], + metadata: { + autoSelectable: isExtensionHostEmbeddingRuntimeBackendAutoSelectable(backendId), + defaultModel: resolveExtensionHostEmbeddingRuntimeDefaultModel(backendId), + }, + })); +} + +export function listExtensionHostMediaRuntimeBackendCatalogEntries(): readonly ExtensionHostRuntimeBackendCatalogEntry[] { + const entries: ExtensionHostRuntimeBackendCatalogEntry[] = []; + for (const capability of ["audio", "image", "video"] as const) { + const providerIds = listExtensionHostMediaRuntimeBackendIdsFromDefinitions(capability); + for (const [defaultRank, providerId] of providerIds.entries()) { + const defaultModel = resolveExtensionHostMediaRuntimeDefaultModelMetadata({ + capability, + backendId: providerId, + }); + entries.push({ + id: buildRuntimeBackendCatalogId(mapMediaCapabilityToSubsystem(capability), providerId), + family: EXTENSION_HOST_RUNTIME_BACKEND_FAMILY, + subsystemId: mapMediaCapabilityToSubsystem(capability), + backendId: providerId, + source: "builtin", + defaultRank, + selectorKeys: buildExtensionHostMediaRuntimeSelectorKeys(providerId), + capabilities: [capability], + metadata: { + autoSelectable: listExtensionHostMediaAutoRuntimeBackendSeedIds(capability).includes( + normalizeExtensionHostMediaProviderId(providerId), + ), + ...(defaultModel ? { defaultModel } : {}), + }, + }); + } + } + return entries; +} + +export function listExtensionHostMediaAutoRuntimeBackendIds( + capability: MediaUnderstandingCapability, +): readonly string[] { + const subsystemId = mapMediaCapabilityToSubsystem(capability); + return resolveExtensionHostRuntimeBackendIdsByPolicy({ + entries: listExtensionHostMediaRuntimeBackendCatalogEntries(), + subsystemId, + include: (entry) => entry.metadata?.autoSelectable === true, + }); +} + +export function resolveExtensionHostMediaRuntimeDefaultModel(params: { + capability: MediaUnderstandingCapability; + backendId: string; +}): string | undefined { + const subsystemId = mapMediaCapabilityToSubsystem(params.capability); + const entry = listExtensionHostMediaRuntimeBackendCatalogEntries().find( + (candidate) => + candidate.subsystemId === subsystemId && candidate.backendId === params.backendId, + ); + const defaultModel = entry?.metadata?.defaultModel; + return typeof defaultModel === "string" ? defaultModel : undefined; +} + +export function listExtensionHostTtsRuntimeBackendCatalogEntries(): readonly ExtensionHostRuntimeBackendCatalogEntry[] { + return listExtensionHostTtsRuntimeBackends().map((provider, defaultRank) => ({ + id: buildRuntimeBackendCatalogId("tts", provider.id), + family: EXTENSION_HOST_RUNTIME_BACKEND_FAMILY, + subsystemId: "tts", + backendId: provider.id, + source: "builtin", + defaultRank, + selectorKeys: [provider.id], + capabilities: provider.supportsTelephony + ? ["tts.synthesis", "tts.telephony"] + : ["tts.synthesis"], + metadata: { + supportsTelephony: provider.supportsTelephony, + }, + })); +} + +export function listExtensionHostTtsRuntimeBackendIds(): readonly TtsProvider[] { + return listExtensionHostTtsRuntimeBackendCatalogEntries().map( + (entry) => entry.backendId as TtsProvider, + ); +} + +export function listExtensionHostRuntimeBackendIdsForSubsystem( + subsystemId: ExtensionHostRuntimeBackendSubsystemId, +): readonly string[] { + return resolveExtensionHostRuntimeBackendIdsByPolicy({ + entries: listExtensionHostRuntimeBackendCatalogEntries(), + subsystemId, + }); +} + +export function resolveExtensionHostRuntimeBackendOrderForSubsystem( + subsystemId: ExtensionHostRuntimeBackendSubsystemId, + preferredBackendId: string, +): readonly string[] { + return resolveExtensionHostRuntimeBackendIdsByPolicy({ + entries: listExtensionHostRuntimeBackendCatalogEntries(), + subsystemId, + preferredBackendId, + }); +} + +export function listExtensionHostMediaRuntimeBackendIds( + subsystemId: ExtensionHostMediaRuntimeSubsystemId, +): readonly string[] { + return listExtensionHostRuntimeBackendIdsForSubsystem(subsystemId); +} + +export function resolveExtensionHostTtsRuntimeBackendOrder( + preferredBackendId: TtsProvider, +): readonly TtsProvider[] { + return resolveExtensionHostRuntimeBackendOrderForSubsystem("tts", preferredBackendId).map( + (backendId) => backendId as TtsProvider, + ); +} + +export function listExtensionHostRuntimeBackendCatalogEntries(): readonly ExtensionHostRuntimeBackendCatalogEntry[] { + return [ + ...listExtensionHostEmbeddingRuntimeBackendCatalogEntries(), + ...listExtensionHostMediaRuntimeBackendCatalogEntries(), + ...listExtensionHostTtsRuntimeBackendCatalogEntries(), + ]; +} + +export function getExtensionHostRuntimeBackendCatalogEntry(params: { + subsystemId: ExtensionHostRuntimeBackendSubsystemId; + backendId: string; +}): ExtensionHostRuntimeBackendCatalogEntry | undefined { + return listExtensionHostRuntimeBackendCatalogEntries().find( + (entry) => entry.subsystemId === params.subsystemId && entry.backendId === params.backendId, + ); +} diff --git a/src/extension-host/static/tts-runtime-backends.test.ts b/src/extension-host/static/tts-runtime-backends.test.ts new file mode 100644 index 00000000000..5f5aea6be60 --- /dev/null +++ b/src/extension-host/static/tts-runtime-backends.test.ts @@ -0,0 +1,37 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { + EXTENSION_HOST_TTS_RUNTIME_BACKEND_IDS, + getExtensionHostTtsRuntimeBackend, + listExtensionHostTtsRuntimeBackends, +} from "./tts-runtime-backends.js"; + +describe("tts-runtime-backends", () => { + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it("keeps the built-in backend order stable", () => { + expect(EXTENSION_HOST_TTS_RUNTIME_BACKEND_IDS).toEqual(["openai", "elevenlabs", "edge"]); + expect(listExtensionHostTtsRuntimeBackends().map((backend) => backend.id)).toEqual([ + "openai", + "elevenlabs", + "edge", + ]); + }); + + it("resolves API keys and configuration through shared backend definitions", () => { + vi.stubEnv("OPENAI_API_KEY", ""); + vi.stubEnv("ELEVENLABS_API_KEY", ""); + vi.stubEnv("XI_API_KEY", ""); + + const config = { + openai: { apiKey: "openai-key" }, + elevenlabs: { apiKey: "" }, + edge: { enabled: true }, + } as const; + + expect(getExtensionHostTtsRuntimeBackend("openai")?.resolveApiKey(config)).toBe("openai-key"); + expect(getExtensionHostTtsRuntimeBackend("elevenlabs")?.isConfigured(config)).toBe(false); + expect(getExtensionHostTtsRuntimeBackend("edge")?.supportsTelephony).toBe(false); + }); +}); diff --git a/src/extension-host/static/tts-runtime-backends.ts b/src/extension-host/static/tts-runtime-backends.ts new file mode 100644 index 00000000000..7ca6f536006 --- /dev/null +++ b/src/extension-host/static/tts-runtime-backends.ts @@ -0,0 +1,56 @@ +import type { TtsProvider } from "../../config/types.tts.js"; +import type { ResolvedTtsConfig } from "../contributions/tts-config.js"; + +export type ExtensionHostTtsRuntimeBackend = { + id: TtsProvider; + supportsTelephony: boolean; + resolveApiKey: (config: ResolvedTtsConfig) => string | undefined; + isConfigured: (config: ResolvedTtsConfig) => boolean; +}; + +const EXTENSION_HOST_TTS_RUNTIME_BACKENDS: readonly ExtensionHostTtsRuntimeBackend[] = [ + { + id: "openai", + supportsTelephony: true, + resolveApiKey(config) { + return config.openai.apiKey || process.env.OPENAI_API_KEY; + }, + isConfigured(config) { + return Boolean(this.resolveApiKey(config)); + }, + }, + { + id: "elevenlabs", + supportsTelephony: true, + resolveApiKey(config) { + return config.elevenlabs.apiKey || process.env.ELEVENLABS_API_KEY || process.env.XI_API_KEY; + }, + isConfigured(config) { + return Boolean(this.resolveApiKey(config)); + }, + }, + { + id: "edge", + supportsTelephony: false, + resolveApiKey() { + return undefined; + }, + isConfigured(config) { + return config.edge.enabled; + }, + }, +] as const; + +export const EXTENSION_HOST_TTS_RUNTIME_BACKEND_IDS = EXTENSION_HOST_TTS_RUNTIME_BACKENDS.map( + (backend) => backend.id, +) as readonly TtsProvider[]; + +export function listExtensionHostTtsRuntimeBackends(): readonly ExtensionHostTtsRuntimeBackend[] { + return EXTENSION_HOST_TTS_RUNTIME_BACKENDS; +} + +export function getExtensionHostTtsRuntimeBackend( + id: TtsProvider, +): ExtensionHostTtsRuntimeBackend | undefined { + return EXTENSION_HOST_TTS_RUNTIME_BACKENDS.find((backend) => backend.id === id); +}