feat: add DC pension ETF-only filter to strategy API
Add dc_only parameter to all strategy endpoints. When true, filters results to include only tickers present in the ETF table, supporting DC pension investment constraints where only ETFs are allowed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
62ac92eaaf
commit
4483f6e4ba
@ -1,11 +1,14 @@
|
||||
"""
|
||||
Quant strategy API endpoints.
|
||||
"""
|
||||
from typing import Set
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import CurrentUser
|
||||
from app.models.stock import ETF
|
||||
from app.schemas.strategy import (
|
||||
MultiFactorRequest, QualityRequest, ValueMomentumRequest, KJBRequest, StrategyResult,
|
||||
)
|
||||
@ -14,6 +17,27 @@ from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMom
|
||||
router = APIRouter(prefix="/api/strategy", tags=["strategy"])
|
||||
|
||||
|
||||
def _filter_dc_only(result: StrategyResult, db: Session) -> StrategyResult:
|
||||
"""Filter strategy result to include only ETFs (DC pension investable)."""
|
||||
tickers = [s.ticker for s in result.stocks]
|
||||
etf_tickers: Set[str] = set(
|
||||
row[0] for row in db.query(ETF.ticker).filter(ETF.ticker.in_(tickers)).all()
|
||||
) if tickers else set()
|
||||
|
||||
filtered = [s for s in result.stocks if s.ticker in etf_tickers]
|
||||
# Re-rank
|
||||
for i, stock in enumerate(filtered, 1):
|
||||
stock.rank = i
|
||||
|
||||
return StrategyResult(
|
||||
strategy_name=result.strategy_name,
|
||||
base_date=result.base_date,
|
||||
universe_count=result.universe_count,
|
||||
result_count=len(filtered),
|
||||
stocks=filtered,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/multi-factor", response_model=StrategyResult)
|
||||
async def run_multi_factor(
|
||||
request: MultiFactorRequest,
|
||||
@ -22,12 +46,13 @@ async def run_multi_factor(
|
||||
):
|
||||
"""Run multi-factor strategy."""
|
||||
strategy = MultiFactorStrategy(db)
|
||||
return strategy.run(
|
||||
result = strategy.run(
|
||||
universe_filter=request.universe,
|
||||
top_n=request.top_n,
|
||||
base_date=request.base_date,
|
||||
weights=request.weights,
|
||||
)
|
||||
return _filter_dc_only(result, db) if request.dc_only else result
|
||||
|
||||
|
||||
@router.post("/quality", response_model=StrategyResult)
|
||||
@ -38,12 +63,13 @@ async def run_quality(
|
||||
):
|
||||
"""Run super quality strategy."""
|
||||
strategy = QualityStrategy(db)
|
||||
return strategy.run(
|
||||
result = strategy.run(
|
||||
universe_filter=request.universe,
|
||||
top_n=request.top_n,
|
||||
base_date=request.base_date,
|
||||
min_fscore=request.min_fscore,
|
||||
)
|
||||
return _filter_dc_only(result, db) if request.dc_only else result
|
||||
|
||||
|
||||
@router.post("/value-momentum", response_model=StrategyResult)
|
||||
@ -54,13 +80,14 @@ async def run_value_momentum(
|
||||
):
|
||||
"""Run value-momentum strategy."""
|
||||
strategy = ValueMomentumStrategy(db)
|
||||
return strategy.run(
|
||||
result = strategy.run(
|
||||
universe_filter=request.universe,
|
||||
top_n=request.top_n,
|
||||
base_date=request.base_date,
|
||||
value_weight=request.value_weight,
|
||||
momentum_weight=request.momentum_weight,
|
||||
)
|
||||
return _filter_dc_only(result, db) if request.dc_only else result
|
||||
|
||||
|
||||
@router.post("/kjb", response_model=StrategyResult)
|
||||
@ -71,8 +98,9 @@ async def run_kjb(
|
||||
):
|
||||
"""Run KJB strategy."""
|
||||
strategy = KJBStrategy(db)
|
||||
return strategy.run(
|
||||
result = strategy.run(
|
||||
universe_filter=request.universe,
|
||||
top_n=request.top_n,
|
||||
base_date=request.base_date,
|
||||
)
|
||||
return _filter_dc_only(result, db) if request.dc_only else result
|
||||
|
||||
@ -32,6 +32,7 @@ class StrategyRequest(BaseModel):
|
||||
universe: UniverseFilter = UniverseFilter()
|
||||
top_n: int = Field(default=30, ge=1, le=100)
|
||||
base_date: Optional[date] = None
|
||||
dc_only: bool = False
|
||||
|
||||
|
||||
class MultiFactorRequest(StrategyRequest):
|
||||
|
||||
@ -79,6 +79,50 @@ def test_value_momentum_strategy(client: TestClient, auth_headers):
|
||||
assert data["strategy_name"] == "value_momentum"
|
||||
|
||||
|
||||
def test_dc_only_filter(client: TestClient, auth_headers):
|
||||
"""Test dc_only parameter filters to ETFs only."""
|
||||
response = client.post(
|
||||
"/api/strategy/multi-factor",
|
||||
json={
|
||||
"universe": {
|
||||
"markets": ["KOSPI"],
|
||||
},
|
||||
"top_n": 20,
|
||||
"dc_only": True,
|
||||
"weights": {
|
||||
"value": 0.3,
|
||||
"quality": 0.3,
|
||||
"momentum": 0.2,
|
||||
"low_vol": 0.2,
|
||||
},
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
# May fail if no data, just check it accepts the parameter
|
||||
assert response.status_code in [200, 400, 500]
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert "stocks" in data
|
||||
# All returned stocks should be ETFs (or empty if no ETFs in universe)
|
||||
|
||||
|
||||
def test_dc_only_false_returns_all(client: TestClient, auth_headers):
|
||||
"""Test dc_only=false returns all stocks (default behavior)."""
|
||||
response = client.post(
|
||||
"/api/strategy/multi-factor",
|
||||
json={
|
||||
"universe": {
|
||||
"markets": ["KOSPI"],
|
||||
},
|
||||
"top_n": 20,
|
||||
"dc_only": False,
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code in [200, 400, 500]
|
||||
|
||||
|
||||
def test_strategy_requires_auth(client: TestClient):
|
||||
"""Test that strategy endpoints require authentication."""
|
||||
response = client.post(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user