Merge 714b5947424e54e48377c8ee191ffe6ef0b329ad into d78e13f545136fcbba1feceecc5e0485a06c33a6

This commit is contained in:
j.osawa 2026-03-21 04:50:38 +00:00 committed by GitHub
commit 0ca2ec6793
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 916 additions and 23 deletions

View File

@ -537,6 +537,9 @@ public struct AgentParams: Codable, Sendable {
public let lane: String?
public let extrasystemprompt: String?
public let internalevents: [[String: AnyCodable]]?
public let clienttools: [[String: AnyCodable]]?
public let disabletools: Bool?
public let streamparams: [String: AnyCodable]?
public let inputprovenance: [String: AnyCodable]?
public let idempotencykey: String
public let label: String?
@ -566,6 +569,9 @@ public struct AgentParams: Codable, Sendable {
lane: String?,
extrasystemprompt: String?,
internalevents: [[String: AnyCodable]]?,
clienttools: [[String: AnyCodable]]?,
disabletools: Bool?,
streamparams: [String: AnyCodable]?,
inputprovenance: [String: AnyCodable]?,
idempotencykey: String,
label: String?)
@ -594,6 +600,9 @@ public struct AgentParams: Codable, Sendable {
self.lane = lane
self.extrasystemprompt = extrasystemprompt
self.internalevents = internalevents
self.clienttools = clienttools
self.disabletools = disabletools
self.streamparams = streamparams
self.inputprovenance = inputprovenance
self.idempotencykey = idempotencykey
self.label = label
@ -624,6 +633,9 @@ public struct AgentParams: Codable, Sendable {
case lane
case extrasystemprompt = "extraSystemPrompt"
case internalevents = "internalEvents"
case clienttools = "clientTools"
case disabletools = "disableTools"
case streamparams = "streamParams"
case inputprovenance = "inputProvenance"
case idempotencykey = "idempotencyKey"
case label

View File

@ -537,6 +537,9 @@ public struct AgentParams: Codable, Sendable {
public let lane: String?
public let extrasystemprompt: String?
public let internalevents: [[String: AnyCodable]]?
public let clienttools: [[String: AnyCodable]]?
public let disabletools: Bool?
public let streamparams: [String: AnyCodable]?
public let inputprovenance: [String: AnyCodable]?
public let idempotencykey: String
public let label: String?
@ -566,6 +569,9 @@ public struct AgentParams: Codable, Sendable {
lane: String?,
extrasystemprompt: String?,
internalevents: [[String: AnyCodable]]?,
clienttools: [[String: AnyCodable]]?,
disabletools: Bool?,
streamparams: [String: AnyCodable]?,
inputprovenance: [String: AnyCodable]?,
idempotencykey: String,
label: String?)
@ -594,6 +600,9 @@ public struct AgentParams: Codable, Sendable {
self.lane = lane
self.extrasystemprompt = extrasystemprompt
self.internalevents = internalevents
self.clienttools = clienttools
self.disabletools = disabletools
self.streamparams = streamparams
self.inputprovenance = inputprovenance
self.idempotencykey = idempotencykey
self.label = label
@ -624,6 +633,9 @@ public struct AgentParams: Codable, Sendable {
case lane
case extrasystemprompt = "extraSystemPrompt"
case internalevents = "internalEvents"
case clienttools = "clientTools"
case disabletools = "disableTools"
case streamparams = "streamParams"
case inputprovenance = "inputProvenance"
case idempotencykey = "idempotencyKey"
case label

View File

@ -512,6 +512,7 @@ function runAgentAttempt(params: {
prompt: effectivePrompt,
images: params.isFallbackRetry ? undefined : params.opts.images,
clientTools: params.opts.clientTools,
disableTools: params.opts.disableTools,
provider: params.providerOverride,
model: params.modelOverride,
authProfileId,
@ -1236,6 +1237,7 @@ async function agentCommandInternal(
endedAt: Date.now(),
aborted: result.meta.aborted ?? false,
stopReason,
pendingToolCalls: result.meta.pendingToolCalls,
},
});
}

View File

@ -17,6 +17,17 @@ export type AgentStreamParams = {
maxTokens?: number;
/** Provider fast-mode override (best-effort). */
fastMode?: boolean;
/** Provider tool-choice override (best-effort). */
toolChoice?:
| "auto"
| "none"
| "required"
| {
type: "function";
function: {
name: string;
};
};
};
export type AgentRunContext = {
@ -37,6 +48,8 @@ export type AgentCommandOpts = {
images?: ImageContent[];
/** Optional client-provided tools (OpenResponses hosted tools). */
clientTools?: ClientToolDefinition[];
/** Disable built-in tools for this run while still allowing clientTools. */
disableTools?: boolean;
/** Agent id override (must exist in config). */
agentId?: string;
/** Per-run provider override. */

View File

@ -144,4 +144,98 @@ describe("cacheRetention default behavior", () => {
provider: "anthropic",
});
});
it("forwards toolChoice overrides to the wrapped stream options", async () => {
const baseStreamFn = vi.fn(async () => undefined);
const agent: { streamFn?: StreamFn } = { streamFn: baseStreamFn as unknown as StreamFn };
applyExtraParamsToAgent(
agent,
undefined,
"openai",
"gpt-4.1",
{
toolChoice: {
type: "function",
function: {
name: "emit_structured_result",
},
},
},
undefined,
undefined,
undefined,
new Set(["emit_structured_result"]),
);
expect(agent.streamFn).toBeDefined();
await agent.streamFn?.(
{ api: "openai-responses", provider: "openai", id: "gpt-4.1", compat: {} } as never,
{} as never,
{},
);
expect(baseStreamFn).toHaveBeenCalledWith(
expect.anything(),
expect.anything(),
expect.objectContaining({
toolChoice: {
type: "function",
function: {
name: "emit_structured_result",
},
},
}),
);
});
it("drops stale function toolChoice selections that are not in the allowed tool set", async () => {
const baseStreamFn = vi.fn(async () => undefined);
const agent: { streamFn?: StreamFn } = { streamFn: baseStreamFn as unknown as StreamFn };
const cfg = {
agents: {
defaults: {
models: {
"openai/gpt-4.1": {
params: {
toolChoice: {
type: "function",
function: {
name: "bash",
},
},
},
},
},
},
},
};
applyExtraParamsToAgent(
agent,
cfg,
"openai",
"gpt-4.1",
undefined,
undefined,
undefined,
undefined,
new Set(["emit_structured_result"]),
);
expect(agent.streamFn).toBeDefined();
await agent.streamFn?.(
{ api: "openai-responses", provider: "openai", id: "gpt-4.1", compat: {} } as never,
{} as never,
{},
);
expect(baseStreamFn).toHaveBeenCalledWith(
expect.anything(),
expect.anything(),
expect.objectContaining({
toolChoice: "auto",
}),
);
});
});

View File

@ -2,11 +2,13 @@ import type { StreamFn } from "@mariozechner/pi-agent-core";
import type { SimpleStreamOptions } from "@mariozechner/pi-ai";
import { streamSimple } from "@mariozechner/pi-ai";
import type { ThinkLevel } from "../../auto-reply/thinking.js";
import type { AgentStreamParams } from "../../commands/agent/types.js";
import type { OpenClawConfig } from "../../config/config.js";
import {
prepareProviderExtraParams,
wrapProviderStreamFn,
} from "../../plugins/provider-runtime.js";
import { normalizeToolName } from "../tool-policy.js";
import {
createAnthropicBetaHeadersWrapper,
createAnthropicFastModeWrapper,
@ -72,8 +74,101 @@ export function resolveExtraParams(params: {
type CacheRetentionStreamOptions = Partial<SimpleStreamOptions> & {
cacheRetention?: "none" | "short" | "long";
openaiWsWarmup?: boolean;
toolChoice?: NonNullable<AgentStreamParams["toolChoice"]>;
};
type ToolChoiceOverride = NonNullable<AgentStreamParams["toolChoice"]>;
function isToolChoiceOverride(value: unknown): value is ToolChoiceOverride {
if (value === "auto" || value === "none" || value === "required") {
return true;
}
if (!value || typeof value !== "object") {
return false;
}
const record = value as Record<string, unknown>;
const fn = record.function;
return (
record.type === "function" &&
!!fn &&
typeof fn === "object" &&
typeof (fn as Record<string, unknown>).name === "string"
);
}
function resolveAllowedToolChoiceName(
rawName: string,
allowedToolNames?: Set<string>,
): string | undefined {
if (!allowedToolNames || allowedToolNames.size === 0) {
return undefined;
}
const trimmed = rawName.trim();
if (!trimmed) {
return undefined;
}
if (allowedToolNames.has(trimmed)) {
return trimmed;
}
const normalized = normalizeToolName(trimmed);
if (allowedToolNames.has(normalized)) {
return normalized;
}
const lowered = normalized.toLowerCase();
let caseInsensitiveMatch: string | undefined;
for (const candidate of allowedToolNames) {
if (candidate.toLowerCase() !== lowered) {
continue;
}
if (caseInsensitiveMatch && caseInsensitiveMatch !== candidate) {
return undefined;
}
caseInsensitiveMatch = candidate;
}
return caseInsensitiveMatch;
}
function sanitizeToolChoiceOverride(
extraParams: Record<string, unknown> | undefined,
allowedToolNames?: Set<string>,
): Record<string, unknown> | undefined {
if (!extraParams || !isToolChoiceOverride(extraParams.toolChoice)) {
return extraParams;
}
if (!allowedToolNames || allowedToolNames.size === 0) {
if (extraParams.toolChoice === "none") {
return extraParams;
}
return {
...extraParams,
toolChoice: "none",
};
}
const toolChoice = extraParams.toolChoice;
if (toolChoice === "auto" || toolChoice === "none" || toolChoice === "required") {
return extraParams;
}
const resolvedName = resolveAllowedToolChoiceName(toolChoice.function.name, allowedToolNames);
if (!resolvedName) {
return {
...extraParams,
toolChoice: "auto",
};
}
if (resolvedName === toolChoice.function.name) {
return extraParams;
}
return {
...extraParams,
toolChoice: {
type: "function",
function: {
name: resolvedName,
},
},
};
}
function createStreamFnWithExtraParams(
baseStreamFn: StreamFn | undefined,
extraParams: Record<string, unknown> | undefined,
@ -90,6 +185,9 @@ function createStreamFnWithExtraParams(
if (typeof extraParams.maxTokens === "number") {
streamParams.maxTokens = extraParams.maxTokens;
}
if (isToolChoiceOverride(extraParams.toolChoice)) {
streamParams.toolChoice = extraParams.toolChoice;
}
const transport = extraParams.transport;
if (transport === "sse" || transport === "websocket" || transport === "auto") {
streamParams.transport = transport;
@ -184,6 +282,7 @@ export function applyExtraParamsToAgent(
thinkingLevel?: ThinkLevel,
agentId?: string,
workspaceDir?: string,
allowedToolNames?: Set<string>,
): void {
const resolvedExtraParams = resolveExtraParams({
cfg,
@ -210,10 +309,11 @@ export function applyExtraParamsToAgent(
thinkingLevel,
},
}) ?? merged;
const sanitizedExtraParams = sanitizeToolChoiceOverride(effectiveExtraParams, allowedToolNames);
const wrappedStreamFn = createStreamFnWithExtraParams(
agent.streamFn,
effectiveExtraParams,
sanitizedExtraParams,
provider,
);
@ -222,7 +322,7 @@ export function applyExtraParamsToAgent(
agent.streamFn = wrappedStreamFn;
}
const anthropicBetas = resolveAnthropicBetas(effectiveExtraParams, provider, modelId);
const anthropicBetas = resolveAnthropicBetas(sanitizedExtraParams, provider, modelId);
if (anthropicBetas?.length) {
log.debug(
`applying Anthropic beta header for ${provider}/${modelId}: ${anthropicBetas.join(",")}`,
@ -249,7 +349,7 @@ export function applyExtraParamsToAgent(
config: cfg,
provider,
modelId,
extraParams: effectiveExtraParams,
extraParams: sanitizedExtraParams,
thinkingLevel,
streamFn: providerStreamBase,
},
@ -263,7 +363,7 @@ export function applyExtraParamsToAgent(
// actually handled the stream function. This covers tests/disabled plugins
// and Ollama Cloud Kimi models until they gain a dedicated runtime hook.
const thinkingType = resolveMoonshotThinkingType({
configuredThinking: effectiveExtraParams?.thinking,
configuredThinking: sanitizedExtraParams?.thinking,
thinkingLevel,
});
agent.streamFn = createMoonshotThinkingWrapper(agent.streamFn, thinkingType);
@ -275,13 +375,13 @@ export function applyExtraParamsToAgent(
agent.streamFn = createAnthropicFastModeWrapper(agent.streamFn, anthropicFastMode);
}
const openAIFastMode = resolveOpenAIFastMode(effectiveExtraParams);
const openAIFastMode = resolveOpenAIFastMode(sanitizedExtraParams);
if (openAIFastMode) {
log.debug(`applying OpenAI fast mode for ${provider}/${modelId}`);
agent.streamFn = createOpenAIFastModeWrapper(agent.streamFn);
}
const openAIServiceTier = resolveOpenAIServiceTier(effectiveExtraParams);
const openAIServiceTier = resolveOpenAIServiceTier(sanitizedExtraParams);
if (openAIServiceTier) {
log.debug(`applying OpenAI service_tier=${openAIServiceTier} for ${provider}/${modelId}`);
agent.streamFn = createOpenAIServiceTierWrapper(agent.streamFn, openAIServiceTier);
@ -292,7 +392,7 @@ export function applyExtraParamsToAgent(
// server-side compaction for compatible OpenAI Responses payloads.
agent.streamFn = createOpenAIResponsesContextManagementWrapper(
agent.streamFn,
effectiveExtraParams,
sanitizedExtraParams,
);
const rawParallelToolCalls = resolveAliasedParamValue(

View File

@ -2243,18 +2243,21 @@ export async function runEmbeddedAttempt(
activeSession.agent.streamFn = wrapOllamaCompatNumCtx(activeSession.agent.streamFn, numCtx);
}
const effectiveStreamParams: Record<string, unknown> = {
...params.streamParams,
...(params.fastMode !== undefined ? { fastMode: params.fastMode } : {}),
...(allowedToolNames.size === 0 ? { toolChoice: "none" } : {}),
};
applyExtraParamsToAgent(
activeSession.agent,
params.config,
params.provider,
params.modelId,
{
...params.streamParams,
fastMode: params.fastMode,
},
effectiveStreamParams,
params.thinkLevel,
sessionAgentId,
effectiveWorkspace,
allowedToolNames,
);
if (cacheTrace) {

View File

@ -29,6 +29,49 @@ export const AgentEventSchema = Type.Object(
{ additionalProperties: false },
);
const ClientToolDefinitionSchema = Type.Object(
{
type: Type.Literal("function"),
function: Type.Object(
{
name: NonEmptyString,
description: Type.Optional(Type.String()),
parameters: Type.Optional(Type.Record(Type.String(), Type.Unknown())),
},
{ additionalProperties: false },
),
},
{ additionalProperties: false },
);
const AgentToolChoiceSchema = Type.Union([
Type.Literal("auto"),
Type.Literal("none"),
Type.Literal("required"),
Type.Object(
{
type: Type.Literal("function"),
function: Type.Object(
{
name: NonEmptyString,
},
{ additionalProperties: false },
),
},
{ additionalProperties: false },
),
]);
const AgentStreamParamsSchema = Type.Object(
{
temperature: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Number()),
fastMode: Type.Optional(Type.Boolean()),
toolChoice: Type.Optional(AgentToolChoiceSchema),
},
{ additionalProperties: false },
);
export const SendParamsSchema = Type.Object(
{
to: NonEmptyString,
@ -97,6 +140,9 @@ export const AgentParamsSchema = Type.Object(
lane: Type.Optional(Type.String()),
extraSystemPrompt: Type.Optional(Type.String()),
internalEvents: Type.Optional(Type.Array(AgentInternalEventSchema)),
clientTools: Type.Optional(Type.Array(ClientToolDefinitionSchema)),
disableTools: Type.Optional(Type.Boolean()),
streamParams: Type.Optional(AgentStreamParamsSchema),
inputProvenance: Type.Optional(InputProvenanceSchema),
idempotencyKey: NonEmptyString,
label: Type.Optional(SessionLabelString),

View File

@ -1,4 +1,5 @@
import { onAgentEvent } from "../../infra/agent-events.js";
import { parsePendingToolCalls } from "./pending-tool-calls.js";
const AGENT_RUN_CACHE_TTL_MS = 10 * 60_000;
/**
@ -19,6 +20,12 @@ type AgentRunSnapshot = {
startedAt?: number;
endedAt?: number;
error?: string;
stopReason?: string;
pendingToolCalls?: Array<{
id: string;
name: string;
arguments: string;
}>;
ts: number;
};
@ -86,12 +93,16 @@ function createSnapshotFromLifecycleEvent(params: {
typeof data?.startedAt === "number" ? data.startedAt : agentRunStarts.get(runId);
const endedAt = typeof data?.endedAt === "number" ? data.endedAt : undefined;
const error = typeof data?.error === "string" ? data.error : undefined;
const stopReason = typeof data?.stopReason === "string" ? data.stopReason : undefined;
const pendingToolCalls = parsePendingToolCalls(data?.pendingToolCalls);
return {
runId,
status: phase === "error" ? "error" : data?.aborted ? "timeout" : "ok",
startedAt,
endedAt,
error,
stopReason,
pendingToolCalls: pendingToolCalls?.length ? pendingToolCalls : undefined,
ts: Date.now(),
};
}

View File

@ -258,6 +258,54 @@ describe("agent wait dedupe helper", () => {
});
});
it("extracts stopReason and pendingToolCalls from nested agent result metadata", () => {
const dedupe = new Map();
const runId = "run-structured-agent";
setRunEntry({
dedupe,
kind: "agent",
runId,
payload: {
runId,
status: "ok",
startedAt: 10,
endedAt: 20,
result: {
meta: {
stopReason: "tool_calls",
pendingToolCalls: [
{
id: "call-1",
name: "emit_structured_result",
arguments: '{"entries":[]}',
},
],
},
},
},
});
expect(
readTerminalSnapshotFromGatewayDedupe({
dedupe,
runId,
}),
).toEqual({
status: "ok",
startedAt: 10,
endedAt: 20,
error: undefined,
stopReason: "tool_calls",
pendingToolCalls: [
{
id: "call-1",
name: "emit_structured_result",
arguments: '{"entries":[]}',
},
],
});
});
it("resolves multiple waiters for the same run id", async () => {
const dedupe = new Map();
const runId = "run-multi";

View File

@ -1,10 +1,17 @@
import type { DedupeEntry } from "../server-shared.js";
import { parsePendingToolCalls } from "./pending-tool-calls.js";
export type AgentWaitTerminalSnapshot = {
status: "ok" | "error" | "timeout";
startedAt?: number;
endedAt?: number;
error?: string;
stopReason?: string;
pendingToolCalls?: Array<{
id: string;
name: string;
arguments: string;
}>;
};
const AGENT_WAITERS_BY_RUN_ID = new Map<string, Set<() => void>>();
@ -72,6 +79,14 @@ export function readTerminalSnapshotFromDedupeEntry(
endedAt?: unknown;
error?: unknown;
summary?: unknown;
stopReason?: unknown;
pendingToolCalls?: unknown;
result?: {
meta?: {
stopReason?: unknown;
pendingToolCalls?: unknown;
};
};
}
| undefined;
const status = typeof payload?.status === "string" ? payload.status : undefined;
@ -87,6 +102,15 @@ export function readTerminalSnapshotFromDedupeEntry(
: typeof payload?.summary === "string"
? payload.summary
: entry.error?.message;
const stopReason =
typeof payload?.result?.meta?.stopReason === "string"
? payload.result.meta.stopReason
: typeof payload?.stopReason === "string"
? payload.stopReason
: undefined;
const pendingToolCalls =
parsePendingToolCalls(payload?.result?.meta?.pendingToolCalls) ??
parsePendingToolCalls(payload?.pendingToolCalls);
if (status === "ok" || status === "timeout") {
return {
@ -94,6 +118,8 @@ export function readTerminalSnapshotFromDedupeEntry(
startedAt,
endedAt,
error: status === "timeout" ? errorMessage : undefined,
stopReason,
pendingToolCalls,
};
}
if (status === "error" || !entry.ok) {
@ -102,6 +128,8 @@ export function readTerminalSnapshotFromDedupeEntry(
startedAt,
endedAt,
error: errorMessage,
stopReason,
pendingToolCalls,
};
}
return null;

View File

@ -1,4 +1,4 @@
import { describe, expect, it, vi } from "vitest";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { BARE_SESSION_RESET_PROMPT } from "../../auto-reply/reply/session-reset-prompt.js";
import { agentHandlers } from "./agent.js";
import type { GatewayRequestContext } from "./types.js";
@ -10,6 +10,9 @@ const mocks = vi.hoisted(() => ({
agentCommand: vi.fn(),
registerAgentRunContext: vi.fn(),
performGatewaySessionReset: vi.fn(),
waitForAgentJob: vi.fn(),
readTerminalSnapshotFromGatewayDedupe: vi.fn(),
waitForTerminalGatewayDedupe: vi.fn(),
getSubagentRunByChildSessionKey: vi.fn(),
replaceSubagentRunAfterSteer: vi.fn(),
loadConfigReturn: {} as Record<string, unknown>,
@ -76,6 +79,17 @@ vi.mock("../session-reset-service.js", () => ({
(mocks.performGatewaySessionReset as (...args: unknown[]) => unknown)(...args),
}));
vi.mock("./agent-job.js", () => ({
waitForAgentJob: (...args: unknown[]) => mocks.waitForAgentJob(...args),
}));
vi.mock("./agent-wait-dedupe.js", () => ({
readTerminalSnapshotFromGatewayDedupe: (...args: unknown[]) =>
mocks.readTerminalSnapshotFromGatewayDedupe(...args),
setGatewayDedupeEntry: vi.fn(),
waitForTerminalGatewayDedupe: (...args: unknown[]) => mocks.waitForTerminalGatewayDedupe(...args),
}));
vi.mock("../../sessions/send-policy.js", () => ({
resolveSendPolicy: () => "allow",
}));
@ -105,6 +119,16 @@ type AgentParams = AgentHandlerArgs["params"];
type AgentIdentityGetHandlerArgs = Parameters<(typeof agentHandlers)["agent.identity.get"]>[0];
type AgentIdentityGetParams = AgentIdentityGetHandlerArgs["params"];
beforeEach(() => {
mocks.waitForAgentJob.mockReset();
mocks.waitForAgentJob.mockResolvedValue(null);
mocks.readTerminalSnapshotFromGatewayDedupe.mockReset();
mocks.readTerminalSnapshotFromGatewayDedupe.mockReturnValue(null);
mocks.waitForTerminalGatewayDedupe.mockReset();
mocks.waitForTerminalGatewayDedupe.mockResolvedValue(null);
mocks.loadConfigReturn = {};
});
async function waitForAssertion(assertion: () => void, timeoutMs = 2_000, stepMs = 5) {
vi.useFakeTimers();
try {
@ -633,6 +657,64 @@ describe("gateway agent handler", () => {
expect(callArgs.bestEffortDeliver).toBe(false);
});
it("forwards structured subagent options to agentCommandFromIngress", async () => {
primeMainAgentRun();
await invokeAgent(
{
message: "structured helper run",
agentId: "main",
sessionKey: "agent:main:main",
disableTools: true,
clientTools: [
{
type: "function",
function: {
name: "emit_structured_result",
description: "Return a structured result payload.",
parameters: {
type: "object",
properties: {
entries: { type: "array" },
},
},
},
},
],
streamParams: {
toolChoice: {
type: "function",
function: {
name: "emit_structured_result",
},
},
},
idempotencyKey: "test-structured-helper",
} as AgentParams,
{ reqId: "structured-helper-1" },
);
await vi.waitFor(() => expect(mocks.agentCommand).toHaveBeenCalled());
const callArgs = mocks.agentCommand.mock.calls.at(-1)?.[0] as Record<string, unknown>;
expect(callArgs.disableTools).toBe(true);
expect(callArgs.clientTools).toEqual([
{
type: "function",
function: expect.objectContaining({
name: "emit_structured_result",
}),
},
]);
expect(callArgs.streamParams).toEqual({
toolChoice: {
type: "function",
function: {
name: "emit_structured_result",
},
},
});
});
it("rejects public spawned-run metadata fields", async () => {
primeMainAgentRun();
mocks.agentCommand.mockClear();
@ -877,4 +959,206 @@ describe("gateway agent handler", () => {
}),
);
});
it("returns structured fields for cached agent.wait snapshots", async () => {
const respond = vi.fn();
const context = makeContext();
context.chatAbortControllers = new Map();
mocks.readTerminalSnapshotFromGatewayDedupe.mockReturnValue({
status: "ok",
startedAt: 10,
endedAt: 20,
stopReason: "tool_calls",
pendingToolCalls: [
{
id: "call-1",
name: "emit_structured_result",
arguments: '{"entries":[]}',
},
],
});
await agentHandlers["agent.wait"]({
params: { runId: "wait-cached", timeoutMs: 100 },
respond,
context,
} as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]);
expect(respond).toHaveBeenCalledWith(
true,
expect.objectContaining({
runId: "wait-cached",
status: "ok",
stopReason: "tool_calls",
pendingToolCalls: [
{
id: "call-1",
name: "emit_structured_result",
arguments: '{"entries":[]}',
},
],
}),
);
});
it("merges structured fields from dedupe when lifecycle resolves first", async () => {
const respond = vi.fn();
const context = makeContext();
context.chatAbortControllers = new Map();
mocks.waitForAgentJob.mockResolvedValue({
status: "ok",
startedAt: 10,
endedAt: 20,
});
mocks.waitForTerminalGatewayDedupe.mockImplementation(async () => {
await Promise.resolve();
return {
status: "ok",
startedAt: 10,
endedAt: 20,
stopReason: "tool_calls",
pendingToolCalls: [
{
id: "call-1",
name: "emit_structured_result",
arguments: '{"entries":[]}',
},
],
};
});
await agentHandlers["agent.wait"]({
params: { runId: "wait-live", timeoutMs: 100 },
respond,
context,
} as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]);
expect(respond).toHaveBeenCalledWith(
true,
expect.objectContaining({
runId: "wait-live",
status: "ok",
stopReason: "tool_calls",
pendingToolCalls: [
{
id: "call-1",
name: "emit_structured_result",
arguments: '{"entries":[]}',
},
],
}),
);
});
it("does not grace-wait when lifecycle resolves without tool calls", async () => {
vi.useFakeTimers();
try {
const respond = vi.fn();
const context = makeContext();
context.chatAbortControllers = new Map([
[
"wait-no-tools",
{
controller: new AbortController(),
sessionKey: "agent:main:main",
sessionId: "test-session",
startedAtMs: Date.now(),
expiresAtMs: Date.now() + 60_000,
},
],
]);
mocks.waitForAgentJob.mockResolvedValue({
status: "ok",
startedAt: 10,
endedAt: 20,
stopReason: "stop",
});
mocks.waitForTerminalGatewayDedupe.mockImplementation(
() =>
new Promise<null>((resolve) => {
setTimeout(() => resolve(null), 1_000);
}),
);
const waitPromise = agentHandlers["agent.wait"]({
params: { runId: "wait-no-tools", timeoutMs: 100 },
respond,
context,
} as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]);
await vi.advanceTimersByTimeAsync(0);
expect(respond).toHaveBeenCalledWith(
true,
expect.objectContaining({
runId: "wait-no-tools",
status: "ok",
stopReason: "stop",
pendingToolCalls: undefined,
}),
);
expect(mocks.waitForTerminalGatewayDedupe).toHaveBeenCalledTimes(1);
await vi.runAllTimersAsync();
await waitPromise;
} finally {
vi.useRealTimers();
}
});
it("does not grace-wait for errors when lifecycle metadata omits stopReason", async () => {
vi.useFakeTimers();
try {
const respond = vi.fn();
const context = makeContext();
context.chatAbortControllers = new Map([
[
"wait-no-stop-reason",
{
controller: new AbortController(),
sessionKey: "agent:main:main",
sessionId: "test-session",
startedAtMs: Date.now(),
expiresAtMs: Date.now() + 60_000,
},
],
]);
mocks.waitForAgentJob.mockResolvedValue({
status: "error",
startedAt: 10,
endedAt: 20,
error: "boom",
});
mocks.waitForTerminalGatewayDedupe.mockImplementation(
() =>
new Promise<null>((resolve) => {
setTimeout(() => resolve(null), 1_000);
}),
);
const waitPromise = agentHandlers["agent.wait"]({
params: { runId: "wait-no-stop-reason", timeoutMs: 100 },
respond,
context,
} as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]);
await vi.advanceTimersByTimeAsync(0);
expect(respond).toHaveBeenCalledWith(
true,
expect.objectContaining({
runId: "wait-no-stop-reason",
status: "error",
error: "boom",
stopReason: undefined,
pendingToolCalls: undefined,
}),
);
await vi.runAllTimersAsync();
await waitPromise;
} finally {
vi.useRealTimers();
}
});
});

View File

@ -7,6 +7,7 @@ import {
} from "../../agents/spawned-context.js";
import { buildBareSessionResetPrompt } from "../../auto-reply/reply/session-reset-prompt.js";
import { agentCommandFromIngress } from "../../commands/agent.js";
import type { AgentStreamParams } from "../../commands/agent/types.js";
import { loadConfig } from "../../config/config.js";
import {
mergeSessionEntry,
@ -67,6 +68,28 @@ import { normalizeRpcAttachmentsToChatAttachments } from "./attachment-normalize
import type { GatewayRequestHandlerOptions, GatewayRequestHandlers } from "./types.js";
const RESET_COMMAND_RE = /^\/(new|reset)(?:\s+([\s\S]*))?$/i;
const AGENT_WAIT_DEDUPE_METADATA_GRACE_MS = 5_000;
function mergeAgentWaitStructuredMetadata<T extends AgentWaitTerminalSnapshot>(
snapshot: T,
dedupeSnapshot: AgentWaitTerminalSnapshot | null | undefined,
): T {
if (!dedupeSnapshot) {
return snapshot;
}
return {
...snapshot,
stopReason: snapshot.stopReason ?? dedupeSnapshot.stopReason,
pendingToolCalls: snapshot.pendingToolCalls ?? dedupeSnapshot.pendingToolCalls,
};
}
function isMissingAgentWaitStructuredMetadata(snapshot: AgentWaitTerminalSnapshot): boolean {
if (snapshot.stopReason === undefined && snapshot.status === "ok") {
return true;
}
return snapshot.stopReason === "tool_calls" && snapshot.pendingToolCalls === undefined;
}
function resolveSenderIsOwnerFromClient(client: GatewayRequestHandlerOptions["client"]): boolean {
const scopes = Array.isArray(client?.connect?.scopes) ? client.connect.scopes : [];
@ -232,6 +255,16 @@ export const agentHandlers: GatewayRequestHandlers = {
lane?: string;
extraSystemPrompt?: string;
internalEvents?: AgentInternalEvent[];
clientTools?: Array<{
type: "function";
function: {
name: string;
description?: string;
parameters?: Record<string, unknown>;
};
}>;
disableTools?: boolean;
streamParams?: AgentStreamParams;
idempotencyKey: string;
timeout?: number;
bestEffortDeliver?: boolean;
@ -698,6 +731,9 @@ export const agentHandlers: GatewayRequestHandlers = {
lane: request.lane,
extraSystemPrompt: request.extraSystemPrompt,
internalEvents: request.internalEvents,
clientTools: request.clientTools,
disableTools: request.disableTools,
streamParams: request.streamParams,
inputProvenance,
// Internal-only: allow workspace override for spawned subagent runs.
workspaceDir: resolveIngressWorkspaceOverrideForSpawnedRun({
@ -799,6 +835,8 @@ export const agentHandlers: GatewayRequestHandlers = {
startedAt: cachedGatewaySnapshot.startedAt,
endedAt: cachedGatewaySnapshot.endedAt,
error: cachedGatewaySnapshot.error,
stopReason: cachedGatewaySnapshot.stopReason,
pendingToolCalls: cachedGatewaySnapshot.pendingToolCalls,
});
return;
}
@ -830,6 +868,44 @@ export const agentHandlers: GatewayRequestHandlers = {
first.snapshot;
if (snapshot) {
if (first.source === "lifecycle") {
snapshot = mergeAgentWaitStructuredMetadata(
snapshot,
readTerminalSnapshotFromGatewayDedupe({
dedupe: context.dedupe,
runId,
ignoreAgentTerminalSnapshot: hasActiveChatRun,
}),
);
if (snapshot.stopReason === undefined) {
const immediateDedupeMetadata =
(await Promise.race([
dedupePromise,
Promise.resolve<AgentWaitTerminalSnapshot | null>(null),
])) ?? null;
snapshot = mergeAgentWaitStructuredMetadata(snapshot, immediateDedupeMetadata);
}
if (isMissingAgentWaitStructuredMetadata(snapshot)) {
let graceTimer: ReturnType<typeof setTimeout> | null = null;
const dedupeMetadata =
(await Promise.race([
dedupePromise.finally(() => {
if (graceTimer != null) {
clearTimeout(graceTimer);
}
}),
new Promise<null>((resolve) => {
graceTimer = setTimeout(
() => resolve(null),
Math.max(
1,
Math.min(timeoutMs, AGENT_WAIT_DEDUPE_METADATA_GRACE_MS, 2_147_483_647),
),
);
graceTimer.unref?.();
}),
])) ?? null;
snapshot = mergeAgentWaitStructuredMetadata(snapshot, dedupeMetadata);
}
dedupeAbortController.abort();
} else {
lifecycleAbortController.abort();
@ -853,6 +929,8 @@ export const agentHandlers: GatewayRequestHandlers = {
startedAt: snapshot.startedAt,
endedAt: snapshot.endedAt,
error: snapshot.error,
stopReason: snapshot.stopReason,
pendingToolCalls: snapshot.pendingToolCalls,
});
},
};

View File

@ -0,0 +1,29 @@
export type PendingToolCall = {
id: string;
name: string;
arguments: string;
};
export function parsePendingToolCalls(value: unknown): PendingToolCall[] | undefined {
if (!Array.isArray(value)) {
return undefined;
}
const calls = value
.map((entry) => {
if (!entry || typeof entry !== "object") {
return null;
}
const record = entry as Record<string, unknown>;
return typeof record.id === "string" &&
typeof record.name === "string" &&
typeof record.arguments === "string"
? {
id: record.id,
name: record.name,
arguments: record.arguments,
}
: null;
})
.filter((entry): entry is PendingToolCall => entry !== null);
return calls.length > 0 ? calls : undefined;
}

View File

@ -29,14 +29,15 @@ vi.mock("./server-methods.js", () => ({
}));
vi.mock("../channels/registry.js", () => ({
CHAT_CHANNEL_ORDER: [],
CHANNEL_IDS: [],
CHAT_CHANNEL_ORDER: ["telegram", "discord", "slack"],
CHANNEL_IDS: ["telegram", "discord", "slack"],
listChatChannels: () => [],
listChatChannelAliases: () => [],
getChatChannelMeta: () => null,
normalizeChatChannelId: () => null,
normalizeChannelId: () => null,
normalizeAnyChannelId: () => null,
normalizeAnyChannelId: (raw?: string | null) =>
typeof raw === "string" && raw.trim().length > 0 ? raw.trim().toLowerCase() : null,
formatChannelPrimerLine: () => "",
formatChannelSelectionLine: () => "",
}));
@ -92,6 +93,11 @@ function getLastDispatchedClientScopes(): string[] {
return Array.isArray(scopes) ? scopes : [];
}
function getLastDispatchedRequest() {
const call = handleGatewayRequest.mock.calls.at(-1)?.[0];
return call?.req;
}
async function loadTestModules() {
serverPluginsModule = await import("./server-plugins.js");
runtimeModule = await import("../plugins/runtime/index.js");
@ -99,6 +105,10 @@ async function loadTestModules() {
methodScopesModule = await import("./method-scopes.js");
}
async function importServerPluginsModule(): Promise<ServerPluginsModule> {
return import("./server-plugins.js");
}
async function createSubagentRuntime(
serverPlugins: ServerPluginsModule,
cfg: Record<string, unknown> = {},
@ -147,7 +157,17 @@ beforeEach(() => {
opts.respond(true, { runId: "run-1" });
return;
case "agent.wait":
opts.respond(true, { status: "ok" });
opts.respond(true, {
status: "ok",
stopReason: "tool_calls",
pendingToolCalls: [
{
id: "call-1",
name: "emit_structured_result",
arguments: '{"entries":[]}',
},
],
});
return;
case "sessions.get":
opts.respond(true, { messages: [] });
@ -579,4 +599,92 @@ describe("loadGatewayPlugins", () => {
| undefined;
expect(dispatched?.marker).toBe("after-mutation");
});
test("forwards structured plugin subagent options to gateway agent methods", async () => {
const serverPlugins = await importServerPluginsModule();
const runtime = await createSubagentRuntime(serverPlugins);
serverPlugins.setFallbackGatewayContext(createTestContext("structured-output"));
await runtime.run({
sessionKey: "s-structured",
message: "extract memories",
disableTools: true,
clientTools: [
{
type: "function",
function: {
name: "emit_structured_result",
description: "Return a structured result payload.",
parameters: {
type: "object",
properties: {
entries: {
type: "array",
},
},
},
},
},
],
streamParams: {
toolChoice: {
type: "function",
function: {
name: "emit_structured_result",
},
},
},
});
expect(getLastDispatchedRequest()).toEqual(
expect.objectContaining({
type: "req",
id: expect.any(String),
method: "agent",
params: expect.objectContaining({
sessionKey: "s-structured",
message: "extract memories",
disableTools: true,
clientTools: [
{
type: "function",
function: expect.objectContaining({
name: "emit_structured_result",
}),
},
],
streamParams: {
toolChoice: {
type: "function",
function: {
name: "emit_structured_result",
},
},
},
}),
}),
);
});
test("returns pending tool calls from gateway agent.wait", async () => {
const serverPlugins = await importServerPluginsModule();
const runtime = await createSubagentRuntime(serverPlugins);
const result = await runtime.waitForRun({
runId: "run-1",
timeoutMs: 1_000,
});
expect(result).toEqual({
status: "ok",
stopReason: "tool_calls",
pendingToolCalls: [
{
id: "call-1",
name: "emit_structured_result",
arguments: '{"entries":[]}',
},
],
});
});
});

View File

@ -333,6 +333,9 @@ function createGatewaySubagentRuntime(): PluginRuntime["subagent"] {
...(allowOverride && params.model && { model: params.model }),
...(params.extraSystemPrompt && { extraSystemPrompt: params.extraSystemPrompt }),
...(params.lane && { lane: params.lane }),
...(params.clientTools && { clientTools: params.clientTools }),
...(params.disableTools === true && { disableTools: true }),
...(params.streamParams && { streamParams: params.streamParams }),
...(params.idempotencyKey && { idempotencyKey: params.idempotencyKey }),
},
{
@ -346,13 +349,19 @@ function createGatewaySubagentRuntime(): PluginRuntime["subagent"] {
return { runId };
},
async waitForRun(params) {
const payload = await dispatchGatewayMethod<{ status?: string; error?: string }>(
"agent.wait",
{
runId: params.runId,
...(params.timeoutMs != null && { timeoutMs: params.timeoutMs }),
},
);
const payload = await dispatchGatewayMethod<{
status?: string;
error?: string;
stopReason?: string;
pendingToolCalls?: Array<{
id: string;
name: string;
arguments: string;
}>;
}>("agent.wait", {
runId: params.runId,
...(params.timeoutMs != null && { timeoutMs: params.timeoutMs }),
});
const status = payload?.status;
if (status !== "ok" && status !== "error" && status !== "timeout") {
throw new Error(`Gateway agent.wait returned unexpected status: ${status}`);
@ -360,6 +369,11 @@ function createGatewaySubagentRuntime(): PluginRuntime["subagent"] {
return {
status,
...(typeof payload?.error === "string" && payload.error && { error: payload.error }),
...(typeof payload?.stopReason === "string" &&
payload.stopReason && { stopReason: payload.stopReason }),
...(Array.isArray(payload?.pendingToolCalls) && payload.pendingToolCalls.length > 0
? { pendingToolCalls: payload.pendingToolCalls }
: {}),
};
},
getSessionMessages,

View File

@ -1,3 +1,5 @@
import type { ClientToolDefinition } from "../../agents/pi-embedded-runner/run/params.js";
import type { AgentStreamParams } from "../../commands/agent/types.js";
import type { PluginRuntimeChannel } from "./types-channel.js";
import type { PluginRuntimeCore, RuntimeLogger } from "./types-core.js";
@ -14,6 +16,9 @@ export type SubagentRunParams = {
lane?: string;
deliver?: boolean;
idempotencyKey?: string;
clientTools?: ClientToolDefinition[];
disableTools?: boolean;
streamParams?: AgentStreamParams;
};
export type SubagentRunResult = {
@ -28,6 +33,12 @@ export type SubagentWaitParams = {
export type SubagentWaitResult = {
status: "ok" | "error" | "timeout";
error?: string;
stopReason?: string;
pendingToolCalls?: Array<{
id: string;
name: string;
arguments: string;
}>;
};
export type SubagentGetSessionMessagesParams = {