GigaChat: reuse cached auth tokens
This commit is contained in:
parent
6f6bf7eff1
commit
db44a27801
@ -54,7 +54,7 @@ describe("GigaChat leaked function-call prelude cleanup", () => {
|
||||
|
||||
const event = await stream.result();
|
||||
|
||||
expect(updateToken).toHaveBeenCalled();
|
||||
expect(updateToken).not.toHaveBeenCalled();
|
||||
expect(event.stopReason).toBe("toolUse");
|
||||
expect(event.content).toEqual([
|
||||
expect.objectContaining({
|
||||
|
||||
@ -1,16 +1,23 @@
|
||||
import { Readable } from "node:stream";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const updateToken = vi.fn(async () => {});
|
||||
let initialAccessToken: { access_token: string } | undefined = { access_token: "test-token" };
|
||||
let refreshedAccessToken = "refreshed-token";
|
||||
const updateToken = vi.fn(async function (this: { _accessToken?: { access_token: string } }) {
|
||||
this._accessToken = { access_token: refreshedAccessToken };
|
||||
});
|
||||
const request = vi.fn();
|
||||
const clientConfigs: Array<Record<string, unknown>> = [];
|
||||
|
||||
vi.mock("gigachat", () => {
|
||||
class MockGigaChat {
|
||||
_client = { request };
|
||||
_accessToken = { access_token: "test-token" };
|
||||
_accessToken = initialAccessToken ? { ...initialAccessToken } : undefined;
|
||||
|
||||
updateToken = updateToken;
|
||||
resetToken() {
|
||||
this._accessToken = undefined;
|
||||
}
|
||||
|
||||
constructor(config: Record<string, unknown>) {
|
||||
clientConfigs.push(config);
|
||||
@ -34,6 +41,8 @@ describe("createGigachatStreamFn tool calling", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
clientConfigs.length = 0;
|
||||
initialAccessToken = { access_token: "test-token" };
|
||||
refreshedAccessToken = "refreshed-token";
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
@ -74,7 +83,7 @@ describe("createGigachatStreamFn tool calling", () => {
|
||||
|
||||
const event = await stream.result();
|
||||
|
||||
expect(updateToken).toHaveBeenCalled();
|
||||
expect(updateToken).not.toHaveBeenCalled();
|
||||
expect(request).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
data: expect.objectContaining({
|
||||
@ -178,6 +187,117 @@ describe("createGigachatStreamFn tool calling", () => {
|
||||
expect(event.content).toEqual([{ type: "text", text: "final tail" }]);
|
||||
});
|
||||
|
||||
it("reuses a cached token across turns for the same GigaChat credentials", async () => {
|
||||
initialAccessToken = undefined;
|
||||
refreshedAccessToken = "cached-after-refresh";
|
||||
request
|
||||
.mockResolvedValueOnce({
|
||||
status: 200,
|
||||
data: createSseStream([
|
||||
'data: {"choices":[{"delta":{"content":"first"}}]}',
|
||||
"data: [DONE]",
|
||||
]),
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
status: 200,
|
||||
data: createSseStream([
|
||||
'data: {"choices":[{"delta":{"content":"second"}}]}',
|
||||
"data: [DONE]",
|
||||
]),
|
||||
});
|
||||
|
||||
const streamFn = createGigachatStreamFn({
|
||||
baseUrl: "https://gigachat.devices.sberbank.ru/api/v1",
|
||||
authMode: "oauth",
|
||||
});
|
||||
|
||||
const firstStream = await streamFn(
|
||||
{ api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never,
|
||||
{ messages: [], tools: [] } as never,
|
||||
{ apiKey: "token" } as never,
|
||||
);
|
||||
await expect(firstStream.result()).resolves.toMatchObject({
|
||||
content: [{ type: "text", text: "first" }],
|
||||
});
|
||||
|
||||
const secondStream = await streamFn(
|
||||
{ api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never,
|
||||
{ messages: [], tools: [] } as never,
|
||||
{ apiKey: "token" } as never,
|
||||
);
|
||||
await expect(secondStream.result()).resolves.toMatchObject({
|
||||
content: [{ type: "text", text: "second" }],
|
||||
});
|
||||
|
||||
expect(updateToken).toHaveBeenCalledTimes(1);
|
||||
expect(clientConfigs).toHaveLength(1);
|
||||
expect(request).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer cached-after-refresh",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(request).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer cached-after-refresh",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("refreshes once and retries the chat request after a 401", async () => {
|
||||
refreshedAccessToken = "fresh-token";
|
||||
request
|
||||
.mockResolvedValueOnce({
|
||||
status: 401,
|
||||
data: "expired token",
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
status: 200,
|
||||
data: createSseStream([
|
||||
'data: {"choices":[{"delta":{"content":"recovered"}}]}',
|
||||
"data: [DONE]",
|
||||
]),
|
||||
});
|
||||
|
||||
const streamFn = createGigachatStreamFn({
|
||||
baseUrl: "https://gigachat.devices.sberbank.ru/api/v1",
|
||||
authMode: "oauth",
|
||||
});
|
||||
|
||||
const stream = await streamFn(
|
||||
{ api: "gigachat", provider: "gigachat", id: "GigaChat-2-Max" } as never,
|
||||
{ messages: [], tools: [] } as never,
|
||||
{ apiKey: "token" } as never,
|
||||
);
|
||||
|
||||
await expect(stream.result()).resolves.toMatchObject({
|
||||
content: [{ type: "text", text: "recovered" }],
|
||||
});
|
||||
|
||||
expect(updateToken).toHaveBeenCalledTimes(1);
|
||||
expect(request).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer test-token",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(request).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer fresh-token",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("prefers the resolved GigaChat baseUrl over the env override", async () => {
|
||||
vi.stubEnv("GIGACHAT_BASE_URL", "https://env-host.example/api/v1");
|
||||
request.mockResolvedValueOnce({
|
||||
|
||||
@ -440,6 +440,80 @@ async function withRetry<T>(
|
||||
throw lastError;
|
||||
}
|
||||
|
||||
type GigachatAccessToken = {
|
||||
access_token?: string;
|
||||
};
|
||||
|
||||
type GigachatTransportResponse = {
|
||||
status: number;
|
||||
data: AsyncIterable<string | Uint8Array> | string | { pipe?: unknown };
|
||||
};
|
||||
|
||||
type GigachatRuntimeClient = GigaChat & {
|
||||
_client: {
|
||||
request: (config: {
|
||||
method: "POST";
|
||||
url: string;
|
||||
data: Chat & { stream: true };
|
||||
responseType: "stream";
|
||||
headers: Record<string, string>;
|
||||
signal?: AbortSignal;
|
||||
}) => Promise<GigachatTransportResponse>;
|
||||
};
|
||||
_accessToken?: GigachatAccessToken;
|
||||
updateToken: () => Promise<void>;
|
||||
resetToken?: () => void;
|
||||
};
|
||||
|
||||
function getGigachatAccessToken(client: GigachatRuntimeClient): string | undefined {
|
||||
return client._accessToken?.access_token?.trim() || undefined;
|
||||
}
|
||||
|
||||
async function ensureGigachatAccessToken(client: GigachatRuntimeClient): Promise<string> {
|
||||
const accessToken = getGigachatAccessToken(client);
|
||||
if (accessToken) {
|
||||
return accessToken;
|
||||
}
|
||||
|
||||
await withRetry(() => client.updateToken(), "token refresh");
|
||||
|
||||
const refreshedToken = getGigachatAccessToken(client);
|
||||
if (!refreshedToken) {
|
||||
throw new Error("GigaChat: failed to obtain access token after retries");
|
||||
}
|
||||
return refreshedToken;
|
||||
}
|
||||
|
||||
function resetGigachatAccessToken(client: GigachatRuntimeClient): void {
|
||||
if (typeof client.resetToken === "function") {
|
||||
client.resetToken();
|
||||
return;
|
||||
}
|
||||
delete client._accessToken;
|
||||
}
|
||||
|
||||
async function readGigachatErrorText(
|
||||
responseData: GigachatTransportResponse["data"],
|
||||
status: number,
|
||||
): Promise<string> {
|
||||
try {
|
||||
if (typeof responseData === "string") {
|
||||
return responseData;
|
||||
}
|
||||
if (responseData && typeof responseData === "object" && "pipe" in responseData) {
|
||||
const chunks: Buffer[] = [];
|
||||
for await (const chunk of responseData as AsyncIterable<string | Uint8Array>) {
|
||||
chunks.push(typeof chunk === "string" ? Buffer.from(chunk) : Buffer.from(chunk));
|
||||
}
|
||||
return Buffer.concat(chunks).toString();
|
||||
}
|
||||
} catch {
|
||||
return `status ${status}`;
|
||||
}
|
||||
|
||||
return "unknown error";
|
||||
}
|
||||
|
||||
// ── Stream function ─────────────────────────────────────────────────────────
|
||||
|
||||
export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
|
||||
@ -460,6 +534,45 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
|
||||
);
|
||||
}
|
||||
|
||||
let cachedClient: GigachatRuntimeClient | null = null;
|
||||
let cachedApiKey: string | null = null;
|
||||
|
||||
const buildClientConfig = (apiKey: string): GigaChatClientConfig => {
|
||||
const clientConfig: GigaChatClientConfig = {
|
||||
baseUrl: effectiveBaseUrl,
|
||||
// Explicitly set to undefined to prevent the library from adding profanity_check
|
||||
profanityCheck: undefined,
|
||||
timeout: 120,
|
||||
};
|
||||
|
||||
if (insecureTls) {
|
||||
clientConfig.httpsAgent = new https.Agent({ rejectUnauthorized: false });
|
||||
}
|
||||
|
||||
if (opts.authMode === "basic") {
|
||||
const { user, password } = parseGigachatBasicCredentials(apiKey);
|
||||
clientConfig.user = user;
|
||||
clientConfig.password = password;
|
||||
log.debug(`GigaChat auth: basic mode`);
|
||||
} else {
|
||||
clientConfig.credentials = apiKey;
|
||||
clientConfig.scope = opts.scope ?? "GIGACHAT_API_PERS";
|
||||
log.debug(`GigaChat auth: oauth scope=${clientConfig.scope}`);
|
||||
}
|
||||
|
||||
return clientConfig;
|
||||
};
|
||||
|
||||
const getClientForApiKey = (apiKey: string): GigachatRuntimeClient => {
|
||||
if (cachedClient && cachedApiKey === apiKey) {
|
||||
return cachedClient;
|
||||
}
|
||||
|
||||
cachedClient = new GigaChat(buildClientConfig(apiKey)) as GigachatRuntimeClient;
|
||||
cachedApiKey = apiKey;
|
||||
return cachedClient;
|
||||
};
|
||||
|
||||
return (model, context, options) => {
|
||||
const stream = createAssistantMessageEventStream();
|
||||
|
||||
@ -596,32 +709,7 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
|
||||
|
||||
// Build auth config
|
||||
const apiKey = options?.apiKey ?? "";
|
||||
|
||||
const clientConfig: GigaChatClientConfig = {
|
||||
baseUrl: effectiveBaseUrl,
|
||||
// Explicitly set to undefined to prevent the library from adding profanity_check
|
||||
profanityCheck: undefined,
|
||||
timeout: 120,
|
||||
};
|
||||
|
||||
// Configure TLS
|
||||
if (insecureTls) {
|
||||
clientConfig.httpsAgent = new https.Agent({ rejectUnauthorized: false });
|
||||
}
|
||||
|
||||
// Set credentials based on auth mode
|
||||
if (opts.authMode === "basic") {
|
||||
const { user, password } = parseGigachatBasicCredentials(apiKey);
|
||||
clientConfig.user = user;
|
||||
clientConfig.password = password;
|
||||
log.debug(`GigaChat auth: basic mode`);
|
||||
} else {
|
||||
clientConfig.credentials = apiKey;
|
||||
clientConfig.scope = opts.scope ?? "GIGACHAT_API_PERS";
|
||||
log.debug(`GigaChat auth: oauth scope=${clientConfig.scope}`);
|
||||
}
|
||||
|
||||
const client = new GigaChat(clientConfig);
|
||||
const client = getClientForApiKey(apiKey);
|
||||
|
||||
// Build chat request - explicitly omit profanity_check
|
||||
const chatRequest: Chat = {
|
||||
@ -644,56 +732,41 @@ export function createGigachatStreamFn(opts: GigachatStreamOptions): StreamFn {
|
||||
|
||||
log.debug(`GigaChat request: ${messages.length} messages, ${functions.length} functions`);
|
||||
|
||||
// Use the library for auth, but our own SSE parsing (library's parseChunk is buggy)
|
||||
// Wrap token refresh in retry logic for transient failures
|
||||
await withRetry(() => client.updateToken(), "token refresh");
|
||||
|
||||
const axiosClient = client._client;
|
||||
// Access the token (protected property, so we cast)
|
||||
const accessToken = (client as unknown as { _accessToken?: { access_token: string } })
|
||||
._accessToken?.access_token;
|
||||
|
||||
if (!accessToken) {
|
||||
throw new Error("GigaChat: failed to obtain access token after retries");
|
||||
}
|
||||
|
||||
const requestId = randomUUID();
|
||||
log.debug(`GigaChat request ${requestId}: starting`);
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
...resolveGigachatModelHeaders(model),
|
||||
...options?.headers,
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
Accept: "text/event-stream",
|
||||
"Cache-Control": "no-store",
|
||||
"X-Request-ID": requestId,
|
||||
const axiosClient = client._client;
|
||||
const sendChatCompletionsRequest = async (): Promise<GigachatTransportResponse> => {
|
||||
const accessToken = await ensureGigachatAccessToken(client);
|
||||
return axiosClient.request({
|
||||
method: "POST",
|
||||
url: "/chat/completions",
|
||||
data: { ...chatRequest, stream: true },
|
||||
responseType: "stream",
|
||||
headers: {
|
||||
...resolveGigachatModelHeaders(model),
|
||||
...options?.headers,
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
Accept: "text/event-stream",
|
||||
"Cache-Control": "no-store",
|
||||
"X-Request-ID": requestId,
|
||||
},
|
||||
signal: options?.signal,
|
||||
});
|
||||
};
|
||||
|
||||
const response = await axiosClient.request({
|
||||
method: "POST",
|
||||
url: "/chat/completions",
|
||||
data: { ...chatRequest, stream: true },
|
||||
responseType: "stream",
|
||||
headers,
|
||||
signal: options?.signal,
|
||||
});
|
||||
let response = await sendChatCompletionsRequest();
|
||||
if (response.status === 401) {
|
||||
log.warn(
|
||||
`GigaChat request ${requestId}: received 401 from chat endpoint, refreshing token and retrying`,
|
||||
);
|
||||
resetGigachatAccessToken(client);
|
||||
await ensureGigachatAccessToken(client);
|
||||
response = await sendChatCompletionsRequest();
|
||||
}
|
||||
|
||||
if (response.status !== 200) {
|
||||
let errorText = "unknown error";
|
||||
try {
|
||||
if (typeof response.data === "string") {
|
||||
errorText = response.data;
|
||||
} else if (response.data && typeof response.data.pipe === "function") {
|
||||
// It's a stream, try to read it
|
||||
const chunks: Buffer[] = [];
|
||||
for await (const chunk of response.data) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
errorText = Buffer.concat(chunks).toString();
|
||||
}
|
||||
} catch {
|
||||
errorText = `status ${response.status}`;
|
||||
}
|
||||
const errorText = await readGigachatErrorText(response.data, response.status);
|
||||
throw new Error(
|
||||
`GigaChat API error ${response.status} (${effectiveBaseUrl}/chat/completions): ${errorText}`,
|
||||
);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user