From f4d4702b84158dc226c29b87b660e578b593c29c Mon Sep 17 00:00:00 2001 From: Youyou972 <50808411+Youyou972@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:26:14 -0500 Subject: [PATCH] fix: honor trigger-based required tool use --- src/agents/pi-embedded-runner/run.ts | 100 +++++++++++------- .../usage-reporting.test.ts | 75 ++++++++++++- 2 files changed, 133 insertions(+), 42 deletions(-) diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index 965cb093b74..77312e65ccb 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -154,7 +154,17 @@ function normalizeRequiredToolNames(toolNames: string[] | undefined): string[] | return normalized.length > 0 ? Array.from(new Set(normalized)) : undefined; } -function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undefined { +function inferRequiredToolUseFromPrompt( + prompt: string, + trigger?: RunEmbeddedPiAgentParams["trigger"], +): RequiredToolUse | undefined { + if (trigger === "heartbeat") { + return { + toolNames: ["read"], + reason: "Heartbeat runs must check HEARTBEAT.md before replying.", + }; + } + const trimmedPrompt = prompt.trim(); if (trimmedPrompt.startsWith("Read HEARTBEAT.md if it exists")) { return { @@ -184,7 +194,11 @@ function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undef function resolveRequiredToolUse( requiredToolUse: RunEmbeddedPiAgentParams["requireToolUse"], prompt: string, + trigger?: RunEmbeddedPiAgentParams["trigger"], ): RequiredToolUse | undefined { + if (requiredToolUse === false) { + return undefined; + } if (requiredToolUse) { if (requiredToolUse === true) { return { reason: "This run requires at least one tool call before success." }; @@ -194,7 +208,7 @@ function resolveRequiredToolUse( reason: requiredToolUse.reason?.trim() || "This run requires tool use before success.", }; } - return inferRequiredToolUseFromPrompt(prompt); + return inferRequiredToolUseFromPrompt(prompt, trigger); } function didSatisfyRequiredToolUse( @@ -356,7 +370,11 @@ export async function runEmbeddedPiAgent( : "plain" : "markdown"); const isProbeSession = params.sessionId?.startsWith("probe-") ?? false; - const requiredToolUse = resolveRequiredToolUse(params.requireToolUse, params.prompt); + const requiredToolUse = resolveRequiredToolUse( + params.requireToolUse, + params.prompt, + params.trigger, + ); return enqueueSession(() => enqueueGlobal(async () => { @@ -1399,43 +1417,6 @@ export async function runEmbeddedPiAgent( compactionCount: autoCompactionCount > 0 ? autoCompactionCount : undefined, }; - if ( - requiredToolUse && - !didSatisfyRequiredToolUse( - { - toolMetas: attempt.toolMetas, - clientToolCall: attempt.clientToolCall, - didSendViaMessagingTool: attempt.didSendViaMessagingTool, - successfulCronAdds: attempt.successfulCronAdds, - }, - requiredToolUse, - ) - ) { - return { - payloads: [ - { - text: formatRequiredToolUseErrorMessage(requiredToolUse), - isError: true, - }, - ], - meta: { - durationMs: Date.now() - started, - agentMeta, - aborted, - systemPromptReport: attempt.systemPromptReport, - error: { - kind: "required_tool_use", - message: formatRequiredToolUseErrorMessage(requiredToolUse), - }, - }, - didSendViaMessagingTool: attempt.didSendViaMessagingTool, - messagingToolSentTexts: attempt.messagingToolSentTexts, - messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls, - messagingToolSentTargets: attempt.messagingToolSentTargets, - successfulCronAdds: attempt.successfulCronAdds, - }; - } - const payloads = buildEmbeddedRunPayloads({ assistantTexts: attempt.assistantTexts, toolMetas: attempt.toolMetas, @@ -1480,6 +1461,45 @@ export async function runEmbeddedPiAgent( }; } + if ( + requiredToolUse && + !timedOut && + !didSatisfyRequiredToolUse( + { + toolMetas: attempt.toolMetas, + clientToolCall: attempt.clientToolCall, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + successfulCronAdds: attempt.successfulCronAdds, + }, + requiredToolUse, + ) + ) { + const message = formatRequiredToolUseErrorMessage(requiredToolUse); + return { + payloads: [ + { + text: message, + isError: true, + }, + ], + meta: { + durationMs: Date.now() - started, + agentMeta, + aborted, + systemPromptReport: attempt.systemPromptReport, + error: { + kind: "required_tool_use", + message, + }, + }, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + messagingToolSentTexts: attempt.messagingToolSentTexts, + messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls, + messagingToolSentTargets: attempt.messagingToolSentTargets, + successfulCronAdds: attempt.successfulCronAdds, + }; + } + log.debug( `embedded run done: runId=${params.runId} sessionId=${params.sessionId} durationMs=${Date.now() - started} aborted=${aborted}`, ); diff --git a/src/agents/pi-embedded-runner/usage-reporting.test.ts b/src/agents/pi-embedded-runner/usage-reporting.test.ts index 4359c200b72..a9ccee193c5 100644 --- a/src/agents/pi-embedded-runner/usage-reporting.test.ts +++ b/src/agents/pi-embedded-runner/usage-reporting.test.ts @@ -122,8 +122,8 @@ describe("runEmbeddedPiAgent usage reporting", () => { sessionKey: "test-key", sessionFile: "/tmp/session.json", workspaceDir: "/tmp/workspace", - prompt: - "Read HEARTBEAT.md if it exists (workspace context). Follow it strictly. If nothing needs attention, reply HEARTBEAT_OK.", + prompt: "Use the heartbeat instructions for this workspace. Reply HEARTBEAT_OK if idle.", + trigger: "heartbeat", timeoutMs: 30000, runId: "run-heartbeat-tool-requirement", }); @@ -134,6 +134,77 @@ describe("runEmbeddedPiAgent usage reporting", () => { expect(result.payloads?.[0]?.text).toContain("read tool"); }); + it("lets timeout errors win over required-tool-use errors", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: true, + promptError: null, + timedOut: true, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: [], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: null, + attemptUsage: { input: 10, output: 0, total: 10 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Call the cron tool with action=list.", + timeoutMs: 30000, + runId: "run-timeout-before-tool", + }); + + expect(result.meta.error?.kind).toBeUndefined(); + expect(result.payloads?.[0]?.isError).toBe(true); + expect(result.payloads?.[0]?.text).toContain( + "Request timed out before a response was generated", + ); + }); + + it("allows callers to disable prompt-inferred tool requirements explicitly", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["Handled without tools."], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Call the cron tool with action=list.", + requireToolUse: false, + timeoutMs: 30000, + runId: "run-disable-inferred-tool-requirement", + }); + + expect(result.meta.error).toBeUndefined(); + expect(result.payloads?.[0]?.isError).not.toBe(true); + }); + it("allows success when a required tool was called", async () => { mockedRunEmbeddedAttempt.mockResolvedValueOnce({ aborted: false,