diff --git a/src/agents/gigachat-stream.leaked-prelude.test.ts b/src/agents/gigachat-stream.leaked-prelude.test.ts index e41c852f4d2..8f18f37672c 100644 --- a/src/agents/gigachat-stream.leaked-prelude.test.ts +++ b/src/agents/gigachat-stream.leaked-prelude.test.ts @@ -54,7 +54,7 @@ describe("GigaChat leaked function-call prelude cleanup", () => { const event = await stream.result(); - expect(updateToken).toHaveBeenCalled(); + expect(updateToken).not.toHaveBeenCalled(); expect(event.stopReason).toBe("toolUse"); expect(event.content).toEqual([ expect.objectContaining({ diff --git a/src/agents/gigachat-stream.tool-calls.test.ts b/src/agents/gigachat-stream.tool-calls.test.ts index 873cc02e4e7..a724a471c4c 100644 --- a/src/agents/gigachat-stream.tool-calls.test.ts +++ b/src/agents/gigachat-stream.tool-calls.test.ts @@ -1,16 +1,23 @@ import { Readable } from "node:stream"; import { beforeEach, describe, expect, it, vi } from "vitest"; -const updateToken = vi.fn(async () => {}); +let initialAccessToken: { access_token: string } | undefined = { access_token: "test-token" }; +let refreshedAccessToken = "refreshed-token"; +const updateToken = vi.fn(async function (this: { _accessToken?: { access_token: string } }) { + this._accessToken = { access_token: refreshedAccessToken }; +}); const request = vi.fn(); const clientConfigs: Array> = []; vi.mock("gigachat", () => { class MockGigaChat { _client = { request }; - _accessToken = { access_token: "test-token" }; + _accessToken = initialAccessToken ? { ...initialAccessToken } : undefined; updateToken = updateToken; + resetToken() { + this._accessToken = undefined; + } constructor(config: Record) { clientConfigs.push(config); @@ -34,6 +41,8 @@ describe("createGigachatStreamFn tool calling", () => { beforeEach(() => { vi.clearAllMocks(); clientConfigs.length = 0; + initialAccessToken = { access_token: "test-token" }; + refreshedAccessToken = "refreshed-token"; vi.unstubAllEnvs(); }); @@ -74,7 +83,7 @@ describe("createGigachatStreamFn tool calling", () => { const event = await stream.result(); - expect(updateToken).toHaveBeenCalled(); + expect(updateToken).not.toHaveBeenCalled(); expect(request).toHaveBeenCalledWith( expect.objectContaining({ data: expect.objectContaining({ @@ -178,6 +187,117 @@ describe("createGigachatStreamFn tool calling", () => { expect(event.content).toEqual([{ type: "text", text: "final tail" }]); }); + it("reuses a cached token across turns for the same GigaChat credentials", async () => { + initialAccessToken = undefined; + refreshedAccessToken = "cached-after-refresh"; + request + .mockResolvedValueOnce({ + status: 200, + data: createSseStream([ + 'data: {"choices":[{"delta":{"content":"first"}}]}', + "data: [DONE]", + ]), + }) + .mockResolvedValueOnce({ + status: 200, + data: createSseStream([ + 'data: {"choices":[{"delta":{"content":"second"}}]}', + "data: [DONE]", + ]), + }); + + const streamFn = createGigachatStreamFn({ + baseUrl: "https://gigachat.devices.sberbank.ru/api/v1", + authMode: "oauth", + }); + + const firstStream = await streamFn( + { api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never, + { messages: [], tools: [] } as never, + { apiKey: "token" } as never, + ); + await expect(firstStream.result()).resolves.toMatchObject({ + content: [{ type: "text", text: "first" }], + }); + + const secondStream = await streamFn( + { api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never, + { messages: [], tools: [] } as never, + { apiKey: "token" } as never, + ); + await expect(secondStream.result()).resolves.toMatchObject({ + content: [{ type: "text", text: "second" }], + }); + + expect(updateToken).toHaveBeenCalledTimes(1); + expect(clientConfigs).toHaveLength(1); + expect(request).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: "Bearer cached-after-refresh", + }), + }), + ); + expect(request).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: "Bearer cached-after-refresh", + }), + }), + ); + }); + + it("refreshes once and retries the chat request after a 401", async () => { + refreshedAccessToken = "fresh-token"; + request + .mockResolvedValueOnce({ + status: 401, + data: "expired token", + }) + .mockResolvedValueOnce({ + status: 200, + data: createSseStream([ + 'data: {"choices":[{"delta":{"content":"recovered"}}]}', + "data: [DONE]", + ]), + }); + + const streamFn = createGigachatStreamFn({ + baseUrl: "https://gigachat.devices.sberbank.ru/api/v1", + authMode: "oauth", + }); + + const stream = await streamFn( + { api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never, + { messages: [], tools: [] } as never, + { apiKey: "token" } as never, + ); + + await expect(stream.result()).resolves.toMatchObject({ + content: [{ type: "text", text: "recovered" }], + }); + + expect(updateToken).toHaveBeenCalledTimes(1); + expect(request).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: "Bearer test-token", + }), + }), + ); + expect(request).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: "Bearer fresh-token", + }), + }), + ); + }); + it("prefers the resolved GigaChat baseUrl over the env override", async () => { vi.stubEnv("GIGACHAT_BASE_URL", "https://env-host.example/api/v1"); request.mockResolvedValueOnce({ diff --git a/src/agents/gigachat-stream.ts b/src/agents/gigachat-stream.ts index 3eb640a1802..7ab42a27381 100644 --- a/src/agents/gigachat-stream.ts +++ b/src/agents/gigachat-stream.ts @@ -440,6 +440,80 @@ async function withRetry( throw lastError; } +type GigachatAccessToken = { + access_token?: string; +}; + +type GigachatTransportResponse = { + status: number; + data: AsyncIterable | string | { pipe?: unknown }; +}; + +type GigachatRuntimeClient = GigaChat & { + _client: { + request: (config: { + method: "POST"; + url: string; + data: Chat & { stream: true }; + responseType: "stream"; + headers: Record; + signal?: AbortSignal; + }) => Promise; + }; + _accessToken?: GigachatAccessToken; + updateToken: () => Promise; + resetToken?: () => void; +}; + +function getGigachatAccessToken(client: GigachatRuntimeClient): string | undefined { + return client._accessToken?.access_token?.trim() || undefined; +} + +async function ensureGigachatAccessToken(client: GigachatRuntimeClient): Promise { + const accessToken = getGigachatAccessToken(client); + if (accessToken) { + return accessToken; + } + + await withRetry(() => client.updateToken(), "token refresh"); + + const refreshedToken = getGigachatAccessToken(client); + if (!refreshedToken) { + throw new Error("GigaChat: failed to obtain access token after retries"); + } + return refreshedToken; +} + +function resetGigachatAccessToken(client: GigachatRuntimeClient): void { + if (typeof client.resetToken === "function") { + client.resetToken(); + return; + } + delete client._accessToken; +} + +async function readGigachatErrorText( + responseData: GigachatTransportResponse["data"], + status: number, +): Promise { + try { + if (typeof responseData === "string") { + return responseData; + } + if (responseData && typeof responseData === "object" && "pipe" in responseData) { + const chunks: Buffer[] = []; + for await (const chunk of responseData as AsyncIterable) { + chunks.push(typeof chunk === "string" ? Buffer.from(chunk) : Buffer.from(chunk)); + } + return Buffer.concat(chunks).toString(); + } + } catch { + return `status ${status}`; + } + + return "unknown error"; +} + // ── Stream function ───────────────────────────────────────────────────────── export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn { @@ -460,6 +534,45 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn { ); } + let cachedClient: GigachatRuntimeClient | null = null; + let cachedApiKey: string | null = null; + + const buildClientConfig = (apiKey: string): GigaChatClientConfig => { + const clientConfig: GigaChatClientConfig = { + baseUrl: effectiveBaseUrl, + // Explicitly set to undefined to prevent the library from adding profanity_check + profanityCheck: undefined, + timeout: 120, + }; + + if (insecureTls) { + clientConfig.httpsAgent = new https.Agent({ rejectUnauthorized: false }); + } + + if (opts.authMode === "basic") { + const { user, password } = parseGigachatBasicCredentials(apiKey); + clientConfig.user = user; + clientConfig.password = password; + log.debug(`GigaChat auth: basic mode`); + } else { + clientConfig.credentials = apiKey; + clientConfig.scope = opts.scope ?? "GIGACHAT_API_PERS"; + log.debug(`GigaChat auth: oauth scope=${clientConfig.scope}`); + } + + return clientConfig; + }; + + const getClientForApiKey = (apiKey: string): GigachatRuntimeClient => { + if (cachedClient && cachedApiKey === apiKey) { + return cachedClient; + } + + cachedClient = new GigaChat(buildClientConfig(apiKey)) as GigachatRuntimeClient; + cachedApiKey = apiKey; + return cachedClient; + }; + return (model, context, options) => { const stream = createAssistantMessageEventStream(); @@ -596,32 +709,7 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn { // Build auth config const apiKey = options?.apiKey ?? ""; - - const clientConfig: GigaChatClientConfig = { - baseUrl: effectiveBaseUrl, - // Explicitly set to undefined to prevent the library from adding profanity_check - profanityCheck: undefined, - timeout: 120, - }; - - // Configure TLS - if (insecureTls) { - clientConfig.httpsAgent = new https.Agent({ rejectUnauthorized: false }); - } - - // Set credentials based on auth mode - if (opts.authMode === "basic") { - const { user, password } = parseGigachatBasicCredentials(apiKey); - clientConfig.user = user; - clientConfig.password = password; - log.debug(`GigaChat auth: basic mode`); - } else { - clientConfig.credentials = apiKey; - clientConfig.scope = opts.scope ?? "GIGACHAT_API_PERS"; - log.debug(`GigaChat auth: oauth scope=${clientConfig.scope}`); - } - - const client = new GigaChat(clientConfig); + const client = getClientForApiKey(apiKey); // Build chat request - explicitly omit profanity_check const chatRequest: Chat = { @@ -644,56 +732,41 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn { log.debug(`GigaChat request: ${messages.length} messages, ${functions.length} functions`); - // Use the library for auth, but our own SSE parsing (library's parseChunk is buggy) - // Wrap token refresh in retry logic for transient failures - await withRetry(() => client.updateToken(), "token refresh"); - - const axiosClient = client._client; - // Access the token (protected property, so we cast) - const accessToken = (client as unknown as { _accessToken?: { access_token: string } }) - ._accessToken?.access_token; - - if (!accessToken) { - throw new Error("GigaChat: failed to obtain access token after retries"); - } - const requestId = randomUUID(); log.debug(`GigaChat request ${requestId}: starting`); - const headers: Record = { - ...resolveGigachatModelHeaders(model), - ...options?.headers, - Authorization: `Bearer ${accessToken}`, - Accept: "text/event-stream", - "Cache-Control": "no-store", - "X-Request-ID": requestId, + const axiosClient = client._client; + const sendChatCompletionsRequest = async (): Promise => { + const accessToken = await ensureGigachatAccessToken(client); + return axiosClient.request({ + method: "POST", + url: "/chat/completions", + data: { ...chatRequest, stream: true }, + responseType: "stream", + headers: { + ...resolveGigachatModelHeaders(model), + ...options?.headers, + Authorization: `Bearer ${accessToken}`, + Accept: "text/event-stream", + "Cache-Control": "no-store", + "X-Request-ID": requestId, + }, + signal: options?.signal, + }); }; - const response = await axiosClient.request({ - method: "POST", - url: "/chat/completions", - data: { ...chatRequest, stream: true }, - responseType: "stream", - headers, - signal: options?.signal, - }); + let response = await sendChatCompletionsRequest(); + if (response.status === 401) { + log.warn( + `GigaChat request ${requestId}: received 401 from chat endpoint, refreshing token and retrying`, + ); + resetGigachatAccessToken(client); + await ensureGigachatAccessToken(client); + response = await sendChatCompletionsRequest(); + } if (response.status !== 200) { - let errorText = "unknown error"; - try { - if (typeof response.data === "string") { - errorText = response.data; - } else if (response.data && typeof response.data.pipe === "function") { - // It's a stream, try to read it - const chunks: Buffer[] = []; - for await (const chunk of response.data) { - chunks.push(chunk); - } - errorText = Buffer.concat(chunks).toString(); - } - } catch { - errorText = `status ${response.status}`; - } + const errorText = await readGigachatErrorText(response.data, response.status); throw new Error( `GigaChat API error ${response.status} (${effectiveBaseUrl}/chat/completions): ${errorText}`, );