GigaChat: preserve streamed tool call sequences
This commit is contained in:
parent
52e371fa33
commit
daf9cad38f
@ -163,6 +163,71 @@ describe("createGigachatStreamFn tool calling", () => {
|
||||
]);
|
||||
});
|
||||
|
||||
it("preserves every streamed function call from a single assistant turn", async () => {
|
||||
request.mockResolvedValueOnce({
|
||||
status: 200,
|
||||
data: createSseStream([
|
||||
'data: {"choices":[{"delta":{"function_call":{"name":"llm_"}}}]}',
|
||||
'data: {"choices":[{"delta":{"function_call":{"name":"task"}}}]}',
|
||||
'data: {"choices":[{"delta":{"function_call":{"arguments":"{\\"prompt\\":\\"first\\"}"}},"finish_reason":"function_call"}]}',
|
||||
'data: {"choices":[{"delta":{"function_call":{"name":"__gpt2giga_user_search_web"}}}]}',
|
||||
'data: {"choices":[{"delta":{"function_call":{"arguments":"{\\"query\\":\\"second\\"}"}}}]}',
|
||||
"data: [DONE]",
|
||||
]),
|
||||
});
|
||||
|
||||
const streamFn = createGigachatStreamFn({
|
||||
baseUrl: "https://gigachat.devices.sberbank.ru/api/v1",
|
||||
authMode: "oauth",
|
||||
});
|
||||
|
||||
const stream = await streamFn(
|
||||
{ api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never,
|
||||
{
|
||||
messages: [],
|
||||
tools: [
|
||||
{
|
||||
name: "llm-task",
|
||||
description: "Run a task",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
prompt: { type: "string" },
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "web_search",
|
||||
description: "Search the web",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
query: { type: "string" },
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
} as never,
|
||||
{ apiKey: "token" } as never,
|
||||
);
|
||||
|
||||
const event = await stream.result();
|
||||
|
||||
expect(event.stopReason).toBe("toolUse");
|
||||
expect(event.content).toEqual([
|
||||
expect.objectContaining({
|
||||
type: "toolCall",
|
||||
name: "llm-task",
|
||||
arguments: { prompt: "first" },
|
||||
}),
|
||||
expect.objectContaining({
|
||||
type: "toolCall",
|
||||
name: "web_search",
|
||||
arguments: { query: "second" },
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
it("parses a final SSE frame even when the stream closes without a trailing newline", async () => {
|
||||
request.mockResolvedValueOnce({
|
||||
status: 200,
|
||||
|
||||
@ -774,6 +774,7 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
|
||||
|
||||
let accumulatedContent = "";
|
||||
const accumulatedToolCalls: ToolCall[] = [];
|
||||
const resolvedFunctionCalls: Array<{ name: string; arguments: string }> = [];
|
||||
let functionCallBuffer: { name: string; arguments: string } | null = null;
|
||||
let promptTokens = 0;
|
||||
let completionTokens = 0;
|
||||
@ -782,6 +783,14 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
|
||||
// UTF-8 code points intact when TCP chunks split multibyte characters.
|
||||
let sseBuffer = "";
|
||||
const sseDecoder = new TextDecoder();
|
||||
const flushFunctionCallBuffer = () => {
|
||||
if (!functionCallBuffer?.name) {
|
||||
functionCallBuffer = null;
|
||||
return;
|
||||
}
|
||||
resolvedFunctionCalls.push(functionCallBuffer);
|
||||
functionCallBuffer = null;
|
||||
};
|
||||
const consumeSseLine = (line: string) => {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || trimmed.startsWith(":")) {
|
||||
@ -801,6 +810,11 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
|
||||
accumulatedContent += choice.delta.content;
|
||||
}
|
||||
if (choice?.delta?.function_call) {
|
||||
if (choice.delta.function_call.name && functionCallBuffer?.arguments) {
|
||||
// A new tool name after arguments indicates the previous streamed
|
||||
// function call is complete and a new call has begun.
|
||||
flushFunctionCallBuffer();
|
||||
}
|
||||
if (!functionCallBuffer) {
|
||||
functionCallBuffer = { name: "", arguments: "" };
|
||||
}
|
||||
@ -813,6 +827,9 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
|
||||
typeof args === "string" ? args : JSON.stringify(args);
|
||||
}
|
||||
}
|
||||
if (choice?.finish_reason === "function_call") {
|
||||
flushFunctionCallBuffer();
|
||||
}
|
||||
if (parsed.usage) {
|
||||
promptTokens = parsed.usage.prompt_tokens ?? 0;
|
||||
completionTokens = parsed.usage.completion_tokens ?? 0;
|
||||
@ -839,38 +856,37 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
|
||||
consumeSseLine(sseBuffer);
|
||||
}
|
||||
|
||||
const resolvedFunctionCall = functionCallBuffer as unknown as {
|
||||
name: string;
|
||||
arguments: string;
|
||||
} | null;
|
||||
if (resolvedFunctionCall && resolvedFunctionCall.name) {
|
||||
flushFunctionCallBuffer();
|
||||
if (resolvedFunctionCalls.length > 0) {
|
||||
accumulatedContent = stripLeakedFunctionCallPrelude(accumulatedContent);
|
||||
let parsedArgs: Record<string, unknown> = {};
|
||||
try {
|
||||
if (resolvedFunctionCall.arguments) {
|
||||
parsedArgs = JSON.parse(resolvedFunctionCall.arguments) as Record<string, unknown>;
|
||||
for (const resolvedFunctionCall of resolvedFunctionCalls) {
|
||||
let parsedArgs: Record<string, unknown> = {};
|
||||
try {
|
||||
if (resolvedFunctionCall.arguments) {
|
||||
parsedArgs = JSON.parse(resolvedFunctionCall.arguments) as Record<string, unknown>;
|
||||
}
|
||||
} catch (parseErr) {
|
||||
const errMsg = parseErr instanceof Error ? parseErr.message : String(parseErr);
|
||||
log.error(
|
||||
`GigaChat: failed to parse function arguments for "${resolvedFunctionCall.name}": ${errMsg}. ` +
|
||||
`Raw arguments: ${resolvedFunctionCall.arguments.slice(0, 500)}`,
|
||||
);
|
||||
// Return error instead of continuing with empty args
|
||||
throw new Error(
|
||||
`Failed to parse function call arguments for "${resolvedFunctionCall.name}": ${errMsg}`,
|
||||
{ cause: parseErr },
|
||||
);
|
||||
}
|
||||
} catch (parseErr) {
|
||||
const errMsg = parseErr instanceof Error ? parseErr.message : String(parseErr);
|
||||
log.error(
|
||||
`GigaChat: failed to parse function arguments for "${resolvedFunctionCall.name}": ${errMsg}. ` +
|
||||
`Raw arguments: ${resolvedFunctionCall.arguments.slice(0, 500)}`,
|
||||
);
|
||||
// Return error instead of continuing with empty args
|
||||
throw new Error(
|
||||
`Failed to parse function call arguments for "${resolvedFunctionCall.name}": ${errMsg}`,
|
||||
{ cause: parseErr },
|
||||
);
|
||||
const clientName =
|
||||
gigaToToolName.get(resolvedFunctionCall.name) ??
|
||||
mapToolNameFromGigaChat(resolvedFunctionCall.name);
|
||||
accumulatedToolCalls.push({
|
||||
type: "toolCall",
|
||||
id: randomUUID(),
|
||||
name: clientName,
|
||||
arguments: parsedArgs,
|
||||
});
|
||||
}
|
||||
const clientName =
|
||||
gigaToToolName.get(resolvedFunctionCall.name) ??
|
||||
mapToolNameFromGigaChat(resolvedFunctionCall.name);
|
||||
accumulatedToolCalls.push({
|
||||
type: "toolCall",
|
||||
id: randomUUID(),
|
||||
name: clientName,
|
||||
arguments: parsedArgs,
|
||||
});
|
||||
}
|
||||
|
||||
const content: AssistantMessage["content"] = [];
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user