Alishahed fe8f06db4e fix(agent): forward configured tool choice into stream params
Ensure agent model params can force tool_choice through the embedded runner so OpenAI-compatible providers actually receive the required tool policy.

Made-with: Cursor
2026-03-20 22:44:36 -07:00

325 lines
11 KiB
TypeScript

import type { StreamFn } from "@mariozechner/pi-agent-core";
import type { SimpleStreamOptions } from "@mariozechner/pi-ai";
import { streamSimple } from "@mariozechner/pi-ai";
import type { ThinkLevel } from "../../auto-reply/thinking.js";
import type { OpenClawConfig } from "../../config/config.js";
import {
prepareProviderExtraParams,
wrapProviderStreamFn,
} from "../../plugins/provider-runtime.js";
import {
createAnthropicBetaHeadersWrapper,
createAnthropicFastModeWrapper,
createAnthropicToolPayloadCompatibilityWrapper,
resolveAnthropicFastMode,
resolveAnthropicBetas,
resolveCacheRetention,
} from "./anthropic-stream-wrappers.js";
import { log } from "./logger.js";
import {
createMoonshotThinkingWrapper,
resolveMoonshotThinkingType,
createSiliconFlowThinkingWrapper,
shouldApplyMoonshotPayloadCompat,
shouldApplySiliconFlowThinkingOffCompat,
} from "./moonshot-stream-wrappers.js";
import {
createOpenAIFastModeWrapper,
createOpenAIResponsesContextManagementWrapper,
createOpenAIServiceTierWrapper,
resolveOpenAIFastMode,
resolveOpenAIServiceTier,
} from "./openai-stream-wrappers.js";
/**
* Resolve provider-specific extra params from model config.
* Used to pass through stream params like temperature/maxTokens.
*
* @internal Exported for testing only
*/
export function resolveExtraParams(params: {
cfg: OpenClawConfig | undefined;
provider: string;
modelId: string;
agentId?: string;
}): Record<string, unknown> | undefined {
const modelKey = `${params.provider}/${params.modelId}`;
const modelConfig = params.cfg?.agents?.defaults?.models?.[modelKey];
const globalParams = modelConfig?.params ? { ...modelConfig.params } : undefined;
const agentParams =
params.agentId && params.cfg?.agents?.list
? params.cfg.agents.list.find((agent) => agent.id === params.agentId)?.params
: undefined;
if (!globalParams && !agentParams) {
return undefined;
}
const merged = Object.assign({}, globalParams, agentParams);
const resolvedParallelToolCalls = resolveAliasedParamValue(
[globalParams, agentParams],
"parallel_tool_calls",
"parallelToolCalls",
);
if (resolvedParallelToolCalls !== undefined) {
merged.parallel_tool_calls = resolvedParallelToolCalls;
delete merged.parallelToolCalls;
}
return merged;
}
type CacheRetentionStreamOptions = Partial<SimpleStreamOptions> & {
cacheRetention?: "none" | "short" | "long";
openaiWsWarmup?: boolean;
};
function createStreamFnWithExtraParams(
baseStreamFn: StreamFn | undefined,
extraParams: Record<string, unknown> | undefined,
provider: string,
): StreamFn | undefined {
if (!extraParams || Object.keys(extraParams).length === 0) {
return undefined;
}
const streamParams: CacheRetentionStreamOptions = {};
if (typeof extraParams.temperature === "number") {
streamParams.temperature = extraParams.temperature;
}
if (typeof extraParams.maxTokens === "number") {
streamParams.maxTokens = extraParams.maxTokens;
}
const resolvedToolChoice = resolveAliasedParamValue(
[extraParams],
"tool_choice",
"toolChoice",
);
if (resolvedToolChoice !== undefined) {
streamParams.toolChoice = resolvedToolChoice as CacheRetentionStreamOptions["toolChoice"];
}
const transport = extraParams.transport;
if (transport === "sse" || transport === "websocket" || transport === "auto") {
streamParams.transport = transport;
} else if (transport != null) {
const transportSummary = typeof transport === "string" ? transport : typeof transport;
log.warn(`ignoring invalid transport param: ${transportSummary}`);
}
if (typeof extraParams.openaiWsWarmup === "boolean") {
streamParams.openaiWsWarmup = extraParams.openaiWsWarmup;
}
const cacheRetention = resolveCacheRetention(extraParams, provider);
if (cacheRetention) {
streamParams.cacheRetention = cacheRetention;
}
if (Object.keys(streamParams).length === 0) {
return undefined;
}
log.debug(`creating streamFn wrapper with params: ${JSON.stringify(streamParams)}`);
const underlying = baseStreamFn ?? streamSimple;
const wrappedStreamFn: StreamFn = (model, context, options) => {
return underlying(model, context, {
...streamParams,
...options,
});
};
return wrappedStreamFn;
}
function resolveAliasedParamValue(
sources: Array<Record<string, unknown> | undefined>,
snakeCaseKey: string,
camelCaseKey: string,
): unknown {
let resolved: unknown = undefined;
let seen = false;
for (const source of sources) {
if (!source) {
continue;
}
const hasSnakeCaseKey = Object.hasOwn(source, snakeCaseKey);
const hasCamelCaseKey = Object.hasOwn(source, camelCaseKey);
if (!hasSnakeCaseKey && !hasCamelCaseKey) {
continue;
}
resolved = hasSnakeCaseKey ? source[snakeCaseKey] : source[camelCaseKey];
seen = true;
}
return seen ? resolved : undefined;
}
function createParallelToolCallsWrapper(
baseStreamFn: StreamFn | undefined,
enabled: boolean,
): StreamFn {
const underlying = baseStreamFn ?? streamSimple;
return (model, context, options) => {
if (model.api !== "openai-completions" && model.api !== "openai-responses") {
return underlying(model, context, options);
}
log.debug(
`applying parallel_tool_calls=${enabled} for ${model.provider ?? "unknown"}/${model.id ?? "unknown"} api=${model.api}`,
);
const originalOnPayload = options?.onPayload;
return underlying(model, context, {
...options,
onPayload: (payload) => {
if (payload && typeof payload === "object") {
(payload as Record<string, unknown>).parallel_tool_calls = enabled;
}
return originalOnPayload?.(payload, model);
},
});
};
}
/**
* Apply extra params (like temperature) to an agent's streamFn.
* Also applies verified provider-specific request wrappers, such as OpenRouter attribution.
*
* @internal Exported for testing
*/
export function applyExtraParamsToAgent(
agent: { streamFn?: StreamFn },
cfg: OpenClawConfig | undefined,
provider: string,
modelId: string,
extraParamsOverride?: Record<string, unknown>,
thinkingLevel?: ThinkLevel,
agentId?: string,
workspaceDir?: string,
): void {
const resolvedExtraParams = resolveExtraParams({
cfg,
provider,
modelId,
agentId,
});
const override =
extraParamsOverride && Object.keys(extraParamsOverride).length > 0
? Object.fromEntries(
Object.entries(extraParamsOverride).filter(([, value]) => value !== undefined),
)
: undefined;
const merged = Object.assign({}, resolvedExtraParams, override);
const effectiveExtraParams =
prepareProviderExtraParams({
provider,
config: cfg,
context: {
config: cfg,
provider,
modelId,
extraParams: merged,
thinkingLevel,
},
}) ?? merged;
const wrappedStreamFn = createStreamFnWithExtraParams(
agent.streamFn,
effectiveExtraParams,
provider,
);
if (wrappedStreamFn) {
log.debug(`applying extraParams to agent streamFn for ${provider}/${modelId}`);
agent.streamFn = wrappedStreamFn;
}
const anthropicBetas = resolveAnthropicBetas(effectiveExtraParams, provider, modelId);
if (anthropicBetas?.length) {
log.debug(
`applying Anthropic beta header for ${provider}/${modelId}: ${anthropicBetas.join(",")}`,
);
agent.streamFn = createAnthropicBetaHeadersWrapper(agent.streamFn, anthropicBetas);
}
if (shouldApplySiliconFlowThinkingOffCompat({ provider, modelId, thinkingLevel })) {
log.debug(
`normalizing thinking=off to thinking=null for SiliconFlow compatibility (${provider}/${modelId})`,
);
agent.streamFn = createSiliconFlowThinkingWrapper(agent.streamFn);
}
agent.streamFn = createAnthropicToolPayloadCompatibilityWrapper(agent.streamFn, {
config: cfg,
workspaceDir,
});
const providerStreamBase = agent.streamFn;
const pluginWrappedStreamFn = wrapProviderStreamFn({
provider,
config: cfg,
context: {
config: cfg,
provider,
modelId,
extraParams: effectiveExtraParams,
thinkingLevel,
streamFn: providerStreamBase,
},
});
agent.streamFn = pluginWrappedStreamFn ?? providerStreamBase;
const providerWrapperHandled =
pluginWrappedStreamFn !== undefined && pluginWrappedStreamFn !== providerStreamBase;
if (!providerWrapperHandled && shouldApplyMoonshotPayloadCompat({ provider, modelId })) {
// Preserve the legacy Moonshot compatibility path when no plugin wrapper
// actually handled the stream function. This covers tests/disabled plugins
// and Ollama Cloud Kimi models until they gain a dedicated runtime hook.
const thinkingType = resolveMoonshotThinkingType({
configuredThinking: effectiveExtraParams?.thinking,
thinkingLevel,
});
agent.streamFn = createMoonshotThinkingWrapper(agent.streamFn, thinkingType);
}
const anthropicFastMode = resolveAnthropicFastMode(effectiveExtraParams);
if (anthropicFastMode !== undefined) {
log.debug(`applying Anthropic fast mode=${anthropicFastMode} for ${provider}/${modelId}`);
agent.streamFn = createAnthropicFastModeWrapper(agent.streamFn, anthropicFastMode);
}
const openAIFastMode = resolveOpenAIFastMode(effectiveExtraParams);
if (openAIFastMode) {
log.debug(`applying OpenAI fast mode for ${provider}/${modelId}`);
agent.streamFn = createOpenAIFastModeWrapper(agent.streamFn);
}
const openAIServiceTier = resolveOpenAIServiceTier(effectiveExtraParams);
if (openAIServiceTier) {
log.debug(`applying OpenAI service_tier=${openAIServiceTier} for ${provider}/${modelId}`);
agent.streamFn = createOpenAIServiceTierWrapper(agent.streamFn, openAIServiceTier);
}
// Work around upstream pi-ai hardcoding `store: false` for Responses API.
// Force `store=true` for direct OpenAI Responses models and auto-enable
// server-side compaction for compatible OpenAI Responses payloads.
agent.streamFn = createOpenAIResponsesContextManagementWrapper(
agent.streamFn,
effectiveExtraParams,
);
const rawParallelToolCalls = resolveAliasedParamValue(
[resolvedExtraParams, override],
"parallel_tool_calls",
"parallelToolCalls",
);
if (rawParallelToolCalls !== undefined) {
if (typeof rawParallelToolCalls === "boolean") {
agent.streamFn = createParallelToolCallsWrapper(agent.streamFn, rawParallelToolCalls);
} else if (rawParallelToolCalls === null) {
log.debug("parallel_tool_calls suppressed by null override, skipping injection");
} else {
const summary =
typeof rawParallelToolCalls === "string"
? rawParallelToolCalls
: typeof rawParallelToolCalls;
log.warn(`ignoring invalid parallel_tool_calls param: ${summary}`);
}
}
}