From 6a52c4f6f2e9377aa4a34c09c6209cd31df8ec1e Mon Sep 17 00:00:00 2001 From: Youyou972 <50808411+Youyou972@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:50:44 -0500 Subject: [PATCH 1/2] Agent: enforce required tool use --- src/agents/pi-embedded-runner/run.ts | 139 ++++++++++++++++++ src/agents/pi-embedded-runner/run/params.ts | 5 + src/agents/pi-embedded-runner/types.ts | 1 + .../usage-reporting.test.ts | 110 ++++++++++++++ 4 files changed, 255 insertions(+) diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index de2274cc3f4..965cb093b74 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -65,6 +65,10 @@ import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js"; import { describeUnknownError } from "./utils.js"; type ApiKeyInfo = ResolvedProviderAuth; +type RequiredToolUse = { + toolNames?: string[]; + reason?: string; +}; type CopilotTokenState = { githubToken: string; @@ -139,6 +143,103 @@ const hasUsageValues = ( (value) => typeof value === "number" && Number.isFinite(value) && value > 0, ); +function normalizeRequiredToolNames(toolNames: string[] | undefined): string[] | undefined { + const normalized = (toolNames ?? []) + .map((name) => + String(name ?? "") + .trim() + .toLowerCase(), + ) + .filter(Boolean); + return normalized.length > 0 ? Array.from(new Set(normalized)) : undefined; +} + +function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undefined { + const trimmedPrompt = prompt.trim(); + if (trimmedPrompt.startsWith("Read HEARTBEAT.md if it exists")) { + return { + toolNames: ["read"], + reason: "Heartbeat prompts must check HEARTBEAT.md before replying.", + }; + } + + const explicitToolNames = Array.from( + trimmedPrompt.matchAll(/\b(?:call|use)\s+(?:the\s+)?([a-z0-9_]+)\s+tool\b/gi), + ) + .map((match) => match[1]?.trim().toLowerCase()) + .filter((name): name is string => Boolean(name)); + if (explicitToolNames.length === 0) { + return undefined; + } + const toolNames = Array.from(new Set(explicitToolNames)); + return { + toolNames, + reason: + toolNames.length === 1 + ? `Prompt explicitly requires the ${toolNames[0]} tool.` + : `Prompt explicitly requires one of these tools: ${toolNames.join(", ")}.`, + }; +} + +function resolveRequiredToolUse( + requiredToolUse: RunEmbeddedPiAgentParams["requireToolUse"], + prompt: string, +): RequiredToolUse | undefined { + if (requiredToolUse) { + if (requiredToolUse === true) { + return { reason: "This run requires at least one tool call before success." }; + } + return { + toolNames: normalizeRequiredToolNames(requiredToolUse.toolNames), + reason: requiredToolUse.reason?.trim() || "This run requires tool use before success.", + }; + } + return inferRequiredToolUseFromPrompt(prompt); +} + +function didSatisfyRequiredToolUse( + attempt: { + toolMetas: Array<{ toolName: string; meta?: string }>; + clientToolCall?: { name: string; params: Record }; + didSendViaMessagingTool: boolean; + successfulCronAdds?: number; + }, + requirement: RequiredToolUse, +): boolean { + const usedToolNames = new Set( + attempt.toolMetas.map((entry) => entry.toolName.trim().toLowerCase()).filter(Boolean), + ); + if (attempt.clientToolCall?.name) { + usedToolNames.add(attempt.clientToolCall.name.trim().toLowerCase()); + } + if (attempt.didSendViaMessagingTool) { + usedToolNames.add("message"); + } + if ((attempt.successfulCronAdds ?? 0) > 0) { + usedToolNames.add("cron"); + } + + const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames); + if (!requiredToolNames) { + return usedToolNames.size > 0; + } + return requiredToolNames.some((toolName) => usedToolNames.has(toolName)); +} + +function formatRequiredToolUseErrorMessage(requirement: RequiredToolUse): string { + const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames); + const expected = + requiredToolNames && requiredToolNames.length > 0 + ? requiredToolNames.length === 1 + ? `the ${requiredToolNames[0]} tool` + : `one of these tools: ${requiredToolNames.join(", ")}` + : "at least one tool"; + const reason = requirement.reason?.trim(); + return reason + ? `${reason} The model replied without calling ${expected}.` + : `Tool use was required for this run, but the model replied without calling ${expected}.`; +} + const mergeUsageIntoAccumulator = ( target: UsageAccumulator, usage: ReturnType, @@ -255,6 +356,7 @@ export async function runEmbeddedPiAgent( : "plain" : "markdown"); const isProbeSession = params.sessionId?.startsWith("probe-") ?? false; + const requiredToolUse = resolveRequiredToolUse(params.requireToolUse, params.prompt); return enqueueSession(() => enqueueGlobal(async () => { @@ -1297,6 +1399,43 @@ export async function runEmbeddedPiAgent( compactionCount: autoCompactionCount > 0 ? autoCompactionCount : undefined, }; + if ( + requiredToolUse && + !didSatisfyRequiredToolUse( + { + toolMetas: attempt.toolMetas, + clientToolCall: attempt.clientToolCall, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + successfulCronAdds: attempt.successfulCronAdds, + }, + requiredToolUse, + ) + ) { + return { + payloads: [ + { + text: formatRequiredToolUseErrorMessage(requiredToolUse), + isError: true, + }, + ], + meta: { + durationMs: Date.now() - started, + agentMeta, + aborted, + systemPromptReport: attempt.systemPromptReport, + error: { + kind: "required_tool_use", + message: formatRequiredToolUseErrorMessage(requiredToolUse), + }, + }, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + messagingToolSentTexts: attempt.messagingToolSentTexts, + messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls, + messagingToolSentTargets: attempt.messagingToolSentTargets, + successfulCronAdds: attempt.successfulCronAdds, + }; + } + const payloads = buildEmbeddedRunPayloads({ assistantTexts: attempt.assistantTexts, toolMetas: attempt.toolMetas, diff --git a/src/agents/pi-embedded-runner/run/params.ts b/src/agents/pi-embedded-runner/run/params.ts index 048efd2cbe4..437ae3f42a7 100644 --- a/src/agents/pi-embedded-runner/run/params.ts +++ b/src/agents/pi-embedded-runner/run/params.ts @@ -71,6 +71,11 @@ export type RunEmbeddedPiAgentParams = { clientTools?: ClientToolDefinition[]; /** Disable built-in tools for this run (LLM-only mode). */ disableTools?: boolean; + /** + * Require at least one tool call before treating the run as successful. + * When `toolNames` is provided, at least one of those tools must be called. + */ + requireToolUse?: boolean | { toolNames?: string[]; reason?: string }; provider?: string; model?: string; authProfileId?: string; diff --git a/src/agents/pi-embedded-runner/types.ts b/src/agents/pi-embedded-runner/types.ts index 722abbf2a9a..b712e22a805 100644 --- a/src/agents/pi-embedded-runner/types.ts +++ b/src/agents/pi-embedded-runner/types.ts @@ -41,6 +41,7 @@ export type EmbeddedPiRunMeta = { | "compaction_failure" | "role_ordering" | "image_size" + | "required_tool_use" | "retry_limit"; message: string; }; diff --git a/src/agents/pi-embedded-runner/usage-reporting.test.ts b/src/agents/pi-embedded-runner/usage-reporting.test.ts index ed8d1227225..4359c200b72 100644 --- a/src/agents/pi-embedded-runner/usage-reporting.test.ts +++ b/src/agents/pi-embedded-runner/usage-reporting.test.ts @@ -58,4 +58,114 @@ describe("runEmbeddedPiAgent usage reporting", () => { // If the bug exists, it will likely be 350 expect(usage?.total).toBe(200); }); + + it("returns an error when a run explicitly requires tool use and no tool was called", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["I checked it."], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Do the thing.", + requireToolUse: true, + timeoutMs: 30000, + runId: "run-require-tool", + }); + + expect(result.payloads?.[0]?.isError).toBe(true); + expect(result.meta.error?.kind).toBe("required_tool_use"); + expect(result.payloads?.[0]?.text).toContain("requires at least one tool call"); + }); + + it("infers required read tool use for heartbeat prompts", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["HEARTBEAT_OK"], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: + "Read HEARTBEAT.md if it exists (workspace context). Follow it strictly. If nothing needs attention, reply HEARTBEAT_OK.", + timeoutMs: 30000, + runId: "run-heartbeat-tool-requirement", + }); + + expect(result.payloads?.[0]?.isError).toBe(true); + expect(result.meta.error?.kind).toBe("required_tool_use"); + expect(result.payloads?.[0]?.text).toContain("HEARTBEAT.md"); + expect(result.payloads?.[0]?.text).toContain("read tool"); + }); + + it("allows success when a required tool was called", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["Done."], + toolMetas: [{ toolName: "cron", meta: "action=list" }], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Call the cron tool with action=list.", + timeoutMs: 30000, + runId: "run-explicit-cron-tool", + }); + + expect(result.meta.error).toBeUndefined(); + expect(result.payloads?.[0]?.isError).not.toBe(true); + }); }); From f4d4702b84158dc226c29b87b660e578b593c29c Mon Sep 17 00:00:00 2001 From: Youyou972 <50808411+Youyou972@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:26:14 -0500 Subject: [PATCH 2/2] fix: honor trigger-based required tool use --- src/agents/pi-embedded-runner/run.ts | 100 +++++++++++------- .../usage-reporting.test.ts | 75 ++++++++++++- 2 files changed, 133 insertions(+), 42 deletions(-) diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index 965cb093b74..77312e65ccb 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -154,7 +154,17 @@ function normalizeRequiredToolNames(toolNames: string[] | undefined): string[] | return normalized.length > 0 ? Array.from(new Set(normalized)) : undefined; } -function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undefined { +function inferRequiredToolUseFromPrompt( + prompt: string, + trigger?: RunEmbeddedPiAgentParams["trigger"], +): RequiredToolUse | undefined { + if (trigger === "heartbeat") { + return { + toolNames: ["read"], + reason: "Heartbeat runs must check HEARTBEAT.md before replying.", + }; + } + const trimmedPrompt = prompt.trim(); if (trimmedPrompt.startsWith("Read HEARTBEAT.md if it exists")) { return { @@ -184,7 +194,11 @@ function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undef function resolveRequiredToolUse( requiredToolUse: RunEmbeddedPiAgentParams["requireToolUse"], prompt: string, + trigger?: RunEmbeddedPiAgentParams["trigger"], ): RequiredToolUse | undefined { + if (requiredToolUse === false) { + return undefined; + } if (requiredToolUse) { if (requiredToolUse === true) { return { reason: "This run requires at least one tool call before success." }; @@ -194,7 +208,7 @@ function resolveRequiredToolUse( reason: requiredToolUse.reason?.trim() || "This run requires tool use before success.", }; } - return inferRequiredToolUseFromPrompt(prompt); + return inferRequiredToolUseFromPrompt(prompt, trigger); } function didSatisfyRequiredToolUse( @@ -356,7 +370,11 @@ export async function runEmbeddedPiAgent( : "plain" : "markdown"); const isProbeSession = params.sessionId?.startsWith("probe-") ?? false; - const requiredToolUse = resolveRequiredToolUse(params.requireToolUse, params.prompt); + const requiredToolUse = resolveRequiredToolUse( + params.requireToolUse, + params.prompt, + params.trigger, + ); return enqueueSession(() => enqueueGlobal(async () => { @@ -1399,43 +1417,6 @@ export async function runEmbeddedPiAgent( compactionCount: autoCompactionCount > 0 ? autoCompactionCount : undefined, }; - if ( - requiredToolUse && - !didSatisfyRequiredToolUse( - { - toolMetas: attempt.toolMetas, - clientToolCall: attempt.clientToolCall, - didSendViaMessagingTool: attempt.didSendViaMessagingTool, - successfulCronAdds: attempt.successfulCronAdds, - }, - requiredToolUse, - ) - ) { - return { - payloads: [ - { - text: formatRequiredToolUseErrorMessage(requiredToolUse), - isError: true, - }, - ], - meta: { - durationMs: Date.now() - started, - agentMeta, - aborted, - systemPromptReport: attempt.systemPromptReport, - error: { - kind: "required_tool_use", - message: formatRequiredToolUseErrorMessage(requiredToolUse), - }, - }, - didSendViaMessagingTool: attempt.didSendViaMessagingTool, - messagingToolSentTexts: attempt.messagingToolSentTexts, - messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls, - messagingToolSentTargets: attempt.messagingToolSentTargets, - successfulCronAdds: attempt.successfulCronAdds, - }; - } - const payloads = buildEmbeddedRunPayloads({ assistantTexts: attempt.assistantTexts, toolMetas: attempt.toolMetas, @@ -1480,6 +1461,45 @@ export async function runEmbeddedPiAgent( }; } + if ( + requiredToolUse && + !timedOut && + !didSatisfyRequiredToolUse( + { + toolMetas: attempt.toolMetas, + clientToolCall: attempt.clientToolCall, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + successfulCronAdds: attempt.successfulCronAdds, + }, + requiredToolUse, + ) + ) { + const message = formatRequiredToolUseErrorMessage(requiredToolUse); + return { + payloads: [ + { + text: message, + isError: true, + }, + ], + meta: { + durationMs: Date.now() - started, + agentMeta, + aborted, + systemPromptReport: attempt.systemPromptReport, + error: { + kind: "required_tool_use", + message, + }, + }, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + messagingToolSentTexts: attempt.messagingToolSentTexts, + messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls, + messagingToolSentTargets: attempt.messagingToolSentTargets, + successfulCronAdds: attempt.successfulCronAdds, + }; + } + log.debug( `embedded run done: runId=${params.runId} sessionId=${params.sessionId} durationMs=${Date.now() - started} aborted=${aborted}`, ); diff --git a/src/agents/pi-embedded-runner/usage-reporting.test.ts b/src/agents/pi-embedded-runner/usage-reporting.test.ts index 4359c200b72..a9ccee193c5 100644 --- a/src/agents/pi-embedded-runner/usage-reporting.test.ts +++ b/src/agents/pi-embedded-runner/usage-reporting.test.ts @@ -122,8 +122,8 @@ describe("runEmbeddedPiAgent usage reporting", () => { sessionKey: "test-key", sessionFile: "/tmp/session.json", workspaceDir: "/tmp/workspace", - prompt: - "Read HEARTBEAT.md if it exists (workspace context). Follow it strictly. If nothing needs attention, reply HEARTBEAT_OK.", + prompt: "Use the heartbeat instructions for this workspace. Reply HEARTBEAT_OK if idle.", + trigger: "heartbeat", timeoutMs: 30000, runId: "run-heartbeat-tool-requirement", }); @@ -134,6 +134,77 @@ describe("runEmbeddedPiAgent usage reporting", () => { expect(result.payloads?.[0]?.text).toContain("read tool"); }); + it("lets timeout errors win over required-tool-use errors", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: true, + promptError: null, + timedOut: true, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: [], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: null, + attemptUsage: { input: 10, output: 0, total: 10 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Call the cron tool with action=list.", + timeoutMs: 30000, + runId: "run-timeout-before-tool", + }); + + expect(result.meta.error?.kind).toBeUndefined(); + expect(result.payloads?.[0]?.isError).toBe(true); + expect(result.payloads?.[0]?.text).toContain( + "Request timed out before a response was generated", + ); + }); + + it("allows callers to disable prompt-inferred tool requirements explicitly", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["Handled without tools."], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Call the cron tool with action=list.", + requireToolUse: false, + timeoutMs: 30000, + runId: "run-disable-inferred-tool-requirement", + }); + + expect(result.meta.error).toBeUndefined(); + expect(result.payloads?.[0]?.isError).not.toBe(true); + }); + it("allows success when a required tool was called", async () => { mockedRunEmbeddedAttempt.mockResolvedValueOnce({ aborted: false,