fix: honor trigger-based required tool use
This commit is contained in:
parent
6a52c4f6f2
commit
f4d4702b84
@ -154,7 +154,17 @@ function normalizeRequiredToolNames(toolNames: string[] | undefined): string[] |
|
||||
return normalized.length > 0 ? Array.from(new Set(normalized)) : undefined;
|
||||
}
|
||||
|
||||
function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undefined {
|
||||
function inferRequiredToolUseFromPrompt(
|
||||
prompt: string,
|
||||
trigger?: RunEmbeddedPiAgentParams["trigger"],
|
||||
): RequiredToolUse | undefined {
|
||||
if (trigger === "heartbeat") {
|
||||
return {
|
||||
toolNames: ["read"],
|
||||
reason: "Heartbeat runs must check HEARTBEAT.md before replying.",
|
||||
};
|
||||
}
|
||||
|
||||
const trimmedPrompt = prompt.trim();
|
||||
if (trimmedPrompt.startsWith("Read HEARTBEAT.md if it exists")) {
|
||||
return {
|
||||
@ -184,7 +194,11 @@ function inferRequiredToolUseFromPrompt(prompt: string): RequiredToolUse | undef
|
||||
function resolveRequiredToolUse(
|
||||
requiredToolUse: RunEmbeddedPiAgentParams["requireToolUse"],
|
||||
prompt: string,
|
||||
trigger?: RunEmbeddedPiAgentParams["trigger"],
|
||||
): RequiredToolUse | undefined {
|
||||
if (requiredToolUse === false) {
|
||||
return undefined;
|
||||
}
|
||||
if (requiredToolUse) {
|
||||
if (requiredToolUse === true) {
|
||||
return { reason: "This run requires at least one tool call before success." };
|
||||
@ -194,7 +208,7 @@ function resolveRequiredToolUse(
|
||||
reason: requiredToolUse.reason?.trim() || "This run requires tool use before success.",
|
||||
};
|
||||
}
|
||||
return inferRequiredToolUseFromPrompt(prompt);
|
||||
return inferRequiredToolUseFromPrompt(prompt, trigger);
|
||||
}
|
||||
|
||||
function didSatisfyRequiredToolUse(
|
||||
@ -356,7 +370,11 @@ export async function runEmbeddedPiAgent(
|
||||
: "plain"
|
||||
: "markdown");
|
||||
const isProbeSession = params.sessionId?.startsWith("probe-") ?? false;
|
||||
const requiredToolUse = resolveRequiredToolUse(params.requireToolUse, params.prompt);
|
||||
const requiredToolUse = resolveRequiredToolUse(
|
||||
params.requireToolUse,
|
||||
params.prompt,
|
||||
params.trigger,
|
||||
);
|
||||
|
||||
return enqueueSession(() =>
|
||||
enqueueGlobal(async () => {
|
||||
@ -1399,43 +1417,6 @@ export async function runEmbeddedPiAgent(
|
||||
compactionCount: autoCompactionCount > 0 ? autoCompactionCount : undefined,
|
||||
};
|
||||
|
||||
if (
|
||||
requiredToolUse &&
|
||||
!didSatisfyRequiredToolUse(
|
||||
{
|
||||
toolMetas: attempt.toolMetas,
|
||||
clientToolCall: attempt.clientToolCall,
|
||||
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
|
||||
successfulCronAdds: attempt.successfulCronAdds,
|
||||
},
|
||||
requiredToolUse,
|
||||
)
|
||||
) {
|
||||
return {
|
||||
payloads: [
|
||||
{
|
||||
text: formatRequiredToolUseErrorMessage(requiredToolUse),
|
||||
isError: true,
|
||||
},
|
||||
],
|
||||
meta: {
|
||||
durationMs: Date.now() - started,
|
||||
agentMeta,
|
||||
aborted,
|
||||
systemPromptReport: attempt.systemPromptReport,
|
||||
error: {
|
||||
kind: "required_tool_use",
|
||||
message: formatRequiredToolUseErrorMessage(requiredToolUse),
|
||||
},
|
||||
},
|
||||
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
|
||||
messagingToolSentTexts: attempt.messagingToolSentTexts,
|
||||
messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls,
|
||||
messagingToolSentTargets: attempt.messagingToolSentTargets,
|
||||
successfulCronAdds: attempt.successfulCronAdds,
|
||||
};
|
||||
}
|
||||
|
||||
const payloads = buildEmbeddedRunPayloads({
|
||||
assistantTexts: attempt.assistantTexts,
|
||||
toolMetas: attempt.toolMetas,
|
||||
@ -1480,6 +1461,45 @@ export async function runEmbeddedPiAgent(
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
requiredToolUse &&
|
||||
!timedOut &&
|
||||
!didSatisfyRequiredToolUse(
|
||||
{
|
||||
toolMetas: attempt.toolMetas,
|
||||
clientToolCall: attempt.clientToolCall,
|
||||
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
|
||||
successfulCronAdds: attempt.successfulCronAdds,
|
||||
},
|
||||
requiredToolUse,
|
||||
)
|
||||
) {
|
||||
const message = formatRequiredToolUseErrorMessage(requiredToolUse);
|
||||
return {
|
||||
payloads: [
|
||||
{
|
||||
text: message,
|
||||
isError: true,
|
||||
},
|
||||
],
|
||||
meta: {
|
||||
durationMs: Date.now() - started,
|
||||
agentMeta,
|
||||
aborted,
|
||||
systemPromptReport: attempt.systemPromptReport,
|
||||
error: {
|
||||
kind: "required_tool_use",
|
||||
message,
|
||||
},
|
||||
},
|
||||
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
|
||||
messagingToolSentTexts: attempt.messagingToolSentTexts,
|
||||
messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls,
|
||||
messagingToolSentTargets: attempt.messagingToolSentTargets,
|
||||
successfulCronAdds: attempt.successfulCronAdds,
|
||||
};
|
||||
}
|
||||
|
||||
log.debug(
|
||||
`embedded run done: runId=${params.runId} sessionId=${params.sessionId} durationMs=${Date.now() - started} aborted=${aborted}`,
|
||||
);
|
||||
|
||||
@ -122,8 +122,8 @@ describe("runEmbeddedPiAgent usage reporting", () => {
|
||||
sessionKey: "test-key",
|
||||
sessionFile: "/tmp/session.json",
|
||||
workspaceDir: "/tmp/workspace",
|
||||
prompt:
|
||||
"Read HEARTBEAT.md if it exists (workspace context). Follow it strictly. If nothing needs attention, reply HEARTBEAT_OK.",
|
||||
prompt: "Use the heartbeat instructions for this workspace. Reply HEARTBEAT_OK if idle.",
|
||||
trigger: "heartbeat",
|
||||
timeoutMs: 30000,
|
||||
runId: "run-heartbeat-tool-requirement",
|
||||
});
|
||||
@ -134,6 +134,77 @@ describe("runEmbeddedPiAgent usage reporting", () => {
|
||||
expect(result.payloads?.[0]?.text).toContain("read tool");
|
||||
});
|
||||
|
||||
it("lets timeout errors win over required-tool-use errors", async () => {
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
|
||||
aborted: true,
|
||||
promptError: null,
|
||||
timedOut: true,
|
||||
timedOutDuringCompaction: false,
|
||||
sessionIdUsed: "test-session",
|
||||
assistantTexts: [],
|
||||
toolMetas: [],
|
||||
didSendViaMessagingTool: false,
|
||||
lastAssistant: null,
|
||||
attemptUsage: { input: 10, output: 0, total: 10 },
|
||||
messagingToolSentTexts: [],
|
||||
messagingToolSentMediaUrls: [],
|
||||
messagingToolSentTargets: [],
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
sessionId: "test-session",
|
||||
sessionKey: "test-key",
|
||||
sessionFile: "/tmp/session.json",
|
||||
workspaceDir: "/tmp/workspace",
|
||||
prompt: "Call the cron tool with action=list.",
|
||||
timeoutMs: 30000,
|
||||
runId: "run-timeout-before-tool",
|
||||
});
|
||||
|
||||
expect(result.meta.error?.kind).toBeUndefined();
|
||||
expect(result.payloads?.[0]?.isError).toBe(true);
|
||||
expect(result.payloads?.[0]?.text).toContain(
|
||||
"Request timed out before a response was generated",
|
||||
);
|
||||
});
|
||||
|
||||
it("allows callers to disable prompt-inferred tool requirements explicitly", async () => {
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
|
||||
aborted: false,
|
||||
promptError: null,
|
||||
timedOut: false,
|
||||
timedOutDuringCompaction: false,
|
||||
sessionIdUsed: "test-session",
|
||||
assistantTexts: ["Handled without tools."],
|
||||
toolMetas: [],
|
||||
didSendViaMessagingTool: false,
|
||||
lastAssistant: {
|
||||
usage: { input: 10, output: 5, total: 15 },
|
||||
stopReason: "end_turn",
|
||||
},
|
||||
attemptUsage: { input: 10, output: 5, total: 15 },
|
||||
messagingToolSentTexts: [],
|
||||
messagingToolSentMediaUrls: [],
|
||||
messagingToolSentTargets: [],
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
sessionId: "test-session",
|
||||
sessionKey: "test-key",
|
||||
sessionFile: "/tmp/session.json",
|
||||
workspaceDir: "/tmp/workspace",
|
||||
prompt: "Call the cron tool with action=list.",
|
||||
requireToolUse: false,
|
||||
timeoutMs: 30000,
|
||||
runId: "run-disable-inferred-tool-requirement",
|
||||
});
|
||||
|
||||
expect(result.meta.error).toBeUndefined();
|
||||
expect(result.payloads?.[0]?.isError).not.toBe(true);
|
||||
});
|
||||
|
||||
it("allows success when a required tool was called", async () => {
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
|
||||
aborted: false,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user