fix: honor trigger-based required tool use

This commit is contained in:
Youyou972 2026-03-05 16:26:14 -05:00
parent 6a52c4f6f2
commit f4d4702b84
2 changed files with 133 additions and 42 deletions

View File

@ -154,7 +154,17 @@ function normalizeRequiredToolNames(toolNames: string[] | undefined): string[] |
return normalized.length > 0 ? Array.from(new Set(normalized)) : undefined;
}
function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undefined {
function inferRequiredToolUseFromPrompt(
prompt: string,
trigger?: RunEmbeddedPiAgentParams["trigger"],
): RequiredToolUse | undefined {
if (trigger === "heartbeat") {
return {
toolNames: ["read"],
reason: "Heartbeat runs must check HEARTBEAT.md before replying.",
};
}
const trimmedPrompt = prompt.trim();
if (trimmedPrompt.startsWith("Read HEARTBEAT.md if it exists")) {
return {
@ -184,7 +194,11 @@ function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undef
function resolveRequiredToolUse(
requiredToolUse: RunEmbeddedPiAgentParams["requireToolUse"],
prompt: string,
trigger?: RunEmbeddedPiAgentParams["trigger"],
): RequiredToolUse | undefined {
if (requiredToolUse === false) {
return undefined;
}
if (requiredToolUse) {
if (requiredToolUse === true) {
return { reason: "This run requires at least one tool call before success." };
@ -194,7 +208,7 @@ function resolveRequiredToolUse(
reason: requiredToolUse.reason?.trim() || "This run requires tool use before success.",
};
}
return inferRequiredToolUseFromPrompt(prompt);
return inferRequiredToolUseFromPrompt(prompt, trigger);
}
function didSatisfyRequiredToolUse(
@ -356,7 +370,11 @@ export async function runEmbeddedPiAgent(
: "plain"
: "markdown");
const isProbeSession = params.sessionId?.startsWith("probe-") ?? false;
const requiredToolUse = resolveRequiredToolUse(params.requireToolUse, params.prompt);
const requiredToolUse = resolveRequiredToolUse(
params.requireToolUse,
params.prompt,
params.trigger,
);
return enqueueSession(() =>
enqueueGlobal(async () => {
@ -1399,43 +1417,6 @@ export async function runEmbeddedPiAgent(
compactionCount: autoCompactionCount > 0 ? autoCompactionCount : undefined,
};
if (
requiredToolUse &&
!didSatisfyRequiredToolUse(
{
toolMetas: attempt.toolMetas,
clientToolCall: attempt.clientToolCall,
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
successfulCronAdds: attempt.successfulCronAdds,
},
requiredToolUse,
)
) {
return {
payloads: [
{
text: formatRequiredToolUseErrorMessage(requiredToolUse),
isError: true,
},
],
meta: {
durationMs: Date.now() - started,
agentMeta,
aborted,
systemPromptReport: attempt.systemPromptReport,
error: {
kind: "required_tool_use",
message: formatRequiredToolUseErrorMessage(requiredToolUse),
},
},
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
messagingToolSentTexts: attempt.messagingToolSentTexts,
messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls,
messagingToolSentTargets: attempt.messagingToolSentTargets,
successfulCronAdds: attempt.successfulCronAdds,
};
}
const payloads = buildEmbeddedRunPayloads({
assistantTexts: attempt.assistantTexts,
toolMetas: attempt.toolMetas,
@ -1480,6 +1461,45 @@ export async function runEmbeddedPiAgent(
};
}
if (
requiredToolUse &&
!timedOut &&
!didSatisfyRequiredToolUse(
{
toolMetas: attempt.toolMetas,
clientToolCall: attempt.clientToolCall,
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
successfulCronAdds: attempt.successfulCronAdds,
},
requiredToolUse,
)
) {
const message = formatRequiredToolUseErrorMessage(requiredToolUse);
return {
payloads: [
{
text: message,
isError: true,
},
],
meta: {
durationMs: Date.now() - started,
agentMeta,
aborted,
systemPromptReport: attempt.systemPromptReport,
error: {
kind: "required_tool_use",
message,
},
},
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
messagingToolSentTexts: attempt.messagingToolSentTexts,
messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls,
messagingToolSentTargets: attempt.messagingToolSentTargets,
successfulCronAdds: attempt.successfulCronAdds,
};
}
log.debug(
`embedded run done: runId=${params.runId} sessionId=${params.sessionId} durationMs=${Date.now() - started} aborted=${aborted}`,
);

View File

@ -122,8 +122,8 @@ describe("runEmbeddedPiAgent usage reporting", () => {
sessionKey: "test-key",
sessionFile: "/tmp/session.json",
workspaceDir: "/tmp/workspace",
prompt:
"Read HEARTBEAT.md if it exists (workspace context). Follow it strictly. If nothing needs attention, reply HEARTBEAT_OK.",
prompt: "Use the heartbeat instructions for this workspace. Reply HEARTBEAT_OK if idle.",
trigger: "heartbeat",
timeoutMs: 30000,
runId: "run-heartbeat-tool-requirement",
});
@ -134,6 +134,77 @@ describe("runEmbeddedPiAgent usage reporting", () => {
expect(result.payloads?.[0]?.text).toContain("read tool");
});
it("lets timeout errors win over required-tool-use errors", async () => {
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
aborted: true,
promptError: null,
timedOut: true,
timedOutDuringCompaction: false,
sessionIdUsed: "test-session",
assistantTexts: [],
toolMetas: [],
didSendViaMessagingTool: false,
lastAssistant: null,
attemptUsage: { input: 10, output: 0, total: 10 },
messagingToolSentTexts: [],
messagingToolSentMediaUrls: [],
messagingToolSentTargets: [],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const result = await runEmbeddedPiAgent({
sessionId: "test-session",
sessionKey: "test-key",
sessionFile: "/tmp/session.json",
workspaceDir: "/tmp/workspace",
prompt: "Call the cron tool with action=list.",
timeoutMs: 30000,
runId: "run-timeout-before-tool",
});
expect(result.meta.error?.kind).toBeUndefined();
expect(result.payloads?.[0]?.isError).toBe(true);
expect(result.payloads?.[0]?.text).toContain(
"Request timed out before a response was generated",
);
});
it("allows callers to disable prompt-inferred tool requirements explicitly", async () => {
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
aborted: false,
promptError: null,
timedOut: false,
timedOutDuringCompaction: false,
sessionIdUsed: "test-session",
assistantTexts: ["Handled without tools."],
toolMetas: [],
didSendViaMessagingTool: false,
lastAssistant: {
usage: { input: 10, output: 5, total: 15 },
stopReason: "end_turn",
},
attemptUsage: { input: 10, output: 5, total: 15 },
messagingToolSentTexts: [],
messagingToolSentMediaUrls: [],
messagingToolSentTargets: [],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const result = await runEmbeddedPiAgent({
sessionId: "test-session",
sessionKey: "test-key",
sessionFile: "/tmp/session.json",
workspaceDir: "/tmp/workspace",
prompt: "Call the cron tool with action=list.",
requireToolUse: false,
timeoutMs: 30000,
runId: "run-disable-inferred-tool-requirement",
});
expect(result.meta.error).toBeUndefined();
expect(result.payloads?.[0]?.isError).not.toBe(true);
});
it("allows success when a required tool was called", async () => {
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
aborted: false,