Merge 714b5947424e54e48377c8ee191ffe6ef0b329ad into d78e13f545136fcbba1feceecc5e0485a06c33a6
This commit is contained in:
commit
0ca2ec6793
@ -537,6 +537,9 @@ public struct AgentParams: Codable, Sendable {
|
||||
public let lane: String?
|
||||
public let extrasystemprompt: String?
|
||||
public let internalevents: [[String: AnyCodable]]?
|
||||
public let clienttools: [[String: AnyCodable]]?
|
||||
public let disabletools: Bool?
|
||||
public let streamparams: [String: AnyCodable]?
|
||||
public let inputprovenance: [String: AnyCodable]?
|
||||
public let idempotencykey: String
|
||||
public let label: String?
|
||||
@ -566,6 +569,9 @@ public struct AgentParams: Codable, Sendable {
|
||||
lane: String?,
|
||||
extrasystemprompt: String?,
|
||||
internalevents: [[String: AnyCodable]]?,
|
||||
clienttools: [[String: AnyCodable]]?,
|
||||
disabletools: Bool?,
|
||||
streamparams: [String: AnyCodable]?,
|
||||
inputprovenance: [String: AnyCodable]?,
|
||||
idempotencykey: String,
|
||||
label: String?)
|
||||
@ -594,6 +600,9 @@ public struct AgentParams: Codable, Sendable {
|
||||
self.lane = lane
|
||||
self.extrasystemprompt = extrasystemprompt
|
||||
self.internalevents = internalevents
|
||||
self.clienttools = clienttools
|
||||
self.disabletools = disabletools
|
||||
self.streamparams = streamparams
|
||||
self.inputprovenance = inputprovenance
|
||||
self.idempotencykey = idempotencykey
|
||||
self.label = label
|
||||
@ -624,6 +633,9 @@ public struct AgentParams: Codable, Sendable {
|
||||
case lane
|
||||
case extrasystemprompt = "extraSystemPrompt"
|
||||
case internalevents = "internalEvents"
|
||||
case clienttools = "clientTools"
|
||||
case disabletools = "disableTools"
|
||||
case streamparams = "streamParams"
|
||||
case inputprovenance = "inputProvenance"
|
||||
case idempotencykey = "idempotencyKey"
|
||||
case label
|
||||
|
||||
@ -537,6 +537,9 @@ public struct AgentParams: Codable, Sendable {
|
||||
public let lane: String?
|
||||
public let extrasystemprompt: String?
|
||||
public let internalevents: [[String: AnyCodable]]?
|
||||
public let clienttools: [[String: AnyCodable]]?
|
||||
public let disabletools: Bool?
|
||||
public let streamparams: [String: AnyCodable]?
|
||||
public let inputprovenance: [String: AnyCodable]?
|
||||
public let idempotencykey: String
|
||||
public let label: String?
|
||||
@ -566,6 +569,9 @@ public struct AgentParams: Codable, Sendable {
|
||||
lane: String?,
|
||||
extrasystemprompt: String?,
|
||||
internalevents: [[String: AnyCodable]]?,
|
||||
clienttools: [[String: AnyCodable]]?,
|
||||
disabletools: Bool?,
|
||||
streamparams: [String: AnyCodable]?,
|
||||
inputprovenance: [String: AnyCodable]?,
|
||||
idempotencykey: String,
|
||||
label: String?)
|
||||
@ -594,6 +600,9 @@ public struct AgentParams: Codable, Sendable {
|
||||
self.lane = lane
|
||||
self.extrasystemprompt = extrasystemprompt
|
||||
self.internalevents = internalevents
|
||||
self.clienttools = clienttools
|
||||
self.disabletools = disabletools
|
||||
self.streamparams = streamparams
|
||||
self.inputprovenance = inputprovenance
|
||||
self.idempotencykey = idempotencykey
|
||||
self.label = label
|
||||
@ -624,6 +633,9 @@ public struct AgentParams: Codable, Sendable {
|
||||
case lane
|
||||
case extrasystemprompt = "extraSystemPrompt"
|
||||
case internalevents = "internalEvents"
|
||||
case clienttools = "clientTools"
|
||||
case disabletools = "disableTools"
|
||||
case streamparams = "streamParams"
|
||||
case inputprovenance = "inputProvenance"
|
||||
case idempotencykey = "idempotencyKey"
|
||||
case label
|
||||
|
||||
@ -512,6 +512,7 @@ function runAgentAttempt(params: {
|
||||
prompt: effectivePrompt,
|
||||
images: params.isFallbackRetry ? undefined : params.opts.images,
|
||||
clientTools: params.opts.clientTools,
|
||||
disableTools: params.opts.disableTools,
|
||||
provider: params.providerOverride,
|
||||
model: params.modelOverride,
|
||||
authProfileId,
|
||||
@ -1236,6 +1237,7 @@ async function agentCommandInternal(
|
||||
endedAt: Date.now(),
|
||||
aborted: result.meta.aborted ?? false,
|
||||
stopReason,
|
||||
pendingToolCalls: result.meta.pendingToolCalls,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@ -17,6 +17,17 @@ export type AgentStreamParams = {
|
||||
maxTokens?: number;
|
||||
/** Provider fast-mode override (best-effort). */
|
||||
fastMode?: boolean;
|
||||
/** Provider tool-choice override (best-effort). */
|
||||
toolChoice?:
|
||||
| "auto"
|
||||
| "none"
|
||||
| "required"
|
||||
| {
|
||||
type: "function";
|
||||
function: {
|
||||
name: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export type AgentRunContext = {
|
||||
@ -37,6 +48,8 @@ export type AgentCommandOpts = {
|
||||
images?: ImageContent[];
|
||||
/** Optional client-provided tools (OpenResponses hosted tools). */
|
||||
clientTools?: ClientToolDefinition[];
|
||||
/** Disable built-in tools for this run while still allowing clientTools. */
|
||||
disableTools?: boolean;
|
||||
/** Agent id override (must exist in config). */
|
||||
agentId?: string;
|
||||
/** Per-run provider override. */
|
||||
|
||||
@ -144,4 +144,98 @@ describe("cacheRetention default behavior", () => {
|
||||
provider: "anthropic",
|
||||
});
|
||||
});
|
||||
|
||||
it("forwards toolChoice overrides to the wrapped stream options", async () => {
|
||||
const baseStreamFn = vi.fn(async () => undefined);
|
||||
const agent: { streamFn?: StreamFn } = { streamFn: baseStreamFn as unknown as StreamFn };
|
||||
|
||||
applyExtraParamsToAgent(
|
||||
agent,
|
||||
undefined,
|
||||
"openai",
|
||||
"gpt-4.1",
|
||||
{
|
||||
toolChoice: {
|
||||
type: "function",
|
||||
function: {
|
||||
name: "emit_structured_result",
|
||||
},
|
||||
},
|
||||
},
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
new Set(["emit_structured_result"]),
|
||||
);
|
||||
|
||||
expect(agent.streamFn).toBeDefined();
|
||||
await agent.streamFn?.(
|
||||
{ api: "openai-responses", provider: "openai", id: "gpt-4.1", compat: {} } as never,
|
||||
{} as never,
|
||||
{},
|
||||
);
|
||||
|
||||
expect(baseStreamFn).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
toolChoice: {
|
||||
type: "function",
|
||||
function: {
|
||||
name: "emit_structured_result",
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("drops stale function toolChoice selections that are not in the allowed tool set", async () => {
|
||||
const baseStreamFn = vi.fn(async () => undefined);
|
||||
const agent: { streamFn?: StreamFn } = { streamFn: baseStreamFn as unknown as StreamFn };
|
||||
const cfg = {
|
||||
agents: {
|
||||
defaults: {
|
||||
models: {
|
||||
"openai/gpt-4.1": {
|
||||
params: {
|
||||
toolChoice: {
|
||||
type: "function",
|
||||
function: {
|
||||
name: "bash",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
applyExtraParamsToAgent(
|
||||
agent,
|
||||
cfg,
|
||||
"openai",
|
||||
"gpt-4.1",
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
new Set(["emit_structured_result"]),
|
||||
);
|
||||
|
||||
expect(agent.streamFn).toBeDefined();
|
||||
await agent.streamFn?.(
|
||||
{ api: "openai-responses", provider: "openai", id: "gpt-4.1", compat: {} } as never,
|
||||
{} as never,
|
||||
{},
|
||||
);
|
||||
|
||||
expect(baseStreamFn).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
toolChoice: "auto",
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@ -2,11 +2,13 @@ import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import type { SimpleStreamOptions } from "@mariozechner/pi-ai";
|
||||
import { streamSimple } from "@mariozechner/pi-ai";
|
||||
import type { ThinkLevel } from "../../auto-reply/thinking.js";
|
||||
import type { AgentStreamParams } from "../../commands/agent/types.js";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import {
|
||||
prepareProviderExtraParams,
|
||||
wrapProviderStreamFn,
|
||||
} from "../../plugins/provider-runtime.js";
|
||||
import { normalizeToolName } from "../tool-policy.js";
|
||||
import {
|
||||
createAnthropicBetaHeadersWrapper,
|
||||
createAnthropicFastModeWrapper,
|
||||
@ -72,8 +74,101 @@ export function resolveExtraParams(params: {
|
||||
type CacheRetentionStreamOptions = Partial<SimpleStreamOptions> & {
|
||||
cacheRetention?: "none" | "short" | "long";
|
||||
openaiWsWarmup?: boolean;
|
||||
toolChoice?: NonNullable<AgentStreamParams["toolChoice"]>;
|
||||
};
|
||||
|
||||
type ToolChoiceOverride = NonNullable<AgentStreamParams["toolChoice"]>;
|
||||
|
||||
function isToolChoiceOverride(value: unknown): value is ToolChoiceOverride {
|
||||
if (value === "auto" || value === "none" || value === "required") {
|
||||
return true;
|
||||
}
|
||||
if (!value || typeof value !== "object") {
|
||||
return false;
|
||||
}
|
||||
const record = value as Record<string, unknown>;
|
||||
const fn = record.function;
|
||||
return (
|
||||
record.type === "function" &&
|
||||
!!fn &&
|
||||
typeof fn === "object" &&
|
||||
typeof (fn as Record<string, unknown>).name === "string"
|
||||
);
|
||||
}
|
||||
|
||||
function resolveAllowedToolChoiceName(
|
||||
rawName: string,
|
||||
allowedToolNames?: Set<string>,
|
||||
): string | undefined {
|
||||
if (!allowedToolNames || allowedToolNames.size === 0) {
|
||||
return undefined;
|
||||
}
|
||||
const trimmed = rawName.trim();
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
if (allowedToolNames.has(trimmed)) {
|
||||
return trimmed;
|
||||
}
|
||||
const normalized = normalizeToolName(trimmed);
|
||||
if (allowedToolNames.has(normalized)) {
|
||||
return normalized;
|
||||
}
|
||||
const lowered = normalized.toLowerCase();
|
||||
let caseInsensitiveMatch: string | undefined;
|
||||
for (const candidate of allowedToolNames) {
|
||||
if (candidate.toLowerCase() !== lowered) {
|
||||
continue;
|
||||
}
|
||||
if (caseInsensitiveMatch && caseInsensitiveMatch !== candidate) {
|
||||
return undefined;
|
||||
}
|
||||
caseInsensitiveMatch = candidate;
|
||||
}
|
||||
return caseInsensitiveMatch;
|
||||
}
|
||||
|
||||
function sanitizeToolChoiceOverride(
|
||||
extraParams: Record<string, unknown> | undefined,
|
||||
allowedToolNames?: Set<string>,
|
||||
): Record<string, unknown> | undefined {
|
||||
if (!extraParams || !isToolChoiceOverride(extraParams.toolChoice)) {
|
||||
return extraParams;
|
||||
}
|
||||
if (!allowedToolNames || allowedToolNames.size === 0) {
|
||||
if (extraParams.toolChoice === "none") {
|
||||
return extraParams;
|
||||
}
|
||||
return {
|
||||
...extraParams,
|
||||
toolChoice: "none",
|
||||
};
|
||||
}
|
||||
const toolChoice = extraParams.toolChoice;
|
||||
if (toolChoice === "auto" || toolChoice === "none" || toolChoice === "required") {
|
||||
return extraParams;
|
||||
}
|
||||
const resolvedName = resolveAllowedToolChoiceName(toolChoice.function.name, allowedToolNames);
|
||||
if (!resolvedName) {
|
||||
return {
|
||||
...extraParams,
|
||||
toolChoice: "auto",
|
||||
};
|
||||
}
|
||||
if (resolvedName === toolChoice.function.name) {
|
||||
return extraParams;
|
||||
}
|
||||
return {
|
||||
...extraParams,
|
||||
toolChoice: {
|
||||
type: "function",
|
||||
function: {
|
||||
name: resolvedName,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function createStreamFnWithExtraParams(
|
||||
baseStreamFn: StreamFn | undefined,
|
||||
extraParams: Record<string, unknown> | undefined,
|
||||
@ -90,6 +185,9 @@ function createStreamFnWithExtraParams(
|
||||
if (typeof extraParams.maxTokens === "number") {
|
||||
streamParams.maxTokens = extraParams.maxTokens;
|
||||
}
|
||||
if (isToolChoiceOverride(extraParams.toolChoice)) {
|
||||
streamParams.toolChoice = extraParams.toolChoice;
|
||||
}
|
||||
const transport = extraParams.transport;
|
||||
if (transport === "sse" || transport === "websocket" || transport === "auto") {
|
||||
streamParams.transport = transport;
|
||||
@ -184,6 +282,7 @@ export function applyExtraParamsToAgent(
|
||||
thinkingLevel?: ThinkLevel,
|
||||
agentId?: string,
|
||||
workspaceDir?: string,
|
||||
allowedToolNames?: Set<string>,
|
||||
): void {
|
||||
const resolvedExtraParams = resolveExtraParams({
|
||||
cfg,
|
||||
@ -210,10 +309,11 @@ export function applyExtraParamsToAgent(
|
||||
thinkingLevel,
|
||||
},
|
||||
}) ?? merged;
|
||||
const sanitizedExtraParams = sanitizeToolChoiceOverride(effectiveExtraParams, allowedToolNames);
|
||||
|
||||
const wrappedStreamFn = createStreamFnWithExtraParams(
|
||||
agent.streamFn,
|
||||
effectiveExtraParams,
|
||||
sanitizedExtraParams,
|
||||
provider,
|
||||
);
|
||||
|
||||
@ -222,7 +322,7 @@ export function applyExtraParamsToAgent(
|
||||
agent.streamFn = wrappedStreamFn;
|
||||
}
|
||||
|
||||
const anthropicBetas = resolveAnthropicBetas(effectiveExtraParams, provider, modelId);
|
||||
const anthropicBetas = resolveAnthropicBetas(sanitizedExtraParams, provider, modelId);
|
||||
if (anthropicBetas?.length) {
|
||||
log.debug(
|
||||
`applying Anthropic beta header for ${provider}/${modelId}: ${anthropicBetas.join(",")}`,
|
||||
@ -249,7 +349,7 @@ export function applyExtraParamsToAgent(
|
||||
config: cfg,
|
||||
provider,
|
||||
modelId,
|
||||
extraParams: effectiveExtraParams,
|
||||
extraParams: sanitizedExtraParams,
|
||||
thinkingLevel,
|
||||
streamFn: providerStreamBase,
|
||||
},
|
||||
@ -263,7 +363,7 @@ export function applyExtraParamsToAgent(
|
||||
// actually handled the stream function. This covers tests/disabled plugins
|
||||
// and Ollama Cloud Kimi models until they gain a dedicated runtime hook.
|
||||
const thinkingType = resolveMoonshotThinkingType({
|
||||
configuredThinking: effectiveExtraParams?.thinking,
|
||||
configuredThinking: sanitizedExtraParams?.thinking,
|
||||
thinkingLevel,
|
||||
});
|
||||
agent.streamFn = createMoonshotThinkingWrapper(agent.streamFn, thinkingType);
|
||||
@ -275,13 +375,13 @@ export function applyExtraParamsToAgent(
|
||||
agent.streamFn = createAnthropicFastModeWrapper(agent.streamFn, anthropicFastMode);
|
||||
}
|
||||
|
||||
const openAIFastMode = resolveOpenAIFastMode(effectiveExtraParams);
|
||||
const openAIFastMode = resolveOpenAIFastMode(sanitizedExtraParams);
|
||||
if (openAIFastMode) {
|
||||
log.debug(`applying OpenAI fast mode for ${provider}/${modelId}`);
|
||||
agent.streamFn = createOpenAIFastModeWrapper(agent.streamFn);
|
||||
}
|
||||
|
||||
const openAIServiceTier = resolveOpenAIServiceTier(effectiveExtraParams);
|
||||
const openAIServiceTier = resolveOpenAIServiceTier(sanitizedExtraParams);
|
||||
if (openAIServiceTier) {
|
||||
log.debug(`applying OpenAI service_tier=${openAIServiceTier} for ${provider}/${modelId}`);
|
||||
agent.streamFn = createOpenAIServiceTierWrapper(agent.streamFn, openAIServiceTier);
|
||||
@ -292,7 +392,7 @@ export function applyExtraParamsToAgent(
|
||||
// server-side compaction for compatible OpenAI Responses payloads.
|
||||
agent.streamFn = createOpenAIResponsesContextManagementWrapper(
|
||||
agent.streamFn,
|
||||
effectiveExtraParams,
|
||||
sanitizedExtraParams,
|
||||
);
|
||||
|
||||
const rawParallelToolCalls = resolveAliasedParamValue(
|
||||
|
||||
@ -2243,18 +2243,21 @@ export async function runEmbeddedAttempt(
|
||||
activeSession.agent.streamFn = wrapOllamaCompatNumCtx(activeSession.agent.streamFn, numCtx);
|
||||
}
|
||||
|
||||
const effectiveStreamParams: Record<string, unknown> = {
|
||||
...params.streamParams,
|
||||
...(params.fastMode !== undefined ? { fastMode: params.fastMode } : {}),
|
||||
...(allowedToolNames.size === 0 ? { toolChoice: "none" } : {}),
|
||||
};
|
||||
applyExtraParamsToAgent(
|
||||
activeSession.agent,
|
||||
params.config,
|
||||
params.provider,
|
||||
params.modelId,
|
||||
{
|
||||
...params.streamParams,
|
||||
fastMode: params.fastMode,
|
||||
},
|
||||
effectiveStreamParams,
|
||||
params.thinkLevel,
|
||||
sessionAgentId,
|
||||
effectiveWorkspace,
|
||||
allowedToolNames,
|
||||
);
|
||||
|
||||
if (cacheTrace) {
|
||||
|
||||
@ -29,6 +29,49 @@ export const AgentEventSchema = Type.Object(
|
||||
{ additionalProperties: false },
|
||||
);
|
||||
|
||||
const ClientToolDefinitionSchema = Type.Object(
|
||||
{
|
||||
type: Type.Literal("function"),
|
||||
function: Type.Object(
|
||||
{
|
||||
name: NonEmptyString,
|
||||
description: Type.Optional(Type.String()),
|
||||
parameters: Type.Optional(Type.Record(Type.String(), Type.Unknown())),
|
||||
},
|
||||
{ additionalProperties: false },
|
||||
),
|
||||
},
|
||||
{ additionalProperties: false },
|
||||
);
|
||||
|
||||
const AgentToolChoiceSchema = Type.Union([
|
||||
Type.Literal("auto"),
|
||||
Type.Literal("none"),
|
||||
Type.Literal("required"),
|
||||
Type.Object(
|
||||
{
|
||||
type: Type.Literal("function"),
|
||||
function: Type.Object(
|
||||
{
|
||||
name: NonEmptyString,
|
||||
},
|
||||
{ additionalProperties: false },
|
||||
),
|
||||
},
|
||||
{ additionalProperties: false },
|
||||
),
|
||||
]);
|
||||
|
||||
const AgentStreamParamsSchema = Type.Object(
|
||||
{
|
||||
temperature: Type.Optional(Type.Number()),
|
||||
maxTokens: Type.Optional(Type.Number()),
|
||||
fastMode: Type.Optional(Type.Boolean()),
|
||||
toolChoice: Type.Optional(AgentToolChoiceSchema),
|
||||
},
|
||||
{ additionalProperties: false },
|
||||
);
|
||||
|
||||
export const SendParamsSchema = Type.Object(
|
||||
{
|
||||
to: NonEmptyString,
|
||||
@ -97,6 +140,9 @@ export const AgentParamsSchema = Type.Object(
|
||||
lane: Type.Optional(Type.String()),
|
||||
extraSystemPrompt: Type.Optional(Type.String()),
|
||||
internalEvents: Type.Optional(Type.Array(AgentInternalEventSchema)),
|
||||
clientTools: Type.Optional(Type.Array(ClientToolDefinitionSchema)),
|
||||
disableTools: Type.Optional(Type.Boolean()),
|
||||
streamParams: Type.Optional(AgentStreamParamsSchema),
|
||||
inputProvenance: Type.Optional(InputProvenanceSchema),
|
||||
idempotencyKey: NonEmptyString,
|
||||
label: Type.Optional(SessionLabelString),
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { onAgentEvent } from "../../infra/agent-events.js";
|
||||
import { parsePendingToolCalls } from "./pending-tool-calls.js";
|
||||
|
||||
const AGENT_RUN_CACHE_TTL_MS = 10 * 60_000;
|
||||
/**
|
||||
@ -19,6 +20,12 @@ type AgentRunSnapshot = {
|
||||
startedAt?: number;
|
||||
endedAt?: number;
|
||||
error?: string;
|
||||
stopReason?: string;
|
||||
pendingToolCalls?: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
}>;
|
||||
ts: number;
|
||||
};
|
||||
|
||||
@ -86,12 +93,16 @@ function createSnapshotFromLifecycleEvent(params: {
|
||||
typeof data?.startedAt === "number" ? data.startedAt : agentRunStarts.get(runId);
|
||||
const endedAt = typeof data?.endedAt === "number" ? data.endedAt : undefined;
|
||||
const error = typeof data?.error === "string" ? data.error : undefined;
|
||||
const stopReason = typeof data?.stopReason === "string" ? data.stopReason : undefined;
|
||||
const pendingToolCalls = parsePendingToolCalls(data?.pendingToolCalls);
|
||||
return {
|
||||
runId,
|
||||
status: phase === "error" ? "error" : data?.aborted ? "timeout" : "ok",
|
||||
startedAt,
|
||||
endedAt,
|
||||
error,
|
||||
stopReason,
|
||||
pendingToolCalls: pendingToolCalls?.length ? pendingToolCalls : undefined,
|
||||
ts: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
@ -258,6 +258,54 @@ describe("agent wait dedupe helper", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("extracts stopReason and pendingToolCalls from nested agent result metadata", () => {
|
||||
const dedupe = new Map();
|
||||
const runId = "run-structured-agent";
|
||||
setRunEntry({
|
||||
dedupe,
|
||||
kind: "agent",
|
||||
runId,
|
||||
payload: {
|
||||
runId,
|
||||
status: "ok",
|
||||
startedAt: 10,
|
||||
endedAt: 20,
|
||||
result: {
|
||||
meta: {
|
||||
stopReason: "tool_calls",
|
||||
pendingToolCalls: [
|
||||
{
|
||||
id: "call-1",
|
||||
name: "emit_structured_result",
|
||||
arguments: '{"entries":[]}',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(
|
||||
readTerminalSnapshotFromGatewayDedupe({
|
||||
dedupe,
|
||||
runId,
|
||||
}),
|
||||
).toEqual({
|
||||
status: "ok",
|
||||
startedAt: 10,
|
||||
endedAt: 20,
|
||||
error: undefined,
|
||||
stopReason: "tool_calls",
|
||||
pendingToolCalls: [
|
||||
{
|
||||
id: "call-1",
|
||||
name: "emit_structured_result",
|
||||
arguments: '{"entries":[]}',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it("resolves multiple waiters for the same run id", async () => {
|
||||
const dedupe = new Map();
|
||||
const runId = "run-multi";
|
||||
|
||||
@ -1,10 +1,17 @@
|
||||
import type { DedupeEntry } from "../server-shared.js";
|
||||
import { parsePendingToolCalls } from "./pending-tool-calls.js";
|
||||
|
||||
export type AgentWaitTerminalSnapshot = {
|
||||
status: "ok" | "error" | "timeout";
|
||||
startedAt?: number;
|
||||
endedAt?: number;
|
||||
error?: string;
|
||||
stopReason?: string;
|
||||
pendingToolCalls?: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
}>;
|
||||
};
|
||||
|
||||
const AGENT_WAITERS_BY_RUN_ID = new Map<string, Set<() => void>>();
|
||||
@ -72,6 +79,14 @@ export function readTerminalSnapshotFromDedupeEntry(
|
||||
endedAt?: unknown;
|
||||
error?: unknown;
|
||||
summary?: unknown;
|
||||
stopReason?: unknown;
|
||||
pendingToolCalls?: unknown;
|
||||
result?: {
|
||||
meta?: {
|
||||
stopReason?: unknown;
|
||||
pendingToolCalls?: unknown;
|
||||
};
|
||||
};
|
||||
}
|
||||
| undefined;
|
||||
const status = typeof payload?.status === "string" ? payload.status : undefined;
|
||||
@ -87,6 +102,15 @@ export function readTerminalSnapshotFromDedupeEntry(
|
||||
: typeof payload?.summary === "string"
|
||||
? payload.summary
|
||||
: entry.error?.message;
|
||||
const stopReason =
|
||||
typeof payload?.result?.meta?.stopReason === "string"
|
||||
? payload.result.meta.stopReason
|
||||
: typeof payload?.stopReason === "string"
|
||||
? payload.stopReason
|
||||
: undefined;
|
||||
const pendingToolCalls =
|
||||
parsePendingToolCalls(payload?.result?.meta?.pendingToolCalls) ??
|
||||
parsePendingToolCalls(payload?.pendingToolCalls);
|
||||
|
||||
if (status === "ok" || status === "timeout") {
|
||||
return {
|
||||
@ -94,6 +118,8 @@ export function readTerminalSnapshotFromDedupeEntry(
|
||||
startedAt,
|
||||
endedAt,
|
||||
error: status === "timeout" ? errorMessage : undefined,
|
||||
stopReason,
|
||||
pendingToolCalls,
|
||||
};
|
||||
}
|
||||
if (status === "error" || !entry.ok) {
|
||||
@ -102,6 +128,8 @@ export function readTerminalSnapshotFromDedupeEntry(
|
||||
startedAt,
|
||||
endedAt,
|
||||
error: errorMessage,
|
||||
stopReason,
|
||||
pendingToolCalls,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { BARE_SESSION_RESET_PROMPT } from "../../auto-reply/reply/session-reset-prompt.js";
|
||||
import { agentHandlers } from "./agent.js";
|
||||
import type { GatewayRequestContext } from "./types.js";
|
||||
@ -10,6 +10,9 @@ const mocks = vi.hoisted(() => ({
|
||||
agentCommand: vi.fn(),
|
||||
registerAgentRunContext: vi.fn(),
|
||||
performGatewaySessionReset: vi.fn(),
|
||||
waitForAgentJob: vi.fn(),
|
||||
readTerminalSnapshotFromGatewayDedupe: vi.fn(),
|
||||
waitForTerminalGatewayDedupe: vi.fn(),
|
||||
getSubagentRunByChildSessionKey: vi.fn(),
|
||||
replaceSubagentRunAfterSteer: vi.fn(),
|
||||
loadConfigReturn: {} as Record<string, unknown>,
|
||||
@ -76,6 +79,17 @@ vi.mock("../session-reset-service.js", () => ({
|
||||
(mocks.performGatewaySessionReset as (...args: unknown[]) => unknown)(...args),
|
||||
}));
|
||||
|
||||
vi.mock("./agent-job.js", () => ({
|
||||
waitForAgentJob: (...args: unknown[]) => mocks.waitForAgentJob(...args),
|
||||
}));
|
||||
|
||||
vi.mock("./agent-wait-dedupe.js", () => ({
|
||||
readTerminalSnapshotFromGatewayDedupe: (...args: unknown[]) =>
|
||||
mocks.readTerminalSnapshotFromGatewayDedupe(...args),
|
||||
setGatewayDedupeEntry: vi.fn(),
|
||||
waitForTerminalGatewayDedupe: (...args: unknown[]) => mocks.waitForTerminalGatewayDedupe(...args),
|
||||
}));
|
||||
|
||||
vi.mock("../../sessions/send-policy.js", () => ({
|
||||
resolveSendPolicy: () => "allow",
|
||||
}));
|
||||
@ -105,6 +119,16 @@ type AgentParams = AgentHandlerArgs["params"];
|
||||
type AgentIdentityGetHandlerArgs = Parameters<(typeof agentHandlers)["agent.identity.get"]>[0];
|
||||
type AgentIdentityGetParams = AgentIdentityGetHandlerArgs["params"];
|
||||
|
||||
beforeEach(() => {
|
||||
mocks.waitForAgentJob.mockReset();
|
||||
mocks.waitForAgentJob.mockResolvedValue(null);
|
||||
mocks.readTerminalSnapshotFromGatewayDedupe.mockReset();
|
||||
mocks.readTerminalSnapshotFromGatewayDedupe.mockReturnValue(null);
|
||||
mocks.waitForTerminalGatewayDedupe.mockReset();
|
||||
mocks.waitForTerminalGatewayDedupe.mockResolvedValue(null);
|
||||
mocks.loadConfigReturn = {};
|
||||
});
|
||||
|
||||
async function waitForAssertion(assertion: () => void, timeoutMs = 2_000, stepMs = 5) {
|
||||
vi.useFakeTimers();
|
||||
try {
|
||||
@ -633,6 +657,64 @@ describe("gateway agent handler", () => {
|
||||
expect(callArgs.bestEffortDeliver).toBe(false);
|
||||
});
|
||||
|
||||
it("forwards structured subagent options to agentCommandFromIngress", async () => {
|
||||
primeMainAgentRun();
|
||||
|
||||
await invokeAgent(
|
||||
{
|
||||
message: "structured helper run",
|
||||
agentId: "main",
|
||||
sessionKey: "agent:main:main",
|
||||
disableTools: true,
|
||||
clientTools: [
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
name: "emit_structured_result",
|
||||
description: "Return a structured result payload.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
entries: { type: "array" },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
streamParams: {
|
||||
toolChoice: {
|
||||
type: "function",
|
||||
function: {
|
||||
name: "emit_structured_result",
|
||||
},
|
||||
},
|
||||
},
|
||||
idempotencyKey: "test-structured-helper",
|
||||
} as AgentParams,
|
||||
{ reqId: "structured-helper-1" },
|
||||
);
|
||||
|
||||
await vi.waitFor(() => expect(mocks.agentCommand).toHaveBeenCalled());
|
||||
const callArgs = mocks.agentCommand.mock.calls.at(-1)?.[0] as Record<string, unknown>;
|
||||
expect(callArgs.disableTools).toBe(true);
|
||||
expect(callArgs.clientTools).toEqual([
|
||||
{
|
||||
type: "function",
|
||||
function: expect.objectContaining({
|
||||
name: "emit_structured_result",
|
||||
}),
|
||||
},
|
||||
]);
|
||||
expect(callArgs.streamParams).toEqual({
|
||||
toolChoice: {
|
||||
type: "function",
|
||||
function: {
|
||||
name: "emit_structured_result",
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("rejects public spawned-run metadata fields", async () => {
|
||||
primeMainAgentRun();
|
||||
mocks.agentCommand.mockClear();
|
||||
@ -877,4 +959,206 @@ describe("gateway agent handler", () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("returns structured fields for cached agent.wait snapshots", async () => {
|
||||
const respond = vi.fn();
|
||||
const context = makeContext();
|
||||
context.chatAbortControllers = new Map();
|
||||
mocks.readTerminalSnapshotFromGatewayDedupe.mockReturnValue({
|
||||
status: "ok",
|
||||
startedAt: 10,
|
||||
endedAt: 20,
|
||||
stopReason: "tool_calls",
|
||||
pendingToolCalls: [
|
||||
{
|
||||
id: "call-1",
|
||||
name: "emit_structured_result",
|
||||
arguments: '{"entries":[]}',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
await agentHandlers["agent.wait"]({
|
||||
params: { runId: "wait-cached", timeoutMs: 100 },
|
||||
respond,
|
||||
context,
|
||||
} as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]);
|
||||
|
||||
expect(respond).toHaveBeenCalledWith(
|
||||
true,
|
||||
expect.objectContaining({
|
||||
runId: "wait-cached",
|
||||
status: "ok",
|
||||
stopReason: "tool_calls",
|
||||
pendingToolCalls: [
|
||||
{
|
||||
id: "call-1",
|
||||
name: "emit_structured_result",
|
||||
arguments: '{"entries":[]}',
|
||||
},
|
||||
],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("merges structured fields from dedupe when lifecycle resolves first", async () => {
|
||||
const respond = vi.fn();
|
||||
const context = makeContext();
|
||||
context.chatAbortControllers = new Map();
|
||||
mocks.waitForAgentJob.mockResolvedValue({
|
||||
status: "ok",
|
||||
startedAt: 10,
|
||||
endedAt: 20,
|
||||
});
|
||||
mocks.waitForTerminalGatewayDedupe.mockImplementation(async () => {
|
||||
await Promise.resolve();
|
||||
return {
|
||||
status: "ok",
|
||||
startedAt: 10,
|
||||
endedAt: 20,
|
||||
stopReason: "tool_calls",
|
||||
pendingToolCalls: [
|
||||
{
|
||||
id: "call-1",
|
||||
name: "emit_structured_result",
|
||||
arguments: '{"entries":[]}',
|
||||
},
|
||||
],
|
||||
};
|
||||
});
|
||||
|
||||
await agentHandlers["agent.wait"]({
|
||||
params: { runId: "wait-live", timeoutMs: 100 },
|
||||
respond,
|
||||
context,
|
||||
} as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]);
|
||||
|
||||
expect(respond).toHaveBeenCalledWith(
|
||||
true,
|
||||
expect.objectContaining({
|
||||
runId: "wait-live",
|
||||
status: "ok",
|
||||
stopReason: "tool_calls",
|
||||
pendingToolCalls: [
|
||||
{
|
||||
id: "call-1",
|
||||
name: "emit_structured_result",
|
||||
arguments: '{"entries":[]}',
|
||||
},
|
||||
],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("does not grace-wait when lifecycle resolves without tool calls", async () => {
|
||||
vi.useFakeTimers();
|
||||
try {
|
||||
const respond = vi.fn();
|
||||
const context = makeContext();
|
||||
context.chatAbortControllers = new Map([
|
||||
[
|
||||
"wait-no-tools",
|
||||
{
|
||||
controller: new AbortController(),
|
||||
sessionKey: "agent:main:main",
|
||||
sessionId: "test-session",
|
||||
startedAtMs: Date.now(),
|
||||
expiresAtMs: Date.now() + 60_000,
|
||||
},
|
||||
],
|
||||
]);
|
||||
mocks.waitForAgentJob.mockResolvedValue({
|
||||
status: "ok",
|
||||
startedAt: 10,
|
||||
endedAt: 20,
|
||||
stopReason: "stop",
|
||||
});
|
||||
mocks.waitForTerminalGatewayDedupe.mockImplementation(
|
||||
() =>
|
||||
new Promise<null>((resolve) => {
|
||||
setTimeout(() => resolve(null), 1_000);
|
||||
}),
|
||||
);
|
||||
|
||||
const waitPromise = agentHandlers["agent.wait"]({
|
||||
params: { runId: "wait-no-tools", timeoutMs: 100 },
|
||||
respond,
|
||||
context,
|
||||
} as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]);
|
||||
|
||||
await vi.advanceTimersByTimeAsync(0);
|
||||
|
||||
expect(respond).toHaveBeenCalledWith(
|
||||
true,
|
||||
expect.objectContaining({
|
||||
runId: "wait-no-tools",
|
||||
status: "ok",
|
||||
stopReason: "stop",
|
||||
pendingToolCalls: undefined,
|
||||
}),
|
||||
);
|
||||
expect(mocks.waitForTerminalGatewayDedupe).toHaveBeenCalledTimes(1);
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
await waitPromise;
|
||||
} finally {
|
||||
vi.useRealTimers();
|
||||
}
|
||||
});
|
||||
|
||||
it("does not grace-wait for errors when lifecycle metadata omits stopReason", async () => {
|
||||
vi.useFakeTimers();
|
||||
try {
|
||||
const respond = vi.fn();
|
||||
const context = makeContext();
|
||||
context.chatAbortControllers = new Map([
|
||||
[
|
||||
"wait-no-stop-reason",
|
||||
{
|
||||
controller: new AbortController(),
|
||||
sessionKey: "agent:main:main",
|
||||
sessionId: "test-session",
|
||||
startedAtMs: Date.now(),
|
||||
expiresAtMs: Date.now() + 60_000,
|
||||
},
|
||||
],
|
||||
]);
|
||||
mocks.waitForAgentJob.mockResolvedValue({
|
||||
status: "error",
|
||||
startedAt: 10,
|
||||
endedAt: 20,
|
||||
error: "boom",
|
||||
});
|
||||
mocks.waitForTerminalGatewayDedupe.mockImplementation(
|
||||
() =>
|
||||
new Promise<null>((resolve) => {
|
||||
setTimeout(() => resolve(null), 1_000);
|
||||
}),
|
||||
);
|
||||
|
||||
const waitPromise = agentHandlers["agent.wait"]({
|
||||
params: { runId: "wait-no-stop-reason", timeoutMs: 100 },
|
||||
respond,
|
||||
context,
|
||||
} as unknown as Parameters<(typeof agentHandlers)["agent.wait"]>[0]);
|
||||
|
||||
await vi.advanceTimersByTimeAsync(0);
|
||||
|
||||
expect(respond).toHaveBeenCalledWith(
|
||||
true,
|
||||
expect.objectContaining({
|
||||
runId: "wait-no-stop-reason",
|
||||
status: "error",
|
||||
error: "boom",
|
||||
stopReason: undefined,
|
||||
pendingToolCalls: undefined,
|
||||
}),
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
await waitPromise;
|
||||
} finally {
|
||||
vi.useRealTimers();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@ -7,6 +7,7 @@ import {
|
||||
} from "../../agents/spawned-context.js";
|
||||
import { buildBareSessionResetPrompt } from "../../auto-reply/reply/session-reset-prompt.js";
|
||||
import { agentCommandFromIngress } from "../../commands/agent.js";
|
||||
import type { AgentStreamParams } from "../../commands/agent/types.js";
|
||||
import { loadConfig } from "../../config/config.js";
|
||||
import {
|
||||
mergeSessionEntry,
|
||||
@ -67,6 +68,28 @@ import { normalizeRpcAttachmentsToChatAttachments } from "./attachment-normalize
|
||||
import type { GatewayRequestHandlerOptions, GatewayRequestHandlers } from "./types.js";
|
||||
|
||||
const RESET_COMMAND_RE = /^\/(new|reset)(?:\s+([\s\S]*))?$/i;
|
||||
const AGENT_WAIT_DEDUPE_METADATA_GRACE_MS = 5_000;
|
||||
|
||||
function mergeAgentWaitStructuredMetadata<T extends AgentWaitTerminalSnapshot>(
|
||||
snapshot: T,
|
||||
dedupeSnapshot: AgentWaitTerminalSnapshot | null | undefined,
|
||||
): T {
|
||||
if (!dedupeSnapshot) {
|
||||
return snapshot;
|
||||
}
|
||||
return {
|
||||
...snapshot,
|
||||
stopReason: snapshot.stopReason ?? dedupeSnapshot.stopReason,
|
||||
pendingToolCalls: snapshot.pendingToolCalls ?? dedupeSnapshot.pendingToolCalls,
|
||||
};
|
||||
}
|
||||
|
||||
function isMissingAgentWaitStructuredMetadata(snapshot: AgentWaitTerminalSnapshot): boolean {
|
||||
if (snapshot.stopReason === undefined && snapshot.status === "ok") {
|
||||
return true;
|
||||
}
|
||||
return snapshot.stopReason === "tool_calls" && snapshot.pendingToolCalls === undefined;
|
||||
}
|
||||
|
||||
function resolveSenderIsOwnerFromClient(client: GatewayRequestHandlerOptions["client"]): boolean {
|
||||
const scopes = Array.isArray(client?.connect?.scopes) ? client.connect.scopes : [];
|
||||
@ -232,6 +255,16 @@ export const agentHandlers: GatewayRequestHandlers = {
|
||||
lane?: string;
|
||||
extraSystemPrompt?: string;
|
||||
internalEvents?: AgentInternalEvent[];
|
||||
clientTools?: Array<{
|
||||
type: "function";
|
||||
function: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters?: Record<string, unknown>;
|
||||
};
|
||||
}>;
|
||||
disableTools?: boolean;
|
||||
streamParams?: AgentStreamParams;
|
||||
idempotencyKey: string;
|
||||
timeout?: number;
|
||||
bestEffortDeliver?: boolean;
|
||||
@ -698,6 +731,9 @@ export const agentHandlers: GatewayRequestHandlers = {
|
||||
lane: request.lane,
|
||||
extraSystemPrompt: request.extraSystemPrompt,
|
||||
internalEvents: request.internalEvents,
|
||||
clientTools: request.clientTools,
|
||||
disableTools: request.disableTools,
|
||||
streamParams: request.streamParams,
|
||||
inputProvenance,
|
||||
// Internal-only: allow workspace override for spawned subagent runs.
|
||||
workspaceDir: resolveIngressWorkspaceOverrideForSpawnedRun({
|
||||
@ -799,6 +835,8 @@ export const agentHandlers: GatewayRequestHandlers = {
|
||||
startedAt: cachedGatewaySnapshot.startedAt,
|
||||
endedAt: cachedGatewaySnapshot.endedAt,
|
||||
error: cachedGatewaySnapshot.error,
|
||||
stopReason: cachedGatewaySnapshot.stopReason,
|
||||
pendingToolCalls: cachedGatewaySnapshot.pendingToolCalls,
|
||||
});
|
||||
return;
|
||||
}
|
||||
@ -830,6 +868,44 @@ export const agentHandlers: GatewayRequestHandlers = {
|
||||
first.snapshot;
|
||||
if (snapshot) {
|
||||
if (first.source === "lifecycle") {
|
||||
snapshot = mergeAgentWaitStructuredMetadata(
|
||||
snapshot,
|
||||
readTerminalSnapshotFromGatewayDedupe({
|
||||
dedupe: context.dedupe,
|
||||
runId,
|
||||
ignoreAgentTerminalSnapshot: hasActiveChatRun,
|
||||
}),
|
||||
);
|
||||
if (snapshot.stopReason === undefined) {
|
||||
const immediateDedupeMetadata =
|
||||
(await Promise.race([
|
||||
dedupePromise,
|
||||
Promise.resolve<AgentWaitTerminalSnapshot | null>(null),
|
||||
])) ?? null;
|
||||
snapshot = mergeAgentWaitStructuredMetadata(snapshot, immediateDedupeMetadata);
|
||||
}
|
||||
if (isMissingAgentWaitStructuredMetadata(snapshot)) {
|
||||
let graceTimer: ReturnType<typeof setTimeout> | null = null;
|
||||
const dedupeMetadata =
|
||||
(await Promise.race([
|
||||
dedupePromise.finally(() => {
|
||||
if (graceTimer != null) {
|
||||
clearTimeout(graceTimer);
|
||||
}
|
||||
}),
|
||||
new Promise<null>((resolve) => {
|
||||
graceTimer = setTimeout(
|
||||
() => resolve(null),
|
||||
Math.max(
|
||||
1,
|
||||
Math.min(timeoutMs, AGENT_WAIT_DEDUPE_METADATA_GRACE_MS, 2_147_483_647),
|
||||
),
|
||||
);
|
||||
graceTimer.unref?.();
|
||||
}),
|
||||
])) ?? null;
|
||||
snapshot = mergeAgentWaitStructuredMetadata(snapshot, dedupeMetadata);
|
||||
}
|
||||
dedupeAbortController.abort();
|
||||
} else {
|
||||
lifecycleAbortController.abort();
|
||||
@ -853,6 +929,8 @@ export const agentHandlers: GatewayRequestHandlers = {
|
||||
startedAt: snapshot.startedAt,
|
||||
endedAt: snapshot.endedAt,
|
||||
error: snapshot.error,
|
||||
stopReason: snapshot.stopReason,
|
||||
pendingToolCalls: snapshot.pendingToolCalls,
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
29
src/gateway/server-methods/pending-tool-calls.ts
Normal file
29
src/gateway/server-methods/pending-tool-calls.ts
Normal file
@ -0,0 +1,29 @@
|
||||
export type PendingToolCall = {
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
};
|
||||
|
||||
export function parsePendingToolCalls(value: unknown): PendingToolCall[] | undefined {
|
||||
if (!Array.isArray(value)) {
|
||||
return undefined;
|
||||
}
|
||||
const calls = value
|
||||
.map((entry) => {
|
||||
if (!entry || typeof entry !== "object") {
|
||||
return null;
|
||||
}
|
||||
const record = entry as Record<string, unknown>;
|
||||
return typeof record.id === "string" &&
|
||||
typeof record.name === "string" &&
|
||||
typeof record.arguments === "string"
|
||||
? {
|
||||
id: record.id,
|
||||
name: record.name,
|
||||
arguments: record.arguments,
|
||||
}
|
||||
: null;
|
||||
})
|
||||
.filter((entry): entry is PendingToolCall => entry !== null);
|
||||
return calls.length > 0 ? calls : undefined;
|
||||
}
|
||||
@ -29,14 +29,15 @@ vi.mock("./server-methods.js", () => ({
|
||||
}));
|
||||
|
||||
vi.mock("../channels/registry.js", () => ({
|
||||
CHAT_CHANNEL_ORDER: [],
|
||||
CHANNEL_IDS: [],
|
||||
CHAT_CHANNEL_ORDER: ["telegram", "discord", "slack"],
|
||||
CHANNEL_IDS: ["telegram", "discord", "slack"],
|
||||
listChatChannels: () => [],
|
||||
listChatChannelAliases: () => [],
|
||||
getChatChannelMeta: () => null,
|
||||
normalizeChatChannelId: () => null,
|
||||
normalizeChannelId: () => null,
|
||||
normalizeAnyChannelId: () => null,
|
||||
normalizeAnyChannelId: (raw?: string | null) =>
|
||||
typeof raw === "string" && raw.trim().length > 0 ? raw.trim().toLowerCase() : null,
|
||||
formatChannelPrimerLine: () => "",
|
||||
formatChannelSelectionLine: () => "",
|
||||
}));
|
||||
@ -92,6 +93,11 @@ function getLastDispatchedClientScopes(): string[] {
|
||||
return Array.isArray(scopes) ? scopes : [];
|
||||
}
|
||||
|
||||
function getLastDispatchedRequest() {
|
||||
const call = handleGatewayRequest.mock.calls.at(-1)?.[0];
|
||||
return call?.req;
|
||||
}
|
||||
|
||||
async function loadTestModules() {
|
||||
serverPluginsModule = await import("./server-plugins.js");
|
||||
runtimeModule = await import("../plugins/runtime/index.js");
|
||||
@ -99,6 +105,10 @@ async function loadTestModules() {
|
||||
methodScopesModule = await import("./method-scopes.js");
|
||||
}
|
||||
|
||||
async function importServerPluginsModule(): Promise<ServerPluginsModule> {
|
||||
return import("./server-plugins.js");
|
||||
}
|
||||
|
||||
async function createSubagentRuntime(
|
||||
serverPlugins: ServerPluginsModule,
|
||||
cfg: Record<string, unknown> = {},
|
||||
@ -147,7 +157,17 @@ beforeEach(() => {
|
||||
opts.respond(true, { runId: "run-1" });
|
||||
return;
|
||||
case "agent.wait":
|
||||
opts.respond(true, { status: "ok" });
|
||||
opts.respond(true, {
|
||||
status: "ok",
|
||||
stopReason: "tool_calls",
|
||||
pendingToolCalls: [
|
||||
{
|
||||
id: "call-1",
|
||||
name: "emit_structured_result",
|
||||
arguments: '{"entries":[]}',
|
||||
},
|
||||
],
|
||||
});
|
||||
return;
|
||||
case "sessions.get":
|
||||
opts.respond(true, { messages: [] });
|
||||
@ -579,4 +599,92 @@ describe("loadGatewayPlugins", () => {
|
||||
| undefined;
|
||||
expect(dispatched?.marker).toBe("after-mutation");
|
||||
});
|
||||
|
||||
test("forwards structured plugin subagent options to gateway agent methods", async () => {
|
||||
const serverPlugins = await importServerPluginsModule();
|
||||
const runtime = await createSubagentRuntime(serverPlugins);
|
||||
serverPlugins.setFallbackGatewayContext(createTestContext("structured-output"));
|
||||
|
||||
await runtime.run({
|
||||
sessionKey: "s-structured",
|
||||
message: "extract memories",
|
||||
disableTools: true,
|
||||
clientTools: [
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
name: "emit_structured_result",
|
||||
description: "Return a structured result payload.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
entries: {
|
||||
type: "array",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
streamParams: {
|
||||
toolChoice: {
|
||||
type: "function",
|
||||
function: {
|
||||
name: "emit_structured_result",
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(getLastDispatchedRequest()).toEqual(
|
||||
expect.objectContaining({
|
||||
type: "req",
|
||||
id: expect.any(String),
|
||||
method: "agent",
|
||||
params: expect.objectContaining({
|
||||
sessionKey: "s-structured",
|
||||
message: "extract memories",
|
||||
disableTools: true,
|
||||
clientTools: [
|
||||
{
|
||||
type: "function",
|
||||
function: expect.objectContaining({
|
||||
name: "emit_structured_result",
|
||||
}),
|
||||
},
|
||||
],
|
||||
streamParams: {
|
||||
toolChoice: {
|
||||
type: "function",
|
||||
function: {
|
||||
name: "emit_structured_result",
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test("returns pending tool calls from gateway agent.wait", async () => {
|
||||
const serverPlugins = await importServerPluginsModule();
|
||||
const runtime = await createSubagentRuntime(serverPlugins);
|
||||
|
||||
const result = await runtime.waitForRun({
|
||||
runId: "run-1",
|
||||
timeoutMs: 1_000,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
status: "ok",
|
||||
stopReason: "tool_calls",
|
||||
pendingToolCalls: [
|
||||
{
|
||||
id: "call-1",
|
||||
name: "emit_structured_result",
|
||||
arguments: '{"entries":[]}',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -333,6 +333,9 @@ function createGatewaySubagentRuntime(): PluginRuntime["subagent"] {
|
||||
...(allowOverride && params.model && { model: params.model }),
|
||||
...(params.extraSystemPrompt && { extraSystemPrompt: params.extraSystemPrompt }),
|
||||
...(params.lane && { lane: params.lane }),
|
||||
...(params.clientTools && { clientTools: params.clientTools }),
|
||||
...(params.disableTools === true && { disableTools: true }),
|
||||
...(params.streamParams && { streamParams: params.streamParams }),
|
||||
...(params.idempotencyKey && { idempotencyKey: params.idempotencyKey }),
|
||||
},
|
||||
{
|
||||
@ -346,13 +349,19 @@ function createGatewaySubagentRuntime(): PluginRuntime["subagent"] {
|
||||
return { runId };
|
||||
},
|
||||
async waitForRun(params) {
|
||||
const payload = await dispatchGatewayMethod<{ status?: string; error?: string }>(
|
||||
"agent.wait",
|
||||
{
|
||||
runId: params.runId,
|
||||
...(params.timeoutMs != null && { timeoutMs: params.timeoutMs }),
|
||||
},
|
||||
);
|
||||
const payload = await dispatchGatewayMethod<{
|
||||
status?: string;
|
||||
error?: string;
|
||||
stopReason?: string;
|
||||
pendingToolCalls?: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
}>;
|
||||
}>("agent.wait", {
|
||||
runId: params.runId,
|
||||
...(params.timeoutMs != null && { timeoutMs: params.timeoutMs }),
|
||||
});
|
||||
const status = payload?.status;
|
||||
if (status !== "ok" && status !== "error" && status !== "timeout") {
|
||||
throw new Error(`Gateway agent.wait returned unexpected status: ${status}`);
|
||||
@ -360,6 +369,11 @@ function createGatewaySubagentRuntime(): PluginRuntime["subagent"] {
|
||||
return {
|
||||
status,
|
||||
...(typeof payload?.error === "string" && payload.error && { error: payload.error }),
|
||||
...(typeof payload?.stopReason === "string" &&
|
||||
payload.stopReason && { stopReason: payload.stopReason }),
|
||||
...(Array.isArray(payload?.pendingToolCalls) && payload.pendingToolCalls.length > 0
|
||||
? { pendingToolCalls: payload.pendingToolCalls }
|
||||
: {}),
|
||||
};
|
||||
},
|
||||
getSessionMessages,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import type { ClientToolDefinition } from "../../agents/pi-embedded-runner/run/params.js";
|
||||
import type { AgentStreamParams } from "../../commands/agent/types.js";
|
||||
import type { PluginRuntimeChannel } from "./types-channel.js";
|
||||
import type { PluginRuntimeCore, RuntimeLogger } from "./types-core.js";
|
||||
|
||||
@ -14,6 +16,9 @@ export type SubagentRunParams = {
|
||||
lane?: string;
|
||||
deliver?: boolean;
|
||||
idempotencyKey?: string;
|
||||
clientTools?: ClientToolDefinition[];
|
||||
disableTools?: boolean;
|
||||
streamParams?: AgentStreamParams;
|
||||
};
|
||||
|
||||
export type SubagentRunResult = {
|
||||
@ -28,6 +33,12 @@ export type SubagentWaitParams = {
|
||||
export type SubagentWaitResult = {
|
||||
status: "ok" | "error" | "timeout";
|
||||
error?: string;
|
||||
stopReason?: string;
|
||||
pendingToolCalls?: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
}>;
|
||||
};
|
||||
|
||||
export type SubagentGetSessionMessagesParams = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user