feat(google): add image generation provider
This commit is contained in:
parent
c1ef5748eb
commit
618d35f933
@ -1,4 +1,5 @@
|
||||
import { emptyPluginConfigSchema, type OpenClawPluginApi } from "openclaw/plugin-sdk/core";
|
||||
import { buildGoogleImageGenerationProvider } from "openclaw/plugin-sdk/image-generation";
|
||||
import { createProviderApiKeyAuthMethod } from "openclaw/plugin-sdk/provider-auth";
|
||||
import {
|
||||
GOOGLE_GEMINI_DEFAULT_MODEL,
|
||||
@ -51,6 +52,7 @@ const googlePlugin = {
|
||||
isModernModelRef: ({ modelId }) => isModernGoogleModel(modelId),
|
||||
});
|
||||
registerGoogleGeminiCliProvider(api);
|
||||
api.registerImageGenerationProvider(buildGoogleImageGenerationProvider());
|
||||
api.registerMediaUnderstandingProvider(googleMediaUnderstandingProvider);
|
||||
api.registerWebSearchProvider(
|
||||
createPluginBackedWebSearchProvider({
|
||||
|
||||
51
src/image-generation/providers/google.live.test.ts
Normal file
51
src/image-generation/providers/google.live.test.ts
Normal file
@ -0,0 +1,51 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import { isTruthyEnvValue } from "../../infra/env.js";
|
||||
import { buildGoogleImageGenerationProvider } from "./google.js";
|
||||
|
||||
const LIVE =
|
||||
isTruthyEnvValue(process.env.GOOGLE_LIVE_TEST) ||
|
||||
isTruthyEnvValue(process.env.LIVE) ||
|
||||
isTruthyEnvValue(process.env.OPENCLAW_LIVE_TEST);
|
||||
const HAS_KEY = Boolean(process.env.GEMINI_API_KEY?.trim() || process.env.GOOGLE_API_KEY?.trim());
|
||||
const MODEL =
|
||||
process.env.GOOGLE_IMAGE_GENERATION_MODEL?.trim() ||
|
||||
process.env.GEMINI_IMAGE_GENERATION_MODEL?.trim() ||
|
||||
"gemini-3.1-flash-image-preview";
|
||||
const BASE_URL = process.env.GOOGLE_IMAGE_BASE_URL?.trim();
|
||||
|
||||
const describeLive = LIVE && HAS_KEY ? describe : describe.skip;
|
||||
|
||||
function buildLiveConfig(): OpenClawConfig {
|
||||
if (!BASE_URL) {
|
||||
return {};
|
||||
}
|
||||
return {
|
||||
models: {
|
||||
providers: {
|
||||
google: {
|
||||
baseUrl: BASE_URL,
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as OpenClawConfig;
|
||||
}
|
||||
|
||||
describeLive("google image-generation live", () => {
|
||||
it("generates a real image", async () => {
|
||||
const provider = buildGoogleImageGenerationProvider();
|
||||
const result = await provider.generateImage({
|
||||
provider: "google",
|
||||
model: MODEL,
|
||||
prompt:
|
||||
"Create a minimal flat illustration of an orange cat face sticker on a white background.",
|
||||
cfg: buildLiveConfig(),
|
||||
size: "1024x1024",
|
||||
});
|
||||
|
||||
expect(result.model).toBeTruthy();
|
||||
expect(result.images.length).toBeGreaterThan(0);
|
||||
expect(result.images[0]?.mimeType.startsWith("image/")).toBe(true);
|
||||
expect(result.images[0]?.buffer.byteLength).toBeGreaterThan(512);
|
||||
}, 120_000);
|
||||
});
|
||||
134
src/image-generation/providers/google.test.ts
Normal file
134
src/image-generation/providers/google.test.ts
Normal file
@ -0,0 +1,134 @@
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import * as modelAuth from "../../agents/model-auth.js";
|
||||
import { buildGoogleImageGenerationProvider } from "./google.js";
|
||||
|
||||
describe("Google image-generation provider", () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("generates image buffers from the Gemini generateContent API", async () => {
|
||||
vi.spyOn(modelAuth, "resolveApiKeyForProvider").mockResolvedValue({
|
||||
apiKey: "google-test-key",
|
||||
source: "env",
|
||||
mode: "api-key",
|
||||
});
|
||||
const fetchMock = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "generated" },
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: "image/png",
|
||||
data: Buffer.from("png-data").toString("base64"),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
});
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
|
||||
const provider = buildGoogleImageGenerationProvider();
|
||||
const result = await provider.generateImage({
|
||||
provider: "google",
|
||||
model: "gemini-3.1-flash-image-preview",
|
||||
prompt: "draw a cat",
|
||||
cfg: {},
|
||||
size: "1536x1024",
|
||||
});
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-3.1-flash-image-preview:generateContent",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
contents: [
|
||||
{
|
||||
role: "user",
|
||||
parts: [{ text: "draw a cat" }],
|
||||
},
|
||||
],
|
||||
generationConfig: {
|
||||
responseModalities: ["TEXT", "IMAGE"],
|
||||
imageConfig: {
|
||||
aspectRatio: "3:2",
|
||||
imageSize: "2K",
|
||||
},
|
||||
},
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(result).toEqual({
|
||||
images: [
|
||||
{
|
||||
buffer: Buffer.from("png-data"),
|
||||
mimeType: "image/png",
|
||||
fileName: "image-1.png",
|
||||
},
|
||||
],
|
||||
model: "gemini-3.1-flash-image-preview",
|
||||
});
|
||||
});
|
||||
|
||||
it("accepts OAuth JSON auth and inline_data responses", async () => {
|
||||
vi.spyOn(modelAuth, "resolveApiKeyForProvider").mockResolvedValue({
|
||||
apiKey: JSON.stringify({ token: "oauth-token" }),
|
||||
source: "profile",
|
||||
mode: "token",
|
||||
});
|
||||
const fetchMock = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
inline_data: {
|
||||
mime_type: "image/jpeg",
|
||||
data: Buffer.from("jpg-data").toString("base64"),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
});
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
|
||||
const provider = buildGoogleImageGenerationProvider();
|
||||
const result = await provider.generateImage({
|
||||
provider: "google",
|
||||
model: "gemini-3.1-flash-image-preview",
|
||||
prompt: "draw a dog",
|
||||
cfg: {},
|
||||
});
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
headers: expect.any(Headers),
|
||||
}),
|
||||
);
|
||||
const [, init] = fetchMock.mock.calls[0];
|
||||
expect(new Headers(init.headers).get("authorization")).toBe("Bearer oauth-token");
|
||||
expect(result).toEqual({
|
||||
images: [
|
||||
{
|
||||
buffer: Buffer.from("jpg-data"),
|
||||
mimeType: "image/jpeg",
|
||||
fileName: "image-1.jpg",
|
||||
},
|
||||
],
|
||||
model: "gemini-3.1-flash-image-preview",
|
||||
});
|
||||
});
|
||||
});
|
||||
159
src/image-generation/providers/google.ts
Normal file
159
src/image-generation/providers/google.ts
Normal file
@ -0,0 +1,159 @@
|
||||
import { resolveApiKeyForProvider } from "../../agents/model-auth.js";
|
||||
import { normalizeGoogleModelId } from "../../agents/model-id-normalization.js";
|
||||
import { parseGeminiAuth } from "../../infra/gemini-auth.js";
|
||||
import {
|
||||
assertOkOrThrowHttpError,
|
||||
normalizeBaseUrl,
|
||||
postJsonRequest,
|
||||
} from "../../media-understanding/providers/shared.js";
|
||||
import type { ImageGenerationProviderPlugin } from "../../plugins/types.js";
|
||||
|
||||
const DEFAULT_GOOGLE_IMAGE_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
|
||||
const DEFAULT_GOOGLE_IMAGE_MODEL = "gemini-3.1-flash-image-preview";
|
||||
const DEFAULT_OUTPUT_MIME = "image/png";
|
||||
const DEFAULT_ASPECT_RATIO = "1:1";
|
||||
|
||||
type GoogleInlineDataPart = {
|
||||
mimeType?: string;
|
||||
mime_type?: string;
|
||||
data?: string;
|
||||
};
|
||||
|
||||
type GoogleGenerateImageResponse = {
|
||||
candidates?: Array<{
|
||||
content?: {
|
||||
parts?: Array<{
|
||||
text?: string;
|
||||
inlineData?: GoogleInlineDataPart;
|
||||
inline_data?: GoogleInlineDataPart;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
};
|
||||
|
||||
function resolveGoogleBaseUrl(cfg: Parameters<typeof resolveApiKeyForProvider>[0]["cfg"]): string {
|
||||
const direct = cfg?.models?.providers?.google?.baseUrl?.trim();
|
||||
return direct || DEFAULT_GOOGLE_IMAGE_BASE_URL;
|
||||
}
|
||||
|
||||
function normalizeGoogleImageModel(model: string | undefined): string {
|
||||
const trimmed = model?.trim();
|
||||
return normalizeGoogleModelId(trimmed || DEFAULT_GOOGLE_IMAGE_MODEL);
|
||||
}
|
||||
|
||||
function mapSizeToImageConfig(
|
||||
size: string | undefined,
|
||||
): { aspectRatio?: string; imageSize?: "2K" | "4K" } | undefined {
|
||||
const trimmed = size?.trim();
|
||||
if (!trimmed) {
|
||||
return { aspectRatio: DEFAULT_ASPECT_RATIO };
|
||||
}
|
||||
|
||||
const normalized = trimmed.toLowerCase();
|
||||
const mapping = new Map<string, string>([
|
||||
["1024x1024", "1:1"],
|
||||
["1024x1536", "2:3"],
|
||||
["1536x1024", "3:2"],
|
||||
["1024x1792", "9:16"],
|
||||
["1792x1024", "16:9"],
|
||||
]);
|
||||
const aspectRatio = mapping.get(normalized);
|
||||
|
||||
const [widthRaw, heightRaw] = normalized.split("x");
|
||||
const width = Number.parseInt(widthRaw ?? "", 10);
|
||||
const height = Number.parseInt(heightRaw ?? "", 10);
|
||||
const longestEdge = Math.max(width, height);
|
||||
const imageSize = longestEdge >= 3072 ? "4K" : longestEdge >= 1536 ? "2K" : undefined;
|
||||
|
||||
if (!aspectRatio && !imageSize) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
...(aspectRatio ? { aspectRatio } : {}),
|
||||
...(imageSize ? { imageSize } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
export function buildGoogleImageGenerationProvider(): ImageGenerationProviderPlugin {
|
||||
return {
|
||||
id: "google",
|
||||
label: "Google",
|
||||
async generateImage(req) {
|
||||
const auth = await resolveApiKeyForProvider({
|
||||
provider: "google",
|
||||
cfg: req.cfg,
|
||||
agentDir: req.agentDir,
|
||||
});
|
||||
if (!auth.apiKey) {
|
||||
throw new Error("Google API key missing");
|
||||
}
|
||||
|
||||
const model = normalizeGoogleImageModel(req.model);
|
||||
const baseUrl = normalizeBaseUrl(
|
||||
resolveGoogleBaseUrl(req.cfg),
|
||||
DEFAULT_GOOGLE_IMAGE_BASE_URL,
|
||||
);
|
||||
const allowPrivate = Boolean(req.cfg?.models?.providers?.google?.baseUrl?.trim());
|
||||
const authHeaders = parseGeminiAuth(auth.apiKey);
|
||||
const headers = new Headers(authHeaders.headers);
|
||||
const imageConfig = mapSizeToImageConfig(req.size);
|
||||
|
||||
const { response: res, release } = await postJsonRequest({
|
||||
url: `${baseUrl}/models/${model}:generateContent`,
|
||||
headers,
|
||||
body: {
|
||||
contents: [
|
||||
{
|
||||
role: "user",
|
||||
parts: [{ text: req.prompt }],
|
||||
},
|
||||
],
|
||||
generationConfig: {
|
||||
responseModalities: ["TEXT", "IMAGE"],
|
||||
...(imageConfig ? { imageConfig } : {}),
|
||||
},
|
||||
},
|
||||
timeoutMs: 60_000,
|
||||
fetchFn: fetch,
|
||||
allowPrivateNetwork: allowPrivate,
|
||||
});
|
||||
|
||||
try {
|
||||
await assertOkOrThrowHttpError(res, "Google image generation failed");
|
||||
|
||||
const payload = (await res.json()) as GoogleGenerateImageResponse;
|
||||
let imageIndex = 0;
|
||||
const images = (payload.candidates ?? [])
|
||||
.flatMap((candidate) => candidate.content?.parts ?? [])
|
||||
.map((part) => {
|
||||
const inline = part.inlineData ?? part.inline_data;
|
||||
const data = inline?.data?.trim();
|
||||
if (!data) {
|
||||
return null;
|
||||
}
|
||||
const mimeType = inline?.mimeType ?? inline?.mime_type ?? DEFAULT_OUTPUT_MIME;
|
||||
const extension = mimeType.includes("jpeg") ? "jpg" : (mimeType.split("/")[1] ?? "png");
|
||||
imageIndex += 1;
|
||||
return {
|
||||
buffer: Buffer.from(data, "base64"),
|
||||
mimeType,
|
||||
fileName: `image-${imageIndex}.${extension}`,
|
||||
};
|
||||
})
|
||||
.filter((entry): entry is NonNullable<typeof entry> => entry !== null);
|
||||
|
||||
if (images.length === 0) {
|
||||
throw new Error("Google image generation response missing image data");
|
||||
}
|
||||
|
||||
return {
|
||||
images,
|
||||
model,
|
||||
};
|
||||
} finally {
|
||||
await release();
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
@ -7,4 +7,5 @@ export type {
|
||||
ImageGenerationResult,
|
||||
} from "../image-generation/types.js";
|
||||
|
||||
export { buildGoogleImageGenerationProvider } from "../image-generation/providers/google.js";
|
||||
export { buildOpenAIImageGenerationProvider } from "../image-generation/providers/openai.js";
|
||||
|
||||
@ -165,6 +165,7 @@ describe("plugin contract registry", () => {
|
||||
});
|
||||
|
||||
it("keeps bundled image-generation ownership explicit", () => {
|
||||
expect(findImageGenerationProviderIdsForPlugin("google")).toEqual(["google"]);
|
||||
expect(findImageGenerationProviderIdsForPlugin("openai")).toEqual(["openai"]);
|
||||
});
|
||||
|
||||
@ -180,6 +181,13 @@ describe("plugin contract registry", () => {
|
||||
});
|
||||
|
||||
it("tracks speech registrations on bundled provider plugins", () => {
|
||||
expect(findRegistrationForPlugin("google")).toMatchObject({
|
||||
providerIds: ["google", "google-gemini-cli"],
|
||||
speechProviderIds: [],
|
||||
mediaUnderstandingProviderIds: ["google"],
|
||||
imageGenerationProviderIds: ["google"],
|
||||
webSearchProviderIds: ["gemini"],
|
||||
});
|
||||
expect(findRegistrationForPlugin("openai")).toMatchObject({
|
||||
providerIds: ["openai", "openai-codex"],
|
||||
speechProviderIds: ["openai"],
|
||||
@ -245,6 +253,9 @@ describe("plugin contract registry", () => {
|
||||
});
|
||||
|
||||
it("keeps bundled image-generation support explicit", () => {
|
||||
expect(findImageGenerationProviderForPlugin("google").generateImage).toEqual(
|
||||
expect.any(Function),
|
||||
);
|
||||
expect(findImageGenerationProviderForPlugin("openai").generateImage).toEqual(
|
||||
expect.any(Function),
|
||||
);
|
||||
|
||||
@ -131,7 +131,7 @@ const bundledMediaUnderstandingPlugins: RegistrablePlugin[] = [
|
||||
zaiPlugin,
|
||||
];
|
||||
|
||||
const bundledImageGenerationPlugins: RegistrablePlugin[] = [openAIPlugin];
|
||||
const bundledImageGenerationPlugins: RegistrablePlugin[] = [googlePlugin, openAIPlugin];
|
||||
|
||||
function captureRegistrations(plugin: RegistrablePlugin) {
|
||||
const captured = createCapturedPluginRegistration();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user