Merge f4d4702b84158dc226c29b87b660e578b593c29c into 5bb5d7dab4b29e68b15bb7665d0736f46499a35c
This commit is contained in:
commit
7704cc1a26
@ -82,6 +82,10 @@ import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js";
|
||||
import { describeUnknownError } from "./utils.js";
|
||||
|
||||
type ApiKeyInfo = ResolvedProviderAuth;
|
||||
type RequiredToolUse = {
|
||||
toolNames?: string[];
|
||||
reason?: string;
|
||||
};
|
||||
|
||||
type RuntimeAuthState = {
|
||||
sourceApiKey: string;
|
||||
@ -166,6 +170,117 @@ const hasUsageValues = (
|
||||
(value) => typeof value === "number" && Number.isFinite(value) && value > 0,
|
||||
);
|
||||
|
||||
function normalizeRequiredToolNames(toolNames: string[] | undefined): string[] | undefined {
|
||||
const normalized = (toolNames ?? [])
|
||||
.map((name) =>
|
||||
String(name ?? "")
|
||||
.trim()
|
||||
.toLowerCase(),
|
||||
)
|
||||
.filter(Boolean);
|
||||
return normalized.length > 0 ? Array.from(new Set(normalized)) : undefined;
|
||||
}
|
||||
|
||||
function inferRequiredToolUseFromPrompt(
|
||||
prompt: string,
|
||||
trigger?: RunEmbeddedPiAgentParams["trigger"],
|
||||
): RequiredToolUse | undefined {
|
||||
if (trigger === "heartbeat") {
|
||||
return {
|
||||
toolNames: ["read"],
|
||||
reason: "Heartbeat runs must check HEARTBEAT.md before replying.",
|
||||
};
|
||||
}
|
||||
|
||||
const trimmedPrompt = prompt.trim();
|
||||
if (trimmedPrompt.startsWith("Read HEARTBEAT.md if it exists")) {
|
||||
return {
|
||||
toolNames: ["read"],
|
||||
reason: "Heartbeat prompts must check HEARTBEAT.md before replying.",
|
||||
};
|
||||
}
|
||||
|
||||
const explicitToolNames = Array.from(
|
||||
trimmedPrompt.matchAll(/\b(?:call|use)\s+(?:the\s+)?([a-z0-9_]+)\s+tool\b/gi),
|
||||
)
|
||||
.map((match) => match[1]?.trim().toLowerCase())
|
||||
.filter((name): name is string => Boolean(name));
|
||||
if (explicitToolNames.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
const toolNames = Array.from(new Set(explicitToolNames));
|
||||
return {
|
||||
toolNames,
|
||||
reason:
|
||||
toolNames.length === 1
|
||||
? `Prompt explicitly requires the ${toolNames[0]} tool.`
|
||||
: `Prompt explicitly requires one of these tools: ${toolNames.join(", ")}.`,
|
||||
};
|
||||
}
|
||||
|
||||
function resolveRequiredToolUse(
|
||||
requiredToolUse: RunEmbeddedPiAgentParams["requireToolUse"],
|
||||
prompt: string,
|
||||
trigger?: RunEmbeddedPiAgentParams["trigger"],
|
||||
): RequiredToolUse | undefined {
|
||||
if (requiredToolUse === false) {
|
||||
return undefined;
|
||||
}
|
||||
if (requiredToolUse) {
|
||||
if (requiredToolUse === true) {
|
||||
return { reason: "This run requires at least one tool call before success." };
|
||||
}
|
||||
return {
|
||||
toolNames: normalizeRequiredToolNames(requiredToolUse.toolNames),
|
||||
reason: requiredToolUse.reason?.trim() || "This run requires tool use before success.",
|
||||
};
|
||||
}
|
||||
return inferRequiredToolUseFromPrompt(prompt, trigger);
|
||||
}
|
||||
|
||||
function didSatisfyRequiredToolUse(
|
||||
attempt: {
|
||||
toolMetas: Array<{ toolName: string; meta?: string }>;
|
||||
clientToolCall?: { name: string; params: Record<string, unknown> };
|
||||
didSendViaMessagingTool: boolean;
|
||||
successfulCronAdds?: number;
|
||||
},
|
||||
requirement: RequiredToolUse,
|
||||
): boolean {
|
||||
const usedToolNames = new Set(
|
||||
attempt.toolMetas.map((entry) => entry.toolName.trim().toLowerCase()).filter(Boolean),
|
||||
);
|
||||
if (attempt.clientToolCall?.name) {
|
||||
usedToolNames.add(attempt.clientToolCall.name.trim().toLowerCase());
|
||||
}
|
||||
if (attempt.didSendViaMessagingTool) {
|
||||
usedToolNames.add("message");
|
||||
}
|
||||
if ((attempt.successfulCronAdds ?? 0) > 0) {
|
||||
usedToolNames.add("cron");
|
||||
}
|
||||
|
||||
const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames);
|
||||
if (!requiredToolNames) {
|
||||
return usedToolNames.size > 0;
|
||||
}
|
||||
return requiredToolNames.some((toolName) => usedToolNames.has(toolName));
|
||||
}
|
||||
|
||||
function formatRequiredToolUseErrorMessage(requirement: RequiredToolUse): string {
|
||||
const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames);
|
||||
const expected =
|
||||
requiredToolNames && requiredToolNames.length > 0
|
||||
? requiredToolNames.length === 1
|
||||
? `the ${requiredToolNames[0]} tool`
|
||||
: `one of these tools: ${requiredToolNames.join(", ")}`
|
||||
: "at least one tool";
|
||||
const reason = requirement.reason?.trim();
|
||||
return reason
|
||||
? `${reason} The model replied without calling ${expected}.`
|
||||
: `Tool use was required for this run, but the model replied without calling ${expected}.`;
|
||||
}
|
||||
|
||||
const mergeUsageIntoAccumulator = (
|
||||
target: UsageAccumulator,
|
||||
usage: ReturnType<typeof normalizeUsage>,
|
||||
@ -282,6 +397,11 @@ export async function runEmbeddedPiAgent(
|
||||
: "plain"
|
||||
: "markdown");
|
||||
const isProbeSession = params.sessionId?.startsWith("probe-") ?? false;
|
||||
const requiredToolUse = resolveRequiredToolUse(
|
||||
params.requireToolUse,
|
||||
params.prompt,
|
||||
params.trigger,
|
||||
);
|
||||
|
||||
return enqueueSession(() =>
|
||||
enqueueGlobal(async () => {
|
||||
@ -1660,6 +1780,45 @@ export async function runEmbeddedPiAgent(
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
requiredToolUse &&
|
||||
!timedOut &&
|
||||
!didSatisfyRequiredToolUse(
|
||||
{
|
||||
toolMetas: attempt.toolMetas,
|
||||
clientToolCall: attempt.clientToolCall,
|
||||
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
|
||||
successfulCronAdds: attempt.successfulCronAdds,
|
||||
},
|
||||
requiredToolUse,
|
||||
)
|
||||
) {
|
||||
const message = formatRequiredToolUseErrorMessage(requiredToolUse);
|
||||
return {
|
||||
payloads: [
|
||||
{
|
||||
text: message,
|
||||
isError: true,
|
||||
},
|
||||
],
|
||||
meta: {
|
||||
durationMs: Date.now() - started,
|
||||
agentMeta,
|
||||
aborted,
|
||||
systemPromptReport: attempt.systemPromptReport,
|
||||
error: {
|
||||
kind: "required_tool_use",
|
||||
message,
|
||||
},
|
||||
},
|
||||
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
|
||||
messagingToolSentTexts: attempt.messagingToolSentTexts,
|
||||
messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls,
|
||||
messagingToolSentTargets: attempt.messagingToolSentTargets,
|
||||
successfulCronAdds: attempt.successfulCronAdds,
|
||||
};
|
||||
}
|
||||
|
||||
log.debug(
|
||||
`embedded run done: runId=${params.runId} sessionId=${params.sessionId} durationMs=${Date.now() - started} aborted=${aborted}`,
|
||||
);
|
||||
|
||||
@ -76,6 +76,11 @@ export type RunEmbeddedPiAgentParams = {
|
||||
clientTools?: ClientToolDefinition[];
|
||||
/** Disable built-in tools for this run (LLM-only mode). */
|
||||
disableTools?: boolean;
|
||||
/**
|
||||
* Require at least one tool call before treating the run as successful.
|
||||
* When `toolNames` is provided, at least one of those tools must be called.
|
||||
*/
|
||||
requireToolUse?: boolean | { toolNames?: string[]; reason?: string };
|
||||
provider?: string;
|
||||
model?: string;
|
||||
authProfileId?: string;
|
||||
|
||||
@ -41,6 +41,7 @@ export type EmbeddedPiRunMeta = {
|
||||
| "compaction_failure"
|
||||
| "role_ordering"
|
||||
| "image_size"
|
||||
| "required_tool_use"
|
||||
| "retry_limit";
|
||||
message: string;
|
||||
};
|
||||
|
||||
@ -188,4 +188,185 @@ describe("runEmbeddedPiAgent usage reporting", () => {
|
||||
// If the bug exists, it will likely be 350
|
||||
expect(usage?.total).toBe(200);
|
||||
});
|
||||
|
||||
it("returns an error when a run explicitly requires tool use and no tool was called", async () => {
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
|
||||
aborted: false,
|
||||
promptError: null,
|
||||
timedOut: false,
|
||||
timedOutDuringCompaction: false,
|
||||
sessionIdUsed: "test-session",
|
||||
assistantTexts: ["I checked it."],
|
||||
toolMetas: [],
|
||||
didSendViaMessagingTool: false,
|
||||
lastAssistant: {
|
||||
usage: { input: 10, output: 5, total: 15 },
|
||||
stopReason: "end_turn",
|
||||
},
|
||||
attemptUsage: { input: 10, output: 5, total: 15 },
|
||||
messagingToolSentTexts: [],
|
||||
messagingToolSentMediaUrls: [],
|
||||
messagingToolSentTargets: [],
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
sessionId: "test-session",
|
||||
sessionKey: "test-key",
|
||||
sessionFile: "/tmp/session.json",
|
||||
workspaceDir: "/tmp/workspace",
|
||||
prompt: "Do the thing.",
|
||||
requireToolUse: true,
|
||||
timeoutMs: 30000,
|
||||
runId: "run-require-tool",
|
||||
});
|
||||
|
||||
expect(result.payloads?.[0]?.isError).toBe(true);
|
||||
expect(result.meta.error?.kind).toBe("required_tool_use");
|
||||
expect(result.payloads?.[0]?.text).toContain("requires at least one tool call");
|
||||
});
|
||||
|
||||
it("infers required read tool use for heartbeat prompts", async () => {
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
|
||||
aborted: false,
|
||||
promptError: null,
|
||||
timedOut: false,
|
||||
timedOutDuringCompaction: false,
|
||||
sessionIdUsed: "test-session",
|
||||
assistantTexts: ["HEARTBEAT_OK"],
|
||||
toolMetas: [],
|
||||
didSendViaMessagingTool: false,
|
||||
lastAssistant: {
|
||||
usage: { input: 10, output: 5, total: 15 },
|
||||
stopReason: "end_turn",
|
||||
},
|
||||
attemptUsage: { input: 10, output: 5, total: 15 },
|
||||
messagingToolSentTexts: [],
|
||||
messagingToolSentMediaUrls: [],
|
||||
messagingToolSentTargets: [],
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
sessionId: "test-session",
|
||||
sessionKey: "test-key",
|
||||
sessionFile: "/tmp/session.json",
|
||||
workspaceDir: "/tmp/workspace",
|
||||
prompt: "Use the heartbeat instructions for this workspace. Reply HEARTBEAT_OK if idle.",
|
||||
trigger: "heartbeat",
|
||||
timeoutMs: 30000,
|
||||
runId: "run-heartbeat-tool-requirement",
|
||||
});
|
||||
|
||||
expect(result.payloads?.[0]?.isError).toBe(true);
|
||||
expect(result.meta.error?.kind).toBe("required_tool_use");
|
||||
expect(result.payloads?.[0]?.text).toContain("HEARTBEAT.md");
|
||||
expect(result.payloads?.[0]?.text).toContain("read tool");
|
||||
});
|
||||
|
||||
it("lets timeout errors win over required-tool-use errors", async () => {
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
|
||||
aborted: true,
|
||||
promptError: null,
|
||||
timedOut: true,
|
||||
timedOutDuringCompaction: false,
|
||||
sessionIdUsed: "test-session",
|
||||
assistantTexts: [],
|
||||
toolMetas: [],
|
||||
didSendViaMessagingTool: false,
|
||||
lastAssistant: null,
|
||||
attemptUsage: { input: 10, output: 0, total: 10 },
|
||||
messagingToolSentTexts: [],
|
||||
messagingToolSentMediaUrls: [],
|
||||
messagingToolSentTargets: [],
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
sessionId: "test-session",
|
||||
sessionKey: "test-key",
|
||||
sessionFile: "/tmp/session.json",
|
||||
workspaceDir: "/tmp/workspace",
|
||||
prompt: "Call the cron tool with action=list.",
|
||||
timeoutMs: 30000,
|
||||
runId: "run-timeout-before-tool",
|
||||
});
|
||||
|
||||
expect(result.meta.error?.kind).toBeUndefined();
|
||||
expect(result.payloads?.[0]?.isError).toBe(true);
|
||||
expect(result.payloads?.[0]?.text).toContain(
|
||||
"Request timed out before a response was generated",
|
||||
);
|
||||
});
|
||||
|
||||
it("allows callers to disable prompt-inferred tool requirements explicitly", async () => {
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
|
||||
aborted: false,
|
||||
promptError: null,
|
||||
timedOut: false,
|
||||
timedOutDuringCompaction: false,
|
||||
sessionIdUsed: "test-session",
|
||||
assistantTexts: ["Handled without tools."],
|
||||
toolMetas: [],
|
||||
didSendViaMessagingTool: false,
|
||||
lastAssistant: {
|
||||
usage: { input: 10, output: 5, total: 15 },
|
||||
stopReason: "end_turn",
|
||||
},
|
||||
attemptUsage: { input: 10, output: 5, total: 15 },
|
||||
messagingToolSentTexts: [],
|
||||
messagingToolSentMediaUrls: [],
|
||||
messagingToolSentTargets: [],
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
sessionId: "test-session",
|
||||
sessionKey: "test-key",
|
||||
sessionFile: "/tmp/session.json",
|
||||
workspaceDir: "/tmp/workspace",
|
||||
prompt: "Call the cron tool with action=list.",
|
||||
requireToolUse: false,
|
||||
timeoutMs: 30000,
|
||||
runId: "run-disable-inferred-tool-requirement",
|
||||
});
|
||||
|
||||
expect(result.meta.error).toBeUndefined();
|
||||
expect(result.payloads?.[0]?.isError).not.toBe(true);
|
||||
});
|
||||
|
||||
it("allows success when a required tool was called", async () => {
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
|
||||
aborted: false,
|
||||
promptError: null,
|
||||
timedOut: false,
|
||||
timedOutDuringCompaction: false,
|
||||
sessionIdUsed: "test-session",
|
||||
assistantTexts: ["Done."],
|
||||
toolMetas: [{ toolName: "cron", meta: "action=list" }],
|
||||
didSendViaMessagingTool: false,
|
||||
lastAssistant: {
|
||||
usage: { input: 10, output: 5, total: 15 },
|
||||
stopReason: "end_turn",
|
||||
},
|
||||
attemptUsage: { input: 10, output: 5, total: 15 },
|
||||
messagingToolSentTexts: [],
|
||||
messagingToolSentMediaUrls: [],
|
||||
messagingToolSentTargets: [],
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
sessionId: "test-session",
|
||||
sessionKey: "test-key",
|
||||
sessionFile: "/tmp/session.json",
|
||||
workspaceDir: "/tmp/workspace",
|
||||
prompt: "Call the cron tool with action=list.",
|
||||
timeoutMs: 30000,
|
||||
runId: "run-explicit-cron-tool",
|
||||
});
|
||||
|
||||
expect(result.meta.error).toBeUndefined();
|
||||
expect(result.payloads?.[0]?.isError).not.toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user