diff --git a/apps/macos/Sources/OpenClawProtocol/GatewayModels.swift b/apps/macos/Sources/OpenClawProtocol/GatewayModels.swift index 0b1d7b13e01..5283ad3907a 100644 --- a/apps/macos/Sources/OpenClawProtocol/GatewayModels.swift +++ b/apps/macos/Sources/OpenClawProtocol/GatewayModels.swift @@ -537,6 +537,9 @@ public struct AgentParams: Codable, Sendable { public let lane: String? public let extrasystemprompt: String? public let internalevents: [[String: AnyCodable]]? + public let clienttools: [[String: AnyCodable]]? + public let disabletools: Bool? + public let streamparams: [String: AnyCodable]? public let inputprovenance: [String: AnyCodable]? public let idempotencykey: String public let label: String? @@ -566,6 +569,9 @@ public struct AgentParams: Codable, Sendable { lane: String?, extrasystemprompt: String?, internalevents: [[String: AnyCodable]]?, + clienttools: [[String: AnyCodable]]?, + disabletools: Bool?, + streamparams: [String: AnyCodable]?, inputprovenance: [String: AnyCodable]?, idempotencykey: String, label: String?) @@ -594,6 +600,9 @@ public struct AgentParams: Codable, Sendable { self.lane = lane self.extrasystemprompt = extrasystemprompt self.internalevents = internalevents + self.clienttools = clienttools + self.disabletools = disabletools + self.streamparams = streamparams self.inputprovenance = inputprovenance self.idempotencykey = idempotencykey self.label = label @@ -624,6 +633,9 @@ public struct AgentParams: Codable, Sendable { case lane case extrasystemprompt = "extraSystemPrompt" case internalevents = "internalEvents" + case clienttools = "clientTools" + case disabletools = "disableTools" + case streamparams = "streamParams" case inputprovenance = "inputProvenance" case idempotencykey = "idempotencyKey" case label diff --git a/apps/shared/OpenClawKit/Sources/OpenClawProtocol/GatewayModels.swift b/apps/shared/OpenClawKit/Sources/OpenClawProtocol/GatewayModels.swift index 0b1d7b13e01..5283ad3907a 100644 --- a/apps/shared/OpenClawKit/Sources/OpenClawProtocol/GatewayModels.swift +++ b/apps/shared/OpenClawKit/Sources/OpenClawProtocol/GatewayModels.swift @@ -537,6 +537,9 @@ public struct AgentParams: Codable, Sendable { public let lane: String? public let extrasystemprompt: String? public let internalevents: [[String: AnyCodable]]? + public let clienttools: [[String: AnyCodable]]? + public let disabletools: Bool? + public let streamparams: [String: AnyCodable]? public let inputprovenance: [String: AnyCodable]? public let idempotencykey: String public let label: String? @@ -566,6 +569,9 @@ public struct AgentParams: Codable, Sendable { lane: String?, extrasystemprompt: String?, internalevents: [[String: AnyCodable]]?, + clienttools: [[String: AnyCodable]]?, + disabletools: Bool?, + streamparams: [String: AnyCodable]?, inputprovenance: [String: AnyCodable]?, idempotencykey: String, label: String?) @@ -594,6 +600,9 @@ public struct AgentParams: Codable, Sendable { self.lane = lane self.extrasystemprompt = extrasystemprompt self.internalevents = internalevents + self.clienttools = clienttools + self.disabletools = disabletools + self.streamparams = streamparams self.inputprovenance = inputprovenance self.idempotencykey = idempotencykey self.label = label @@ -624,6 +633,9 @@ public struct AgentParams: Codable, Sendable { case lane case extrasystemprompt = "extraSystemPrompt" case internalevents = "internalEvents" + case clienttools = "clientTools" + case disabletools = "disableTools" + case streamparams = "streamParams" case inputprovenance = "inputProvenance" case idempotencykey = "idempotencyKey" case label diff --git a/src/agents/agent-command.ts b/src/agents/agent-command.ts index 5db40b13a27..a35588d38c9 100644 --- a/src/agents/agent-command.ts +++ b/src/agents/agent-command.ts @@ -512,6 +512,7 @@ function runAgentAttempt(params: { prompt: effectivePrompt, images: params.isFallbackRetry ? undefined : params.opts.images, clientTools: params.opts.clientTools, + disableTools: params.opts.disableTools, provider: params.providerOverride, model: params.modelOverride, authProfileId, @@ -1236,6 +1237,7 @@ async function agentCommandInternal( endedAt: Date.now(), aborted: result.meta.aborted ?? false, stopReason, + pendingToolCalls: result.meta.pendingToolCalls, }, }); } diff --git a/src/agents/command/types.ts b/src/agents/command/types.ts index a85157bb191..d4bfe2fee8c 100644 --- a/src/agents/command/types.ts +++ b/src/agents/command/types.ts @@ -17,6 +17,17 @@ export type AgentStreamParams = { maxTokens?: number; /** Provider fast-mode override (best-effort). */ fastMode?: boolean; + /** Provider tool-choice override (best-effort). */ + toolChoice?: + | "auto" + | "none" + | "required" + | { + type: "function"; + function: { + name: string; + }; + }; }; export type AgentRunContext = { @@ -37,6 +48,8 @@ export type AgentCommandOpts = { images?: ImageContent[]; /** Optional client-provided tools (OpenResponses hosted tools). */ clientTools?: ClientToolDefinition[]; + /** Disable built-in tools for this run while still allowing clientTools. */ + disableTools?: boolean; /** Agent id override (must exist in config). */ agentId?: string; /** Per-run provider override. */ diff --git a/src/agents/pi-embedded-runner/extra-params.cache-retention-default.test.ts b/src/agents/pi-embedded-runner/extra-params.cache-retention-default.test.ts index b988a8c3c59..8d827924a2d 100644 --- a/src/agents/pi-embedded-runner/extra-params.cache-retention-default.test.ts +++ b/src/agents/pi-embedded-runner/extra-params.cache-retention-default.test.ts @@ -144,4 +144,98 @@ describe("cacheRetention default behavior", () => { provider: "anthropic", }); }); + + it("forwards toolChoice overrides to the wrapped stream options", async () => { + const baseStreamFn = vi.fn(async () => undefined); + const agent: { streamFn?: StreamFn } = { streamFn: baseStreamFn as unknown as StreamFn }; + + applyExtraParamsToAgent( + agent, + undefined, + "openai", + "gpt-4.1", + { + toolChoice: { + type: "function", + function: { + name: "emit_structured_result", + }, + }, + }, + undefined, + undefined, + undefined, + new Set(["emit_structured_result"]), + ); + + expect(agent.streamFn).toBeDefined(); + await agent.streamFn?.( + { api: "openai-responses", provider: "openai", id: "gpt-4.1", compat: {} } as never, + {} as never, + {}, + ); + + expect(baseStreamFn).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + expect.objectContaining({ + toolChoice: { + type: "function", + function: { + name: "emit_structured_result", + }, + }, + }), + ); + }); + + it("drops stale function toolChoice selections that are not in the allowed tool set", async () => { + const baseStreamFn = vi.fn(async () => undefined); + const agent: { streamFn?: StreamFn } = { streamFn: baseStreamFn as unknown as StreamFn }; + const cfg = { + agents: { + defaults: { + models: { + "openai/gpt-4.1": { + params: { + toolChoice: { + type: "function", + function: { + name: "bash", + }, + }, + }, + }, + }, + }, + }, + }; + + applyExtraParamsToAgent( + agent, + cfg, + "openai", + "gpt-4.1", + undefined, + undefined, + undefined, + undefined, + new Set(["emit_structured_result"]), + ); + + expect(agent.streamFn).toBeDefined(); + await agent.streamFn?.( + { api: "openai-responses", provider: "openai", id: "gpt-4.1", compat: {} } as never, + {} as never, + {}, + ); + + expect(baseStreamFn).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + expect.objectContaining({ + toolChoice: "auto", + }), + ); + }); }); diff --git a/src/agents/pi-embedded-runner/extra-params.ts b/src/agents/pi-embedded-runner/extra-params.ts index 8ed61d4aeff..61d5c2c3363 100644 --- a/src/agents/pi-embedded-runner/extra-params.ts +++ b/src/agents/pi-embedded-runner/extra-params.ts @@ -2,11 +2,13 @@ import type { StreamFn } from "@mariozechner/pi-agent-core"; import type { SimpleStreamOptions } from "@mariozechner/pi-ai"; import { streamSimple } from "@mariozechner/pi-ai"; import type { ThinkLevel } from "../../auto-reply/thinking.js"; +import type { AgentStreamParams } from "../../commands/agent/types.js"; import type { OpenClawConfig } from "../../config/config.js"; import { prepareProviderExtraParams, wrapProviderStreamFn, } from "../../plugins/provider-runtime.js"; +import { normalizeToolName } from "../tool-policy.js"; import { createAnthropicBetaHeadersWrapper, createAnthropicFastModeWrapper, @@ -72,8 +74,101 @@ export function resolveExtraParams(params: { type CacheRetentionStreamOptions = Partial & { cacheRetention?: "none" | "short" | "long"; openaiWsWarmup?: boolean; + toolChoice?: NonNullable; }; +type ToolChoiceOverride = NonNullable; + +function isToolChoiceOverride(value: unknown): value is ToolChoiceOverride { + if (value === "auto" || value === "none" || value === "required") { + return true; + } + if (!value || typeof value !== "object") { + return false; + } + const record = value as Record; + const fn = record.function; + return ( + record.type === "function" && + !!fn && + typeof fn === "object" && + typeof (fn as Record).name === "string" + ); +} + +function resolveAllowedToolChoiceName( + rawName: string, + allowedToolNames?: Set, +): string | undefined { + if (!allowedToolNames || allowedToolNames.size === 0) { + return undefined; + } + const trimmed = rawName.trim(); + if (!trimmed) { + return undefined; + } + if (allowedToolNames.has(trimmed)) { + return trimmed; + } + const normalized = normalizeToolName(trimmed); + if (allowedToolNames.has(normalized)) { + return normalized; + } + const lowered = normalized.toLowerCase(); + let caseInsensitiveMatch: string | undefined; + for (const candidate of allowedToolNames) { + if (candidate.toLowerCase() !== lowered) { + continue; + } + if (caseInsensitiveMatch && caseInsensitiveMatch !== candidate) { + return undefined; + } + caseInsensitiveMatch = candidate; + } + return caseInsensitiveMatch; +} + +function sanitizeToolChoiceOverride( + extraParams: Record | undefined, + allowedToolNames?: Set, +): Record | undefined { + if (!extraParams || !isToolChoiceOverride(extraParams.toolChoice)) { + return extraParams; + } + if (!allowedToolNames || allowedToolNames.size === 0) { + if (extraParams.toolChoice === "none") { + return extraParams; + } + return { + ...extraParams, + toolChoice: "none", + }; + } + const toolChoice = extraParams.toolChoice; + if (toolChoice === "auto" || toolChoice === "none" || toolChoice === "required") { + return extraParams; + } + const resolvedName = resolveAllowedToolChoiceName(toolChoice.function.name, allowedToolNames); + if (!resolvedName) { + return { + ...extraParams, + toolChoice: "auto", + }; + } + if (resolvedName === toolChoice.function.name) { + return extraParams; + } + return { + ...extraParams, + toolChoice: { + type: "function", + function: { + name: resolvedName, + }, + }, + }; +} + function createStreamFnWithExtraParams( baseStreamFn: StreamFn | undefined, extraParams: Record | undefined, @@ -90,6 +185,9 @@ function createStreamFnWithExtraParams( if (typeof extraParams.maxTokens === "number") { streamParams.maxTokens = extraParams.maxTokens; } + if (isToolChoiceOverride(extraParams.toolChoice)) { + streamParams.toolChoice = extraParams.toolChoice; + } const transport = extraParams.transport; if (transport === "sse" || transport === "websocket" || transport === "auto") { streamParams.transport = transport; @@ -184,6 +282,7 @@ export function applyExtraParamsToAgent( thinkingLevel?: ThinkLevel, agentId?: string, workspaceDir?: string, + allowedToolNames?: Set, ): void { const resolvedExtraParams = resolveExtraParams({ cfg, @@ -210,10 +309,11 @@ export function applyExtraParamsToAgent( thinkingLevel, }, }) ?? merged; + const sanitizedExtraParams = sanitizeToolChoiceOverride(effectiveExtraParams, allowedToolNames); const wrappedStreamFn = createStreamFnWithExtraParams( agent.streamFn, - effectiveExtraParams, + sanitizedExtraParams, provider, ); @@ -222,7 +322,7 @@ export function applyExtraParamsToAgent( agent.streamFn = wrappedStreamFn; } - const anthropicBetas = resolveAnthropicBetas(effectiveExtraParams, provider, modelId); + const anthropicBetas = resolveAnthropicBetas(sanitizedExtraParams, provider, modelId); if (anthropicBetas?.length) { log.debug( `applying Anthropic beta header for ${provider}/${modelId}: ${anthropicBetas.join(",")}`, @@ -249,7 +349,7 @@ export function applyExtraParamsToAgent( config: cfg, provider, modelId, - extraParams: effectiveExtraParams, + extraParams: sanitizedExtraParams, thinkingLevel, streamFn: providerStreamBase, }, @@ -263,7 +363,7 @@ export function applyExtraParamsToAgent( // actually handled the stream function. This covers tests/disabled plugins // and Ollama Cloud Kimi models until they gain a dedicated runtime hook. const thinkingType = resolveMoonshotThinkingType({ - configuredThinking: effectiveExtraParams?.thinking, + configuredThinking: sanitizedExtraParams?.thinking, thinkingLevel, }); agent.streamFn = createMoonshotThinkingWrapper(agent.streamFn, thinkingType); @@ -275,13 +375,13 @@ export function applyExtraParamsToAgent( agent.streamFn = createAnthropicFastModeWrapper(agent.streamFn, anthropicFastMode); } - const openAIFastMode = resolveOpenAIFastMode(effectiveExtraParams); + const openAIFastMode = resolveOpenAIFastMode(sanitizedExtraParams); if (openAIFastMode) { log.debug(`applying OpenAI fast mode for ${provider}/${modelId}`); agent.streamFn = createOpenAIFastModeWrapper(agent.streamFn); } - const openAIServiceTier = resolveOpenAIServiceTier(effectiveExtraParams); + const openAIServiceTier = resolveOpenAIServiceTier(sanitizedExtraParams); if (openAIServiceTier) { log.debug(`applying OpenAI service_tier=${openAIServiceTier} for ${provider}/${modelId}`); agent.streamFn = createOpenAIServiceTierWrapper(agent.streamFn, openAIServiceTier); @@ -292,7 +392,7 @@ export function applyExtraParamsToAgent( // server-side compaction for compatible OpenAI Responses payloads. agent.streamFn = createOpenAIResponsesContextManagementWrapper( agent.streamFn, - effectiveExtraParams, + sanitizedExtraParams, ); const rawParallelToolCalls = resolveAliasedParamValue( diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index d785218f819..9565fef1332 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -2243,18 +2243,21 @@ export async function runEmbeddedAttempt( activeSession.agent.streamFn = wrapOllamaCompatNumCtx(activeSession.agent.streamFn, numCtx); } + const effectiveStreamParams: Record = { + ...params.streamParams, + ...(params.fastMode !== undefined ? { fastMode: params.fastMode } : {}), + ...(allowedToolNames.size === 0 ? { toolChoice: "none" } : {}), + }; applyExtraParamsToAgent( activeSession.agent, params.config, params.provider, params.modelId, - { - ...params.streamParams, - fastMode: params.fastMode, - }, + effectiveStreamParams, params.thinkLevel, sessionAgentId, effectiveWorkspace, + allowedToolNames, ); if (cacheTrace) { diff --git a/src/gateway/protocol/schema/agent.ts b/src/gateway/protocol/schema/agent.ts index b9c844b135b..f9d12dc2aaa 100644 --- a/src/gateway/protocol/schema/agent.ts +++ b/src/gateway/protocol/schema/agent.ts @@ -29,6 +29,49 @@ export const AgentEventSchema = Type.Object( { additionalProperties: false }, ); +const ClientToolDefinitionSchema = Type.Object( + { + type: Type.Literal("function"), + function: Type.Object( + { + name: NonEmptyString, + description: Type.Optional(Type.String()), + parameters: Type.Optional(Type.Record(Type.String(), Type.Unknown())), + }, + { additionalProperties: false }, + ), + }, + { additionalProperties: false }, +); + +const AgentToolChoiceSchema = Type.Union([ + Type.Literal("auto"), + Type.Literal("none"), + Type.Literal("required"), + Type.Object( + { + type: Type.Literal("function"), + function: Type.Object( + { + name: NonEmptyString, + }, + { additionalProperties: false }, + ), + }, + { additionalProperties: false }, + ), +]); + +const AgentStreamParamsSchema = Type.Object( + { + temperature: Type.Optional(Type.Number()), + maxTokens: Type.Optional(Type.Number()), + fastMode: Type.Optional(Type.Boolean()), + toolChoice: Type.Optional(AgentToolChoiceSchema), + }, + { additionalProperties: false }, +); + export const SendParamsSchema = Type.Object( { to: NonEmptyString, @@ -97,6 +140,9 @@ export const AgentParamsSchema = Type.Object( lane: Type.Optional(Type.String()), extraSystemPrompt: Type.Optional(Type.String()), internalEvents: Type.Optional(Type.Array(AgentInternalEventSchema)), + clientTools: Type.Optional(Type.Array(ClientToolDefinitionSchema)), + disableTools: Type.Optional(Type.Boolean()), + streamParams: Type.Optional(AgentStreamParamsSchema), inputProvenance: Type.Optional(InputProvenanceSchema), idempotencyKey: NonEmptyString, label: Type.Optional(SessionLabelString), diff --git a/src/gateway/server-methods/agent-job.ts b/src/gateway/server-methods/agent-job.ts index 2c7e7a6aeba..3a2d847c11b 100644 --- a/src/gateway/server-methods/agent-job.ts +++ b/src/gateway/server-methods/agent-job.ts @@ -1,4 +1,5 @@ import { onAgentEvent } from "../../infra/agent-events.js"; +import { parsePendingToolCalls } from "./pending-tool-calls.js"; const AGENT_RUN_CACHE_TTL_MS = 10 * 60_000; /** @@ -19,6 +20,12 @@ type AgentRunSnapshot = { startedAt?: number; endedAt?: number; error?: string; + stopReason?: string; + pendingToolCalls?: Array<{ + id: string; + name: string; + arguments: string; + }>; ts: number; }; @@ -86,12 +93,16 @@ function createSnapshotFromLifecycleEvent(params: { typeof data?.startedAt === "number" ? data.startedAt : agentRunStarts.get(runId); const endedAt = typeof data?.endedAt === "number" ? data.endedAt : undefined; const error = typeof data?.error === "string" ? data.error : undefined; + const stopReason = typeof data?.stopReason === "string" ? data.stopReason : undefined; + const pendingToolCalls = parsePendingToolCalls(data?.pendingToolCalls); return { runId, status: phase === "error" ? "error" : data?.aborted ? "timeout" : "ok", startedAt, endedAt, error, + stopReason, + pendingToolCalls: pendingToolCalls?.length ? pendingToolCalls : undefined, ts: Date.now(), }; } diff --git a/src/gateway/server-methods/agent-wait-dedupe.test.ts b/src/gateway/server-methods/agent-wait-dedupe.test.ts index 4bbf2a575a0..c9d467f6227 100644 --- a/src/gateway/server-methods/agent-wait-dedupe.test.ts +++ b/src/gateway/server-methods/agent-wait-dedupe.test.ts @@ -258,6 +258,54 @@ describe("agent wait dedupe helper", () => { }); }); + it("extracts stopReason and pendingToolCalls from nested agent result metadata", () => { + const dedupe = new Map(); + const runId = "run-structured-agent"; + setRunEntry({ + dedupe, + kind: "agent", + runId, + payload: { + runId, + status: "ok", + startedAt: 10, + endedAt: 20, + result: { + meta: { + stopReason: "tool_calls", + pendingToolCalls: [ + { + id: "call-1", + name: "emit_structured_result", + arguments: '{"entries":[]}', + }, + ], + }, + }, + }, + }); + + expect( + readTerminalSnapshotFromGatewayDedupe({ + dedupe, + runId, + }), + ).toEqual({ + status: "ok", + startedAt: 10, + endedAt: 20, + error: undefined, + stopReason: "tool_calls", + pendingToolCalls: [ + { + id: "call-1", + name: "emit_structured_result", + arguments: '{"entries":[]}', + }, + ], + }); + }); + it("resolves multiple waiters for the same run id", async () => { const dedupe = new Map(); const runId = "run-multi"; diff --git a/src/gateway/server-methods/agent-wait-dedupe.ts b/src/gateway/server-methods/agent-wait-dedupe.ts index 50629beb3eb..59990b83225 100644 --- a/src/gateway/server-methods/agent-wait-dedupe.ts +++ b/src/gateway/server-methods/agent-wait-dedupe.ts @@ -1,10 +1,17 @@ import type { DedupeEntry } from "../server-shared.js"; +import { parsePendingToolCalls } from "./pending-tool-calls.js"; export type AgentWaitTerminalSnapshot = { status: "ok" | "error" | "timeout"; startedAt?: number; endedAt?: number; error?: string; + stopReason?: string; + pendingToolCalls?: Array<{ + id: string; + name: string; + arguments: string; + }>; }; const AGENT_WAITERS_BY_RUN_ID = new Map void>>(); @@ -72,6 +79,14 @@ export function readTerminalSnapshotFromDedupeEntry( endedAt?: unknown; error?: unknown; summary?: unknown; + stopReason?: unknown; + pendingToolCalls?: unknown; + result?: { + meta?: { + stopReason?: unknown; + pendingToolCalls?: unknown; + }; + }; } | undefined; const status = typeof payload?.status === "string" ? payload.status : undefined; @@ -87,6 +102,15 @@ export function readTerminalSnapshotFromDedupeEntry( : typeof payload?.summary === "string" ? payload.summary : entry.error?.message; + const stopReason = + typeof payload?.result?.meta?.stopReason === "string" + ? payload.result.meta.stopReason + : typeof payload?.stopReason === "string" + ? payload.stopReason + : undefined; + const pendingToolCalls = + parsePendingToolCalls(payload?.result?.meta?.pendingToolCalls) ?? + parsePendingToolCalls(payload?.pendingToolCalls); if (status === "ok" || status === "timeout") { return { @@ -94,6 +118,8 @@ export function readTerminalSnapshotFromDedupeEntry( startedAt, endedAt, error: status === "timeout" ? errorMessage : undefined, + stopReason, + pendingToolCalls, }; } if (status === "error" || !entry.ok) { @@ -102,6 +128,8 @@ export function readTerminalSnapshotFromDedupeEntry( startedAt, endedAt, error: errorMessage, + stopReason, + pendingToolCalls, }; } return null; diff --git a/src/gateway/server-methods/agent.test.ts b/src/gateway/server-methods/agent.test.ts index f29a9a4c85d..7648d9835e6 100644 --- a/src/gateway/server-methods/agent.test.ts +++ b/src/gateway/server-methods/agent.test.ts @@ -1,4 +1,4 @@ -import { describe, expect, it, vi } from "vitest"; +import { beforeEach, describe, expect, it, vi } from "vitest"; import { BARE_SESSION_RESET_PROMPT } from "../../auto-reply/reply/session-reset-prompt.js"; import { agentHandlers } from "./agent.js"; import type { GatewayRequestContext } from "./types.js"; @@ -10,6 +10,9 @@ const mocks = vi.hoisted(() => ({ agentCommand: vi.fn(), registerAgentRunContext: vi.fn(), performGatewaySessionReset: vi.fn(), + waitForAgentJob: vi.fn(), + readTerminalSnapshotFromGatewayDedupe: vi.fn(), + waitForTerminalGatewayDedupe: vi.fn(), getSubagentRunByChildSessionKey: vi.fn(), replaceSubagentRunAfterSteer: vi.fn(), loadConfigReturn: {} as Record, @@ -76,6 +79,17 @@ vi.mock("../session-reset-service.js", () => ({ (mocks.performGatewaySessionReset as (...args: unknown[]) => unknown)(...args), })); +vi.mock("./agent-job.js", () => ({ + waitForAgentJob: (...args: unknown[]) => mocks.waitForAgentJob(...args), +})); + +vi.mock("./agent-wait-dedupe.js", () => ({ + readTerminalSnapshotFromGatewayDedupe: (...args: unknown[]) => + mocks.readTerminalSnapshotFromGatewayDedupe(...args), + setGatewayDedupeEntry: vi.fn(), + waitForTerminalGatewayDedupe: (...args: unknown[]) => mocks.waitForTerminalGatewayDedupe(...args), +})); + vi.mock("../../sessions/send-policy.js", () => ({ resolveSendPolicy: () => "allow", })); @@ -105,6 +119,16 @@ type AgentParams = AgentHandlerArgs["params"]; type AgentIdentityGetHandlerArgs = Parameters<(typeof agentHandlers)["agent.identity.get"]>[0]; type AgentIdentityGetParams = AgentIdentityGetHandlerArgs["params"]; +beforeEach(() => { + mocks.waitForAgentJob.mockReset(); + mocks.waitForAgentJob.mockResolvedValue(null); + mocks.readTerminalSnapshotFromGatewayDedupe.mockReset(); + mocks.readTerminalSnapshotFromGatewayDedupe.mockReturnValue(null); + mocks.waitForTerminalGatewayDedupe.mockReset(); + mocks.waitForTerminalGatewayDedupe.mockResolvedValue(null); + mocks.loadConfigReturn = {}; +}); + async function waitForAssertion(assertion: () => void, timeoutMs = 2_000, stepMs = 5) { vi.useFakeTimers(); try { @@ -633,6 +657,64 @@ describe("gateway agent handler", () => { expect(callArgs.bestEffortDeliver).toBe(false); }); + it("forwards structured subagent options to agentCommandFromIngress", async () => { + primeMainAgentRun(); + + await invokeAgent( + { + message: "structured helper run", + agentId: "main", + sessionKey: "agent:main:main", + disableTools: true, + clientTools: [ + { + type: "function", + function: { + name: "emit_structured_result", + description: "Return a structured result payload.", + parameters: { + type: "object", + properties: { + entries: { type: "array" }, + }, + }, + }, + }, + ], + streamParams: { + toolChoice: { + type: "function", + function: { + name: "emit_structured_result", + }, + }, + }, + idempotencyKey: "test-structured-helper", + } as AgentParams, + { reqId: "structured-helper-1" }, + ); + + await vi.waitFor(() => expect(mocks.agentCommand).toHaveBeenCalled()); + const callArgs = mocks.agentCommand.mock.calls.at(-1)?.[0] as Record; + expect(callArgs.disableTools).toBe(true); + expect(callArgs.clientTools).toEqual([ + { + type: "function", + function: expect.objectContaining({ + name: "emit_structured_result", + }), + }, + ]); + expect(callArgs.streamParams).toEqual({ + toolChoice: { + type: "function", + function: { + name: "emit_structured_result", + }, + }, + }); + }); + it("rejects public spawned-run metadata fields", async () => { primeMainAgentRun(); mocks.agentCommand.mockClear(); @@ -877,4 +959,206 @@ describe("gateway agent handler", () => { }), ); }); + + it("returns structured fields for cached agent.wait snapshots", async () => { + const respond = vi.fn(); + const context = makeContext(); + context.chatAbortControllers = new Map(); + mocks.readTerminalSnapshotFromGatewayDedupe.mockReturnValue({ + status: "ok", + startedAt: 10, + endedAt: 20, + stopReason: "tool_calls", + pendingToolCalls: [ + { + id: "call-1", + name: "emit_structured_result", + arguments: '{"entries":[]}', + }, + ], + }); + + await agentHandlers["agent.wait"]({ + params: { runId: "wait-cached", timeoutMs: 100 }, + respond, + context, + } as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]); + + expect(respond).toHaveBeenCalledWith( + true, + expect.objectContaining({ + runId: "wait-cached", + status: "ok", + stopReason: "tool_calls", + pendingToolCalls: [ + { + id: "call-1", + name: "emit_structured_result", + arguments: '{"entries":[]}', + }, + ], + }), + ); + }); + + it("merges structured fields from dedupe when lifecycle resolves first", async () => { + const respond = vi.fn(); + const context = makeContext(); + context.chatAbortControllers = new Map(); + mocks.waitForAgentJob.mockResolvedValue({ + status: "ok", + startedAt: 10, + endedAt: 20, + }); + mocks.waitForTerminalGatewayDedupe.mockImplementation(async () => { + await Promise.resolve(); + return { + status: "ok", + startedAt: 10, + endedAt: 20, + stopReason: "tool_calls", + pendingToolCalls: [ + { + id: "call-1", + name: "emit_structured_result", + arguments: '{"entries":[]}', + }, + ], + }; + }); + + await agentHandlers["agent.wait"]({ + params: { runId: "wait-live", timeoutMs: 100 }, + respond, + context, + } as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]); + + expect(respond).toHaveBeenCalledWith( + true, + expect.objectContaining({ + runId: "wait-live", + status: "ok", + stopReason: "tool_calls", + pendingToolCalls: [ + { + id: "call-1", + name: "emit_structured_result", + arguments: '{"entries":[]}', + }, + ], + }), + ); + }); + + it("does not grace-wait when lifecycle resolves without tool calls", async () => { + vi.useFakeTimers(); + try { + const respond = vi.fn(); + const context = makeContext(); + context.chatAbortControllers = new Map([ + [ + "wait-no-tools", + { + controller: new AbortController(), + sessionKey: "agent:main:main", + sessionId: "test-session", + startedAtMs: Date.now(), + expiresAtMs: Date.now() + 60_000, + }, + ], + ]); + mocks.waitForAgentJob.mockResolvedValue({ + status: "ok", + startedAt: 10, + endedAt: 20, + stopReason: "stop", + }); + mocks.waitForTerminalGatewayDedupe.mockImplementation( + () => + new Promise((resolve) => { + setTimeout(() => resolve(null), 1_000); + }), + ); + + const waitPromise = agentHandlers["agent.wait"]({ + params: { runId: "wait-no-tools", timeoutMs: 100 }, + respond, + context, + } as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]); + + await vi.advanceTimersByTimeAsync(0); + + expect(respond).toHaveBeenCalledWith( + true, + expect.objectContaining({ + runId: "wait-no-tools", + status: "ok", + stopReason: "stop", + pendingToolCalls: undefined, + }), + ); + expect(mocks.waitForTerminalGatewayDedupe).toHaveBeenCalledTimes(1); + + await vi.runAllTimersAsync(); + await waitPromise; + } finally { + vi.useRealTimers(); + } + }); + + it("does not grace-wait for errors when lifecycle metadata omits stopReason", async () => { + vi.useFakeTimers(); + try { + const respond = vi.fn(); + const context = makeContext(); + context.chatAbortControllers = new Map([ + [ + "wait-no-stop-reason", + { + controller: new AbortController(), + sessionKey: "agent:main:main", + sessionId: "test-session", + startedAtMs: Date.now(), + expiresAtMs: Date.now() + 60_000, + }, + ], + ]); + mocks.waitForAgentJob.mockResolvedValue({ + status: "error", + startedAt: 10, + endedAt: 20, + error: "boom", + }); + mocks.waitForTerminalGatewayDedupe.mockImplementation( + () => + new Promise((resolve) => { + setTimeout(() => resolve(null), 1_000); + }), + ); + + const waitPromise = agentHandlers["agent.wait"]({ + params: { runId: "wait-no-stop-reason", timeoutMs: 100 }, + respond, + context, + } as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]); + + await vi.advanceTimersByTimeAsync(0); + + expect(respond).toHaveBeenCalledWith( + true, + expect.objectContaining({ + runId: "wait-no-stop-reason", + status: "error", + error: "boom", + stopReason: undefined, + pendingToolCalls: undefined, + }), + ); + + await vi.runAllTimersAsync(); + await waitPromise; + } finally { + vi.useRealTimers(); + } + }); }); diff --git a/src/gateway/server-methods/agent.ts b/src/gateway/server-methods/agent.ts index bd5637fa78f..3b3d4787f06 100644 --- a/src/gateway/server-methods/agent.ts +++ b/src/gateway/server-methods/agent.ts @@ -7,6 +7,7 @@ import { } from "../../agents/spawned-context.js"; import { buildBareSessionResetPrompt } from "../../auto-reply/reply/session-reset-prompt.js"; import { agentCommandFromIngress } from "../../commands/agent.js"; +import type { AgentStreamParams } from "../../commands/agent/types.js"; import { loadConfig } from "../../config/config.js"; import { mergeSessionEntry, @@ -67,6 +68,28 @@ import { normalizeRpcAttachmentsToChatAttachments } from "./attachment-normalize import type { GatewayRequestHandlerOptions, GatewayRequestHandlers } from "./types.js"; const RESET_COMMAND_RE = /^\/(new|reset)(?:\s+([\s\S]*))?$/i; +const AGENT_WAIT_DEDUPE_METADATA_GRACE_MS = 5_000; + +function mergeAgentWaitStructuredMetadata( + snapshot: T, + dedupeSnapshot: AgentWaitTerminalSnapshot | null | undefined, +): T { + if (!dedupeSnapshot) { + return snapshot; + } + return { + ...snapshot, + stopReason: snapshot.stopReason ?? dedupeSnapshot.stopReason, + pendingToolCalls: snapshot.pendingToolCalls ?? dedupeSnapshot.pendingToolCalls, + }; +} + +function isMissingAgentWaitStructuredMetadata(snapshot: AgentWaitTerminalSnapshot): boolean { + if (snapshot.stopReason === undefined && snapshot.status === "ok") { + return true; + } + return snapshot.stopReason === "tool_calls" && snapshot.pendingToolCalls === undefined; +} function resolveSenderIsOwnerFromClient(client: GatewayRequestHandlerOptions["client"]): boolean { const scopes = Array.isArray(client?.connect?.scopes) ? client.connect.scopes : []; @@ -232,6 +255,16 @@ export const agentHandlers: GatewayRequestHandlers = { lane?: string; extraSystemPrompt?: string; internalEvents?: AgentInternalEvent[]; + clientTools?: Array<{ + type: "function"; + function: { + name: string; + description?: string; + parameters?: Record; + }; + }>; + disableTools?: boolean; + streamParams?: AgentStreamParams; idempotencyKey: string; timeout?: number; bestEffortDeliver?: boolean; @@ -698,6 +731,9 @@ export const agentHandlers: GatewayRequestHandlers = { lane: request.lane, extraSystemPrompt: request.extraSystemPrompt, internalEvents: request.internalEvents, + clientTools: request.clientTools, + disableTools: request.disableTools, + streamParams: request.streamParams, inputProvenance, // Internal-only: allow workspace override for spawned subagent runs. workspaceDir: resolveIngressWorkspaceOverrideForSpawnedRun({ @@ -799,6 +835,8 @@ export const agentHandlers: GatewayRequestHandlers = { startedAt: cachedGatewaySnapshot.startedAt, endedAt: cachedGatewaySnapshot.endedAt, error: cachedGatewaySnapshot.error, + stopReason: cachedGatewaySnapshot.stopReason, + pendingToolCalls: cachedGatewaySnapshot.pendingToolCalls, }); return; } @@ -830,6 +868,44 @@ export const agentHandlers: GatewayRequestHandlers = { first.snapshot; if (snapshot) { if (first.source === "lifecycle") { + snapshot = mergeAgentWaitStructuredMetadata( + snapshot, + readTerminalSnapshotFromGatewayDedupe({ + dedupe: context.dedupe, + runId, + ignoreAgentTerminalSnapshot: hasActiveChatRun, + }), + ); + if (snapshot.stopReason === undefined) { + const immediateDedupeMetadata = + (await Promise.race([ + dedupePromise, + Promise.resolve(null), + ])) ?? null; + snapshot = mergeAgentWaitStructuredMetadata(snapshot, immediateDedupeMetadata); + } + if (isMissingAgentWaitStructuredMetadata(snapshot)) { + let graceTimer: ReturnType | null = null; + const dedupeMetadata = + (await Promise.race([ + dedupePromise.finally(() => { + if (graceTimer != null) { + clearTimeout(graceTimer); + } + }), + new Promise((resolve) => { + graceTimer = setTimeout( + () => resolve(null), + Math.max( + 1, + Math.min(timeoutMs, AGENT_WAIT_DEDUPE_METADATA_GRACE_MS, 2_147_483_647), + ), + ); + graceTimer.unref?.(); + }), + ])) ?? null; + snapshot = mergeAgentWaitStructuredMetadata(snapshot, dedupeMetadata); + } dedupeAbortController.abort(); } else { lifecycleAbortController.abort(); @@ -853,6 +929,8 @@ export const agentHandlers: GatewayRequestHandlers = { startedAt: snapshot.startedAt, endedAt: snapshot.endedAt, error: snapshot.error, + stopReason: snapshot.stopReason, + pendingToolCalls: snapshot.pendingToolCalls, }); }, }; diff --git a/src/gateway/server-methods/pending-tool-calls.ts b/src/gateway/server-methods/pending-tool-calls.ts new file mode 100644 index 00000000000..11bbbddf722 --- /dev/null +++ b/src/gateway/server-methods/pending-tool-calls.ts @@ -0,0 +1,29 @@ +export type PendingToolCall = { + id: string; + name: string; + arguments: string; +}; + +export function parsePendingToolCalls(value: unknown): PendingToolCall[] | undefined { + if (!Array.isArray(value)) { + return undefined; + } + const calls = value + .map((entry) => { + if (!entry || typeof entry !== "object") { + return null; + } + const record = entry as Record; + return typeof record.id === "string" && + typeof record.name === "string" && + typeof record.arguments === "string" + ? { + id: record.id, + name: record.name, + arguments: record.arguments, + } + : null; + }) + .filter((entry): entry is PendingToolCall => entry !== null); + return calls.length > 0 ? calls : undefined; +} diff --git a/src/gateway/server-plugins.test.ts b/src/gateway/server-plugins.test.ts index 7bb0fb20f06..fe374c70aae 100644 --- a/src/gateway/server-plugins.test.ts +++ b/src/gateway/server-plugins.test.ts @@ -29,14 +29,15 @@ vi.mock("./server-methods.js", () => ({ })); vi.mock("../channels/registry.js", () => ({ - CHAT_CHANNEL_ORDER: [], - CHANNEL_IDS: [], + CHAT_CHANNEL_ORDER: ["telegram", "discord", "slack"], + CHANNEL_IDS: ["telegram", "discord", "slack"], listChatChannels: () => [], listChatChannelAliases: () => [], getChatChannelMeta: () => null, normalizeChatChannelId: () => null, normalizeChannelId: () => null, - normalizeAnyChannelId: () => null, + normalizeAnyChannelId: (raw?: string | null) => + typeof raw === "string" && raw.trim().length > 0 ? raw.trim().toLowerCase() : null, formatChannelPrimerLine: () => "", formatChannelSelectionLine: () => "", })); @@ -92,6 +93,11 @@ function getLastDispatchedClientScopes(): string[] { return Array.isArray(scopes) ? scopes : []; } +function getLastDispatchedRequest() { + const call = handleGatewayRequest.mock.calls.at(-1)?.[0]; + return call?.req; +} + async function loadTestModules() { serverPluginsModule = await import("./server-plugins.js"); runtimeModule = await import("../plugins/runtime/index.js"); @@ -99,6 +105,10 @@ async function loadTestModules() { methodScopesModule = await import("./method-scopes.js"); } +async function importServerPluginsModule(): Promise { + return import("./server-plugins.js"); +} + async function createSubagentRuntime( serverPlugins: ServerPluginsModule, cfg: Record = {}, @@ -147,7 +157,17 @@ beforeEach(() => { opts.respond(true, { runId: "run-1" }); return; case "agent.wait": - opts.respond(true, { status: "ok" }); + opts.respond(true, { + status: "ok", + stopReason: "tool_calls", + pendingToolCalls: [ + { + id: "call-1", + name: "emit_structured_result", + arguments: '{"entries":[]}', + }, + ], + }); return; case "sessions.get": opts.respond(true, { messages: [] }); @@ -579,4 +599,92 @@ describe("loadGatewayPlugins", () => { | undefined; expect(dispatched?.marker).toBe("after-mutation"); }); + + test("forwards structured plugin subagent options to gateway agent methods", async () => { + const serverPlugins = await importServerPluginsModule(); + const runtime = await createSubagentRuntime(serverPlugins); + serverPlugins.setFallbackGatewayContext(createTestContext("structured-output")); + + await runtime.run({ + sessionKey: "s-structured", + message: "extract memories", + disableTools: true, + clientTools: [ + { + type: "function", + function: { + name: "emit_structured_result", + description: "Return a structured result payload.", + parameters: { + type: "object", + properties: { + entries: { + type: "array", + }, + }, + }, + }, + }, + ], + streamParams: { + toolChoice: { + type: "function", + function: { + name: "emit_structured_result", + }, + }, + }, + }); + + expect(getLastDispatchedRequest()).toEqual( + expect.objectContaining({ + type: "req", + id: expect.any(String), + method: "agent", + params: expect.objectContaining({ + sessionKey: "s-structured", + message: "extract memories", + disableTools: true, + clientTools: [ + { + type: "function", + function: expect.objectContaining({ + name: "emit_structured_result", + }), + }, + ], + streamParams: { + toolChoice: { + type: "function", + function: { + name: "emit_structured_result", + }, + }, + }, + }), + }), + ); + }); + + test("returns pending tool calls from gateway agent.wait", async () => { + const serverPlugins = await importServerPluginsModule(); + const runtime = await createSubagentRuntime(serverPlugins); + + const result = await runtime.waitForRun({ + runId: "run-1", + timeoutMs: 1_000, + }); + + expect(result).toEqual({ + status: "ok", + stopReason: "tool_calls", + pendingToolCalls: [ + { + id: "call-1", + name: "emit_structured_result", + arguments: '{"entries":[]}', + }, + ], + }); + }); }); diff --git a/src/gateway/server-plugins.ts b/src/gateway/server-plugins.ts index 071819be73e..aa84d6b4955 100644 --- a/src/gateway/server-plugins.ts +++ b/src/gateway/server-plugins.ts @@ -333,6 +333,9 @@ function createGatewaySubagentRuntime(): PluginRuntime["subagent"] { ...(allowOverride && params.model && { model: params.model }), ...(params.extraSystemPrompt && { extraSystemPrompt: params.extraSystemPrompt }), ...(params.lane && { lane: params.lane }), + ...(params.clientTools && { clientTools: params.clientTools }), + ...(params.disableTools === true && { disableTools: true }), + ...(params.streamParams && { streamParams: params.streamParams }), ...(params.idempotencyKey && { idempotencyKey: params.idempotencyKey }), }, { @@ -346,13 +349,19 @@ function createGatewaySubagentRuntime(): PluginRuntime["subagent"] { return { runId }; }, async waitForRun(params) { - const payload = await dispatchGatewayMethod<{ status?: string; error?: string }>( - "agent.wait", - { - runId: params.runId, - ...(params.timeoutMs != null && { timeoutMs: params.timeoutMs }), - }, - ); + const payload = await dispatchGatewayMethod<{ + status?: string; + error?: string; + stopReason?: string; + pendingToolCalls?: Array<{ + id: string; + name: string; + arguments: string; + }>; + }>("agent.wait", { + runId: params.runId, + ...(params.timeoutMs != null && { timeoutMs: params.timeoutMs }), + }); const status = payload?.status; if (status !== "ok" && status !== "error" && status !== "timeout") { throw new Error(`Gateway agent.wait returned unexpected status: ${status}`); @@ -360,6 +369,11 @@ function createGatewaySubagentRuntime(): PluginRuntime["subagent"] { return { status, ...(typeof payload?.error === "string" && payload.error && { error: payload.error }), + ...(typeof payload?.stopReason === "string" && + payload.stopReason && { stopReason: payload.stopReason }), + ...(Array.isArray(payload?.pendingToolCalls) && payload.pendingToolCalls.length > 0 + ? { pendingToolCalls: payload.pendingToolCalls } + : {}), }; }, getSessionMessages, diff --git a/src/plugins/runtime/types.ts b/src/plugins/runtime/types.ts index aa1118ecf92..e9511e709d9 100644 --- a/src/plugins/runtime/types.ts +++ b/src/plugins/runtime/types.ts @@ -1,3 +1,5 @@ +import type { ClientToolDefinition } from "../../agents/pi-embedded-runner/run/params.js"; +import type { AgentStreamParams } from "../../commands/agent/types.js"; import type { PluginRuntimeChannel } from "./types-channel.js"; import type { PluginRuntimeCore, RuntimeLogger } from "./types-core.js"; @@ -14,6 +16,9 @@ export type SubagentRunParams = { lane?: string; deliver?: boolean; idempotencyKey?: string; + clientTools?: ClientToolDefinition[]; + disableTools?: boolean; + streamParams?: AgentStreamParams; }; export type SubagentRunResult = { @@ -28,6 +33,12 @@ export type SubagentWaitParams = { export type SubagentWaitResult = { status: "ok" | "error" | "timeout"; error?: string; + stopReason?: string; + pendingToolCalls?: Array<{ + id: string; + name: string; + arguments: string; + }>; }; export type SubagentGetSessionMessagesParams = {