93 lines
2.3 KiB
Python
93 lines
2.3 KiB
Python
from enum import Enum
|
|
from functools import lru_cache
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
|
from app.core.config import get_settings
|
|
|
|
|
|
class LLMProvider(str, Enum):
|
|
ANTHROPIC = "anthropic"
|
|
OPENAI = "openai"
|
|
|
|
|
|
class ModelTier(str, Enum):
|
|
FAST = "fast"
|
|
STRONG = "strong"
|
|
|
|
|
|
_DEFAULTS = {
|
|
LLMProvider.ANTHROPIC: {
|
|
ModelTier.FAST: "claude-haiku-4-5-20251001",
|
|
ModelTier.STRONG: "claude-sonnet-4-6",
|
|
},
|
|
LLMProvider.OPENAI: {
|
|
ModelTier.FAST: "gpt-4.1-mini",
|
|
ModelTier.STRONG: "gpt-4.1",
|
|
},
|
|
}
|
|
|
|
|
|
def detect_provider() -> LLMProvider:
|
|
settings = get_settings()
|
|
if settings.anthropic_api_key:
|
|
return LLMProvider.ANTHROPIC
|
|
if settings.openai_api_key:
|
|
return LLMProvider.OPENAI
|
|
raise RuntimeError(
|
|
"No LLM API key configured. Set ANTHROPIC_API_KEY or OPENAI_API_KEY."
|
|
)
|
|
|
|
|
|
def _resolve_model_name(provider: LLMProvider, tier: ModelTier) -> str:
|
|
settings = get_settings()
|
|
if tier == ModelTier.FAST and settings.llm_fast_model:
|
|
return settings.llm_fast_model
|
|
if tier == ModelTier.STRONG and settings.llm_strong_model:
|
|
return settings.llm_strong_model
|
|
return _DEFAULTS[provider][tier]
|
|
|
|
|
|
def _build_chat_model(
|
|
provider: LLMProvider, model_name: str, temperature: float
|
|
) -> BaseChatModel:
|
|
settings = get_settings()
|
|
|
|
if provider == LLMProvider.ANTHROPIC:
|
|
from langchain_anthropic import ChatAnthropic
|
|
|
|
return ChatAnthropic(
|
|
model=model_name,
|
|
api_key=settings.anthropic_api_key,
|
|
temperature=temperature,
|
|
max_tokens=4096,
|
|
)
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
return ChatOpenAI(
|
|
model=model_name,
|
|
api_key=settings.openai_api_key,
|
|
temperature=temperature,
|
|
)
|
|
|
|
|
|
def get_chat_model(
|
|
tier: ModelTier = ModelTier.STRONG,
|
|
temperature: float = 0.0,
|
|
provider: LLMProvider | None = None,
|
|
) -> BaseChatModel:
|
|
resolved_provider = provider or detect_provider()
|
|
model_name = _resolve_model_name(resolved_provider, tier)
|
|
return _build_chat_model(resolved_provider, model_name, temperature)
|
|
|
|
|
|
@lru_cache
|
|
def get_fast_model() -> BaseChatModel:
|
|
return get_chat_model(tier=ModelTier.FAST)
|
|
|
|
|
|
@lru_cache
|
|
def get_strong_model() -> BaseChatModel:
|
|
return get_chat_model(tier=ModelTier.STRONG)
|