from enum import Enum from functools import lru_cache from langchain_core.language_models.chat_models import BaseChatModel from app.core.config import get_settings class LLMProvider(str, Enum): ANTHROPIC = "anthropic" OPENAI = "openai" class ModelTier(str, Enum): FAST = "fast" STRONG = "strong" _DEFAULTS = { LLMProvider.ANTHROPIC: { ModelTier.FAST: "claude-haiku-4-5-20251001", ModelTier.STRONG: "claude-sonnet-4-6", }, LLMProvider.OPENAI: { ModelTier.FAST: "gpt-4.1-mini", ModelTier.STRONG: "gpt-4.1", }, } def detect_provider() -> LLMProvider: settings = get_settings() if settings.anthropic_api_key: return LLMProvider.ANTHROPIC if settings.openai_api_key: return LLMProvider.OPENAI raise RuntimeError( "No LLM API key configured. Set ANTHROPIC_API_KEY or OPENAI_API_KEY." ) def _resolve_model_name(provider: LLMProvider, tier: ModelTier) -> str: settings = get_settings() if tier == ModelTier.FAST and settings.llm_fast_model: return settings.llm_fast_model if tier == ModelTier.STRONG and settings.llm_strong_model: return settings.llm_strong_model return _DEFAULTS[provider][tier] def _build_chat_model( provider: LLMProvider, model_name: str, temperature: float ) -> BaseChatModel: settings = get_settings() if provider == LLMProvider.ANTHROPIC: from langchain_anthropic import ChatAnthropic return ChatAnthropic( model=model_name, api_key=settings.anthropic_api_key, temperature=temperature, max_tokens=4096, ) from langchain_openai import ChatOpenAI return ChatOpenAI( model=model_name, api_key=settings.openai_api_key, temperature=temperature, ) def get_chat_model( tier: ModelTier = ModelTier.STRONG, temperature: float = 0.0, provider: LLMProvider | None = None, ) -> BaseChatModel: resolved_provider = provider or detect_provider() model_name = _resolve_model_name(resolved_provider, tier) return _build_chat_model(resolved_provider, model_name, temperature) @lru_cache def get_fast_model() -> BaseChatModel: return get_chat_model(tier=ModelTier.FAST) @lru_cache def get_strong_model() -> BaseChatModel: return get_chat_model(tier=ModelTier.STRONG)