109 lines
4.3 KiB
Python
109 lines
4.3 KiB
Python
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from app.services.llm import (
|
|
LLMProvider,
|
|
ModelTier,
|
|
detect_provider,
|
|
get_chat_model,
|
|
_resolve_model_name,
|
|
_DEFAULTS,
|
|
)
|
|
|
|
|
|
class TestDetectProvider:
|
|
def test_anthropic_preferred_when_both_keys_set(self):
|
|
with patch("app.services.llm.get_settings") as mock:
|
|
mock.return_value = MagicMock(
|
|
anthropic_api_key="sk-ant-xxx", openai_api_key="sk-xxx"
|
|
)
|
|
assert detect_provider() == LLMProvider.ANTHROPIC
|
|
|
|
def test_anthropic_when_only_anthropic_key(self):
|
|
with patch("app.services.llm.get_settings") as mock:
|
|
mock.return_value = MagicMock(anthropic_api_key="sk-ant-xxx", openai_api_key="")
|
|
assert detect_provider() == LLMProvider.ANTHROPIC
|
|
|
|
def test_openai_when_only_openai_key(self):
|
|
with patch("app.services.llm.get_settings") as mock:
|
|
mock.return_value = MagicMock(anthropic_api_key="", openai_api_key="sk-xxx")
|
|
assert detect_provider() == LLMProvider.OPENAI
|
|
|
|
def test_raises_when_no_keys(self):
|
|
with patch("app.services.llm.get_settings") as mock:
|
|
mock.return_value = MagicMock(anthropic_api_key="", openai_api_key="")
|
|
with pytest.raises(RuntimeError, match="No LLM API key configured"):
|
|
detect_provider()
|
|
|
|
|
|
class TestResolveModelName:
|
|
def test_defaults_anthropic_fast(self):
|
|
with patch("app.services.llm.get_settings") as mock:
|
|
mock.return_value = MagicMock(llm_fast_model="", llm_strong_model="")
|
|
name = _resolve_model_name(LLMProvider.ANTHROPIC, ModelTier.FAST)
|
|
assert name == _DEFAULTS[LLMProvider.ANTHROPIC][ModelTier.FAST]
|
|
|
|
def test_defaults_openai_strong(self):
|
|
with patch("app.services.llm.get_settings") as mock:
|
|
mock.return_value = MagicMock(llm_fast_model="", llm_strong_model="")
|
|
name = _resolve_model_name(LLMProvider.OPENAI, ModelTier.STRONG)
|
|
assert name == _DEFAULTS[LLMProvider.OPENAI][ModelTier.STRONG]
|
|
|
|
def test_override_fast_model(self):
|
|
with patch("app.services.llm.get_settings") as mock:
|
|
mock.return_value = MagicMock(llm_fast_model="custom-fast", llm_strong_model="")
|
|
name = _resolve_model_name(LLMProvider.ANTHROPIC, ModelTier.FAST)
|
|
assert name == "custom-fast"
|
|
|
|
def test_override_strong_model(self):
|
|
with patch("app.services.llm.get_settings") as mock:
|
|
mock.return_value = MagicMock(llm_fast_model="", llm_strong_model="custom-strong")
|
|
name = _resolve_model_name(LLMProvider.OPENAI, ModelTier.STRONG)
|
|
assert name == "custom-strong"
|
|
|
|
|
|
class TestGetChatModel:
|
|
@patch("app.services.llm.get_settings")
|
|
def test_returns_anthropic_model(self, mock_settings):
|
|
mock_settings.return_value = MagicMock(
|
|
anthropic_api_key="sk-ant-xxx",
|
|
openai_api_key="",
|
|
llm_fast_model="",
|
|
llm_strong_model="",
|
|
)
|
|
model = get_chat_model(tier=ModelTier.STRONG)
|
|
assert type(model).__name__ == "ChatAnthropic"
|
|
|
|
@patch("app.services.llm.get_settings")
|
|
def test_returns_openai_model(self, mock_settings):
|
|
mock_settings.return_value = MagicMock(
|
|
anthropic_api_key="",
|
|
openai_api_key="sk-xxx",
|
|
llm_fast_model="",
|
|
llm_strong_model="",
|
|
)
|
|
model = get_chat_model(tier=ModelTier.FAST)
|
|
assert type(model).__name__ == "ChatOpenAI"
|
|
|
|
@patch("app.services.llm.get_settings")
|
|
def test_explicit_provider_override(self, mock_settings):
|
|
mock_settings.return_value = MagicMock(
|
|
anthropic_api_key="sk-ant-xxx",
|
|
openai_api_key="sk-xxx",
|
|
llm_fast_model="",
|
|
llm_strong_model="",
|
|
)
|
|
model = get_chat_model(tier=ModelTier.FAST, provider=LLMProvider.OPENAI)
|
|
assert type(model).__name__ == "ChatOpenAI"
|
|
|
|
@patch("app.services.llm.get_settings")
|
|
def test_custom_temperature(self, mock_settings):
|
|
mock_settings.return_value = MagicMock(
|
|
anthropic_api_key="sk-ant-xxx",
|
|
openai_api_key="",
|
|
llm_fast_model="",
|
|
llm_strong_model="",
|
|
)
|
|
model = get_chat_model(tier=ModelTier.STRONG, temperature=0.7)
|
|
assert model.temperature == 0.7
|