diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index de2274cc3f4..965cb093b74 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -65,6 +65,10 @@ import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js"; import { describeUnknownError } from "./utils.js"; type ApiKeyInfo = ResolvedProviderAuth; +type RequiredToolUse = { + toolNames?: string[]; + reason?: string; +}; type CopilotTokenState = { githubToken: string; @@ -139,6 +143,103 @@ const hasUsageValues = ( (value) => typeof value === "number" && Number.isFinite(value) && value > 0, ); +function normalizeRequiredToolNames(toolNames: string[] | undefined): string[] | undefined { + const normalized = (toolNames ?? []) + .map((name) => + String(name ?? "") + .trim() + .toLowerCase(), + ) + .filter(Boolean); + return normalized.length > 0 ? Array.from(new Set(normalized)) : undefined; +} + +function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undefined { + const trimmedPrompt = prompt.trim(); + if (trimmedPrompt.startsWith("Read HEARTBEAT.md if it exists")) { + return { + toolNames: ["read"], + reason: "Heartbeat prompts must check HEARTBEAT.md before replying.", + }; + } + + const explicitToolNames = Array.from( + trimmedPrompt.matchAll(/\b(?:call|use)\s+(?:the\s+)?([a-z0-9_]+)\s+tool\b/gi), + ) + .map((match) => match[1]?.trim().toLowerCase()) + .filter((name): name is string => Boolean(name)); + if (explicitToolNames.length === 0) { + return undefined; + } + const toolNames = Array.from(new Set(explicitToolNames)); + return { + toolNames, + reason: + toolNames.length === 1 + ? `Prompt explicitly requires the ${toolNames[0]} tool.` + : `Prompt explicitly requires one of these tools: ${toolNames.join(", ")}.`, + }; +} + +function resolveRequiredToolUse( + requiredToolUse: RunEmbeddedPiAgentParams["requireToolUse"], + prompt: string, +): RequiredToolUse | undefined { + if (requiredToolUse) { + if (requiredToolUse === true) { + return { reason: "This run requires at least one tool call before success." }; + } + return { + toolNames: normalizeRequiredToolNames(requiredToolUse.toolNames), + reason: requiredToolUse.reason?.trim() || "This run requires tool use before success.", + }; + } + return inferRequiredToolUseFromPrompt(prompt); +} + +function didSatisfyRequiredToolUse( + attempt: { + toolMetas: Array<{ toolName: string; meta?: string }>; + clientToolCall?: { name: string; params: Record }; + didSendViaMessagingTool: boolean; + successfulCronAdds?: number; + }, + requirement: RequiredToolUse, +): boolean { + const usedToolNames = new Set( + attempt.toolMetas.map((entry) => entry.toolName.trim().toLowerCase()).filter(Boolean), + ); + if (attempt.clientToolCall?.name) { + usedToolNames.add(attempt.clientToolCall.name.trim().toLowerCase()); + } + if (attempt.didSendViaMessagingTool) { + usedToolNames.add("message"); + } + if ((attempt.successfulCronAdds ?? 0) > 0) { + usedToolNames.add("cron"); + } + + const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames); + if (!requiredToolNames) { + return usedToolNames.size > 0; + } + return requiredToolNames.some((toolName) => usedToolNames.has(toolName)); +} + +function formatRequiredToolUseErrorMessage(requirement: RequiredToolUse): string { + const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames); + const expected = + requiredToolNames && requiredToolNames.length > 0 + ? requiredToolNames.length === 1 + ? `the ${requiredToolNames[0]} tool` + : `one of these tools: ${requiredToolNames.join(", ")}` + : "at least one tool"; + const reason = requirement.reason?.trim(); + return reason + ? `${reason} The model replied without calling ${expected}.` + : `Tool use was required for this run, but the model replied without calling ${expected}.`; +} + const mergeUsageIntoAccumulator = ( target: UsageAccumulator, usage: ReturnType, @@ -255,6 +356,7 @@ export async function runEmbeddedPiAgent( : "plain" : "markdown"); const isProbeSession = params.sessionId?.startsWith("probe-") ?? false; + const requiredToolUse = resolveRequiredToolUse(params.requireToolUse, params.prompt); return enqueueSession(() => enqueueGlobal(async () => { @@ -1297,6 +1399,43 @@ export async function runEmbeddedPiAgent( compactionCount: autoCompactionCount > 0 ? autoCompactionCount : undefined, }; + if ( + requiredToolUse && + !didSatisfyRequiredToolUse( + { + toolMetas: attempt.toolMetas, + clientToolCall: attempt.clientToolCall, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + successfulCronAdds: attempt.successfulCronAdds, + }, + requiredToolUse, + ) + ) { + return { + payloads: [ + { + text: formatRequiredToolUseErrorMessage(requiredToolUse), + isError: true, + }, + ], + meta: { + durationMs: Date.now() - started, + agentMeta, + aborted, + systemPromptReport: attempt.systemPromptReport, + error: { + kind: "required_tool_use", + message: formatRequiredToolUseErrorMessage(requiredToolUse), + }, + }, + didSendViaMessagingTool: attempt.didSendViaMessagingTool, + messagingToolSentTexts: attempt.messagingToolSentTexts, + messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls, + messagingToolSentTargets: attempt.messagingToolSentTargets, + successfulCronAdds: attempt.successfulCronAdds, + }; + } + const payloads = buildEmbeddedRunPayloads({ assistantTexts: attempt.assistantTexts, toolMetas: attempt.toolMetas, diff --git a/src/agents/pi-embedded-runner/run/params.ts b/src/agents/pi-embedded-runner/run/params.ts index 048efd2cbe4..437ae3f42a7 100644 --- a/src/agents/pi-embedded-runner/run/params.ts +++ b/src/agents/pi-embedded-runner/run/params.ts @@ -71,6 +71,11 @@ export type RunEmbeddedPiAgentParams = { clientTools?: ClientToolDefinition[]; /** Disable built-in tools for this run (LLM-only mode). */ disableTools?: boolean; + /** + * Require at least one tool call before treating the run as successful. + * When `toolNames` is provided, at least one of those tools must be called. + */ + requireToolUse?: boolean | { toolNames?: string[]; reason?: string }; provider?: string; model?: string; authProfileId?: string; diff --git a/src/agents/pi-embedded-runner/types.ts b/src/agents/pi-embedded-runner/types.ts index 722abbf2a9a..b712e22a805 100644 --- a/src/agents/pi-embedded-runner/types.ts +++ b/src/agents/pi-embedded-runner/types.ts @@ -41,6 +41,7 @@ export type EmbeddedPiRunMeta = { | "compaction_failure" | "role_ordering" | "image_size" + | "required_tool_use" | "retry_limit"; message: string; }; diff --git a/src/agents/pi-embedded-runner/usage-reporting.test.ts b/src/agents/pi-embedded-runner/usage-reporting.test.ts index ed8d1227225..4359c200b72 100644 --- a/src/agents/pi-embedded-runner/usage-reporting.test.ts +++ b/src/agents/pi-embedded-runner/usage-reporting.test.ts @@ -58,4 +58,114 @@ describe("runEmbeddedPiAgent usage reporting", () => { // If the bug exists, it will likely be 350 expect(usage?.total).toBe(200); }); + + it("returns an error when a run explicitly requires tool use and no tool was called", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["I checked it."], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Do the thing.", + requireToolUse: true, + timeoutMs: 30000, + runId: "run-require-tool", + }); + + expect(result.payloads?.[0]?.isError).toBe(true); + expect(result.meta.error?.kind).toBe("required_tool_use"); + expect(result.payloads?.[0]?.text).toContain("requires at least one tool call"); + }); + + it("infers required read tool use for heartbeat prompts", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["HEARTBEAT_OK"], + toolMetas: [], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: + "Read HEARTBEAT.md if it exists (workspace context). Follow it strictly. If nothing needs attention, reply HEARTBEAT_OK.", + timeoutMs: 30000, + runId: "run-heartbeat-tool-requirement", + }); + + expect(result.payloads?.[0]?.isError).toBe(true); + expect(result.meta.error?.kind).toBe("required_tool_use"); + expect(result.payloads?.[0]?.text).toContain("HEARTBEAT.md"); + expect(result.payloads?.[0]?.text).toContain("read tool"); + }); + + it("allows success when a required tool was called", async () => { + mockedRunEmbeddedAttempt.mockResolvedValueOnce({ + aborted: false, + promptError: null, + timedOut: false, + timedOutDuringCompaction: false, + sessionIdUsed: "test-session", + assistantTexts: ["Done."], + toolMetas: [{ toolName: "cron", meta: "action=list" }], + didSendViaMessagingTool: false, + lastAssistant: { + usage: { input: 10, output: 5, total: 15 }, + stopReason: "end_turn", + }, + attemptUsage: { input: 10, output: 5, total: 15 }, + messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], + messagingToolSentTargets: [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + const result = await runEmbeddedPiAgent({ + sessionId: "test-session", + sessionKey: "test-key", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp/workspace", + prompt: "Call the cron tool with action=list.", + timeoutMs: 30000, + runId: "run-explicit-cron-tool", + }); + + expect(result.meta.error).toBeUndefined(); + expect(result.payloads?.[0]?.isError).not.toBe(true); + }); });