feat(agents): infer image generation defaults

This commit is contained in:
Peter Steinberger 2026-03-17 09:23:15 -07:00
parent 9f8cf7f71a
commit 0aff1c7630
No known key found for this signature in database
7 changed files with 308 additions and 141 deletions

View File

@ -1,5 +1,6 @@
import { describe, expect, it, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import type { OpenClawConfig } from "../config/config.js";
import * as imageGenerationRuntime from "../image-generation/runtime.js";
import { createOpenClawTools } from "./openclaw-tools.js";
vi.mock("../plugins/tools.js", () => ({
@ -10,7 +11,33 @@ function asConfig(value: unknown): OpenClawConfig {
return value as OpenClawConfig;
}
function stubImageGenerationProviders() {
vi.spyOn(imageGenerationRuntime, "listRuntimeImageGenerationProviders").mockReturnValue([
{
id: "openai",
defaultModel: "gpt-image-1",
models: ["gpt-image-1"],
supportedSizes: ["1024x1024"],
generateImage: vi.fn(async () => {
throw new Error("not used");
}),
},
]);
}
describe("openclaw tools image generation registration", () => {
beforeEach(() => {
vi.stubEnv("OPENAI_API_KEY", "");
vi.stubEnv("OPENAI_API_KEYS", "");
vi.stubEnv("GEMINI_API_KEY", "");
vi.stubEnv("GEMINI_API_KEYS", "");
});
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllEnvs();
});
it("registers image_generate when image-generation config is present", () => {
const tools = createOpenClawTools({
config: asConfig({
@ -28,7 +55,21 @@ describe("openclaw tools image generation registration", () => {
expect(tools.map((tool) => tool.name)).toContain("image_generate");
});
it("omits image_generate when image-generation config is absent", () => {
it("registers image_generate when a compatible provider has env-backed auth", () => {
stubImageGenerationProviders();
vi.stubEnv("OPENAI_API_KEY", "openai-test");
const tools = createOpenClawTools({
config: asConfig({}),
agentDir: "/tmp/openclaw-agent-main",
});
expect(tools.map((tool) => tool.name)).toContain("image_generate");
});
it("omits image_generate when config is absent and no compatible provider auth exists", () => {
stubImageGenerationProviders();
const tools = createOpenClawTools({
config: asConfig({}),
agentDir: "/tmp/openclaw-agent-main",

View File

@ -1,19 +1,89 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as imageGenerationRuntime from "../../image-generation/runtime.js";
import * as imageOps from "../../media/image-ops.js";
import * as mediaStore from "../../media/store.js";
import * as webMedia from "../../plugin-sdk/web-media.js";
import { createImageGenerateTool } from "./image-generate-tool.js";
import {
createImageGenerateTool,
resolveImageGenerationModelConfigForTool,
} from "./image-generate-tool.js";
function stubImageGenerationProviders() {
vi.spyOn(imageGenerationRuntime, "listRuntimeImageGenerationProviders").mockReturnValue([
{
id: "google",
defaultModel: "gemini-3.1-flash-image-preview",
models: ["gemini-3.1-flash-image-preview", "gemini-3-pro-image-preview"],
supportedResolutions: ["1K", "2K", "4K"],
supportsImageEditing: true,
generateImage: vi.fn(async () => {
throw new Error("not used");
}),
},
{
id: "openai",
defaultModel: "gpt-image-1",
models: ["gpt-image-1"],
supportedSizes: ["1024x1024", "1024x1536", "1536x1024"],
supportsImageEditing: false,
generateImage: vi.fn(async () => {
throw new Error("not used");
}),
},
]);
}
describe("createImageGenerateTool", () => {
afterEach(() => {
vi.restoreAllMocks();
beforeEach(() => {
vi.stubEnv("OPENAI_API_KEY", "");
vi.stubEnv("OPENAI_API_KEYS", "");
vi.stubEnv("GEMINI_API_KEY", "");
vi.stubEnv("GEMINI_API_KEYS", "");
});
it("returns null when image-generation model is not configured", () => {
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllEnvs();
});
it("returns null when no image-generation model can be inferred", () => {
stubImageGenerationProviders();
expect(createImageGenerateTool({ config: {} })).toBeNull();
});
it("infers an OpenAI image-generation model from env-backed auth", () => {
stubImageGenerationProviders();
vi.stubEnv("OPENAI_API_KEY", "openai-test");
expect(resolveImageGenerationModelConfigForTool({ cfg: {} })).toEqual({
primary: "openai/gpt-image-1",
});
expect(createImageGenerateTool({ config: {} })).not.toBeNull();
});
it("prefers the primary model provider when multiple image providers have auth", () => {
stubImageGenerationProviders();
vi.stubEnv("OPENAI_API_KEY", "openai-test");
vi.stubEnv("GEMINI_API_KEY", "gemini-test");
expect(
resolveImageGenerationModelConfigForTool({
cfg: {
agents: {
defaults: {
model: {
primary: "google/gemini-3.1-pro-preview",
},
},
},
},
}),
).toEqual({
primary: "google/gemini-3.1-flash-image-preview",
fallbacks: ["openai/gpt-image-1"],
});
});
it("generates images and returns MEDIA paths", async () => {
const generateImage = vi.spyOn(imageGenerationRuntime, "generateImage").mockResolvedValue({
provider: "openai",
@ -215,28 +285,7 @@ describe("createImageGenerateTool", () => {
});
it("lists registered provider and model options", async () => {
vi.spyOn(imageGenerationRuntime, "listRuntimeImageGenerationProviders").mockReturnValue([
{
id: "google",
defaultModel: "gemini-3.1-flash-image-preview",
models: ["gemini-3.1-flash-image-preview", "gemini-3-pro-image-preview"],
supportedResolutions: ["1K", "2K", "4K"],
supportsImageEditing: true,
generateImage: vi.fn(async () => {
throw new Error("not used");
}),
},
{
id: "openai",
defaultModel: "gpt-image-1",
models: ["gpt-image-1"],
supportedSizes: ["1024x1024", "1024x1536", "1536x1024"],
supportsImageEditing: false,
generateImage: vi.fn(async () => {
throw new Error("not used");
}),
},
]);
stubImageGenerationProviders();
const tool = createImageGenerateTool({
config: {

View File

@ -15,7 +15,17 @@ import { loadWebMedia } from "../../plugin-sdk/web-media.js";
import { resolveUserPath } from "../../utils.js";
import { ToolInputError, readNumberParam, readStringParam } from "./common.js";
import { decodeDataUrl } from "./image-tool.helpers.js";
import { resolveMediaToolLocalRoots } from "./media-tool-shared.js";
import {
applyImageGenerationModelConfigDefaults,
resolveMediaToolLocalRoots,
} from "./media-tool-shared.js";
import {
buildToolModelConfigFromCandidates,
coerceToolModelConfig,
hasToolModelConfig,
resolveDefaultModelRef,
type ToolModelConfig,
} from "./model-config.helpers.js";
import {
createSandboxBridgeReadFile,
resolveSandboxedBridgeMediaPath,
@ -71,15 +81,51 @@ const ImageGenerateToolSchema = Type.Object({
),
});
function hasConfiguredImageGenerationModel(cfg: OpenClawConfig): boolean {
const configured = cfg.agents?.defaults?.imageGenerationModel;
if (typeof configured === "string") {
return configured.trim().length > 0;
function resolveImageGenerationModelCandidates(
cfg: OpenClawConfig | undefined,
): Array<string | undefined> {
const providerDefaults = new Map<string, string>();
for (const provider of listRuntimeImageGenerationProviders({ config: cfg })) {
const providerId = provider.id.trim();
const modelId = provider.defaultModel?.trim();
if (!providerId || !modelId || providerDefaults.has(providerId)) {
continue;
}
providerDefaults.set(providerId, `${providerId}/${modelId}`);
}
if (configured?.primary?.trim()) {
return true;
const orderedProviders = [
resolveDefaultModelRef(cfg).provider,
"openai",
"google",
...providerDefaults.keys(),
];
const orderedRefs: string[] = [];
const seen = new Set<string>();
for (const providerId of orderedProviders) {
const ref = providerDefaults.get(providerId);
if (!ref || seen.has(ref)) {
continue;
}
seen.add(ref);
orderedRefs.push(ref);
}
return (configured?.fallbacks ?? []).some((entry) => entry.trim().length > 0);
return orderedRefs;
}
export function resolveImageGenerationModelConfigForTool(params: {
cfg?: OpenClawConfig;
agentDir?: string;
}): ToolModelConfig | null {
const explicit = coerceToolModelConfig(params.cfg?.agents?.defaults?.imageGenerationModel);
if (hasToolModelConfig(explicit)) {
return explicit;
}
return buildToolModelConfigFromCandidates({
explicit,
agentDir: params.agentDir,
candidates: resolveImageGenerationModelCandidates(params.cfg),
});
}
function resolveAction(args: Record<string, unknown>): "generate" | "list" {
@ -274,9 +320,15 @@ export function createImageGenerateTool(options?: {
fsPolicy?: ToolFsPolicy;
}): AnyAgentTool | null {
const cfg = options?.config ?? loadConfig();
if (!hasConfiguredImageGenerationModel(cfg)) {
const imageGenerationModelConfig = resolveImageGenerationModelConfigForTool({
cfg,
agentDir: options?.agentDir,
});
if (!imageGenerationModelConfig) {
return null;
}
const effectiveCfg =
applyImageGenerationModelConfigDefaults(cfg, imageGenerationModelConfig) ?? cfg;
const localRoots = resolveMediaToolLocalRoots(options?.workspaceDir, {
workspaceOnly: options?.fsPolicy?.workspaceOnly === true,
});
@ -293,25 +345,27 @@ export function createImageGenerateTool(options?: {
label: "Image Generation",
name: "image_generate",
description:
'Generate new images or edit reference images with the configured image-generation model. Use action="list" to inspect available providers/models. Generated images are delivered automatically from the tool result as MEDIA paths.',
'Generate new images or edit reference images with the configured or inferred image-generation model. Use action="list" to inspect available providers/models. Generated images are delivered automatically from the tool result as MEDIA paths.',
parameters: ImageGenerateToolSchema,
execute: async (_toolCallId, args) => {
const params = args as Record<string, unknown>;
const action = resolveAction(params);
if (action === "list") {
const providers = listRuntimeImageGenerationProviders({ config: cfg }).map((provider) => ({
id: provider.id,
...(provider.label ? { label: provider.label } : {}),
...(provider.defaultModel ? { defaultModel: provider.defaultModel } : {}),
models: provider.models ?? (provider.defaultModel ? [provider.defaultModel] : []),
...(provider.supportedSizes ? { supportedSizes: [...provider.supportedSizes] } : {}),
...(provider.supportedResolutions
? { supportedResolutions: [...provider.supportedResolutions] }
: {}),
...(typeof provider.supportsImageEditing === "boolean"
? { supportsImageEditing: provider.supportsImageEditing }
: {}),
}));
const providers = listRuntimeImageGenerationProviders({ config: effectiveCfg }).map(
(provider) => ({
id: provider.id,
...(provider.label ? { label: provider.label } : {}),
...(provider.defaultModel ? { defaultModel: provider.defaultModel } : {}),
models: provider.models ?? (provider.defaultModel ? [provider.defaultModel] : []),
...(provider.supportedSizes ? { supportedSizes: [...provider.supportedSizes] } : {}),
...(provider.supportedResolutions
? { supportedResolutions: [...provider.supportedResolutions] }
: {}),
...(typeof provider.supportsImageEditing === "boolean"
? { supportsImageEditing: provider.supportsImageEditing }
: {}),
}),
);
const lines = providers.flatMap((provider) => {
const caps: string[] = [];
if (provider.supportsImageEditing) {
@ -360,7 +414,7 @@ export function createImageGenerateTool(options?: {
: undefined);
const result = await generateImage({
cfg,
cfg: effectiveCfg,
prompt,
agentDir: options?.agentDir,
modelOverride: model,

View File

@ -1,12 +1,9 @@
import type { AssistantMessage } from "@mariozechner/pi-ai";
import type { OpenClawConfig } from "../../config/config.js";
import {
resolveAgentModelFallbackValues,
resolveAgentModelPrimaryValue,
} from "../../config/model-input.js";
import { extractAssistantText } from "../pi-embedded-utils.js";
import { coerceToolModelConfig, type ToolModelConfig } from "./model-config.helpers.js";
export type ImageModelConfig = { primary?: string; fallbacks?: string[] };
export type ImageModelConfig = ToolModelConfig;
export function decodeDataUrl(dataUrl: string): {
buffer: Buffer;
@ -55,12 +52,7 @@ export function coerceImageAssistantText(params: {
}
export function coerceImageModelConfig(cfg?: OpenClawConfig): ImageModelConfig {
const primary = resolveAgentModelPrimaryValue(cfg?.agents?.defaults?.imageModel);
const fallbacks = resolveAgentModelFallbackValues(cfg?.agents?.defaults?.imageModel);
return {
...(primary?.trim() ? { primary: primary.trim() } : {}),
...(fallbacks.length > 0 ? { fallbacks } : {}),
};
return coerceToolModelConfig(cfg?.agents?.defaults?.imageModel);
}
export function resolveProviderVisionModelFromConfig(params: {

View File

@ -18,7 +18,11 @@ import {
resolveMediaToolLocalRoots,
resolvePromptAndModelOverride,
} from "./media-tool-shared.js";
import { hasAuthForProvider, resolveDefaultModelRef } from "./model-config.helpers.js";
import {
buildToolModelConfigFromCandidates,
hasToolModelConfig,
resolveDefaultModelRef,
} from "./model-config.helpers.js";
import {
createSandboxBridgeReadFile,
resolveSandboxedBridgeMediaPath,
@ -68,89 +72,40 @@ export function resolveImageModelConfigForTool(params: {
// because images are auto-injected into prompts (see attempt.ts detectAndLoadPromptImages).
// The tool description is adjusted via modelHasVision to discourage redundant usage.
const explicit = coerceImageModelConfig(params.cfg);
if (explicit.primary?.trim() || (explicit.fallbacks?.length ?? 0) > 0) {
if (hasToolModelConfig(explicit)) {
return explicit;
}
const primary = resolveDefaultModelRef(params.cfg);
const openaiOk = hasAuthForProvider({
provider: "openai",
agentDir: params.agentDir,
});
const anthropicOk = hasAuthForProvider({
provider: "anthropic",
agentDir: params.agentDir,
});
const fallbacks: string[] = [];
const addFallback = (modelRef: string | null) => {
const ref = (modelRef ?? "").trim();
if (!ref) {
return;
}
if (fallbacks.includes(ref)) {
return;
}
fallbacks.push(ref);
};
const providerVisionFromConfig = resolveProviderVisionModelFromConfig({
cfg: params.cfg,
provider: primary.provider,
});
const providerOk = hasAuthForProvider({
provider: primary.provider,
const primaryCandidates = (() => {
if (isMinimaxVlmProvider(primary.provider)) {
return [`${primary.provider}/MiniMax-VL-01`];
}
if (providerVisionFromConfig) {
return [providerVisionFromConfig];
}
if (primary.provider === "zai") {
return ["zai/glm-4.6v"];
}
if (primary.provider === "openai") {
return ["openai/gpt-5-mini"];
}
if (primary.provider === "anthropic") {
return [ANTHROPIC_IMAGE_PRIMARY];
}
return [];
})();
return buildToolModelConfigFromCandidates({
explicit,
agentDir: params.agentDir,
candidates: [...primaryCandidates, "openai/gpt-5-mini", ANTHROPIC_IMAGE_FALLBACK],
});
let preferred: string | null = null;
// MiniMax users: always try the canonical vision model first when auth exists.
if (isMinimaxVlmProvider(primary.provider) && providerOk) {
preferred = `${primary.provider}/MiniMax-VL-01`;
} else if (providerOk && providerVisionFromConfig) {
preferred = providerVisionFromConfig;
} else if (primary.provider === "zai" && providerOk) {
preferred = "zai/glm-4.6v";
} else if (primary.provider === "openai" && openaiOk) {
preferred = "openai/gpt-5-mini";
} else if (primary.provider === "anthropic" && anthropicOk) {
preferred = ANTHROPIC_IMAGE_PRIMARY;
}
if (preferred?.trim()) {
if (openaiOk) {
addFallback("openai/gpt-5-mini");
}
if (anthropicOk) {
addFallback(ANTHROPIC_IMAGE_FALLBACK);
}
// Don't duplicate primary in fallbacks.
const pruned = fallbacks.filter((ref) => ref !== preferred);
return {
primary: preferred,
...(pruned.length > 0 ? { fallbacks: pruned } : {}),
};
}
// Cross-provider fallback when we can't pair with the primary provider.
if (openaiOk) {
if (anthropicOk) {
addFallback(ANTHROPIC_IMAGE_FALLBACK);
}
return {
primary: "openai/gpt-5-mini",
...(fallbacks.length ? { fallbacks } : {}),
};
}
if (anthropicOk) {
return {
primary: ANTHROPIC_IMAGE_PRIMARY,
fallbacks: [ANTHROPIC_IMAGE_FALLBACK],
};
}
return null;
}
function pickMaxBytes(cfg?: OpenClawConfig, maxBytesMb?: number): number | undefined {
@ -279,7 +234,7 @@ export function createImageTool(options?: {
const agentDir = options?.agentDir?.trim();
if (!agentDir) {
const explicit = coerceImageModelConfig(options?.config);
if (explicit.primary?.trim() || (explicit.fallbacks?.length ?? 0) > 0) {
if (hasToolModelConfig(explicit)) {
throw new Error("createImageTool requires agentDir when enabled");
}
return null;

View File

@ -2,6 +2,7 @@ import { type Api, type Model } from "@mariozechner/pi-ai";
import type { OpenClawConfig } from "../../config/config.js";
import { getDefaultLocalRoots } from "../../plugin-sdk/web-media.js";
import type { ImageModelConfig } from "./image-tool.helpers.js";
import type { ToolModelConfig } from "./model-config.helpers.js";
import { getApiKeyForModel, normalizeWorkspaceDir, requireApiKey } from "./tool-runtime.helpers.js";
type TextToolAttempt = {
@ -20,6 +21,21 @@ type TextToolResult = {
export function applyImageModelConfigDefaults(
cfg: OpenClawConfig | undefined,
imageModelConfig: ImageModelConfig,
): OpenClawConfig | undefined {
return applyAgentDefaultModelConfig(cfg, "imageModel", imageModelConfig);
}
export function applyImageGenerationModelConfigDefaults(
cfg: OpenClawConfig | undefined,
imageGenerationModelConfig: ToolModelConfig,
): OpenClawConfig | undefined {
return applyAgentDefaultModelConfig(cfg, "imageGenerationModel", imageGenerationModelConfig);
}
function applyAgentDefaultModelConfig(
cfg: OpenClawConfig | undefined,
key: "imageModel" | "imageGenerationModel",
modelConfig: ToolModelConfig,
): OpenClawConfig | undefined {
if (!cfg) {
return undefined;
@ -30,7 +46,7 @@ export function applyImageModelConfigDefaults(
...cfg.agents,
defaults: {
...cfg.agents?.defaults,
imageModel: imageModelConfig,
[key]: modelConfig,
},
},
};

View File

@ -1,9 +1,22 @@
import type { OpenClawConfig } from "../../config/config.js";
import {
resolveAgentModelFallbackValues,
resolveAgentModelPrimaryValue,
} from "../../config/model-input.js";
import type { AgentModelConfig } from "../../config/types.agents-shared.js";
import { ensureAuthProfileStore, listProfilesForProvider } from "../auth-profiles.js";
import { DEFAULT_MODEL, DEFAULT_PROVIDER } from "../defaults.js";
import { resolveEnvApiKey } from "../model-auth.js";
import { resolveConfiguredModelRef } from "../model-selection.js";
export type ToolModelConfig = { primary?: string; fallbacks?: string[] };
export function hasToolModelConfig(model: ToolModelConfig | undefined): boolean {
return Boolean(
model?.primary?.trim() || (model?.fallbacks ?? []).some((entry) => entry.trim().length > 0),
);
}
export function resolveDefaultModelRef(cfg?: OpenClawConfig): { provider: string; model: string } {
if (cfg) {
const resolved = resolveConfiguredModelRef({
@ -16,12 +29,59 @@ export function resolveDefaultModelRef(cfg?: OpenClawConfig): { provider: string
return { provider: DEFAULT_PROVIDER, model: DEFAULT_MODEL };
}
export function hasAuthForProvider(params: { provider: string; agentDir: string }): boolean {
export function hasAuthForProvider(params: { provider: string; agentDir?: string }): boolean {
if (resolveEnvApiKey(params.provider)?.apiKey) {
return true;
}
const store = ensureAuthProfileStore(params.agentDir, {
const agentDir = params.agentDir?.trim();
if (!agentDir) {
return false;
}
const store = ensureAuthProfileStore(agentDir, {
allowKeychainPrompt: false,
});
return listProfilesForProvider(store, params.provider).length > 0;
}
export function coerceToolModelConfig(model?: AgentModelConfig): ToolModelConfig {
const primary = resolveAgentModelPrimaryValue(model);
const fallbacks = resolveAgentModelFallbackValues(model);
return {
...(primary?.trim() ? { primary: primary.trim() } : {}),
...(fallbacks.length > 0 ? { fallbacks } : {}),
};
}
export function buildToolModelConfigFromCandidates(params: {
explicit: ToolModelConfig;
agentDir?: string;
candidates: Array<string | null | undefined>;
}): ToolModelConfig | null {
if (hasToolModelConfig(params.explicit)) {
return params.explicit;
}
const deduped: string[] = [];
for (const candidate of params.candidates) {
const trimmed = candidate?.trim();
if (!trimmed || !trimmed.includes("/")) {
continue;
}
const provider = trimmed.slice(0, trimmed.indexOf("/")).trim();
if (!provider || !hasAuthForProvider({ provider, agentDir: params.agentDir })) {
continue;
}
if (!deduped.includes(trimmed)) {
deduped.push(trimmed);
}
}
if (deduped.length === 0) {
return null;
}
return {
primary: deduped[0],
...(deduped.length > 1 ? { fallbacks: deduped.slice(1) } : {}),
};
}