import type { StreamFn } from "@mariozechner/pi-agent-core"; import { streamSimple } from "@mariozechner/pi-ai"; function stripUnsupportedStrictFlag(tool: unknown): unknown { if (!tool || typeof tool !== "object") { return tool; } const toolObj = tool as Record; const fn = toolObj.function; if (!fn || typeof fn !== "object") { return tool; } const fnObj = fn as Record; if (typeof fnObj.strict !== "boolean") { return tool; } const nextFunction = { ...fnObj }; delete nextFunction.strict; return { ...toolObj, function: nextFunction }; } export function createXaiToolPayloadCompatibilityWrapper( baseStreamFn: StreamFn | undefined, ): StreamFn { const underlying = baseStreamFn ?? streamSimple; return (model, context, options) => { const originalOnPayload = options?.onPayload; return underlying(model, context, { ...options, onPayload: (payload) => { if (payload && typeof payload === "object") { const payloadObj = payload as Record; if (Array.isArray(payloadObj.tools)) { payloadObj.tools = payloadObj.tools.map((tool) => stripUnsupportedStrictFlag(tool)); } } return originalOnPayload?.(payload, model); }, }); }; } function decodeHtmlEntities(value: string): string { return value .replaceAll(""", '"') .replaceAll(""", '"') .replaceAll("'", "'") .replaceAll("'", "'") .replaceAll("<", "<") .replaceAll("<", "<") .replaceAll(">", ">") .replaceAll(">", ">") .replaceAll("&", "&") .replaceAll("&", "&"); } function decodeHtmlEntitiesInObject(value: unknown): unknown { if (typeof value === "string") { return decodeHtmlEntities(value); } if (!value || typeof value !== "object") { return value; } if (Array.isArray(value)) { return value.map((entry) => decodeHtmlEntitiesInObject(entry)); } const record = value as Record; for (const [key, entry] of Object.entries(record)) { record[key] = decodeHtmlEntitiesInObject(entry); } return record; } function decodeXaiToolCallArgumentsInMessage(message: unknown): void { if (!message || typeof message !== "object") { return; } const content = (message as { content?: unknown }).content; if (!Array.isArray(content)) { return; } for (const block of content) { if (!block || typeof block !== "object") { continue; } const typedBlock = block as { type?: unknown; arguments?: unknown }; if (typedBlock.type !== "toolCall" || !typedBlock.arguments) { continue; } if (typeof typedBlock.arguments === "object") { typedBlock.arguments = decodeHtmlEntitiesInObject(typedBlock.arguments); } } } function wrapStreamDecodeXaiToolCallArguments( stream: ReturnType, ): ReturnType { const originalResult = stream.result.bind(stream); stream.result = async () => { const message = await originalResult(); decodeXaiToolCallArgumentsInMessage(message); return message; }; const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream); (stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] = function () { const iterator = originalAsyncIterator(); return { async next() { const result = await iterator.next(); if (!result.done && result.value && typeof result.value === "object") { const event = result.value as { partial?: unknown; message?: unknown }; decodeXaiToolCallArgumentsInMessage(event.partial); decodeXaiToolCallArgumentsInMessage(event.message); } return result; }, async return(value?: unknown) { return iterator.return?.(value) ?? { done: true as const, value: undefined }; }, async throw(error?: unknown) { return iterator.throw?.(error) ?? { done: true as const, value: undefined }; }, }; }; return stream; } export function createXaiToolCallArgumentDecodingWrapper(baseStreamFn: StreamFn): StreamFn { return (model, context, options) => { const maybeStream = baseStreamFn(model, context, options); if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) { return Promise.resolve(maybeStream).then((stream) => wrapStreamDecodeXaiToolCallArguments(stream), ); } return wrapStreamDecodeXaiToolCallArguments(maybeStream); }; }