feat: improve security, performance, and add missing features

- Remove hardcoded database_url/jwt_secret defaults, require env vars
- Add DB indexes for stocks.market, market_cap, backtests.user_id
- Optimize backtest engine: preload all prices, move stock_names out of loop
- Fix backtest API auth: filter by user_id at query level (6 endpoints)
- Add manual transaction entry modal on portfolio detail page
- Replace console.error with toast.error in signals, backtest, data explorer
- Add backtest delete button with confirmation dialog
- Replace simulated sine chart with real snapshot data
- Add strategy-to-portfolio apply flow with dialog
- Add DC pension risk asset ratio >70% warning on rebalance page
- Add backtest comparison page with metrics table and overlay chart
This commit is contained in:
머니페니 2026-03-20 12:27:05 +09:00
parent 49bd0d8745
commit f6db08c9bd
15 changed files with 823 additions and 212 deletions

View File

@ -2,9 +2,7 @@
# Copy this file to .env and fill in the values # Copy this file to .env and fill in the values
# Database # Database
DB_USER=galaxy DATABASE_URL=postgresql://galaxy:your_secure_password_here@localhost:5432/galaxy_po
DB_PASSWORD=your_secure_password_here
DB_NAME=galaxy_po
# JWT Authentication # JWT Authentication
JWT_SECRET=your_jwt_secret_key_here_at_least_32_characters JWT_SECRET=your_jwt_secret_key_here_at_least_32_characters

View File

