From 3a456678ee3516679185a0814ac61d91128fcb9d Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Tue, 17 Mar 2026 01:09:24 -0700 Subject: [PATCH] feat(image-generation): add image_generate tool --- .../openclaw-tools.image-generation.test.ts | 39 ++ src/agents/openclaw-tools.ts | 9 + .../pi-embedded-subscribe.tools.media.test.ts | 9 +- src/agents/pi-embedded-subscribe.tools.ts | 1 + src/agents/system-prompt.ts | 2 + src/agents/tool-catalog.test.ts | 1 + src/agents/tool-catalog.ts | 8 + src/agents/tools/image-generate-tool.test.ts | 280 ++++++++++++ src/agents/tools/image-generate-tool.ts | 424 ++++++++++++++++++ src/image-generation/providers/google.test.ts | 74 +++ src/image-generation/providers/google.ts | 21 +- src/image-generation/providers/openai.test.ts | 30 +- src/image-generation/providers/openai.ts | 7 +- src/image-generation/runtime.test.ts | 19 +- src/image-generation/runtime.ts | 14 +- src/image-generation/types.ts | 17 + src/plugin-sdk/image-generation.ts | 2 + 17 files changed, 949 insertions(+), 8 deletions(-) create mode 100644 src/agents/openclaw-tools.image-generation.test.ts create mode 100644 src/agents/tools/image-generate-tool.test.ts create mode 100644 src/agents/tools/image-generate-tool.ts diff --git a/src/agents/openclaw-tools.image-generation.test.ts b/src/agents/openclaw-tools.image-generation.test.ts new file mode 100644 index 00000000000..dd237115ab7 --- /dev/null +++ b/src/agents/openclaw-tools.image-generation.test.ts @@ -0,0 +1,39 @@ +import { describe, expect, it, vi } from "vitest"; +import type { OpenClawConfig } from "../config/config.js"; +import { createOpenClawTools } from "./openclaw-tools.js"; + +vi.mock("../plugins/tools.js", () => ({ + resolvePluginTools: () => [], +})); + +function asConfig(value: unknown): OpenClawConfig { + return value as OpenClawConfig; +} + +describe("openclaw tools image generation registration", () => { + it("registers image_generate when image-generation config is present", () => { + const tools = createOpenClawTools({ + config: asConfig({ + agents: { + defaults: { + imageGenerationModel: { + primary: "openai/gpt-image-1", + }, + }, + }, + }), + agentDir: "/tmp/openclaw-agent-main", + }); + + expect(tools.map((tool) => tool.name)).toContain("image_generate"); + }); + + it("omits image_generate when image-generation config is absent", () => { + const tools = createOpenClawTools({ + config: asConfig({}), + agentDir: "/tmp/openclaw-agent-main", + }); + + expect(tools.map((tool) => tool.name)).not.toContain("image_generate"); + }); +}); diff --git a/src/agents/openclaw-tools.ts b/src/agents/openclaw-tools.ts index 32bd92f4207..6f4929d288a 100644 --- a/src/agents/openclaw-tools.ts +++ b/src/agents/openclaw-tools.ts @@ -12,6 +12,7 @@ import { createCanvasTool } from "./tools/canvas-tool.js"; import type { AnyAgentTool } from "./tools/common.js"; import { createCronTool } from "./tools/cron-tool.js"; import { createGatewayTool } from "./tools/gateway-tool.js"; +import { createImageGenerateTool } from "./tools/image-generate-tool.js"; import { createImageTool } from "./tools/image-tool.js"; import { createMessageTool } from "./tools/message-tool.js"; import { createNodesTool } from "./tools/nodes-tool.js"; @@ -103,6 +104,13 @@ export function createOpenClawTools( modelHasVision: options?.modelHasVision, }) : null; + const imageGenerateTool = createImageGenerateTool({ + config: options?.config, + agentDir: options?.agentDir, + workspaceDir, + sandbox, + fsPolicy: options?.fsPolicy, + }); const pdfTool = options?.agentDir?.trim() ? createPdfTool({ config: options?.config, @@ -163,6 +171,7 @@ export function createOpenClawTools( agentChannel: options?.agentChannel, config: options?.config, }), + ...(imageGenerateTool ? [imageGenerateTool] : []), createGatewayTool({ agentSessionKey: options?.agentSessionKey, config: options?.config, diff --git a/src/agents/pi-embedded-subscribe.tools.media.test.ts b/src/agents/pi-embedded-subscribe.tools.media.test.ts index a07ed71473d..7cf51bb7c1c 100644 --- a/src/agents/pi-embedded-subscribe.tools.media.test.ts +++ b/src/agents/pi-embedded-subscribe.tools.media.test.ts @@ -1,5 +1,8 @@ import { describe, expect, it } from "vitest"; -import { extractToolResultMediaPaths } from "./pi-embedded-subscribe.tools.js"; +import { + extractToolResultMediaPaths, + isToolResultMediaTrusted, +} from "./pi-embedded-subscribe.tools.js"; describe("extractToolResultMediaPaths", () => { it("returns empty array for null/undefined", () => { @@ -229,4 +232,8 @@ describe("extractToolResultMediaPaths", () => { }; expect(extractToolResultMediaPaths(result)).toEqual(["/tmp/page1.png", "/tmp/page2.png"]); }); + + it("trusts image_generate local MEDIA paths", () => { + expect(isToolResultMediaTrusted("image_generate")).toBe(true); + }); }); diff --git a/src/agents/pi-embedded-subscribe.tools.ts b/src/agents/pi-embedded-subscribe.tools.ts index 08a5e5f80c4..925f56fa6ee 100644 --- a/src/agents/pi-embedded-subscribe.tools.ts +++ b/src/agents/pi-embedded-subscribe.tools.ts @@ -142,6 +142,7 @@ const TRUSTED_TOOL_RESULT_MEDIA = new Set([ "exec", "gateway", "image", + "image_generate", "memory_get", "memory_search", "message", diff --git a/src/agents/system-prompt.ts b/src/agents/system-prompt.ts index 5f4ee932bd7..3ee438db2d4 100644 --- a/src/agents/system-prompt.ts +++ b/src/agents/system-prompt.ts @@ -268,6 +268,7 @@ export function buildAgentSystemPrompt(params: { session_status: "Show a /status-equivalent status card (usage + time + Reasoning/Verbose/Elevated); use for model-use questions (📊 session_status); optional per-session model override", image: "Analyze an image with the configured image model", + image_generate: "Generate images with the configured image-generation model", }; const toolOrder = [ @@ -295,6 +296,7 @@ export function buildAgentSystemPrompt(params: { "subagents", "session_status", "image", + "image_generate", ]; const rawToolNames = (params.toolNames ?? []).map((tool) => tool.trim()); diff --git a/src/agents/tool-catalog.test.ts b/src/agents/tool-catalog.test.ts index 120a744432c..2f7fa0fc5d6 100644 --- a/src/agents/tool-catalog.test.ts +++ b/src/agents/tool-catalog.test.ts @@ -7,5 +7,6 @@ describe("tool-catalog", () => { expect(policy).toBeDefined(); expect(policy!.allow).toContain("web_search"); expect(policy!.allow).toContain("web_fetch"); + expect(policy!.allow).toContain("image_generate"); }); }); diff --git a/src/agents/tool-catalog.ts b/src/agents/tool-catalog.ts index 445cdc5f10b..0d58c066928 100644 --- a/src/agents/tool-catalog.ts +++ b/src/agents/tool-catalog.ts @@ -233,6 +233,14 @@ const CORE_TOOL_DEFINITIONS: CoreToolDefinition[] = [ profiles: ["coding"], includeInOpenClawGroup: true, }, + { + id: "image_generate", + label: "image_generate", + description: "Image generation", + sectionId: "media", + profiles: ["coding"], + includeInOpenClawGroup: true, + }, { id: "tts", label: "tts", diff --git a/src/agents/tools/image-generate-tool.test.ts b/src/agents/tools/image-generate-tool.test.ts new file mode 100644 index 00000000000..97f324921e3 --- /dev/null +++ b/src/agents/tools/image-generate-tool.test.ts @@ -0,0 +1,280 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import * as imageGenerationRuntime from "../../image-generation/runtime.js"; +import * as imageOps from "../../media/image-ops.js"; +import * as mediaStore from "../../media/store.js"; +import * as webMedia from "../../plugin-sdk/web-media.js"; +import { createImageGenerateTool } from "./image-generate-tool.js"; + +describe("createImageGenerateTool", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + it("returns null when image-generation model is not configured", () => { + expect(createImageGenerateTool({ config: {} })).toBeNull(); + }); + + it("generates images and returns MEDIA paths", async () => { + const generateImage = vi.spyOn(imageGenerationRuntime, "generateImage").mockResolvedValue({ + provider: "openai", + model: "gpt-image-1", + attempts: [], + images: [ + { + buffer: Buffer.from("png-1"), + mimeType: "image/png", + fileName: "cat-one.png", + }, + { + buffer: Buffer.from("png-2"), + mimeType: "image/png", + fileName: "cat-two.png", + revisedPrompt: "A more cinematic cat", + }, + ], + }); + const saveMediaBuffer = vi.spyOn(mediaStore, "saveMediaBuffer"); + saveMediaBuffer.mockResolvedValueOnce({ + path: "/tmp/generated-1.png", + id: "generated-1.png", + size: 5, + contentType: "image/png", + }); + saveMediaBuffer.mockResolvedValueOnce({ + path: "/tmp/generated-2.png", + id: "generated-2.png", + size: 5, + contentType: "image/png", + }); + + const tool = createImageGenerateTool({ + config: { + agents: { + defaults: { + imageGenerationModel: { + primary: "openai/gpt-image-1", + }, + }, + }, + }, + agentDir: "/tmp/agent", + }); + + expect(tool).not.toBeNull(); + if (!tool) { + throw new Error("expected image_generate tool"); + } + + const result = await tool.execute("call-1", { + prompt: "A cat wearing sunglasses", + model: "openai/gpt-image-1", + count: 2, + size: "1024x1024", + }); + + expect(generateImage).toHaveBeenCalledWith( + expect.objectContaining({ + cfg: { + agents: { + defaults: { + imageGenerationModel: { + primary: "openai/gpt-image-1", + }, + }, + }, + }, + prompt: "A cat wearing sunglasses", + agentDir: "/tmp/agent", + modelOverride: "openai/gpt-image-1", + size: "1024x1024", + count: 2, + inputImages: [], + }), + ); + expect(saveMediaBuffer).toHaveBeenNthCalledWith( + 1, + Buffer.from("png-1"), + "image/png", + "tool-image-generation", + undefined, + "cat-one.png", + ); + expect(saveMediaBuffer).toHaveBeenNthCalledWith( + 2, + Buffer.from("png-2"), + "image/png", + "tool-image-generation", + undefined, + "cat-two.png", + ); + expect(result).toMatchObject({ + content: [ + { + type: "text", + text: expect.stringContaining("Generated 2 images with openai/gpt-image-1."), + }, + ], + details: { + provider: "openai", + model: "gpt-image-1", + count: 2, + paths: ["/tmp/generated-1.png", "/tmp/generated-2.png"], + revisedPrompts: ["A more cinematic cat"], + }, + }); + const text = (result.content?.[0] as { text: string } | undefined)?.text ?? ""; + expect(text).toContain("MEDIA:/tmp/generated-1.png"); + expect(text).toContain("MEDIA:/tmp/generated-2.png"); + }); + + it("rejects counts outside the supported range", async () => { + const tool = createImageGenerateTool({ + config: { + agents: { + defaults: { + imageGenerationModel: { + primary: "google/gemini-3.1-flash-image-preview", + }, + }, + }, + }, + }); + expect(tool).not.toBeNull(); + if (!tool) { + throw new Error("expected image_generate tool"); + } + + await expect(tool.execute("call-2", { prompt: "too many cats", count: 5 })).rejects.toThrow( + "count must be between 1 and 4", + ); + }); + + it("forwards reference images and inferred resolution for edit mode", async () => { + const generateImage = vi.spyOn(imageGenerationRuntime, "generateImage").mockResolvedValue({ + provider: "google", + model: "gemini-3-pro-image-preview", + attempts: [], + images: [ + { + buffer: Buffer.from("png-out"), + mimeType: "image/png", + fileName: "edited.png", + }, + ], + }); + vi.spyOn(webMedia, "loadWebMedia").mockResolvedValue({ + kind: "image", + buffer: Buffer.from("input-image"), + contentType: "image/png", + }); + vi.spyOn(imageOps, "getImageMetadata").mockResolvedValue({ + width: 3200, + height: 1800, + }); + vi.spyOn(mediaStore, "saveMediaBuffer").mockResolvedValue({ + path: "/tmp/edited.png", + id: "edited.png", + size: 7, + contentType: "image/png", + }); + + const tool = createImageGenerateTool({ + config: { + agents: { + defaults: { + imageGenerationModel: { + primary: "google/gemini-3-pro-image-preview", + }, + }, + }, + }, + workspaceDir: process.cwd(), + }); + + expect(tool).not.toBeNull(); + if (!tool) { + throw new Error("expected image_generate tool"); + } + + await tool.execute("call-edit", { + prompt: "Add a dramatic stormy sky but keep everything else identical.", + image: "./fixtures/reference.png", + }); + + expect(generateImage).toHaveBeenCalledWith( + expect.objectContaining({ + resolution: "4K", + inputImages: [ + expect.objectContaining({ + buffer: Buffer.from("input-image"), + mimeType: "image/png", + }), + ], + }), + ); + }); + + it("lists registered provider and model options", async () => { + vi.spyOn(imageGenerationRuntime, "listRuntimeImageGenerationProviders").mockReturnValue([ + { + id: "google", + defaultModel: "gemini-3.1-flash-image-preview", + models: ["gemini-3.1-flash-image-preview", "gemini-3-pro-image-preview"], + supportedResolutions: ["1K", "2K", "4K"], + supportsImageEditing: true, + generateImage: vi.fn(async () => { + throw new Error("not used"); + }), + }, + { + id: "openai", + defaultModel: "gpt-image-1", + models: ["gpt-image-1"], + supportedSizes: ["1024x1024", "1024x1536", "1536x1024"], + supportsImageEditing: false, + generateImage: vi.fn(async () => { + throw new Error("not used"); + }), + }, + ]); + + const tool = createImageGenerateTool({ + config: { + agents: { + defaults: { + imageGenerationModel: { + primary: "google/gemini-3.1-flash-image-preview", + }, + }, + }, + }, + }); + + expect(tool).not.toBeNull(); + if (!tool) { + throw new Error("expected image_generate tool"); + } + + const result = await tool.execute("call-list", { action: "list" }); + const text = (result.content?.[0] as { text: string } | undefined)?.text ?? ""; + + expect(text).toContain("google (default gemini-3.1-flash-image-preview)"); + expect(text).toContain("gemini-3.1-flash-image-preview"); + expect(text).toContain("gemini-3-pro-image-preview"); + expect(text).toContain("editing"); + expect(result).toMatchObject({ + details: { + providers: expect.arrayContaining([ + expect.objectContaining({ + id: "google", + defaultModel: "gemini-3.1-flash-image-preview", + models: expect.arrayContaining([ + "gemini-3.1-flash-image-preview", + "gemini-3-pro-image-preview", + ]), + }), + ]), + }, + }); + }); +}); diff --git a/src/agents/tools/image-generate-tool.ts b/src/agents/tools/image-generate-tool.ts new file mode 100644 index 00000000000..810bfe3ba6f --- /dev/null +++ b/src/agents/tools/image-generate-tool.ts @@ -0,0 +1,424 @@ +import { Type } from "@sinclair/typebox"; +import type { OpenClawConfig } from "../../config/config.js"; +import { loadConfig } from "../../config/config.js"; +import { + generateImage, + listRuntimeImageGenerationProviders, +} from "../../image-generation/runtime.js"; +import type { + ImageGenerationResolution, + ImageGenerationSourceImage, +} from "../../image-generation/types.js"; +import { getImageMetadata } from "../../media/image-ops.js"; +import { saveMediaBuffer } from "../../media/store.js"; +import { loadWebMedia } from "../../plugin-sdk/web-media.js"; +import { resolveUserPath } from "../../utils.js"; +import { ToolInputError, readNumberParam, readStringParam } from "./common.js"; +import { decodeDataUrl } from "./image-tool.helpers.js"; +import { resolveMediaToolLocalRoots } from "./media-tool-shared.js"; +import { + createSandboxBridgeReadFile, + resolveSandboxedBridgeMediaPath, + type AnyAgentTool, + type SandboxFsBridge, + type ToolFsPolicy, +} from "./tool-runtime.helpers.js"; + +const DEFAULT_COUNT = 1; +const MAX_COUNT = 4; +const MAX_INPUT_IMAGES = 4; +const DEFAULT_RESOLUTION: ImageGenerationResolution = "1K"; + +const ImageGenerateToolSchema = Type.Object({ + action: Type.Optional( + Type.String({ + description: + 'Optional action: "generate" (default) or "list" to inspect available providers/models.', + }), + ), + prompt: Type.Optional(Type.String({ description: "Image generation prompt." })), + image: Type.Optional( + Type.String({ + description: "Optional reference image path or URL for edit mode.", + }), + ), + images: Type.Optional( + Type.Array(Type.String(), { + description: `Optional reference images for edit mode (up to ${MAX_INPUT_IMAGES}).`, + }), + ), + model: Type.Optional( + Type.String({ description: "Optional provider/model override, e.g. openai/gpt-image-1." }), + ), + size: Type.Optional( + Type.String({ + description: + "Optional size hint like 1024x1024, 1536x1024, 1024x1536, 1024x1792, or 1792x1024.", + }), + ), + resolution: Type.Optional( + Type.String({ + description: + "Optional resolution hint: 1K, 2K, or 4K. Useful for Google edit/generation flows.", + }), + ), + count: Type.Optional( + Type.Number({ + description: `Optional number of images to request (1-${MAX_COUNT}).`, + minimum: 1, + maximum: MAX_COUNT, + }), + ), +}); + +function hasConfiguredImageGenerationModel(cfg: OpenClawConfig): boolean { + const configured = cfg.agents?.defaults?.imageGenerationModel; + if (typeof configured === "string") { + return configured.trim().length > 0; + } + if (configured?.primary?.trim()) { + return true; + } + return (configured?.fallbacks ?? []).some((entry) => entry.trim().length > 0); +} + +function resolveAction(args: Record): "generate" | "list" { + const raw = readStringParam(args, "action"); + if (!raw) { + return "generate"; + } + const normalized = raw.trim().toLowerCase(); + if (normalized === "generate" || normalized === "list") { + return normalized; + } + throw new ToolInputError('action must be "generate" or "list"'); +} + +function resolveRequestedCount(args: Record): number { + const count = readNumberParam(args, "count", { integer: true }); + if (count === undefined) { + return DEFAULT_COUNT; + } + if (count < 1 || count > MAX_COUNT) { + throw new ToolInputError(`count must be between 1 and ${MAX_COUNT}`); + } + return count; +} + +function normalizeResolution(raw: string | undefined): ImageGenerationResolution | undefined { + const normalized = raw?.trim().toUpperCase(); + if (!normalized) { + return undefined; + } + if (normalized === "1K" || normalized === "2K" || normalized === "4K") { + return normalized; + } + throw new ToolInputError("resolution must be one of 1K, 2K, or 4K"); +} + +function normalizeReferenceImages(args: Record): string[] { + const imageCandidates: string[] = []; + if (typeof args.image === "string") { + imageCandidates.push(args.image); + } + if (Array.isArray(args.images)) { + imageCandidates.push( + ...args.images.filter((value): value is string => typeof value === "string"), + ); + } + + const seen = new Set(); + const normalized: string[] = []; + for (const candidate of imageCandidates) { + const trimmed = candidate.trim(); + const dedupe = trimmed.startsWith("@") ? trimmed.slice(1).trim() : trimmed; + if (!dedupe || seen.has(dedupe)) { + continue; + } + seen.add(dedupe); + normalized.push(trimmed); + } + if (normalized.length > MAX_INPUT_IMAGES) { + throw new ToolInputError( + `Too many reference images: ${normalized.length} provided, maximum is ${MAX_INPUT_IMAGES}.`, + ); + } + return normalized; +} + +type ImageGenerateSandboxConfig = { + root: string; + bridge: SandboxFsBridge; +}; + +async function loadReferenceImages(params: { + imageInputs: string[]; + maxBytes?: number; + localRoots: string[]; + sandboxConfig: { root: string; bridge: SandboxFsBridge; workspaceOnly: boolean } | null; +}): Promise< + Array<{ + sourceImage: ImageGenerationSourceImage; + resolvedImage: string; + rewrittenFrom?: string; + }> +> { + const loaded: Array<{ + sourceImage: ImageGenerationSourceImage; + resolvedImage: string; + rewrittenFrom?: string; + }> = []; + + for (const imageRawInput of params.imageInputs) { + const trimmed = imageRawInput.trim(); + const imageRaw = trimmed.startsWith("@") ? trimmed.slice(1).trim() : trimmed; + if (!imageRaw) { + throw new ToolInputError("image required (empty string in array)"); + } + const looksLikeWindowsDrivePath = /^[a-zA-Z]:[\\/]/.test(imageRaw); + const hasScheme = /^[a-z][a-z0-9+.-]*:/i.test(imageRaw); + const isFileUrl = /^file:/i.test(imageRaw); + const isHttpUrl = /^https?:\/\//i.test(imageRaw); + const isDataUrl = /^data:/i.test(imageRaw); + if (hasScheme && !looksLikeWindowsDrivePath && !isFileUrl && !isHttpUrl && !isDataUrl) { + throw new ToolInputError( + `Unsupported image reference: ${imageRawInput}. Use a file path, a file:// URL, a data: URL, or an http(s) URL.`, + ); + } + if (params.sandboxConfig && isHttpUrl) { + throw new ToolInputError("Sandboxed image_generate does not allow remote URLs."); + } + + const resolvedImage = (() => { + if (params.sandboxConfig) { + return imageRaw; + } + if (imageRaw.startsWith("~")) { + return resolveUserPath(imageRaw); + } + return imageRaw; + })(); + + const resolvedPathInfo: { resolved: string; rewrittenFrom?: string } = isDataUrl + ? { resolved: "" } + : params.sandboxConfig + ? await resolveSandboxedBridgeMediaPath({ + sandbox: params.sandboxConfig, + mediaPath: resolvedImage, + inboundFallbackDir: "media/inbound", + }) + : { + resolved: resolvedImage.startsWith("file://") + ? resolvedImage.slice("file://".length) + : resolvedImage, + }; + const resolvedPath = isDataUrl ? null : resolvedPathInfo.resolved; + + const media = isDataUrl + ? decodeDataUrl(resolvedImage) + : params.sandboxConfig + ? await loadWebMedia(resolvedPath ?? resolvedImage, { + maxBytes: params.maxBytes, + sandboxValidated: true, + readFile: createSandboxBridgeReadFile({ sandbox: params.sandboxConfig }), + }) + : await loadWebMedia(resolvedPath ?? resolvedImage, { + maxBytes: params.maxBytes, + localRoots: params.localRoots, + }); + if (media.kind !== "image") { + throw new ToolInputError(`Unsupported media type: ${media.kind}`); + } + + const mimeType = + ("contentType" in media && media.contentType) || + ("mimeType" in media && media.mimeType) || + "image/png"; + + loaded.push({ + sourceImage: { + buffer: media.buffer, + mimeType, + }, + resolvedImage, + ...(resolvedPathInfo.rewrittenFrom ? { rewrittenFrom: resolvedPathInfo.rewrittenFrom } : {}), + }); + } + + return loaded; +} + +async function inferResolutionFromInputImages( + images: ImageGenerationSourceImage[], +): Promise { + let maxDimension = 0; + for (const image of images) { + const meta = await getImageMetadata(image.buffer); + const dimension = Math.max(meta?.width ?? 0, meta?.height ?? 0); + maxDimension = Math.max(maxDimension, dimension); + } + if (maxDimension >= 3000) { + return "4K"; + } + if (maxDimension >= 1500) { + return "2K"; + } + return DEFAULT_RESOLUTION; +} + +export function createImageGenerateTool(options?: { + config?: OpenClawConfig; + agentDir?: string; + workspaceDir?: string; + sandbox?: ImageGenerateSandboxConfig; + fsPolicy?: ToolFsPolicy; +}): AnyAgentTool | null { + const cfg = options?.config ?? loadConfig(); + if (!hasConfiguredImageGenerationModel(cfg)) { + return null; + } + const localRoots = resolveMediaToolLocalRoots(options?.workspaceDir, { + workspaceOnly: options?.fsPolicy?.workspaceOnly === true, + }); + const sandboxConfig = + options?.sandbox && options.sandbox.root.trim() + ? { + root: options.sandbox.root.trim(), + bridge: options.sandbox.bridge, + workspaceOnly: options.fsPolicy?.workspaceOnly === true, + } + : null; + + return { + label: "Image Generation", + name: "image_generate", + description: + 'Generate new images or edit reference images with the configured image-generation model. Use action="list" to inspect available providers/models. Generated images are delivered automatically from the tool result as MEDIA paths.', + parameters: ImageGenerateToolSchema, + execute: async (_toolCallId, args) => { + const params = args as Record; + const action = resolveAction(params); + if (action === "list") { + const providers = listRuntimeImageGenerationProviders({ config: cfg }).map((provider) => ({ + id: provider.id, + ...(provider.label ? { label: provider.label } : {}), + ...(provider.defaultModel ? { defaultModel: provider.defaultModel } : {}), + models: provider.models ?? (provider.defaultModel ? [provider.defaultModel] : []), + ...(provider.supportedSizes ? { supportedSizes: [...provider.supportedSizes] } : {}), + ...(provider.supportedResolutions + ? { supportedResolutions: [...provider.supportedResolutions] } + : {}), + ...(typeof provider.supportsImageEditing === "boolean" + ? { supportsImageEditing: provider.supportsImageEditing } + : {}), + })); + const lines = providers.flatMap((provider) => { + const caps: string[] = []; + if (provider.supportsImageEditing) { + caps.push("editing"); + } + if ((provider.supportedResolutions?.length ?? 0) > 0) { + caps.push(`resolutions ${provider.supportedResolutions?.join("/")}`); + } + if ((provider.supportedSizes?.length ?? 0) > 0) { + caps.push(`sizes ${provider.supportedSizes?.join(", ")}`); + } + const modelLine = + provider.models.length > 0 + ? `models: ${provider.models.join(", ")}` + : "models: unknown"; + return [ + `${provider.id}${provider.defaultModel ? ` (default ${provider.defaultModel})` : ""}`, + ` ${modelLine}`, + ...(caps.length > 0 ? [` capabilities: ${caps.join("; ")}`] : []), + ]; + }); + return { + content: [{ type: "text", text: lines.join("\n") }], + details: { providers }, + }; + } + + const prompt = readStringParam(params, "prompt", { required: true }); + const imageInputs = normalizeReferenceImages(params); + const model = readStringParam(params, "model"); + const size = readStringParam(params, "size"); + const explicitResolution = normalizeResolution(readStringParam(params, "resolution")); + const count = resolveRequestedCount(params); + const loadedReferenceImages = await loadReferenceImages({ + imageInputs, + localRoots, + sandboxConfig, + }); + const inputImages = loadedReferenceImages.map((entry) => entry.sourceImage); + const resolution = + explicitResolution ?? + (size + ? undefined + : inputImages.length > 0 + ? await inferResolutionFromInputImages(inputImages) + : undefined); + + const result = await generateImage({ + cfg, + prompt, + agentDir: options?.agentDir, + modelOverride: model, + size, + resolution, + count, + inputImages, + }); + + const savedImages = await Promise.all( + result.images.map((image) => + saveMediaBuffer( + image.buffer, + image.mimeType, + "tool-image-generation", + undefined, + image.fileName, + ), + ), + ); + + const revisedPrompts = result.images + .map((image) => image.revisedPrompt?.trim()) + .filter((entry): entry is string => Boolean(entry)); + const lines = [ + `Generated ${savedImages.length} image${savedImages.length === 1 ? "" : "s"} with ${result.provider}/${result.model}.`, + ...savedImages.map((image) => `MEDIA:${image.path}`), + ]; + + return { + content: [{ type: "text", text: lines.join("\n") }], + details: { + provider: result.provider, + model: result.model, + count: savedImages.length, + paths: savedImages.map((image) => image.path), + ...(imageInputs.length === 1 + ? { + image: loadedReferenceImages[0]?.resolvedImage, + ...(loadedReferenceImages[0]?.rewrittenFrom + ? { rewrittenFrom: loadedReferenceImages[0].rewrittenFrom } + : {}), + } + : imageInputs.length > 1 + ? { + images: loadedReferenceImages.map((entry) => ({ + image: entry.resolvedImage, + ...(entry.rewrittenFrom ? { rewrittenFrom: entry.rewrittenFrom } : {}), + })), + } + : {}), + ...(resolution ? { resolution } : {}), + ...(size ? { size } : {}), + attempts: result.attempts, + metadata: result.metadata, + ...(revisedPrompts.length > 0 ? { revisedPrompts } : {}), + }, + }; + }, + }; +} diff --git a/src/image-generation/providers/google.test.ts b/src/image-generation/providers/google.test.ts index 83f7e565a80..224779f3429 100644 --- a/src/image-generation/providers/google.test.ts +++ b/src/image-generation/providers/google.test.ts @@ -131,4 +131,78 @@ describe("Google image-generation provider", () => { model: "gemini-3.1-flash-image-preview", }); }); + + it("sends reference images and explicit resolution for edit flows", async () => { + vi.spyOn(modelAuth, "resolveApiKeyForProvider").mockResolvedValue({ + apiKey: "google-test-key", + source: "env", + mode: "api-key", + }); + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + candidates: [ + { + content: { + parts: [ + { + inlineData: { + mimeType: "image/png", + data: Buffer.from("png-data").toString("base64"), + }, + }, + ], + }, + }, + ], + }), + }); + vi.stubGlobal("fetch", fetchMock); + + const provider = buildGoogleImageGenerationProvider(); + await provider.generateImage({ + provider: "google", + model: "gemini-3-pro-image-preview", + prompt: "Change only the sky to a sunset.", + cfg: {}, + resolution: "4K", + inputImages: [ + { + buffer: Buffer.from("reference-bytes"), + mimeType: "image/png", + fileName: "reference.png", + }, + ], + }); + + expect(fetchMock).toHaveBeenCalledWith( + "https://generativelanguage.googleapis.com/v1beta/models/gemini-3-pro-image-preview:generateContent", + expect.objectContaining({ + method: "POST", + body: JSON.stringify({ + contents: [ + { + role: "user", + parts: [ + { + inlineData: { + mimeType: "image/png", + data: Buffer.from("reference-bytes").toString("base64"), + }, + }, + { text: "Change only the sky to a sunset." }, + ], + }, + ], + generationConfig: { + responseModalities: ["TEXT", "IMAGE"], + imageConfig: { + aspectRatio: "1:1", + imageSize: "4K", + }, + }, + }), + }), + ); + }); }); diff --git a/src/image-generation/providers/google.ts b/src/image-generation/providers/google.ts index 0519aef7bc3..f7469b147fa 100644 --- a/src/image-generation/providers/google.ts +++ b/src/image-generation/providers/google.ts @@ -79,11 +79,16 @@ export function buildGoogleImageGenerationProvider(): ImageGenerationProviderPlu return { id: "google", label: "Google", + defaultModel: DEFAULT_GOOGLE_IMAGE_MODEL, + models: [DEFAULT_GOOGLE_IMAGE_MODEL, "gemini-3-pro-image-preview"], + supportedResolutions: ["1K", "2K", "4K"], + supportsImageEditing: true, async generateImage(req) { const auth = await resolveApiKeyForProvider({ provider: "google", cfg: req.cfg, agentDir: req.agentDir, + store: req.authStore, }); if (!auth.apiKey) { throw new Error("Google API key missing"); @@ -98,6 +103,16 @@ export function buildGoogleImageGenerationProvider(): ImageGenerationProviderPlu const authHeaders = parseGeminiAuth(auth.apiKey); const headers = new Headers(authHeaders.headers); const imageConfig = mapSizeToImageConfig(req.size); + const inputParts = (req.inputImages ?? []).map((image) => ({ + inlineData: { + mimeType: image.mimeType, + data: image.buffer.toString("base64"), + }, + })); + const resolvedImageConfig = { + ...imageConfig, + ...(req.resolution ? { imageSize: req.resolution } : {}), + }; const { response: res, release } = await postJsonRequest({ url: `${baseUrl}/models/${model}:generateContent`, @@ -106,12 +121,14 @@ export function buildGoogleImageGenerationProvider(): ImageGenerationProviderPlu contents: [ { role: "user", - parts: [{ text: req.prompt }], + parts: [...inputParts, { text: req.prompt }], }, ], generationConfig: { responseModalities: ["TEXT", "IMAGE"], - ...(imageConfig ? { imageConfig } : {}), + ...(Object.keys(resolvedImageConfig).length > 0 + ? { imageConfig: resolvedImageConfig } + : {}), }, }, timeoutMs: 60_000, diff --git a/src/image-generation/providers/openai.test.ts b/src/image-generation/providers/openai.test.ts index a55e6107d3b..a128d6c6e04 100644 --- a/src/image-generation/providers/openai.test.ts +++ b/src/image-generation/providers/openai.test.ts @@ -8,7 +8,7 @@ describe("OpenAI image-generation provider", () => { }); it("generates PNG buffers from the OpenAI Images API", async () => { - vi.spyOn(modelAuth, "resolveApiKeyForProvider").mockResolvedValue({ + const resolveApiKeySpy = vi.spyOn(modelAuth, "resolveApiKeyForProvider").mockResolvedValue({ apiKey: "sk-test", source: "env", mode: "api-key", @@ -27,17 +27,31 @@ describe("OpenAI image-generation provider", () => { vi.stubGlobal("fetch", fetchMock); const provider = buildOpenAIImageGenerationProvider(); + const authStore = { version: 1, profiles: {} }; const result = await provider.generateImage({ provider: "openai", model: "gpt-image-1", prompt: "draw a cat", cfg: {}, + authStore, }); + expect(resolveApiKeySpy).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "openai", + store: authStore, + }), + ); expect(fetchMock).toHaveBeenCalledWith( "https://api.openai.com/v1/images/generations", expect.objectContaining({ method: "POST", + body: JSON.stringify({ + model: "gpt-image-1", + prompt: "draw a cat", + n: 1, + size: "1024x1024", + }), }), ); expect(result).toEqual({ @@ -52,4 +66,18 @@ describe("OpenAI image-generation provider", () => { model: "gpt-image-1", }); }); + + it("rejects reference-image edits for now", async () => { + const provider = buildOpenAIImageGenerationProvider(); + + await expect( + provider.generateImage({ + provider: "openai", + model: "gpt-image-1", + prompt: "Edit this image", + cfg: {}, + inputImages: [{ buffer: Buffer.from("x"), mimeType: "image/png" }], + }), + ).rejects.toThrow("does not support reference-image edits"); + }); }); diff --git a/src/image-generation/providers/openai.ts b/src/image-generation/providers/openai.ts index 0c7788fb5d5..1a0afe1f67d 100644 --- a/src/image-generation/providers/openai.ts +++ b/src/image-generation/providers/openai.ts @@ -22,12 +22,18 @@ export function buildOpenAIImageGenerationProvider(): ImageGenerationProviderPlu return { id: "openai", label: "OpenAI", + defaultModel: DEFAULT_OPENAI_IMAGE_MODEL, + models: [DEFAULT_OPENAI_IMAGE_MODEL], supportedSizes: ["1024x1024", "1024x1536", "1536x1024"], async generateImage(req) { + if ((req.inputImages?.length ?? 0) > 0) { + throw new Error("OpenAI image generation provider does not support reference-image edits"); + } const auth = await resolveApiKeyForProvider({ provider: "openai", cfg: req.cfg, agentDir: req.agentDir, + store: req.authStore, }); if (!auth.apiKey) { throw new Error("OpenAI API key missing"); @@ -44,7 +50,6 @@ export function buildOpenAIImageGenerationProvider(): ImageGenerationProviderPlu prompt: req.prompt, n: req.count ?? 1, size: req.size ?? DEFAULT_SIZE, - response_format: "b64_json", }), }); diff --git a/src/image-generation/runtime.test.ts b/src/image-generation/runtime.test.ts index 4ef478b3349..b044c899c60 100644 --- a/src/image-generation/runtime.test.ts +++ b/src/image-generation/runtime.test.ts @@ -11,13 +11,16 @@ describe("image-generation runtime helpers", () => { it("generates images through the active image-generation registry", async () => { const pluginRegistry = createEmptyPluginRegistry(); + const authStore = { version: 1, profiles: {} } as const; + let seenAuthStore: unknown; pluginRegistry.imageGenerationProviders.push({ pluginId: "image-plugin", pluginName: "Image Plugin", source: "test", provider: { id: "image-plugin", - async generateImage() { + async generateImage(req) { + seenAuthStore = req.authStore; return { images: [ { @@ -47,11 +50,13 @@ describe("image-generation runtime helpers", () => { cfg, prompt: "draw a cat", agentDir: "/tmp/agent", + authStore, }); expect(result.provider).toBe("image-plugin"); expect(result.model).toBe("img-v1"); expect(result.attempts).toEqual([]); + expect(seenAuthStore).toEqual(authStore); expect(result.images).toEqual([ { buffer: Buffer.from("png-bytes"), @@ -69,6 +74,9 @@ describe("image-generation runtime helpers", () => { source: "test", provider: { id: "image-plugin", + defaultModel: "img-v1", + models: ["img-v1", "img-v2"], + supportedResolutions: ["1K", "2K"], generateImage: async () => ({ images: [{ buffer: Buffer.from("x"), mimeType: "image/png" }], }), @@ -76,6 +84,13 @@ describe("image-generation runtime helpers", () => { }); setActivePluginRegistry(pluginRegistry); - expect(listRuntimeImageGenerationProviders()).toMatchObject([{ id: "image-plugin" }]); + expect(listRuntimeImageGenerationProviders()).toMatchObject([ + { + id: "image-plugin", + defaultModel: "img-v1", + models: ["img-v1", "img-v2"], + supportedResolutions: ["1K", "2K"], + }, + ]); }); }); diff --git a/src/image-generation/runtime.ts b/src/image-generation/runtime.ts index 8c9104edd5d..f25048cd0b1 100644 --- a/src/image-generation/runtime.ts +++ b/src/image-generation/runtime.ts @@ -1,3 +1,4 @@ +import type { AuthProfileStore } from "../agents/auth-profiles.js"; import { describeFailoverError, isFailoverError } from "../agents/failover-error.js"; import type { FallbackAttempt } from "../agents/model-fallback.types.js"; import type { OpenClawConfig } from "../config/config.js"; @@ -7,7 +8,12 @@ import { } from "../config/model-input.js"; import { createSubsystemLogger } from "../logging/subsystem.js"; import { getImageGenerationProvider, listImageGenerationProviders } from "./provider-registry.js"; -import type { GeneratedImageAsset, ImageGenerationResult } from "./types.js"; +import type { + GeneratedImageAsset, + ImageGenerationResolution, + ImageGenerationResult, + ImageGenerationSourceImage, +} from "./types.js"; const log = createSubsystemLogger("image-generation"); @@ -15,9 +21,12 @@ export type GenerateImageParams = { cfg: OpenClawConfig; prompt: string; agentDir?: string; + authStore?: AuthProfileStore; modelOverride?: string; count?: number; size?: string; + resolution?: ImageGenerationResolution; + inputImages?: ImageGenerationSourceImage[]; }; export type GenerateImageRuntimeResult = { @@ -130,8 +139,11 @@ export async function generateImage( prompt: params.prompt, cfg: params.cfg, agentDir: params.agentDir, + authStore: params.authStore, count: params.count, size: params.size, + resolution: params.resolution, + inputImages: params.inputImages, }); if (!Array.isArray(result.images) || result.images.length === 0) { throw new Error("Image generation provider returned no images."); diff --git a/src/image-generation/types.ts b/src/image-generation/types.ts index ff33d6079ee..7ea530ac2b9 100644 --- a/src/image-generation/types.ts +++ b/src/image-generation/types.ts @@ -1,3 +1,4 @@ +import type { AuthProfileStore } from "../agents/auth-profiles.js"; import type { OpenClawConfig } from "../config/config.js"; export type GeneratedImageAsset = { @@ -8,14 +9,26 @@ export type GeneratedImageAsset = { metadata?: Record; }; +export type ImageGenerationResolution = "1K" | "2K" | "4K"; + +export type ImageGenerationSourceImage = { + buffer: Buffer; + mimeType: string; + fileName?: string; + metadata?: Record; +}; + export type ImageGenerationRequest = { provider: string; model: string; prompt: string; cfg: OpenClawConfig; agentDir?: string; + authStore?: AuthProfileStore; count?: number; size?: string; + resolution?: ImageGenerationResolution; + inputImages?: ImageGenerationSourceImage[]; }; export type ImageGenerationResult = { @@ -28,6 +41,10 @@ export type ImageGenerationProvider = { id: string; aliases?: string[]; label?: string; + defaultModel?: string; + models?: string[]; supportedSizes?: string[]; + supportedResolutions?: ImageGenerationResolution[]; + supportsImageEditing?: boolean; generateImage: (req: ImageGenerationRequest) => Promise; }; diff --git a/src/plugin-sdk/image-generation.ts b/src/plugin-sdk/image-generation.ts index d9afa8b3a3d..25fde2e9d2b 100644 --- a/src/plugin-sdk/image-generation.ts +++ b/src/plugin-sdk/image-generation.ts @@ -3,8 +3,10 @@ export type { GeneratedImageAsset, ImageGenerationProvider, + ImageGenerationResolution, ImageGenerationRequest, ImageGenerationResult, + ImageGenerationSourceImage, } from "../image-generation/types.js"; export { buildGoogleImageGenerationProvider } from "../image-generation/providers/google.js";