GigaChat: reuse cached auth tokens

This commit is contained in:
Alexander Davydov 2026-03-18 17:38:54 +03:00
parent 6f6bf7eff1
commit db44a27801
3 changed files with 266 additions and 73 deletions

View File

@ -54,7 +54,7 @@ describe("GigaChat leaked function-call prelude cleanup", () => {
const event = await stream.result();
expect(updateToken).toHaveBeenCalled();
expect(updateToken).not.toHaveBeenCalled();
expect(event.stopReason).toBe("toolUse");
expect(event.content).toEqual([
expect.objectContaining({

View File

@ -1,16 +1,23 @@
import { Readable } from "node:stream";
import { beforeEach, describe, expect, it, vi } from "vitest";
const updateToken = vi.fn(async () => {});
let initialAccessToken: { access_token: string } | undefined = { access_token: "test-token" };
let refreshedAccessToken = "refreshed-token";
const updateToken = vi.fn(async function (this: { _accessToken?: { access_token: string } }) {
this._accessToken = { access_token: refreshedAccessToken };
});
const request = vi.fn();
const clientConfigs: Array<Record<string, unknown>> = [];
vi.mock("gigachat", () => {
class MockGigaChat {
_client = { request };
_accessToken = { access_token: "test-token" };
_accessToken = initialAccessToken ? { ...initialAccessToken } : undefined;
updateToken = updateToken;
resetToken() {
this._accessToken = undefined;
}
constructor(config: Record<string, unknown>) {
clientConfigs.push(config);
@ -34,6 +41,8 @@ describe("createGigachatStreamFn tool calling", () => {
beforeEach(() => {
vi.clearAllMocks();
clientConfigs.length = 0;
initialAccessToken = { access_token: "test-token" };
refreshedAccessToken = "refreshed-token";
vi.unstubAllEnvs();
});
@ -74,7 +83,7 @@ describe("createGigachatStreamFn tool calling", () => {
const event = await stream.result();
expect(updateToken).toHaveBeenCalled();
expect(updateToken).not.toHaveBeenCalled();
expect(request).toHaveBeenCalledWith(
expect.objectContaining({
data: expect.objectContaining({
@ -178,6 +187,117 @@ describe("createGigachatStreamFn tool calling", () => {
expect(event.content).toEqual([{ type: "text", text: "final tail" }]);
});
it("reuses a cached token across turns for the same GigaChat credentials", async () => {
initialAccessToken = undefined;
refreshedAccessToken = "cached-after-refresh";
request
.mockResolvedValueOnce({
status: 200,
data: createSseStream([
'data: {"choices":[{"delta":{"content":"first"}}]}',
"data: [DONE]",
]),
})
.mockResolvedValueOnce({
status: 200,
data: createSseStream([
'data: {"choices":[{"delta":{"content":"second"}}]}',
"data: [DONE]",
]),
});
const streamFn = createGigachatStreamFn({
baseUrl: "https://gigachat.devices.sberbank.ru/api/v1",
authMode: "oauth",
});
const firstStream = await streamFn(
{ api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never,
{ messages: [], tools: [] } as never,
{ apiKey: "token" } as never,
);
await expect(firstStream.result()).resolves.toMatchObject({
content: [{ type: "text", text: "first" }],
});
const secondStream = await streamFn(
{ api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never,
{ messages: [], tools: [] } as never,
{ apiKey: "token" } as never,
);
await expect(secondStream.result()).resolves.toMatchObject({
content: [{ type: "text", text: "second" }],
});
expect(updateToken).toHaveBeenCalledTimes(1);
expect(clientConfigs).toHaveLength(1);
expect(request).toHaveBeenNthCalledWith(
1,
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer cached-after-refresh",
}),
}),
);
expect(request).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer cached-after-refresh",
}),
}),
);
});
it("refreshes once and retries the chat request after a 401", async () => {
refreshedAccessToken = "fresh-token";
request
.mockResolvedValueOnce({
status: 401,
data: "expired token",
})
.mockResolvedValueOnce({
status: 200,
data: createSseStream([
'data: {"choices":[{"delta":{"content":"recovered"}}]}',
"data: [DONE]",
]),
});
const streamFn = createGigachatStreamFn({
baseUrl: "https://gigachat.devices.sberbank.ru/api/v1",
authMode: "oauth",
});
const stream = await streamFn(
{ api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never,
{ messages: [], tools: [] } as never,
{ apiKey: "token" } as never,
);
await expect(stream.result()).resolves.toMatchObject({
content: [{ type: "text", text: "recovered" }],
});
expect(updateToken).toHaveBeenCalledTimes(1);
expect(request).toHaveBeenNthCalledWith(
1,
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer test-token",
}),
}),
);
expect(request).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer fresh-token",
}),
}),
);
});
it("prefers the resolved GigaChat baseUrl over the env override", async () => {
vi.stubEnv("GIGACHAT_BASE_URL", "https://env-host.example/api/v1");
request.mockResolvedValueOnce({

View File

@ -440,6 +440,80 @@ async function withRetry<T>(
throw lastError;
}
type GigachatAccessToken = {
access_token?: string;
};
type GigachatTransportResponse = {
status: number;
data: AsyncIterable<string | Uint8Array> | string | { pipe?: unknown };
};
type GigachatRuntimeClient = GigaChat & {
_client: {
request: (config: {
method: "POST";
url: string;
data: Chat & { stream: true };
responseType: "stream";
headers: Record<string, string>;
signal?: AbortSignal;
}) => Promise<GigachatTransportResponse>;
};
_accessToken?: GigachatAccessToken;
updateToken: () => Promise<void>;
resetToken?: () => void;
};
function getGigachatAccessToken(client: GigachatRuntimeClient): string | undefined {
return client._accessToken?.access_token?.trim() || undefined;
}
async function ensureGigachatAccessToken(client: GigachatRuntimeClient): Promise<string> {
const accessToken = getGigachatAccessToken(client);
if (accessToken) {
return accessToken;
}
await withRetry(() => client.updateToken(), "token refresh");
const refreshedToken = getGigachatAccessToken(client);
if (!refreshedToken) {
throw new Error("GigaChat: failed to obtain access token after retries");
}
return refreshedToken;
}
function resetGigachatAccessToken(client: GigachatRuntimeClient): void {
if (typeof client.resetToken === "function") {
client.resetToken();
return;
}
delete client._accessToken;
}
async function readGigachatErrorText(
responseData: GigachatTransportResponse["data"],
status: number,
): Promise<string> {
try {
if (typeof responseData === "string") {
return responseData;
}
if (responseData && typeof responseData === "object" && "pipe" in responseData) {
const chunks: Buffer[] = [];
for await (const chunk of responseData as AsyncIterable<string | Uint8Array>) {
chunks.push(typeof chunk === "string" ? Buffer.from(chunk) : Buffer.from(chunk));
}
return Buffer.concat(chunks).toString();
}
} catch {
return `status ${status}`;
}
return "unknown error";
}
// ── Stream function ─────────────────────────────────────────────────────────
export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
@ -460,6 +534,45 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
);
}
let cachedClient: GigachatRuntimeClient | null = null;
let cachedApiKey: string | null = null;
const buildClientConfig = (apiKey: string): GigaChatClientConfig => {
const clientConfig: GigaChatClientConfig = {
baseUrl: effectiveBaseUrl,
// Explicitly set to undefined to prevent the library from adding profanity_check
profanityCheck: undefined,
timeout: 120,
};
if (insecureTls) {
clientConfig.httpsAgent = new https.Agent({ rejectUnauthorized: false });
}
if (opts.authMode === "basic") {
const { user, password } = parseGigachatBasicCredentials(apiKey);
clientConfig.user = user;
clientConfig.password = password;
log.debug(`GigaChat auth: basic mode`);
} else {
clientConfig.credentials = apiKey;
clientConfig.scope = opts.scope ?? "GIGACHAT_API_PERS";
log.debug(`GigaChat auth: oauth scope=${clientConfig.scope}`);
}
return clientConfig;
};
const getClientForApiKey = (apiKey: string): GigachatRuntimeClient => {
if (cachedClient && cachedApiKey === apiKey) {
return cachedClient;
}
cachedClient = new GigaChat(buildClientConfig(apiKey)) as GigachatRuntimeClient;
cachedApiKey = apiKey;
return cachedClient;
};
return (model, context, options) => {
const stream = createAssistantMessageEventStream();
@ -596,32 +709,7 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
// Build auth config
const apiKey = options?.apiKey ?? "";
const clientConfig: GigaChatClientConfig = {
baseUrl: effectiveBaseUrl,
// Explicitly set to undefined to prevent the library from adding profanity_check
profanityCheck: undefined,
timeout: 120,
};
// Configure TLS
if (insecureTls) {
clientConfig.httpsAgent = new https.Agent({ rejectUnauthorized: false });
}
// Set credentials based on auth mode
if (opts.authMode === "basic") {
const { user, password } = parseGigachatBasicCredentials(apiKey);
clientConfig.user = user;
clientConfig.password = password;
log.debug(`GigaChat auth: basic mode`);
} else {
clientConfig.credentials = apiKey;
clientConfig.scope = opts.scope ?? "GIGACHAT_API_PERS";
log.debug(`GigaChat auth: oauth scope=${clientConfig.scope}`);
}
const client = new GigaChat(clientConfig);
const client = getClientForApiKey(apiKey);
// Build chat request - explicitly omit profanity_check
const chatRequest: Chat = {
@ -644,56 +732,41 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
log.debug(`GigaChat request: ${messages.length} messages, ${functions.length} functions`);
// Use the library for auth, but our own SSE parsing (library's parseChunk is buggy)
// Wrap token refresh in retry logic for transient failures
await withRetry(() => client.updateToken(), "token refresh");
const axiosClient = client._client;
// Access the token (protected property, so we cast)
const accessToken = (client as unknown as { _accessToken?: { access_token: string } })
._accessToken?.access_token;
if (!accessToken) {
throw new Error("GigaChat: failed to obtain access token after retries");
}
const requestId = randomUUID();
log.debug(`GigaChat request ${requestId}: starting`);
const headers: Record<string, string> = {
...resolveGigachatModelHeaders(model),
...options?.headers,
Authorization: `Bearer ${accessToken}`,
Accept: "text/event-stream",
"Cache-Control": "no-store",
"X-Request-ID": requestId,
const axiosClient = client._client;
const sendChatCompletionsRequest = async (): Promise<GigachatTransportResponse> => {
const accessToken = await ensureGigachatAccessToken(client);
return axiosClient.request({
method: "POST",
url: "/chat/completions",
data: { ...chatRequest, stream: true },
responseType: "stream",
headers: {
...resolveGigachatModelHeaders(model),
...options?.headers,
Authorization: `Bearer ${accessToken}`,
Accept: "text/event-stream",
"Cache-Control": "no-store",
"X-Request-ID": requestId,
},
signal: options?.signal,
});
};
const response = await axiosClient.request({
method: "POST",
url: "/chat/completions",
data: { ...chatRequest, stream: true },
responseType: "stream",
headers,
signal: options?.signal,
});
let response = await sendChatCompletionsRequest();
if (response.status === 401) {
log.warn(
`GigaChat request ${requestId}: received 401 from chat endpoint, refreshing token and retrying`,
);
resetGigachatAccessToken(client);
await ensureGigachatAccessToken(client);
response = await sendChatCompletionsRequest();
}
if (response.status !== 200) {
let errorText = "unknown error";
try {
if (typeof response.data === "string") {
errorText = response.data;
} else if (response.data && typeof response.data.pipe === "function") {
// It's a stream, try to read it
const chunks: Buffer[] = [];
for await (const chunk of response.data) {
chunks.push(chunk);
}
errorText = Buffer.concat(chunks).toString();
}
} catch {
errorText = `status ${response.status}`;
}
const errorText = await readGigachatErrorText(response.data, response.status);
throw new Error(
`GigaChat API error ${response.status} (${effectiveBaseUrl}/chat/completions): ${errorText}`,
);