@ -0,0 +1,42 @@
"""add missing performance indexes
Revision ID: c3d4e5f6a7b8
Revises: b7c8d9e0f1a2
Create Date: 2026-03-19 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c3d4e5f6a7b8"
down_revision: Union[str, None] = "59807c4e84ee"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Stock universe filtering (strategy engine uses market + market_cap frequently)
op.create_index("idx_stocks_market", "stocks", ["market"])
op.create_index(
"idx_stocks_market_cap", "stocks", [sa.text("market_cap DESC NULLS LAST")]
)
# Backtest listing by user (always filtered by user_id + ordered by created_at)
op.create_index(
"idx_backtests_user_created",
"backtests",
["user_id", sa.text("created_at DESC")],
)
op.create_index("idx_backtests_status", "backtests", ["status"])
def downgrade() -> None:
op.drop_index("idx_backtests_status", table_name="backtests")
op.drop_index("idx_backtests_user_created", table_name="backtests")
op.drop_index("idx_stocks_market_cap", table_name="stocks")
op.drop_index("idx_stocks_market", table_name="stocks")

View File

@ -1,6 +1,7 @@
""" """
Backtest API endpoints. Backtest API endpoints.
""" """
from typing import List from typing import List
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
@ -9,14 +10,26 @@ from sqlalchemy.orm import Session, joinedload
from app.core.database import get_db from app.core.database import get_db
from app.api.deps import CurrentUser from app.api.deps import CurrentUser
from app.models.backtest import ( from app.models.backtest import (
Backtest, BacktestResult, BacktestEquityCurve, Backtest,
BacktestHolding, BacktestTransaction, BacktestStatus, BacktestResult,
BacktestEquityCurve,
BacktestHolding,
BacktestTransaction,
BacktestStatus,
WalkForwardResult, WalkForwardResult,
) )
from app.schemas.backtest import ( from app.schemas.backtest import (
BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics, BacktestCreate,
EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem, BacktestResponse,
WalkForwardRequest, WalkForwardWindowResult, WalkForwardResponse, BacktestListItem,
BacktestMetrics,
EquityCurvePoint,
RebalanceHoldings,
HoldingItem,
TransactionItem,
WalkForwardRequest,
WalkForwardWindowResult,
WalkForwardResponse,
) )
from app.services.backtest import submit_backtest from app.services.backtest import submit_backtest
from app.services.backtest.walkforward_engine import WalkForwardEngine from app.services.backtest.walkforward_engine import WalkForwardEngine
@ -97,14 +110,15 @@ async def get_backtest(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Get backtest details and results.""" """Get backtest details and results."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest: if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found") raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
result_metrics = None result_metrics = None
if backtest.result: if backtest.result:
result_metrics = BacktestMetrics( result_metrics = BacktestMetrics(
@ -144,14 +158,15 @@ async def get_equity_curve(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Get equity curve data for chart.""" """Get equity curve data for chart."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest: if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found") raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
curve = ( curve = (
db.query(BacktestEquityCurve) db.query(BacktestEquityCurve)
.filter(BacktestEquityCurve.backtest_id == backtest_id) .filter(BacktestEquityCurve.backtest_id == backtest_id)
@ -177,14 +192,15 @@ async def get_holdings(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Get holdings at each rebalance date.""" """Get holdings at each rebalance date."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest: if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found") raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
holdings = ( holdings = (
db.query(BacktestHolding) db.query(BacktestHolding)
.filter(BacktestHolding.backtest_id == backtest_id) .filter(BacktestHolding.backtest_id == backtest_id)
@ -197,13 +213,15 @@ async def get_holdings(
for h in holdings: for h in holdings:
if h.rebalance_date not in grouped: if h.rebalance_date not in grouped:
grouped[h.rebalance_date] = [] grouped[h.rebalance_date] = []
grouped[h.rebalance_date].append(HoldingItem( grouped[h.rebalance_date].append(
ticker=h.ticker, HoldingItem(
name=h.name, ticker=h.ticker,
weight=h.weight, name=h.name,
shares=h.shares, weight=h.weight,
price=h.price, shares=h.shares,
)) price=h.price,
)
)
return [ return [
RebalanceHoldings(rebalance_date=date, holdings=items) RebalanceHoldings(rebalance_date=date, holdings=items)
@ -218,14 +236,15 @@ async def get_transactions(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Get all transactions.""" """Get all transactions."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest: if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found") raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
transactions = ( transactions = (
db.query(BacktestTransaction) db.query(BacktestTransaction)
.filter(BacktestTransaction.backtest_id == backtest_id) .filter(BacktestTransaction.backtest_id == backtest_id)
@ -261,14 +280,15 @@ async def run_walkforward(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Run walk-forward analysis on a completed backtest.""" """Run walk-forward analysis on a completed backtest."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest: if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found") raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
if backtest.status != BacktestStatus.COMPLETED: if backtest.status != BacktestStatus.COMPLETED:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -296,14 +316,15 @@ async def get_walkforward(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Get walk-forward analysis results.""" """Get walk-forward analysis results."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest: if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found") raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
results = ( results = (
db.query(WalkForwardResult) db.query(WalkForwardResult)
.filter(WalkForwardResult.backtest_id == backtest_id) .filter(WalkForwardResult.backtest_id == backtest_id)
@ -336,14 +357,15 @@ async def delete_backtest(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Delete a backtest and all its data.""" """Delete a backtest and all its data."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest: if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found") raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
# Delete related data # Delete related data
db.query(WalkForwardResult).filter( db.query(WalkForwardResult).filter(
WalkForwardResult.backtest_id == backtest_id WalkForwardResult.backtest_id == backtest_id
@ -357,9 +379,7 @@ async def delete_backtest(
db.query(BacktestEquityCurve).filter( db.query(BacktestEquityCurve).filter(
BacktestEquityCurve.backtest_id == backtest_id BacktestEquityCurve.backtest_id == backtest_id
).delete() ).delete()
db.query(BacktestResult).filter( db.query(BacktestResult).filter(BacktestResult.backtest_id == backtest_id).delete()
BacktestResult.backtest_id == backtest_id
).delete()
db.delete(backtest) db.delete(backtest)
db.commit() db.commit()

View File

@ -1,6 +1,7 @@
""" """
Application configuration using Pydantic Settings. Application configuration using Pydantic Settings.
""" """
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from functools import lru_cache from functools import lru_cache
@ -11,10 +12,10 @@ class Settings(BaseSettings):
debug: bool = False debug: bool = False
# Database # Database
database_url: str = "postgresql://galaxy:devpassword@localhost:5432/galaxy_po" database_url: str
# JWT # JWT
jwt_secret: str = "dev-jwt-secret-change-in-production" jwt_secret: str
jwt_algorithm: str = "HS256" jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 60 * 24 # 24 hours access_token_expire_minutes: int = 60 * 24 # 24 hours

View File

@ -1,18 +1,18 @@
""" """
Database connection and session management. Database connection and session management.
""" """
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.orm import sessionmaker, declarative_base
from app.core.config import get_settings from app.core.config import get_settings
settings = get_settings() settings = get_settings()
engine = create_engine( _engine_kwargs = {"pool_pre_ping": True}
settings.database_url, if settings.database_url.startswith("postgresql"):
pool_pre_ping=True, _engine_kwargs.update(pool_size=10, max_overflow=20)
pool_size=10,
max_overflow=20, engine = create_engine(settings.database_url, **_engine_kwargs)
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@ -1,6 +1,7 @@
""" """
Main backtest engine. Main backtest engine.
""" """
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import date, timedelta from datetime import date, timedelta
@ -12,13 +13,21 @@ from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.models.backtest import ( from app.models.backtest import (
Backtest, BacktestResult, BacktestEquityCurve, Backtest,
BacktestHolding, BacktestTransaction, RebalancePeriod, BacktestResult,
BacktestEquityCurve,
BacktestHolding,
BacktestTransaction,
RebalancePeriod,
) )
from app.models.stock import Stock, Price from app.models.stock import Stock, Price
from app.services.backtest.portfolio import VirtualPortfolio, Transaction from app.services.backtest.portfolio import VirtualPortfolio, Transaction
from app.services.backtest.metrics import MetricsCalculator from app.services.backtest.metrics import MetricsCalculator
from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMomentumStrategy from app.services.strategy import (
MultiFactorStrategy,
QualityStrategy,
ValueMomentumStrategy,
)
from app.schemas.strategy import UniverseFilter, FactorWeights from app.schemas.strategy import UniverseFilter, FactorWeights
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,6 +36,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class DataValidationResult: class DataValidationResult:
"""Result of pre-backtest data validation.""" """Result of pre-backtest data validation."""
is_valid: bool = True is_valid: bool = True
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list) warnings: List[str] = field(default_factory=list)
@ -85,9 +95,7 @@ class BacktestEngine:
logger.warning(f"Backtest {backtest_id}: {warning}") logger.warning(f"Backtest {backtest_id}: {warning}")
if not validation.is_valid: if not validation.is_valid:
raise ValueError( raise ValueError("데이터 검증 실패:\n" + "\n".join(validation.errors))
"데이터 검증 실패:\n" + "\n".join(validation.errors)
)
# Create strategy instance # Create strategy instance
strategy = self._create_strategy( strategy = self._create_strategy(
@ -105,20 +113,24 @@ class BacktestEngine:
if initial_benchmark == 0: if initial_benchmark == 0:
initial_benchmark = Decimal("1") initial_benchmark = Decimal("1")
names = self._get_stock_names()
all_date_prices = self._load_all_prices_by_date(
backtest.start_date,
backtest.end_date,
)
for trading_date in trading_days: for trading_date in trading_days:
# Get prices for this date prices = all_date_prices.get(trading_date, {})
prices = self._get_prices_for_date(trading_date)
names = self._get_stock_names()
# Warn about holdings with missing prices # Warn about holdings with missing prices
missing = [ missing = [
t for t in portfolio.holdings t
for t in portfolio.holdings
if portfolio.holdings[t] > 0 and t not in prices if portfolio.holdings[t] > 0 and t not in prices
] ]
if missing: if missing:
logger.warning( logger.warning(
f"{trading_date}: 보유 종목 가격 누락 {missing} " f"{trading_date}: 보유 종목 가격 누락 {missing} (0원으로 처리됨)"
f"(0원으로 처리됨)"
) )
# Rebalance if needed # Rebalance if needed
@ -141,16 +153,16 @@ class BacktestEngine:
slippage_rate=backtest.slippage_rate, slippage_rate=backtest.slippage_rate,
) )
all_transactions.extend([ all_transactions.extend([(trading_date, txn) for txn in transactions])
(trading_date, txn) for txn in transactions
])
# Record holdings # Record holdings
holdings = portfolio.get_holdings_with_weights(prices, names) holdings = portfolio.get_holdings_with_weights(prices, names)
holdings_history.append({ holdings_history.append(
'date': trading_date, {
'holdings': holdings, "date": trading_date,
}) "holdings": holdings,
}
)
# Record daily value # Record daily value
portfolio_value = portfolio.get_value(prices) portfolio_value = portfolio.get_value(prices)
@ -161,15 +173,21 @@ class BacktestEngine:
benchmark_value / initial_benchmark * backtest.initial_capital benchmark_value / initial_benchmark * backtest.initial_capital
) )
equity_curve_data.append({ equity_curve_data.append(
'date': trading_date, {
'portfolio_value': portfolio_value, "date": trading_date,
'benchmark_value': normalized_benchmark, "portfolio_value": portfolio_value,
}) "benchmark_value": normalized_benchmark,
}
)
# Calculate metrics # Calculate metrics
portfolio_values = [Decimal(str(e['portfolio_value'])) for e in equity_curve_data] portfolio_values = [
benchmark_values = [Decimal(str(e['benchmark_value'])) for e in equity_curve_data] Decimal(str(e["portfolio_value"])) for e in equity_curve_data
]
benchmark_values = [
Decimal(str(e["benchmark_value"])) for e in equity_curve_data
]
metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values) metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values)
drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values) drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values)
@ -221,18 +239,13 @@ class BacktestEngine:
# 2. Benchmark data coverage # 2. Benchmark data coverage
benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500" benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500"
benchmark_coverage = sum( benchmark_coverage = sum(1 for d in total_days if d in benchmark_prices)
1 for d in total_days if d in benchmark_prices
)
benchmark_pct = ( benchmark_pct = (
benchmark_coverage / num_trading_days * 100 benchmark_coverage / num_trading_days * 100 if num_trading_days > 0 else 0
if num_trading_days > 0 else 0
) )
if benchmark_coverage == 0: if benchmark_coverage == 0:
result.errors.append( result.errors.append(f"벤치마크({benchmark_ticker}) 가격 데이터 없음")
f"벤치마크({benchmark_ticker}) 가격 데이터 없음"
)
result.is_valid = False result.is_valid = False
elif benchmark_pct < 90: elif benchmark_pct < 90:
result.warnings.append( result.warnings.append(
@ -254,22 +267,17 @@ class BacktestEngine:
.scalar() .scalar()
) )
if ticker_count == 0: if ticker_count == 0:
result.errors.append( result.errors.append(f"{sample_date} 가격 데이터 없음 (종목 0개)")
f"{sample_date} 가격 데이터 없음 (종목 0개)"
)
result.is_valid = False result.is_valid = False
elif ticker_count < 100: elif ticker_count < 100:
result.warnings.append( result.warnings.append(f"{sample_date} 종목 수 적음: {ticker_count}")
f"{sample_date} 종목 수 적음: {ticker_count}"
)
# 4. Large gaps in trading days (> 7 calendar days excluding normal weekends) # 4. Large gaps in trading days (> 7 calendar days excluding normal weekends)
for i in range(1, num_trading_days): for i in range(1, num_trading_days):
gap = (total_days[i] - total_days[i - 1]).days gap = (total_days[i] - total_days[i - 1]).days
if gap > 7: if gap > 7:
result.warnings.append( result.warnings.append(
f"거래일 갭 발견: {total_days[i-1]} ~ {total_days[i]} " f"거래일 갭 발견: {total_days[i - 1]} ~ {total_days[i]} ({gap}일)"
f"({gap}일)"
) )
if result.is_valid and not result.warnings: if result.is_valid and not result.warnings:
@ -338,18 +346,25 @@ class BacktestEngine:
return {p.date: p.close for p in prices} return {p.date: p.close for p in prices}
def _get_prices_for_date(self, trading_date: date) -> Dict[str, Decimal]: def _load_all_prices_by_date(
"""Get all stock prices for a specific date.""" self,
start_date: date,
end_date: date,
) -> Dict[date, Dict[str, Decimal]]:
prices = ( prices = (
self.db.query(Price) self.db.query(Price)
.filter(Price.date == trading_date) .filter(Price.date >= start_date, Price.date <= end_date)
.all() .all()
) )
return {p.ticker: p.close for p in prices} result: Dict[date, Dict[str, Decimal]] = {}
for p in prices:
if p.date not in result:
result[p.date] = {}
result[p.date][p.ticker] = p.close
return result
def _get_stock_names(self) -> Dict[str, str]: def _get_stock_names(self) -> Dict[str, str]:
"""Get all stock names.""" stocks = self.db.query(Stock.ticker, Stock.name).all()
stocks = self.db.query(Stock).all()
return {s.ticker: s.name for s in stocks} return {s.ticker: s.name for s in stocks}
def _create_strategy( def _create_strategy(
@ -367,8 +382,12 @@ class BacktestEngine:
strategy._min_fscore = strategy_params.get("min_fscore", 7) strategy._min_fscore = strategy_params.get("min_fscore", 7)
elif strategy_type == "value_momentum": elif strategy_type == "value_momentum":
strategy = ValueMomentumStrategy(self.db) strategy = ValueMomentumStrategy(self.db)
strategy._value_weight = Decimal(str(strategy_params.get("value_weight", 0.5))) strategy._value_weight = Decimal(
strategy._momentum_weight = Decimal(str(strategy_params.get("momentum_weight", 0.5))) str(strategy_params.get("value_weight", 0.5))
)
strategy._momentum_weight = Decimal(
str(strategy_params.get("momentum_weight", 0.5))
)
else: else:
raise ValueError(f"Unknown strategy type: {strategy_type}") raise ValueError(f"Unknown strategy type: {strategy_type}")
@ -401,19 +420,19 @@ class BacktestEngine:
for i, point in enumerate(equity_curve_data): for i, point in enumerate(equity_curve_data):
curve_point = BacktestEquityCurve( curve_point = BacktestEquityCurve(
backtest_id=backtest_id, backtest_id=backtest_id,
date=point['date'], date=point["date"],
portfolio_value=point['portfolio_value'], portfolio_value=point["portfolio_value"],
benchmark_value=point['benchmark_value'], benchmark_value=point["benchmark_value"],
drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"), drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"),
) )
self.db.add(curve_point) self.db.add(curve_point)
# Save holdings history # Save holdings history
for record in holdings_history: for record in holdings_history:
for holding in record['holdings']: for holding in record["holdings"]:
h = BacktestHolding( h = BacktestHolding(
backtest_id=backtest_id, backtest_id=backtest_id,
rebalance_date=record['date'], rebalance_date=record["date"],
ticker=holding.ticker, ticker=holding.ticker,
name=holding.name, name=holding.name,
weight=holding.weight, weight=holding.weight,

View File

@ -1,10 +1,14 @@
""" """
Pytest configuration and fixtures for E2E tests. Pytest configuration and fixtures for E2E tests.
""" """
import os import os
import pytest import pytest
from typing import Generator from typing import Generator
os.environ.setdefault("DATABASE_URL", "sqlite:///:memory:")
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-pytest-only")
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.orm import sessionmaker, Session

View File

@ -97,8 +97,7 @@ export default function DataExplorerPage() {
const endpoint = `/api/data/${tab}?${params}`; const endpoint = `/api/data/${tab}?${params}`;
const result = await api.get<PaginatedResponse<unknown>>(endpoint); const result = await api.get<PaginatedResponse<unknown>>(endpoint);
setData(result); setData(result);
} catch (err) { } catch {
console.error('Failed to fetch data:', err);
toast.error('데이터를 불러오는데 실패했습니다.'); toast.error('데이터를 불러오는데 실패했습니다.');
setData(null); setData(null);
} finally { } finally {
@ -131,8 +130,7 @@ export default function DataExplorerPage() {
: `/api/data/etfs/${ticker}/prices`; : `/api/data/etfs/${ticker}/prices`;
const result = await api.get<PricePoint[]>(endpoint); const result = await api.get<PricePoint[]>(endpoint);
setPrices(result); setPrices(result);
} catch (err) { } catch {
console.error('Failed to fetch prices:', err);
toast.error('가격 데이터를 불러오는데 실패했습니다.'); toast.error('가격 데이터를 불러오는데 실패했습니다.');
setPrices([]); setPrices([]);
} finally { } finally {

View File

@ -0,0 +1,419 @@
'use client';
import { useEffect, useState } from 'react';
import { useRouter } from 'next/navigation';
import Link from 'next/link';
import {
Area,
AreaChart as RechartsAreaChart,
CartesianGrid,
Legend,
ResponsiveContainer,
Tooltip,
XAxis,
YAxis,
} from 'recharts';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Skeleton } from '@/components/ui/skeleton';
import { api } from '@/lib/api';
import { toast } from 'sonner';
interface BacktestListItem {
id: number;
strategy_type: string;
start_date: string;
end_date: string;
rebalance_period: string;
status: string;
created_at: string;
total_return: number | null;
cagr: number | null;
mdd: number | null;
}
interface BacktestDetail {
id: number;
strategy_type: string;
start_date: string;
end_date: string;
status: string;
result: {
total_return: number;
cagr: number;
mdd: number;
sharpe_ratio: number;
volatility: number;
benchmark_return: number;
excess_return: number;
} | null;
}
interface EquityCurvePoint {
date: string;
portfolio_value: number;
benchmark_value: number;
drawdown: number;
}
interface CompareData {
detail: BacktestDetail;
equityCurve: EquityCurvePoint[];
}
const STRATEGY_LABELS: Record<string, string> = {
multi_factor: '멀티 팩터',
quality: '슈퍼 퀄리티',
value_momentum: '밸류 모멘텀',
kjb: '김종봉 단기매매',
};
const COMPARE_COLORS = ['#3b82f6', '#ef4444', '#22c55e'];
const METRICS = [
{ key: 'total_return', label: '총 수익률', suffix: '%' },
{ key: 'cagr', label: 'CAGR', suffix: '%' },
{ key: 'mdd', label: 'MDD', suffix: '%' },
{ key: 'sharpe_ratio', label: '샤프 비율', suffix: '' },
{ key: 'volatility', label: '변동성', suffix: '%' },
{ key: 'benchmark_return', label: '벤치마크 수익률', suffix: '%' },
{ key: 'excess_return', label: '초과 수익률', suffix: '%' },
] as const;
export default function BacktestComparePage() {
const router = useRouter();
const [loading, setLoading] = useState(true);
const [backtests, setBacktests] = useState<BacktestListItem[]>([]);
const [selectedIds, setSelectedIds] = useState<Set<number>>(new Set());
const [compareData, setCompareData] = useState<CompareData[]>([]);
const [comparing, setComparing] = useState(false);
useEffect(() => {
const init = async () => {
try {
await api.getCurrentUser();
const data = await api.get<BacktestListItem[]>('/api/backtest');
setBacktests(data.filter((bt) => bt.status === 'completed'));
} catch {
router.push('/login');
} finally {
setLoading(false);
}
};
init();
}, [router]);
const toggleSelection = (id: number) => {
setSelectedIds((prev) => {
const next = new Set(prev);
if (next.has(id)) {
next.delete(id);
} else if (next.size < 3) {
next.add(id);
} else {
toast.error('최대 3개까지 선택할 수 있습니다.');
}
return next;
});
};
const handleCompare = async () => {
if (selectedIds.size < 2) {
toast.error('비교할 백테스트를 2개 이상 선택하세요.');
return;
}
setComparing(true);
try {
const ids = Array.from(selectedIds);
const results = await Promise.all(
ids.map(async (id) => {
const [detail, equityCurve] = await Promise.all([
api.get<BacktestDetail>(`/api/backtest/${id}`),
api.get<EquityCurvePoint[]>(`/api/backtest/${id}/equity-curve`),
]);
return { detail, equityCurve };
})
);
setCompareData(results);
} catch (err) {
toast.error(err instanceof Error ? err.message : '비교 데이터를 불러오는데 실패했습니다.');
} finally {
setComparing(false);
}
};
const getStrategyLabel = (type: string) => STRATEGY_LABELS[type] || type;
const formatNumber = (value: number | null | undefined, decimals: number = 2) => {
if (value === null || value === undefined) return '-';
return value.toFixed(decimals);
};
const formatCurrency = (value: number) => {
return new Intl.NumberFormat('ko-KR', {
style: 'currency',
currency: 'KRW',
maximumFractionDigits: 0,
}).format(value);
};
const buildChartData = () => {
if (compareData.length === 0) return [];
const dateMap = new Map<string, Record<string, number>>();
compareData.forEach((cd, idx) => {
for (const point of cd.equityCurve) {
const existing = dateMap.get(point.date) || {};
existing[`value_${idx}`] = point.portfolio_value;
dateMap.set(point.date, existing);
}
});
return Array.from(dateMap.entries())
.sort(([a], [b]) => a.localeCompare(b))
.map(([date, values]) => ({ date, ...values }));
};
const getCompareLabel = (idx: number) => {
const cd = compareData[idx];
return `${getStrategyLabel(cd.detail.strategy_type)} (${cd.detail.start_date.slice(0, 4)}~${cd.detail.end_date.slice(0, 4)})`;
};
if (loading) {
return (
<DashboardLayout>
<Skeleton className="h-8 w-48 mb-6" />
<Skeleton className="h-96 rounded-xl" />
</DashboardLayout>
);
}
const chartData = buildChartData();
return (
<DashboardLayout>
<div className="flex items-center justify-between mb-6">
<div>
<h1 className="text-2xl font-bold text-foreground"> </h1>
<p className="mt-1 text-muted-foreground">
( 3)
</p>
</div>
<Button variant="outline" asChild>
<Link href="/backtest"> </Link>
</Button>
</div>
<Card className="mb-6">
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
{backtests.length === 0 ? (
<p className="text-muted-foreground text-center py-8">
.
</p>
) : (
<>
<div className="space-y-2 max-h-64 overflow-y-auto mb-4">
{backtests.map((bt) => (
<label
key={bt.id}
className={`flex items-center gap-3 p-3 border rounded cursor-pointer transition-colors ${
selectedIds.has(bt.id)
? 'border-primary bg-primary/5'
: 'border-border hover:bg-muted/50'
}`}
>
<input
type="checkbox"
checked={selectedIds.has(bt.id)}
onChange={() => toggleSelection(bt.id)}
className="h-4 w-4 rounded border-input"
/>
<div className="flex-1 min-w-0">
<span className="font-medium text-sm">
{getStrategyLabel(bt.strategy_type)}
</span>
<span className="text-xs text-muted-foreground ml-2">
{bt.start_date} ~ {bt.end_date}
</span>
</div>
<div className="flex gap-4 text-xs text-muted-foreground">
<span>: <span className={(bt.total_return ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'}>{formatNumber(bt.total_return)}%</span></span>
<span>CAGR: <span className={(bt.cagr ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'}>{formatNumber(bt.cagr)}%</span></span>
<span>MDD: <span className="text-red-600">{formatNumber(bt.mdd)}%</span></span>
</div>
</label>
))}
</div>
<Button
onClick={handleCompare}
disabled={selectedIds.size < 2 || comparing}
>
{comparing ? '비교 중...' : `비교하기 (${selectedIds.size}개 선택)`}
</Button>
</>
)}
</CardContent>
</Card>
{compareData.length >= 2 && (
<>
<Card className="mb-6">
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent className="p-0">
<div className="overflow-x-auto">
<table className="w-full">
<thead className="bg-muted">
<tr>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground">
</th>
{compareData.map((cd, idx) => (
<th
key={cd.detail.id}
scope="col"
className="px-4 py-3 text-right text-sm font-medium"
style={{ color: COMPARE_COLORS[idx] }}
>
{getCompareLabel(idx)}
</th>
))}
</tr>
</thead>
<tbody className="divide-y divide-border">
{METRICS.map((metric) => (
<tr key={metric.key}>
<td className="px-4 py-3 text-sm font-medium text-muted-foreground">
{metric.label}
</td>
{compareData.map((cd) => {
const value = cd.detail.result
? cd.detail.result[metric.key as keyof typeof cd.detail.result]
: null;
const numValue = typeof value === 'number' ? value : null;
const isNegativeMetric = metric.key === 'mdd';
const colorClass = numValue !== null
? isNegativeMetric
? 'text-red-600'
: numValue >= 0
? 'text-green-600'
: 'text-red-600'
: '';
return (
<td
key={cd.detail.id}
className={`px-4 py-3 text-sm text-right font-medium ${colorClass}`}
>
{formatNumber(numValue)}{numValue !== null ? metric.suffix : ''}
</td>
);
})}
</tr>
))}
</tbody>
</table>
</div>
</CardContent>
</Card>
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
{chartData.length > 0 ? (
<div style={{ height: 400 }}>
<ResponsiveContainer width="100%" height="100%">
<RechartsAreaChart
data={chartData}
margin={{ top: 10, right: 30, left: 0, bottom: 0 }}
>
<defs>
{compareData.map((_, idx) => (
<linearGradient
key={idx}
id={`compareGradient_${idx}`}
x1="0"
y1="0"
x2="0"
y2="1"
>
<stop offset="5%" stopColor={COMPARE_COLORS[idx]} stopOpacity={0.15} />
<stop offset="95%" stopColor={COMPARE_COLORS[idx]} stopOpacity={0} />
</linearGradient>
))}
</defs>
<CartesianGrid
strokeDasharray="3 3"
stroke="hsl(var(--border))"
vertical={false}
/>
<XAxis
dataKey="date"
tickFormatter={(v: string) => v.slice(5)}
stroke="hsl(var(--muted-foreground))"
fontSize={12}
tickLine={false}
axisLine={false}
/>
<YAxis
tickFormatter={(v: number) => formatCurrency(v)}
stroke="hsl(var(--muted-foreground))"
fontSize={12}
tickLine={false}
axisLine={false}
width={100}
/>
<Tooltip
contentStyle={{
backgroundColor: 'hsl(var(--popover))',
border: '1px solid hsl(var(--border))',
borderRadius: '8px',
color: 'hsl(var(--popover-foreground))',
}}
labelStyle={{ color: 'hsl(var(--popover-foreground))' }}
formatter={(value, name) => {
const idx = parseInt(String(name).replace('value_', ''));
return [formatCurrency(Number(value)), getCompareLabel(idx)];
}}
/>
<Legend
wrapperStyle={{ color: 'hsl(var(--foreground))' }}
formatter={(value: string) => {
const idx = parseInt(value.replace('value_', ''));
return getCompareLabel(idx);
}}
/>
{compareData.map((_, idx) => (
<Area
key={idx}
type="monotone"
dataKey={`value_${idx}`}
stroke={COMPARE_COLORS[idx]}
strokeWidth={2}
fill={`url(#compareGradient_${idx})`}
name={`value_${idx}`}
connectNulls
/>
))}
</RechartsAreaChart>
</ResponsiveContainer>
</div>
) : (
<div className="flex items-center justify-center h-64 bg-muted/50 rounded-lg">
<p className="text-muted-foreground"> </p>
</div>
)}
</CardContent>
</Card>
</>
)}
</DashboardLayout>
);
}

View File

@ -17,9 +17,18 @@ import {
SelectTrigger, SelectTrigger,
SelectValue, SelectValue,
} from '@/components/ui/select'; } from '@/components/ui/select';
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
DialogFooter,
} from '@/components/ui/dialog';
import { AreaChart } from '@/components/charts/area-chart'; import { AreaChart } from '@/components/charts/area-chart';
import { api } from '@/lib/api'; import { api } from '@/lib/api';
import { TrendingUp, TrendingDown, Activity, Target, Calendar, Settings } from 'lucide-react'; import { toast } from 'sonner';
import { TrendingUp, TrendingDown, Activity, Target, Calendar, Settings, Trash2, GitCompareArrows } from 'lucide-react';
interface BacktestResult { interface BacktestResult {
id: number; id: number;
@ -148,7 +157,7 @@ export default function BacktestPage() {
} }
} }
} catch (err) { } catch (err) {
console.error('Failed to fetch backtests:', err); toast.error(err instanceof Error ? err.message : '백테스트 목록을 불러오는데 실패했습니다.');
} }
}; };
@ -238,6 +247,34 @@ export default function BacktestPage() {
}).format(value); }).format(value);
}; };
const [deleteConfirmOpen, setDeleteConfirmOpen] = useState(false);
const [deleteTargetId, setDeleteTargetId] = useState<number | null>(null);
const [deleting, setDeleting] = useState(false);
const handleDeleteClick = (id: number) => {
setDeleteTargetId(id);
setDeleteConfirmOpen(true);
};
const handleConfirmDelete = async () => {
if (deleteTargetId === null) return;
setDeleting(true);
try {
await api.delete(`/api/backtest/${deleteTargetId}`);
setBacktests((prev) => prev.filter((bt) => bt.id !== deleteTargetId));
if (currentResult?.id === deleteTargetId) {
setCurrentResult(null);
}
setDeleteConfirmOpen(false);
setDeleteTargetId(null);
toast.success('백테스트가 삭제되었습니다.');
} catch (err) {
toast.error(err instanceof Error ? err.message : '백테스트 삭제에 실패했습니다.');
} finally {
setDeleting(false);
}
};
const displayResult = currentResult; const displayResult = currentResult;
if (loading) { if (loading) {
@ -263,13 +300,21 @@ export default function BacktestPage() {
</p> </p>
</div> </div>
<Button <div className="flex gap-2">
variant="outline" <Button variant="outline" asChild>
onClick={() => setShowHistory(!showHistory)} <Link href="/backtest/compare">
> <GitCompareArrows className="mr-2 h-4 w-4" />
<Calendar className="mr-2 h-4 w-4" />
{showHistory ? '새 백테스트' : '이전 기록'} </Link>
</Button> </Button>
<Button
variant="outline"
onClick={() => setShowHistory(!showHistory)}
>
<Calendar className="mr-2 h-4 w-4" />
{showHistory ? '새 백테스트' : '이전 기록'}
</Button>
</div>
</div> </div>
{error && ( {error && (
@ -297,6 +342,7 @@ export default function BacktestPage() {
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground">MDD</th> <th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground">MDD</th>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground"></th> <th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th> <th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground"></th>
</tr> </tr>
</thead> </thead>
<tbody className="divide-y divide-border"> <tbody className="divide-y divide-border">
@ -328,11 +374,21 @@ export default function BacktestPage() {
<td className="px-4 py-3 text-sm text-muted-foreground"> <td className="px-4 py-3 text-sm text-muted-foreground">
{new Date(bt.created_at).toLocaleDateString('ko-KR')} {new Date(bt.created_at).toLocaleDateString('ko-KR')}
</td> </td>
<td className="px-4 py-3 text-center">
<Button
variant="ghost"
size="icon"
onClick={() => handleDeleteClick(bt.id)}
className="h-7 w-7 text-muted-foreground hover:text-destructive"
>
<Trash2 className="h-4 w-4" />
</Button>
</td>
</tr> </tr>
))} ))}
{backtests.length === 0 && ( {backtests.length === 0 && (
<tr> <tr>
<td colSpan={8} className="px-4 py-8 text-center text-muted-foreground"> <td colSpan={9} className="px-4 py-8 text-center text-muted-foreground">
. .
</td> </td>
</tr> </tr>
@ -771,6 +827,24 @@ export default function BacktestPage() {
</div> </div>
</div> </div>
)} )}
<Dialog open={deleteConfirmOpen} onOpenChange={setDeleteConfirmOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle> </DialogTitle>
<DialogDescription>
. ?
</DialogDescription>
</DialogHeader>
<DialogFooter>
<Button variant="outline" onClick={() => setDeleteConfirmOpen(false)} disabled={deleting}>
</Button>
<Button variant="destructive" onClick={handleConfirmDelete} disabled={deleting}>
{deleting ? '삭제 중...' : '삭제'}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</DashboardLayout> </DashboardLayout>
); );
} }

View File

@ -127,6 +127,7 @@ export default function PortfolioDetailPage() {
tx_type: 'buy', tx_type: 'buy',
quantity: '', quantity: '',
price: '', price: '',
executed_at: '',
memo: '', memo: '',
}); });
@ -237,7 +238,7 @@ export default function PortfolioDetailPage() {
}; };
const handleAddTransaction = async () => { const handleAddTransaction = async () => {
if (!txForm.ticker || !txForm.quantity || !txForm.price) return; if (!txForm.ticker || !txForm.quantity || !txForm.price || !txForm.executed_at) return;
setTxSubmitting(true); setTxSubmitting(true);
try { try {
await api.post(`/api/portfolios/${portfolioId}/transactions`, { await api.post(`/api/portfolios/${portfolioId}/transactions`, {
@ -245,11 +246,11 @@ export default function PortfolioDetailPage() {
tx_type: txForm.tx_type, tx_type: txForm.tx_type,
quantity: parseInt(txForm.quantity, 10), quantity: parseInt(txForm.quantity, 10),
price: parseFloat(txForm.price), price: parseFloat(txForm.price),
executed_at: new Date().toISOString(), executed_at: new Date(txForm.executed_at).toISOString(),
memo: txForm.memo || null, memo: txForm.memo || null,
}); });
setTxModalOpen(false); setTxModalOpen(false);
setTxForm({ ticker: '', tx_type: 'buy', quantity: '', price: '', memo: '' }); setTxForm({ ticker: '', tx_type: 'buy', quantity: '', price: '', executed_at: '', memo: '' });
await Promise.all([fetchPortfolio(), fetchTransactions()]); await Promise.all([fetchPortfolio(), fetchTransactions()]);
} catch (err) { } catch (err) {
const message = err instanceof Error ? err.message : '거래 추가 실패'; const message = err instanceof Error ? err.message : '거래 추가 실패';
@ -790,6 +791,15 @@ export default function PortfolioDetailPage() {
/> />
</div> </div>
</div> </div>
<div className="space-y-2">
<Label htmlFor="tx-executed-at"> </Label>
<Input
id="tx-executed-at"
type="datetime-local"
value={txForm.executed_at}
onChange={(e) => setTxForm({ ...txForm, executed_at: e.target.value })}
/>
</div>
<div className="space-y-2"> <div className="space-y-2">
<Label htmlFor="tx-memo"> ()</Label> <Label htmlFor="tx-memo"> ()</Label>
<Input <Input
@ -806,7 +816,7 @@ export default function PortfolioDetailPage() {
</Button> </Button>
<Button <Button
onClick={handleAddTransaction} onClick={handleAddTransaction}
disabled={txSubmitting || !txForm.ticker || !txForm.quantity || !txForm.price} disabled={txSubmitting || !txForm.ticker || !txForm.quantity || !txForm.price || !txForm.executed_at}
> >
{txSubmitting ? '저장 중...' : '저장'} {txSubmitting ? '저장 중...' : '저장'}
</Button> </Button>

View File

@ -64,6 +64,8 @@ export default function RebalancePage() {
const [applyPrices, setApplyPrices] = useState<Record<string, string>>({}); const [applyPrices, setApplyPrices] = useState<Record<string, string>>({});
const [applying, setApplying] = useState(false); const [applying, setApplying] = useState(false);
const [applyError, setApplyError] = useState<string | null>(null); const [applyError, setApplyError] = useState<string | null>(null);
const [portfolioType, setPortfolioType] = useState<string>('general');
const [currentRiskRatio, setCurrentRiskRatio] = useState<number | null>(null);
useEffect(() => { useEffect(() => {
const init = async () => { const init = async () => {
@ -87,16 +89,21 @@ export default function RebalancePage() {
}); });
setPrices(initialPrices); setPrices(initialPrices);
// Fetch stock names from portfolio detail
try { try {
const detail = await api.get<{ holdings: { ticker: string; name: string | null }[] }>(`/api/portfolios/${portfolioId}/detail`); const detail = await api.get<{
portfolio_type: string;
risk_asset_ratio: number | null;
holdings: { ticker: string; name: string | null }[];
}>(`/api/portfolios/${portfolioId}/detail`);
const names: Record<string, string> = {}; const names: Record<string, string> = {};
for (const h of detail.holdings) { for (const h of detail.holdings) {
if (h.name) names[h.ticker] = h.name; if (h.name) names[h.ticker] = h.name;
} }
setNameMap(names); setNameMap(names);
setPortfolioType(detail.portfolio_type);
setCurrentRiskRatio(detail.risk_asset_ratio);
} catch { } catch {
// Names are optional, continue without // ignore
} }
} catch { } catch {
router.push('/login'); router.push('/login');
@ -316,7 +323,19 @@ export default function RebalancePage() {
</CardContent> </CardContent>
</Card> </Card>
{/* Results */} {result && portfolioType === 'pension' && currentRiskRatio !== null && currentRiskRatio > 70 && (
<div className="bg-amber-50 border border-amber-300 text-amber-800 dark:bg-amber-950 dark:border-amber-700 dark:text-amber-200 px-4 py-3 rounded mb-4 flex items-start gap-2">
<span className="text-lg">&#9888;</span>
<div>
<p className="font-medium">DC형 </p>
<p className="text-sm mt-1">
: <strong>{currentRiskRatio.toFixed(1)}%</strong> ( 한도: 70%).
/ ETF .
</p>
</div>
</div>
)}
{result && ( {result && (
<> <>
<Card className="mb-6"> <Card className="mb-6">

View File

@ -158,8 +158,7 @@ export default function SignalsPage() {
try { try {
const data = await api.get<Signal[]>('/api/signal/kjb/today'); const data = await api.get<Signal[]>('/api/signal/kjb/today');
setTodaySignals(data); setTodaySignals(data);
} catch (err) { } catch {
console.error('Failed to fetch today signals:', err);
toast.error('오늘의 신호를 불러오는데 실패했습니다.'); toast.error('오늘의 신호를 불러오는데 실패했습니다.');
} }
}; };
@ -174,8 +173,7 @@ export default function SignalsPage() {
const url = `/api/signal/kjb/history${query ? `?${query}` : ''}`; const url = `/api/signal/kjb/history${query ? `?${query}` : ''}`;
const data = await api.get<Signal[]>(url); const data = await api.get<Signal[]>(url);
setHistorySignals(data); setHistorySignals(data);
} catch (err) { } catch {
console.error('Failed to fetch signal history:', err);
toast.error('신호 이력을 불러오는데 실패했습니다.'); toast.error('신호 이력을 불러오는데 실패했습니다.');
} }
}; };
@ -184,8 +182,7 @@ export default function SignalsPage() {
try { try {
const data = await api.get<Portfolio[]>('/api/portfolios'); const data = await api.get<Portfolio[]>('/api/portfolios');
setPortfolios(data); setPortfolios(data);
} catch (err) { } catch {
console.error('Failed to fetch portfolios:', err);
toast.error('포트폴리오 목록을 불러오는데 실패했습니다.'); toast.error('포트폴리오 목록을 불러오는데 실패했습니다.');
} }
}; };
@ -254,8 +251,7 @@ export default function SignalsPage() {
if (ps.recommended_quantity > 0) { if (ps.recommended_quantity > 0) {
setExecuteQuantity(String(ps.recommended_quantity)); setExecuteQuantity(String(ps.recommended_quantity));
} }
} catch (err) { } catch {
console.error('Failed to fetch position sizing:', err);
toast.error('포지션 사이징 정보를 불러오는데 실패했습니다.'); toast.error('포지션 사이징 정보를 불러오는데 실패했습니다.');
} }
} }

View File

@ -9,6 +9,7 @@ import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label'; import { Label } from '@/components/ui/label';
import { Skeleton } from '@/components/ui/skeleton'; import { Skeleton } from '@/components/ui/skeleton';
import { api } from '@/lib/api'; import { api } from '@/lib/api';
import { ApplyToPortfolio } from '@/components/strategy/apply-to-portfolio';
interface StockFactor { interface StockFactor {
ticker: string; ticker: string;
@ -191,6 +192,9 @@ export default function KJBStrategyPage() {
</table> </table>
</div> </div>
</CardContent> </CardContent>
<div className="px-4 pb-4">
<ApplyToPortfolio stocks={result.stocks.map((s) => ({ ticker: s.ticker, name: s.name }))} />
</div>
</Card> </Card>
)} )}
</DashboardLayout> </DashboardLayout>

View File

@ -1,9 +1,26 @@
'use client'; 'use client';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { useRouter } from 'next/navigation';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import { Label } from '@/components/ui/label'; import { Label } from '@/components/ui/label';
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
DialogFooter,
} from '@/components/ui/dialog';
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { api } from '@/lib/api'; import { api } from '@/lib/api';
import { toast } from 'sonner';
interface Portfolio { interface Portfolio {
id: number; id: number;
@ -21,19 +38,18 @@ interface ApplyToPortfolioProps {
} }
export function ApplyToPortfolio({ stocks }: ApplyToPortfolioProps) { export function ApplyToPortfolio({ stocks }: ApplyToPortfolioProps) {
const router = useRouter();
const [portfolios, setPortfolios] = useState<Portfolio[]>([]); const [portfolios, setPortfolios] = useState<Portfolio[]>([]);
const [selectedId, setSelectedId] = useState<number | null>(null); const [selectedId, setSelectedId] = useState<string>('');
const [showConfirm, setShowConfirm] = useState(false); const [dialogOpen, setDialogOpen] = useState(false);
const [applying, setApplying] = useState(false); const [applying, setApplying] = useState(false);
const [error, setError] = useState<string | null>(null);
const [success, setSuccess] = useState(false);
useEffect(() => { useEffect(() => {
const load = async () => { const load = async () => {
try { try {
const data = await api.get<Portfolio[]>('/api/portfolios'); const data = await api.get<Portfolio[]>('/api/portfolios');
setPortfolios(data); setPortfolios(data);
if (data.length > 0) setSelectedId(data[0].id); if (data.length > 0) setSelectedId(String(data[0].id));
} catch { } catch {
// ignore // ignore
} }
@ -41,10 +57,10 @@ export function ApplyToPortfolio({ stocks }: ApplyToPortfolioProps) {
load(); load();
}, []); }, []);
const apply = async () => { const handleApply = async () => {
if (!selectedId || stocks.length === 0) return; const portfolioId = Number(selectedId);
if (!portfolioId || stocks.length === 0) return;
setApplying(true); setApplying(true);
setError(null);
try { try {
const ratio = parseFloat((100 / stocks.length).toFixed(2)); const ratio = parseFloat((100 / stocks.length).toFixed(2));
const targets: TargetItem[] = stocks.map((s, i) => ({ const targets: TargetItem[] = stocks.map((s, i) => ({
@ -54,12 +70,16 @@ export function ApplyToPortfolio({ stocks }: ApplyToPortfolioProps) {
: ratio, : ratio,
})); }));
await api.put(`/api/portfolios/${selectedId}/targets`, targets); await api.put(`/api/portfolios/${portfolioId}/targets`, targets);
setShowConfirm(false); setDialogOpen(false);
setSuccess(true); toast.success('목표 배분이 적용되었습니다.', {
setTimeout(() => setSuccess(false), 3000); action: {
label: '리밸런싱으로 이동',
onClick: () => router.push(`/portfolio/${portfolioId}/rebalance`),
},
});
} catch (err) { } catch (err) {
setError(err instanceof Error ? err.message : '적용 실패'); toast.error(err instanceof Error ? err.message : '목표 배분 적용 실패했습니다.');
} finally { } finally {
setApplying(false); setApplying(false);
} }
@ -68,71 +88,58 @@ export function ApplyToPortfolio({ stocks }: ApplyToPortfolioProps) {
if (portfolios.length === 0) return null; if (portfolios.length === 0) return null;
return ( return (
<div className="mt-4"> <>
<div className="flex items-end gap-3"> <Button className="mt-4" onClick={() => setDialogOpen(true)}>
<div>
<Label htmlFor="portfolio-select"> </Label> </Button>
<select
id="portfolio-select"
className="mt-1 block w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
value={selectedId ?? ''}
onChange={(e) => setSelectedId(Number(e.target.value))}
>
{portfolios.map((p) => (
<option key={p.id} value={p.id}>
{p.name} ({p.portfolio_type === 'pension' ? '퇴직연금' : '일반'})
</option>
))}
</select>
</div>
<Button onClick={() => setShowConfirm(true)}>
</Button>
</div>
{success && ( <Dialog open={dialogOpen} onOpenChange={setDialogOpen}>
<div className="mt-2 text-sm text-green-600 dark:text-green-400"> <DialogContent>
. <DialogHeader>
</div> <DialogTitle> </DialogTitle>
)} <DialogDescription>
.
{stocks.length} ({(100 / stocks.length).toFixed(2)}%) .
</DialogDescription>
</DialogHeader>
{showConfirm && ( <div className="space-y-4">
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/50"> <div className="space-y-2">
<div className="bg-background rounded-lg shadow-lg max-w-md w-full mx-4 max-h-[80vh] overflow-y-auto"> <Label> </Label>
<div className="p-6"> <Select value={selectedId} onValueChange={setSelectedId}>
<h2 className="text-lg font-bold mb-2"> </h2> <SelectTrigger>
<p className="text-sm text-muted-foreground mb-4"> <SelectValue placeholder="포트폴리오를 선택하세요" />
. </SelectTrigger>
{stocks.length} ({(100 / stocks.length).toFixed(2)}%) . <SelectContent>
</p> {portfolios.map((p) => (
<SelectItem key={p.id} value={String(p.id)}>
{p.name} ({p.portfolio_type === 'pension' ? '퇴직연금' : '일반'})
</SelectItem>
))}
</SelectContent>
</Select>
</div>
{error && ( <div className="max-h-48 overflow-y-auto border rounded p-2">
<div className="bg-destructive/10 border border-destructive text-destructive px-3 py-2 rounded mb-4 text-sm"> {stocks.map((s) => (
{error} <div key={s.ticker} className="text-sm py-1 flex justify-between">
<span>{s.name || s.ticker}</span>
<span className="text-muted-foreground">{(100 / stocks.length).toFixed(2)}%</span>
</div> </div>
)} ))}
<div className="max-h-48 overflow-y-auto mb-4 border rounded p-2">
{stocks.map((s) => (
<div key={s.ticker} className="text-sm py-1 flex justify-between">
<span>{s.name || s.ticker}</span>
<span className="text-muted-foreground">{(100 / stocks.length).toFixed(2)}%</span>
</div>
))}
</div>
<div className="flex justify-end gap-2">
<Button variant="outline" onClick={() => setShowConfirm(false)}>
</Button>
<Button onClick={apply} disabled={applying}>
{applying ? '적용 중...' : '적용'}
</Button>
</div>
</div> </div>
</div> </div>
</div>
)} <DialogFooter>
</div> <Button variant="outline" onClick={() => setDialogOpen(false)}>
</Button>
<Button onClick={handleApply} disabled={applying || !selectedId}>
{applying ? '적용 중...' : '적용'}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</>
); );
} }