diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index 0c66203992f..64b83d0e677 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -82,6 +82,10 @@ import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js"; import { describeUnknownError } from "./utils.js"; type ApiKeyInfo = ResolvedProviderAuth; +type RequiredToolUse = { + toolNames?: string[]; + reason?: string; +}; type RuntimeAuthState = { sourceApiKey: string; @@ -166,6 +170,117 @@ const hasUsageValues = ( (value) => typeof value === "number" && Number.isFinite(value) && value > 0, ); +function normalizeRequiredToolNames(toolNames: string[] | undefined): string[] | undefined { + const normalized = (toolNames ?? []) + .map((name) => + String(name ?? "") + .trim() + .toLowerCase(), + ) + .filter(Boolean); + return normalized.length > 0 ? Array.from(new Set(normalized)) : undefined; +} + +function inferRequiredToolUseFromPrompt( + prompt: string, + trigger?: RunEmbeddedPiAgentParams["trigger"], +): RequiredToolUse | undefined { + if (trigger === "heartbeat") { + return { + toolNames: ["read"], + reason: "Heartbeat runs must check HEARTBEAT.md before replying.", + }; + } + + const trimmedPrompt = prompt.trim(); + if (trimmedPrompt.startsWith("Read HEARTBEAT.md if it exists")) { + return { + toolNames: ["read"], + reason: "Heartbeat prompts must check HEARTBEAT.md before replying.", + }; + } + + const explicitToolNames = Array.from( + trimmedPrompt.matchAll(/\b(?:call|use)\s+(?:the\s+)?([a-z0-9_]+)\s+tool\b/gi), + ) + .map((match) => match[1]?.trim().toLowerCase()) + .filter((name): name is string => Boolean(name)); + if (explicitToolNames.length === 0) { + return undefined; + } + const toolNames = Array.from(new Set(explicitToolNames)); + return { + toolNames, + reason: + toolNames.length === 1 + ? `Prompt explicitly requires the ${toolNames[0]} tool.` + : `Prompt explicitly requires one of these tools: ${toolNames.join(", ")}.`, + }; +} + +function resolveRequiredToolUse( + requiredToolUse: RunEmbeddedPiAgentParams["requireToolUse"], + prompt: string, + trigger?: RunEmbeddedPiAgentParams["trigger"], +): RequiredToolUse | undefined { + if (requiredToolUse === false) { + return undefined; + } + if (requiredToolUse) { + if (requiredToolUse === true) { + return { reason: "This run requires at least one tool call before success." }; + } + return { + toolNames: normalizeRequiredToolNames(requiredToolUse.toolNames), + reason: requiredToolUse.reason?.trim() || "This run requires tool use before success.", + }; + } + return inferRequiredToolUseFromPrompt(prompt, trigger); +} + +function didSatisfyRequiredToolUse( + attempt: { + toolMetas: Array<{ toolName: string; meta?: string }>; + clientToolCall?: { name: string; params: Record }; + didSendViaMessagingTool: boolean; + successfulCronAdds?: number; + }, + requirement: RequiredToolUse, +): boolean { + const usedToolNames = new Set( + attempt.toolMetas.map((entry) => entry.toolName.trim().toLowerCase()).filter(Boolean), + ); + if (attempt.clientToolCall?.name) { + usedToolNames.add(attempt.clientToolCall.name.trim().toLowerCase()); + } + if (attempt.didSendViaMessagingTool) { + usedToolNames.add("message"); + } + if ((attempt.successfulCronAdds ?? 0) > 0) { + usedToolNames.add("cron"); + } + + const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames); + if (!requiredToolNames) { + return usedToolNames.size > 0; + } + return requiredToolNames.some((toolName) => usedToolNames.has(toolName)); +} + +function formatRequiredToolUseErrorMessage(requirement: RequiredToolUse): string { + const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames); + const expected = + requiredToolNames && requiredToolNames.length > 0 + ? requiredToolNames.length === 1 + ? `the ${requiredToolNames[0]} tool` + : `one of these tools: ${requiredToolNames.join(", ")}` + : "at least one tool"; + const reason = requirement.reason?.trim(); + return reason + ? `${reason} The model replied without calling ${expected}.` + : `Tool use was required for this run, but the model replied without calling ${expected}.`; +} + const mergeUsageIntoAccumulator = ( target: UsageAccumulator, usage: ReturnType, @@ -282,6 +397,11 @@ export async function runEmbeddedPiAgent( : "plain" : "markdown"); const isProbeSession = params.sessionId?.startsWith("probe-") ?? false; + const requiredToolUse = resolveRequiredToolUse( + params.requireToolUse, + params.prompt, + params.trigger, + ); return enqueueSession(() => enqueueGlobal(async () => { @@ -1660,6 +1780,45 @@ export async function runEmbeddedPiAgent( }; } + if ( + requiredToolUse && + !timedOut && + !didSatisfyRequiredToolUse( + { + toolMetas: attempt.toolMetas, + clientToolCall: attempt.clientToolCall, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + successfulCronAdds: attempt.successfulCronAdds, + }, + requiredToolUse, + ) + ) { + const message = formatRequiredToolUseErrorMessage(requiredToolUse); + return { + payloads: [ + { + text: message, + isError: true, + }, + ], + meta: { + durationMs: Date.now() - started, + agentMeta, + aborted, + systemPromptReport: attempt.systemPromptReport, + error: { + kind: "required_tool_use", + message, + }, + }, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + messagingToolSentTexts: attempt.messagingToolSentTexts, + messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls, + messagingToolSentTargets: attempt.messagingToolSentTargets, + successfulCronAdds: attempt.successfulCronAdds, + }; + } + log.debug( `embedded run done: runId=${params.runId} sessionId=${params.sessionId} durationMs=${Date.now() - started} aborted=${aborted}`, ); diff --git a/src/agents/pi-embedded-runner/run/params.ts b/src/agents/pi-embedded-runner/run/params.ts index f59bb8f27b5..69102cea072 100644 --- a/src/agents/pi-embedded-runner/run/params.ts +++ b/src/agents/pi-embedded-runner/run/params.ts @@ -76,6 +76,11 @@ export type RunEmbeddedPiAgentParams = { clientTools?: ClientToolDefinition[]; /** Disable built-in tools for this run (LLM-only mode). */ disableTools?: boolean; + /** + * Require at least one tool call before treating the run as successful. + * When `toolNames` is provided, at least one of those tools must be called. + */ + requireToolUse?: boolean | { toolNames?: string[]; reason?: string }; provider?: string; model?: string; authProfileId?: string; diff --git a/src/agents/pi-embedded-runner/types.ts b/src/agents/pi-embedded-runner/types.ts index 722abbf2a9a..b712e22a805 100644 --- a/src/agents/pi-embedded-runner/types.ts +++ b/src/agents/pi-embedded-runner/types.ts @@ -41,6 +41,7 @@ export type EmbeddedPiRunMeta = { | "compaction_failure" | "role_ordering" | "image_size" + | "required_tool_use" | "retry_limit"; message: string; }; diff --git a/src/agents/pi-embedded-runner/usage-reporting.test.ts b/src/agents/pi-embedded-runner/usage-reporting.test.ts index f748ac3b9b5..11d970e400d 100644 --- a/src/agents/pi-embedded-runner/usage-reporting.test.ts +++ b/src/agents/pi-embedded-runner/usage-reporting.test.ts @@ -188,4 +188,185 @@ describe("runEmbeddedPiAgent usage reporting", () => { // If the bug exists, it will likely be 350 expect(usage?.total).toBe(200); }); + + it("returns an error when a run explicitly requires tool use and no tool was called", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["I checked it."], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Do the thing.", + requireToolUse: true, + timeoutMs: 30000, + runId: "run-require-tool", + }); + + expect(result.payloads?.[0]?.isError).toBe(true); + expect(result.meta.error?.kind).toBe("required_tool_use"); + expect(result.payloads?.[0]?.text).toContain("requires at least one tool call"); + }); + + it("infers required read tool use for heartbeat prompts", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["HEARTBEAT_OK"], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Use the heartbeat instructions for this workspace. Reply HEARTBEAT_OK if idle.", + trigger: "heartbeat", + timeoutMs: 30000, + runId: "run-heartbeat-tool-requirement", + }); + + expect(result.payloads?.[0]?.isError).toBe(true); + expect(result.meta.error?.kind).toBe("required_tool_use"); + expect(result.payloads?.[0]?.text).toContain("HEARTBEAT.md"); + expect(result.payloads?.[0]?.text).toContain("read tool"); + }); + + it("lets timeout errors win over required-tool-use errors", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: true, + promptError: null, + timedOut: true, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: [], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: null, + attemptUsage: { input: 10, output: 0, total: 10 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Call the cron tool with action=list.", + timeoutMs: 30000, + runId: "run-timeout-before-tool", + }); + + expect(result.meta.error?.kind).toBeUndefined(); + expect(result.payloads?.[0]?.isError).toBe(true); + expect(result.payloads?.[0]?.text).toContain( + "Request timed out before a response was generated", + ); + }); + + it("allows callers to disable prompt-inferred tool requirements explicitly", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["Handled without tools."], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Call the cron tool with action=list.", + requireToolUse: false, + timeoutMs: 30000, + runId: "run-disable-inferred-tool-requirement", + }); + + expect(result.meta.error).toBeUndefined(); + expect(result.payloads?.[0]?.isError).not.toBe(true); + }); + + it("allows success when a required tool was called", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["Done."], + toolMetas: [{ toolName: "cron", meta: "action=list" }], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Call the cron tool with action=list.", + timeoutMs: 30000, + runId: "run-explicit-cron-tool", + }); + + expect(result.meta.error).toBeUndefined(); + expect(result.payloads?.[0]?.isError).not.toBe(true); + }); });