diff --git a/backend/app/api/strategy.py b/backend/app/api/strategy.py index 3f04bca..abb6f46 100644 --- a/backend/app/api/strategy.py +++ b/backend/app/api/strategy.py @@ -1,11 +1,14 @@ """ Quant strategy API endpoints. """ +from typing import Set + from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from app.core.database import get_db from app.api.deps import CurrentUser +from app.models.stock import ETF from app.schemas.strategy import ( MultiFactorRequest, QualityRequest, ValueMomentumRequest, KJBRequest, StrategyResult, ) @@ -14,6 +17,27 @@ from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMom router = APIRouter(prefix="/api/strategy", tags=["strategy"]) +def _filter_dc_only(result: StrategyResult, db: Session) -> StrategyResult: + """Filter strategy result to include only ETFs (DC pension investable).""" + tickers = [s.ticker for s in result.stocks] + etf_tickers: Set[str] = set( + row[0] for row in db.query(ETF.ticker).filter(ETF.ticker.in_(tickers)).all() + ) if tickers else set() + + filtered = [s for s in result.stocks if s.ticker in etf_tickers] + # Re-rank + for i, stock in enumerate(filtered, 1): + stock.rank = i + + return StrategyResult( + strategy_name=result.strategy_name, + base_date=result.base_date, + universe_count=result.universe_count, + result_count=len(filtered), + stocks=filtered, + ) + + @router.post("/multi-factor", response_model=StrategyResult) async def run_multi_factor( request: MultiFactorRequest, @@ -22,12 +46,13 @@ async def run_multi_factor( ): """Run multi-factor strategy.""" strategy = MultiFactorStrategy(db) - return strategy.run( + result = strategy.run( universe_filter=request.universe, top_n=request.top_n, base_date=request.base_date, weights=request.weights, ) + return _filter_dc_only(result, db) if request.dc_only else result @router.post("/quality", response_model=StrategyResult) @@ -38,12 +63,13 @@ async def run_quality( ): """Run super quality strategy.""" strategy = QualityStrategy(db) - return strategy.run( + result = strategy.run( universe_filter=request.universe, top_n=request.top_n, base_date=request.base_date, min_fscore=request.min_fscore, ) + return _filter_dc_only(result, db) if request.dc_only else result @router.post("/value-momentum", response_model=StrategyResult) @@ -54,13 +80,14 @@ async def run_value_momentum( ): """Run value-momentum strategy.""" strategy = ValueMomentumStrategy(db) - return strategy.run( + result = strategy.run( universe_filter=request.universe, top_n=request.top_n, base_date=request.base_date, value_weight=request.value_weight, momentum_weight=request.momentum_weight, ) + return _filter_dc_only(result, db) if request.dc_only else result @router.post("/kjb", response_model=StrategyResult) @@ -71,8 +98,9 @@ async def run_kjb( ): """Run KJB strategy.""" strategy = KJBStrategy(db) - return strategy.run( + result = strategy.run( universe_filter=request.universe, top_n=request.top_n, base_date=request.base_date, ) + return _filter_dc_only(result, db) if request.dc_only else result diff --git a/backend/app/schemas/strategy.py b/backend/app/schemas/strategy.py index d208227..0feb4b4 100644 --- a/backend/app/schemas/strategy.py +++ b/backend/app/schemas/strategy.py @@ -32,6 +32,7 @@ class StrategyRequest(BaseModel): universe: UniverseFilter = UniverseFilter() top_n: int = Field(default=30, ge=1, le=100) base_date: Optional[date] = None + dc_only: bool = False class MultiFactorRequest(StrategyRequest): diff --git a/backend/tests/e2e/test_strategy_flow.py b/backend/tests/e2e/test_strategy_flow.py index 6775c9f..4e3421c 100644 --- a/backend/tests/e2e/test_strategy_flow.py +++ b/backend/tests/e2e/test_strategy_flow.py @@ -79,6 +79,50 @@ def test_value_momentum_strategy(client: TestClient, auth_headers): assert data["strategy_name"] == "value_momentum" +def test_dc_only_filter(client: TestClient, auth_headers): + """Test dc_only parameter filters to ETFs only.""" + response = client.post( + "/api/strategy/multi-factor", + json={ + "universe": { + "markets": ["KOSPI"], + }, + "top_n": 20, + "dc_only": True, + "weights": { + "value": 0.3, + "quality": 0.3, + "momentum": 0.2, + "low_vol": 0.2, + }, + }, + headers=auth_headers, + ) + # May fail if no data, just check it accepts the parameter + assert response.status_code in [200, 400, 500] + + if response.status_code == 200: + data = response.json() + assert "stocks" in data + # All returned stocks should be ETFs (or empty if no ETFs in universe) + + +def test_dc_only_false_returns_all(client: TestClient, auth_headers): + """Test dc_only=false returns all stocks (default behavior).""" + response = client.post( + "/api/strategy/multi-factor", + json={ + "universe": { + "markets": ["KOSPI"], + }, + "top_n": 20, + "dc_only": False, + }, + headers=auth_headers, + ) + assert response.status_code in [200, 400, 500] + + def test_strategy_requires_auth(client: TestClient): """Test that strategy endpoints require authentication.""" response = client.post(