Merge f4d4702b84158dc226c29b87b660e578b593c29c into 5bb5d7dab4b29e68b15bb7665d0736f46499a35c

This commit is contained in:
Youyou972 2026-03-21 05:28:41 +00:00 committed by GitHub
commit 7704cc1a26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 346 additions and 0 deletions

View File

@ -82,6 +82,10 @@ import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js";
import { describeUnknownError } from "./utils.js";
type ApiKeyInfo = ResolvedProviderAuth;
type RequiredToolUse = {
toolNames?: string[];
reason?: string;
};
type RuntimeAuthState = {
sourceApiKey: string;
@ -166,6 +170,117 @@ const hasUsageValues = (
(value) => typeof value === "number" && Number.isFinite(value) && value > 0,
);
function normalizeRequiredToolNames(toolNames: string[] | undefined): string[] | undefined {
const normalized = (toolNames ?? [])
.map((name) =>
String(name ?? "")
.trim()
.toLowerCase(),
)
.filter(Boolean);
return normalized.length > 0 ? Array.from(new Set(normalized)) : undefined;
}
function inferRequiredToolUseFromPrompt(
prompt: string,
trigger?: RunEmbeddedPiAgentParams["trigger"],
): RequiredToolUse | undefined {
if (trigger === "heartbeat") {
return {
toolNames: ["read"],
reason: "Heartbeat runs must check HEARTBEAT.md before replying.",
};
}
const trimmedPrompt = prompt.trim();
if (trimmedPrompt.startsWith("Read HEARTBEAT.md if it exists")) {
return {
toolNames: ["read"],
reason: "Heartbeat prompts must check HEARTBEAT.md before replying.",
};
}
const explicitToolNames = Array.from(
trimmedPrompt.matchAll(/\b(?:call|use)\s+(?:the\s+)?([a-z0-9_]+)\s+tool\b/gi),
)
.map((match) => match[1]?.trim().toLowerCase())
.filter((name): name is string => Boolean(name));
if (explicitToolNames.length === 0) {
return undefined;
}
const toolNames = Array.from(new Set(explicitToolNames));
return {
toolNames,
reason:
toolNames.length === 1
? `Prompt explicitly requires the ${toolNames[0]} tool.`
: `Prompt explicitly requires one of these tools: ${toolNames.join(", ")}.`,
};
}
function resolveRequiredToolUse(
requiredToolUse: RunEmbeddedPiAgentParams["requireToolUse"],
prompt: string,
trigger?: RunEmbeddedPiAgentParams["trigger"],
): RequiredToolUse | undefined {
if (requiredToolUse === false) {
return undefined;
}
if (requiredToolUse) {
if (requiredToolUse === true) {
return { reason: "This run requires at least one tool call before success." };
}
return {
toolNames: normalizeRequiredToolNames(requiredToolUse.toolNames),
reason: requiredToolUse.reason?.trim() || "This run requires tool use before success.",
};
}
return inferRequiredToolUseFromPrompt(prompt, trigger);
}
function didSatisfyRequiredToolUse(
attempt: {
toolMetas: Array<{ toolName: string; meta?: string }>;
clientToolCall?: { name: string; params: Record<string, unknown> };
didSendViaMessagingTool: boolean;
successfulCronAdds?: number;
},
requirement: RequiredToolUse,
): boolean {
const usedToolNames = new Set(
attempt.toolMetas.map((entry) => entry.toolName.trim().toLowerCase()).filter(Boolean),
);
if (attempt.clientToolCall?.name) {
usedToolNames.add(attempt.clientToolCall.name.trim().toLowerCase());
}
if (attempt.didSendViaMessagingTool) {
usedToolNames.add("message");
}
if ((attempt.successfulCronAdds ?? 0) > 0) {
usedToolNames.add("cron");
}
const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames);
if (!requiredToolNames) {
return usedToolNames.size > 0;
}
return requiredToolNames.some((toolName) => usedToolNames.has(toolName));
}
function formatRequiredToolUseErrorMessage(requirement: RequiredToolUse): string {
const requiredToolNames = normalizeRequiredToolNames(requirement.toolNames);
const expected =
requiredToolNames && requiredToolNames.length > 0
? requiredToolNames.length === 1
? `the ${requiredToolNames[0]} tool`
: `one of these tools: ${requiredToolNames.join(", ")}`
: "at least one tool";
const reason = requirement.reason?.trim();
return reason
? `${reason} The model replied without calling ${expected}.`
: `Tool use was required for this run, but the model replied without calling ${expected}.`;
}
const mergeUsageIntoAccumulator = (
target: UsageAccumulator,
usage: ReturnType<typeof normalizeUsage>,
@ -282,6 +397,11 @@ export async function runEmbeddedPiAgent(
: "plain"
: "markdown");
const isProbeSession = params.sessionId?.startsWith("probe-") ?? false;
const requiredToolUse = resolveRequiredToolUse(
params.requireToolUse,
params.prompt,
params.trigger,
);
return enqueueSession(() =>
enqueueGlobal(async () => {
@ -1660,6 +1780,45 @@ export async function runEmbeddedPiAgent(
};
}
if (
requiredToolUse &&
!timedOut &&
!didSatisfyRequiredToolUse(
{
toolMetas: attempt.toolMetas,
clientToolCall: attempt.clientToolCall,
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
successfulCronAdds: attempt.successfulCronAdds,
},
requiredToolUse,
)
) {
const message = formatRequiredToolUseErrorMessage(requiredToolUse);
return {
payloads: [
{
text: message,
isError: true,
},
],
meta: {
durationMs: Date.now() - started,
agentMeta,
aborted,
systemPromptReport: attempt.systemPromptReport,
error: {
kind: "required_tool_use",
message,
},
},
didSendViaMessagingTool: attempt.didSendViaMessagingTool,
messagingToolSentTexts: attempt.messagingToolSentTexts,
messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls,
messagingToolSentTargets: attempt.messagingToolSentTargets,
successfulCronAdds: attempt.successfulCronAdds,
};
}
log.debug(
`embedded run done: runId=${params.runId} sessionId=${params.sessionId} durationMs=${Date.now() - started} aborted=${aborted}`,
);

View File

@ -76,6 +76,11 @@ export type RunEmbeddedPiAgentParams = {
clientTools?: ClientToolDefinition[];
/** Disable built-in tools for this run (LLM-only mode). */
disableTools?: boolean;
/**
* Require at least one tool call before treating the run as successful.
* When `toolNames` is provided, at least one of those tools must be called.
*/
requireToolUse?: boolean | { toolNames?: string[]; reason?: string };
provider?: string;
model?: string;
authProfileId?: string;

View File

@ -41,6 +41,7 @@ export type EmbeddedPiRunMeta = {
| "compaction_failure"
| "role_ordering"
| "image_size"
| "required_tool_use"
| "retry_limit";
message: string;
};

View File

@ -188,4 +188,185 @@ describe("runEmbeddedPiAgent usage reporting", () => {
// If the bug exists, it will likely be 350
expect(usage?.total).toBe(200);
});
it("returns an error when a run explicitly requires tool use and no tool was called", async () => {
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
aborted: false,
promptError: null,
timedOut: false,
timedOutDuringCompaction: false,
sessionIdUsed: "test-session",
assistantTexts: ["I checked it."],
toolMetas: [],
didSendViaMessagingTool: false,
lastAssistant: {
usage: { input: 10, output: 5, total: 15 },
stopReason: "end_turn",
},
attemptUsage: { input: 10, output: 5, total: 15 },
messagingToolSentTexts: [],
messagingToolSentMediaUrls: [],
messagingToolSentTargets: [],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const result = await runEmbeddedPiAgent({
sessionId: "test-session",
sessionKey: "test-key",
sessionFile: "/tmp/session.json",
workspaceDir: "/tmp/workspace",
prompt: "Do the thing.",
requireToolUse: true,
timeoutMs: 30000,
runId: "run-require-tool",
});
expect(result.payloads?.[0]?.isError).toBe(true);
expect(result.meta.error?.kind).toBe("required_tool_use");
expect(result.payloads?.[0]?.text).toContain("requires at least one tool call");
});
it("infers required read tool use for heartbeat prompts", async () => {
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
aborted: false,
promptError: null,
timedOut: false,
timedOutDuringCompaction: false,
sessionIdUsed: "test-session",
assistantTexts: ["HEARTBEAT_OK"],
toolMetas: [],
didSendViaMessagingTool: false,
lastAssistant: {
usage: { input: 10, output: 5, total: 15 },
stopReason: "end_turn",
},
attemptUsage: { input: 10, output: 5, total: 15 },
messagingToolSentTexts: [],
messagingToolSentMediaUrls: [],
messagingToolSentTargets: [],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const result = await runEmbeddedPiAgent({
sessionId: "test-session",
sessionKey: "test-key",
sessionFile: "/tmp/session.json",
workspaceDir: "/tmp/workspace",
prompt: "Use the heartbeat instructions for this workspace. Reply HEARTBEAT_OK if idle.",
trigger: "heartbeat",
timeoutMs: 30000,
runId: "run-heartbeat-tool-requirement",
});
expect(result.payloads?.[0]?.isError).toBe(true);
expect(result.meta.error?.kind).toBe("required_tool_use");
expect(result.payloads?.[0]?.text).toContain("HEARTBEAT.md");
expect(result.payloads?.[0]?.text).toContain("read tool");
});
it("lets timeout errors win over required-tool-use errors", async () => {
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
aborted: true,
promptError: null,
timedOut: true,
timedOutDuringCompaction: false,
sessionIdUsed: "test-session",
assistantTexts: [],
toolMetas: [],
didSendViaMessagingTool: false,
lastAssistant: null,
attemptUsage: { input: 10, output: 0, total: 10 },
messagingToolSentTexts: [],
messagingToolSentMediaUrls: [],
messagingToolSentTargets: [],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const result = await runEmbeddedPiAgent({
sessionId: "test-session",
sessionKey: "test-key",
sessionFile: "/tmp/session.json",
workspaceDir: "/tmp/workspace",
prompt: "Call the cron tool with action=list.",
timeoutMs: 30000,
runId: "run-timeout-before-tool",
});
expect(result.meta.error?.kind).toBeUndefined();
expect(result.payloads?.[0]?.isError).toBe(true);
expect(result.payloads?.[0]?.text).toContain(
"Request timed out before a response was generated",
);
});
it("allows callers to disable prompt-inferred tool requirements explicitly", async () => {
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
aborted: false,
promptError: null,
timedOut: false,
timedOutDuringCompaction: false,
sessionIdUsed: "test-session",
assistantTexts: ["Handled without tools."],
toolMetas: [],
didSendViaMessagingTool: false,
lastAssistant: {
usage: { input: 10, output: 5, total: 15 },
stopReason: "end_turn",
},
attemptUsage: { input: 10, output: 5, total: 15 },
messagingToolSentTexts: [],
messagingToolSentMediaUrls: [],
messagingToolSentTargets: [],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const result = await runEmbeddedPiAgent({
sessionId: "test-session",
sessionKey: "test-key",
sessionFile: "/tmp/session.json",
workspaceDir: "/tmp/workspace",
prompt: "Call the cron tool with action=list.",
requireToolUse: false,
timeoutMs: 30000,
runId: "run-disable-inferred-tool-requirement",
});
expect(result.meta.error).toBeUndefined();
expect(result.payloads?.[0]?.isError).not.toBe(true);
});
it("allows success when a required tool was called", async () => {
mockedRunEmbeddedAttempt.mockResolvedValueOnce({
aborted: false,
promptError: null,
timedOut: false,
timedOutDuringCompaction: false,
sessionIdUsed: "test-session",
assistantTexts: ["Done."],
toolMetas: [{ toolName: "cron", meta: "action=list" }],
didSendViaMessagingTool: false,
lastAssistant: {
usage: { input: 10, output: 5, total: 15 },
stopReason: "end_turn",
},
attemptUsage: { input: 10, output: 5, total: 15 },
messagingToolSentTexts: [],
messagingToolSentMediaUrls: [],
messagingToolSentTargets: [],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const result = await runEmbeddedPiAgent({
sessionId: "test-session",
sessionKey: "test-key",
sessionFile: "/tmp/session.json",
workspaceDir: "/tmp/workspace",
prompt: "Call the cron tool with action=list.",
timeoutMs: 30000,
runId: "run-explicit-cron-tool",
});
expect(result.meta.error).toBeUndefined();
expect(result.payloads?.[0]?.isError).not.toBe(true);
});
});