diff --git a/.env.example b/.env.example index 1445b76..55d6845 100644 --- a/.env.example +++ b/.env.example @@ -15,5 +15,10 @@ KIS_ACCOUNT_NO=your_account_number # DART OpenAPI (Financial Statements, optional) DART_API_KEY=your_dart_api_key +# Notifications (optional) +DISCORD_WEBHOOK_URL= +TELEGRAM_BOT_TOKEN= +TELEGRAM_CHAT_ID= + # Production only API_URL=https://your-domain.com diff --git a/backend/alembic/versions/d4e5f6a7b8c9_add_notification_tables.py b/backend/alembic/versions/d4e5f6a7b8c9_add_notification_tables.py new file mode 100644 index 0000000..5537414 --- /dev/null +++ b/backend/alembic/versions/d4e5f6a7b8c9_add_notification_tables.py @@ -0,0 +1,59 @@ +"""add notification settings and history tables + +Revision ID: d4e5f6a7b8c9 +Revises: c3d4e5f6a7b8 +Create Date: 2026-03-29 10:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'd4e5f6a7b8c9' +down_revision: Union[str, None] = 'c3d4e5f6a7b8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'notification_settings', + sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), + sa.Column('user_id', sa.Integer(), sa.ForeignKey('users.id'), nullable=False), + sa.Column('channel_type', sa.Enum('discord', 'telegram', name='channeltype'), nullable=False), + sa.Column('webhook_url', sa.String(500), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default='true'), + sa.Column('created_at', sa.DateTime(), server_default=sa.func.now()), + sa.Column('updated_at', sa.DateTime(), server_default=sa.func.now()), + ) + op.create_index('ix_notification_settings_id', 'notification_settings', ['id']) + op.create_index('ix_notification_settings_user_id', 'notification_settings', ['user_id']) + + op.create_table( + 'notification_history', + sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), + sa.Column('signal_id', sa.Integer(), sa.ForeignKey('signals.id'), nullable=False), + sa.Column('channel_type', sa.Enum('discord', 'telegram', name='channeltype', create_type=False), nullable=False), + sa.Column('sent_at', sa.DateTime(), server_default=sa.func.now()), + sa.Column('status', sa.Enum('sent', 'failed', name='notificationstatus'), nullable=False), + sa.Column('message', sa.Text(), nullable=True), + sa.Column('error_message', sa.Text(), nullable=True), + ) + op.create_index('ix_notification_history_id', 'notification_history', ['id']) + op.create_index('ix_notification_history_signal_id', 'notification_history', ['signal_id']) + + +def downgrade() -> None: + op.drop_index('ix_notification_history_signal_id', table_name='notification_history') + op.drop_index('ix_notification_history_id', table_name='notification_history') + op.drop_table('notification_history') + + op.drop_index('ix_notification_settings_user_id', table_name='notification_settings') + op.drop_index('ix_notification_settings_id', table_name='notification_settings') + op.drop_table('notification_settings') + + sa.Enum(name='notificationstatus').drop(op.get_bind(), checkfirst=True) + sa.Enum(name='channeltype').drop(op.get_bind(), checkfirst=True) diff --git a/backend/alembic/versions/e5f6a7b8c9d0_add_trade_journal_table.py b/backend/alembic/versions/e5f6a7b8c9d0_add_trade_journal_table.py new file mode 100644 index 0000000..eca5230 --- /dev/null +++ b/backend/alembic/versions/e5f6a7b8c9d0_add_trade_journal_table.py @@ -0,0 +1,71 @@ +"""add trade journal table + +Revision ID: e5f6a7b8c9d0 +Revises: d4e5f6a7b8c9 +Create Date: 2026-03-29 12:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "e5f6a7b8c9d0" +down_revision: Union[str, None] = "d4e5f6a7b8c9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "trade_journals", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("user_id", sa.Integer(), sa.ForeignKey("users.id"), nullable=False), + sa.Column("stock_code", sa.String(20), nullable=False), + sa.Column("stock_name", sa.String(100), nullable=True), + sa.Column( + "trade_type", + sa.Enum("buy", "sell", name="tradetype"), + nullable=False, + ), + sa.Column("entry_price", sa.Numeric(12, 2), nullable=True), + sa.Column("target_price", sa.Numeric(12, 2), nullable=True), + sa.Column("stop_loss_price", sa.Numeric(12, 2), nullable=True), + sa.Column("exit_price", sa.Numeric(12, 2), nullable=True), + sa.Column("entry_date", sa.Date(), nullable=False), + sa.Column("exit_date", sa.Date(), nullable=True), + sa.Column("quantity", sa.Integer(), nullable=True), + sa.Column("profit_loss", sa.Numeric(14, 2), nullable=True), + sa.Column("profit_loss_pct", sa.Numeric(8, 4), nullable=True), + sa.Column("entry_reason", sa.Text(), nullable=True), + sa.Column("exit_reason", sa.Text(), nullable=True), + sa.Column("scenario", sa.Text(), nullable=True), + sa.Column("lessons_learned", sa.Text(), nullable=True), + sa.Column("emotional_state", sa.Text(), nullable=True), + sa.Column("strategy_id", sa.Integer(), nullable=True), + sa.Column( + "status", + sa.Enum("open", "closed", name="journalstatus"), + server_default="open", + ), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(), server_default=sa.func.now()), + ) + op.create_index("ix_trade_journals_id", "trade_journals", ["id"]) + op.create_index("ix_trade_journals_user_id", "trade_journals", ["user_id"]) + op.create_index("ix_trade_journals_stock_code", "trade_journals", ["stock_code"]) + op.create_index("ix_trade_journals_entry_date", "trade_journals", ["entry_date"]) + op.create_index("ix_trade_journals_status", "trade_journals", ["status"]) + + +def downgrade() -> None: + op.drop_index("ix_trade_journals_status") + op.drop_index("ix_trade_journals_entry_date") + op.drop_index("ix_trade_journals_stock_code") + op.drop_index("ix_trade_journals_user_id") + op.drop_index("ix_trade_journals_id") + op.drop_table("trade_journals") + sa.Enum(name="tradetype").drop(op.get_bind(), checkfirst=True) + sa.Enum(name="journalstatus").drop(op.get_bind(), checkfirst=True) diff --git a/backend/alembic/versions/f6a7b8c9d0e1_add_pension_account_tables.py b/backend/alembic/versions/f6a7b8c9d0e1_add_pension_account_tables.py new file mode 100644 index 0000000..7eab1a6 --- /dev/null +++ b/backend/alembic/versions/f6a7b8c9d0e1_add_pension_account_tables.py @@ -0,0 +1,74 @@ +"""add pension account tables + +Revision ID: f6a7b8c9d0e1 +Revises: e5f6a7b8c9d0 +Create Date: 2026-03-29 14:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "f6a7b8c9d0e1" +down_revision: Union[str, None] = "e5f6a7b8c9d0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "pension_accounts", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("user_id", sa.Integer(), sa.ForeignKey("users.id"), nullable=False), + sa.Column( + "account_type", + sa.Enum("dc", "irp", "personal", name="accounttype"), + nullable=False, + ), + sa.Column("account_name", sa.String(100), nullable=False), + sa.Column("total_amount", sa.Numeric(16, 2), nullable=False, server_default="0"), + sa.Column("birth_year", sa.Integer(), nullable=False), + sa.Column("target_retirement_age", sa.Integer(), server_default="60"), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(), server_default=sa.func.now()), + ) + op.create_index("ix_pension_accounts_id", "pension_accounts", ["id"]) + op.create_index("ix_pension_accounts_user_id", "pension_accounts", ["user_id"]) + op.create_index("ix_pension_accounts_account_type", "pension_accounts", ["account_type"]) + + op.create_table( + "pension_holdings", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column( + "account_id", + sa.Integer(), + sa.ForeignKey("pension_accounts.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("asset_name", sa.String(200), nullable=False), + sa.Column( + "asset_type", + sa.Enum("safe", "risky", name="assetrisktype"), + nullable=False, + ), + sa.Column("amount", sa.Numeric(16, 2), nullable=False, server_default="0"), + sa.Column("ratio", sa.Numeric(6, 2), nullable=False, server_default="0"), + ) + op.create_index("ix_pension_holdings_id", "pension_holdings", ["id"]) + op.create_index("ix_pension_holdings_account_id", "pension_holdings", ["account_id"]) + + +def downgrade() -> None: + op.drop_index("ix_pension_holdings_account_id") + op.drop_index("ix_pension_holdings_id") + op.drop_table("pension_holdings") + sa.Enum(name="assetrisktype").drop(op.get_bind(), checkfirst=True) + + op.drop_index("ix_pension_accounts_account_type") + op.drop_index("ix_pension_accounts_user_id") + op.drop_index("ix_pension_accounts_id") + op.drop_table("pension_accounts") + sa.Enum(name="accounttype").drop(op.get_bind(), checkfirst=True) diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index 1530c2c..f82349a 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -7,6 +7,15 @@ from app.api.backtest import router as backtest_router from app.api.snapshot import router as snapshot_router from app.api.data_explorer import router as data_explorer_router from app.api.signal import router as signal_router +from app.api.notification import router as notification_router +from app.api.journal import router as journal_router +from app.api.position_sizing import router as position_sizing_router +from app.api.pension import router as pension_router +from app.api.drawdown import router as drawdown_router +from app.api.benchmark import router as benchmark_router +from app.api.tax_simulation import router as tax_simulation_router +from app.api.correlation import router as correlation_router +from app.api.optimizer import router as optimizer_router __all__ = [ "auth_router", @@ -18,4 +27,13 @@ __all__ = [ "snapshot_router", "data_explorer_router", "signal_router", + "notification_router", + "journal_router", + "position_sizing_router", + "pension_router", + "drawdown_router", + "benchmark_router", + "tax_simulation_router", + "correlation_router", + "optimizer_router", ] diff --git a/backend/app/api/benchmark.py b/backend/app/api/benchmark.py new file mode 100644 index 0000000..80aea0e --- /dev/null +++ b/backend/app/api/benchmark.py @@ -0,0 +1,59 @@ +""" +Benchmark comparison API endpoints. +""" +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session + +from app.core.database import get_db +from app.api.deps import CurrentUser +from app.schemas.benchmark import ( + BenchmarkCompareResponse, + BenchmarkIndexInfo, + BenchmarkType, + PeriodType, +) +from app.services.benchmark import BenchmarkService + +router = APIRouter(prefix="/api/benchmark", tags=["benchmark"]) + + +@router.get("/indices", response_model=List[BenchmarkIndexInfo]) +async def list_indices(): + """사용 가능한 벤치마크 목록 반환.""" + return [ + BenchmarkIndexInfo( + code="kospi", + name="KOSPI", + description="코스피 종합지수", + ), + BenchmarkIndexInfo( + code="deposit", + name="정기예금", + description="정기예금 금리 (연 3.5%)", + ), + ] + + +@router.get("/compare/{portfolio_id}", response_model=BenchmarkCompareResponse) +async def compare_with_benchmark( + portfolio_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), + benchmark: BenchmarkType = Query(BenchmarkType.KOSPI), + period: PeriodType = Query(PeriodType.ONE_YEAR), +): + """포트폴리오 성과를 벤치마크와 비교.""" + service = BenchmarkService(db) + try: + result = service.compare_with_benchmark( + portfolio_id=portfolio_id, + benchmark_type=benchmark.value, + period=period.value, + user_id=current_user.id, + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + return result diff --git a/backend/app/api/correlation.py b/backend/app/api/correlation.py new file mode 100644 index 0000000..1012c0c --- /dev/null +++ b/backend/app/api/correlation.py @@ -0,0 +1,86 @@ +""" +Correlation analysis API endpoints. +""" +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session + +from app.core.database import get_db +from app.api.deps import CurrentUser +from app.schemas.correlation import ( + CorrelationMatrixRequest, + CorrelationMatrixResponse, + DiversificationResponse, + HighCorrelationPair, +) +from app.services.correlation import CorrelationService + +router = APIRouter(prefix="/api/correlation", tags=["correlation"]) + + +@router.post("/matrix", response_model=CorrelationMatrixResponse) +async def calculate_correlation_matrix( + request: CorrelationMatrixRequest, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """종목 간 수익률 상관 행렬 계산.""" + service = CorrelationService(db) + result = service.get_correlation_data(request.stock_codes, request.period_days) + + return CorrelationMatrixResponse( + stock_codes=result["stock_codes"], + matrix=result["matrix"], + high_correlation_pairs=[ + HighCorrelationPair(**p) for p in result["high_correlation_pairs"] + ], + ) + + +@router.get("/portfolio/{portfolio_id}", response_model=DiversificationResponse) +async def get_portfolio_diversification( + portfolio_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """포트폴리오 분산 효과 점수 조회.""" + service = CorrelationService(db) + + try: + score = service.calculate_portfolio_diversification(portfolio_id) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + # Get holdings for correlation data + from app.models.portfolio import Portfolio, PortfolioSnapshot + + portfolio = db.query(Portfolio).filter( + Portfolio.id == portfolio_id, + Portfolio.user_id == current_user.id, + ).first() + if not portfolio: + raise HTTPException(status_code=404, detail="Portfolio not found") + + snapshot = ( + db.query(PortfolioSnapshot) + .filter(PortfolioSnapshot.portfolio_id == portfolio_id) + .order_by(PortfolioSnapshot.snapshot_date.desc()) + .first() + ) + + high_pairs = [] + stock_count = 0 + if snapshot and snapshot.holdings: + tickers = [h.ticker for h in snapshot.holdings] + stock_count = len(tickers) + if len(tickers) >= 2: + corr_data = service.get_correlation_data(tickers, period_days=60) + high_pairs = [ + HighCorrelationPair(**p) for p in corr_data["high_correlation_pairs"] + ] + + return DiversificationResponse( + portfolio_id=portfolio_id, + diversification_score=score, + stock_count=stock_count, + high_correlation_pairs=high_pairs, + ) diff --git a/backend/app/api/drawdown.py b/backend/app/api/drawdown.py new file mode 100644 index 0000000..a238257 --- /dev/null +++ b/backend/app/api/drawdown.py @@ -0,0 +1,69 @@ +""" +Drawdown API endpoints for portfolio risk monitoring. +""" +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from app.core.database import get_db +from app.api.deps import CurrentUser +from app.api.snapshot import _get_portfolio +from app.schemas.drawdown import ( + DrawdownResponse, + DrawdownHistoryResponse, + DrawdownSettingsUpdate, +) +from app.services.drawdown import ( + calculate_drawdown, + calculate_rolling_drawdown, + get_alert_threshold, + set_alert_threshold, +) + +router = APIRouter(prefix="/api/drawdown", tags=["drawdown"]) + + +@router.get("/{portfolio_id}", response_model=DrawdownResponse) +async def get_drawdown( + portfolio_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get current drawdown metrics for a portfolio.""" + _get_portfolio(db, portfolio_id, current_user.id) + data = calculate_drawdown(db, portfolio_id) + return DrawdownResponse(portfolio_id=portfolio_id, **data) + + +@router.get("/{portfolio_id}/history", response_model=DrawdownHistoryResponse) +async def get_drawdown_history( + portfolio_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get rolling drawdown time series for a portfolio.""" + _get_portfolio(db, portfolio_id, current_user.id) + rolling = calculate_rolling_drawdown(db, portfolio_id) + summary = calculate_drawdown(db, portfolio_id) + + return DrawdownHistoryResponse( + portfolio_id=portfolio_id, + data=rolling, + max_drawdown_pct=summary["max_drawdown_pct"], + current_drawdown_pct=summary["current_drawdown_pct"], + ) + + +@router.put("/settings/{portfolio_id}") +async def update_drawdown_settings( + portfolio_id: int, + body: DrawdownSettingsUpdate, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Update drawdown alert threshold for a portfolio.""" + _get_portfolio(db, portfolio_id, current_user.id) + set_alert_threshold(portfolio_id, body.alert_threshold_pct) + return { + "portfolio_id": portfolio_id, + "alert_threshold_pct": float(body.alert_threshold_pct), + } diff --git a/backend/app/api/journal.py b/backend/app/api/journal.py new file mode 100644 index 0000000..2ea4d8b --- /dev/null +++ b/backend/app/api/journal.py @@ -0,0 +1,173 @@ +""" +Trading journal API endpoints. +""" +from datetime import date +from decimal import Decimal +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session +from sqlalchemy import func + +from app.core.database import get_db +from app.api.deps import CurrentUser +from app.models.journal import TradeJournal, JournalStatus +from app.schemas.journal import ( + TradeJournalCreate, + TradeJournalUpdate, + TradeJournalResponse, + TradeJournalStats, +) + +router = APIRouter(prefix="/api/journal", tags=["journal"]) + + +@router.post("", response_model=TradeJournalResponse, status_code=201) +async def create_journal( + data: TradeJournalCreate, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + journal = TradeJournal( + user_id=current_user.id, + **data.model_dump(), + ) + db.add(journal) + db.commit() + db.refresh(journal) + return journal + + +@router.get("", response_model=List[TradeJournalResponse]) +async def list_journals( + current_user: CurrentUser, + db: Session = Depends(get_db), + status: Optional[str] = Query(None), + stock_code: Optional[str] = Query(None), + start_date: Optional[date] = Query(None), + end_date: Optional[date] = Query(None), + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + query = db.query(TradeJournal).filter(TradeJournal.user_id == current_user.id) + + if status: + query = query.filter(TradeJournal.status == status) + if stock_code: + query = query.filter(TradeJournal.stock_code == stock_code) + if start_date: + query = query.filter(TradeJournal.entry_date >= start_date) + if end_date: + query = query.filter(TradeJournal.entry_date <= end_date) + + journals = ( + query.order_by(TradeJournal.entry_date.desc(), TradeJournal.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + return journals + + +@router.get("/stats", response_model=TradeJournalStats) +async def get_journal_stats( + current_user: CurrentUser, + db: Session = Depends(get_db), +): + journals = ( + db.query(TradeJournal) + .filter(TradeJournal.user_id == current_user.id) + .all() + ) + + total = len(journals) + open_trades = sum(1 for j in journals if j.status == JournalStatus.OPEN) + closed = [j for j in journals if j.status == JournalStatus.CLOSED] + closed_count = len(closed) + + closed_with_pnl = [j for j in closed if j.profit_loss_pct is not None] + win_count = sum(1 for j in closed_with_pnl if j.profit_loss_pct > 0) + loss_count = sum(1 for j in closed_with_pnl if j.profit_loss_pct <= 0) + + win_rate = None + avg_pnl_pct = None + max_profit = None + max_loss = None + total_pnl = None + + if closed_with_pnl: + win_rate = Decimal(win_count) / Decimal(len(closed_with_pnl)) * 100 + pcts = [j.profit_loss_pct for j in closed_with_pnl] + avg_pnl_pct = sum(pcts) / len(pcts) + max_profit = max(pcts) + max_loss = min(pcts) + + closed_with_pl = [j for j in closed if j.profit_loss is not None] + if closed_with_pl: + total_pnl = sum(j.profit_loss for j in closed_with_pl) + + return TradeJournalStats( + total_trades=total, + open_trades=open_trades, + closed_trades=closed_count, + win_count=win_count, + loss_count=loss_count, + win_rate=win_rate, + avg_profit_loss_pct=avg_pnl_pct, + max_profit_pct=max_profit, + max_loss_pct=max_loss, + total_profit_loss=total_pnl, + ) + + +@router.get("/{journal_id}", response_model=TradeJournalResponse) +async def get_journal( + journal_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + journal = ( + db.query(TradeJournal) + .filter(TradeJournal.id == journal_id, TradeJournal.user_id == current_user.id) + .first() + ) + if not journal: + raise HTTPException(status_code=404, detail="Journal not found") + return journal + + +@router.put("/{journal_id}", response_model=TradeJournalResponse) +async def update_journal( + journal_id: int, + data: TradeJournalUpdate, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + journal = ( + db.query(TradeJournal) + .filter(TradeJournal.id == journal_id, TradeJournal.user_id == current_user.id) + .first() + ) + if not journal: + raise HTTPException(status_code=404, detail="Journal not found") + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(journal, field, value) + + # Auto-calculate profit/loss when closing + if data.exit_price is not None and journal.entry_price is not None: + if journal.trade_type.value == "buy": + journal.profit_loss = (data.exit_price - journal.entry_price) * (journal.quantity or 1) + journal.profit_loss_pct = (data.exit_price - journal.entry_price) / journal.entry_price * 100 + else: + journal.profit_loss = (journal.entry_price - data.exit_price) * (journal.quantity or 1) + journal.profit_loss_pct = (journal.entry_price - data.exit_price) / journal.entry_price * 100 + + # Auto-close when exit info is provided + if data.exit_price is not None and data.exit_date is not None and journal.status == JournalStatus.OPEN: + journal.status = JournalStatus.CLOSED + + db.commit() + db.refresh(journal) + return journal diff --git a/backend/app/api/notification.py b/backend/app/api/notification.py new file mode 100644 index 0000000..233f706 --- /dev/null +++ b/backend/app/api/notification.py @@ -0,0 +1,111 @@ +""" +Notification settings and history API endpoints. +""" +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session + +from app.core.database import get_db +from app.api.deps import CurrentUser +from app.models.notification import NotificationSetting, NotificationHistory +from app.schemas.notification import ( + NotificationSettingCreate, + NotificationSettingUpdate, + NotificationSettingResponse, + NotificationHistoryResponse, +) + +router = APIRouter(prefix="/api/notifications", tags=["notifications"]) + + +@router.get("/settings", response_model=List[NotificationSettingResponse]) +async def get_notification_settings( + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get all notification settings for the current user.""" + settings = ( + db.query(NotificationSetting) + .filter(NotificationSetting.user_id == current_user.id) + .all() + ) + return settings + + +@router.post("/settings", response_model=NotificationSettingResponse, status_code=201) +async def create_notification_setting( + data: NotificationSettingCreate, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Create a new notification setting.""" + existing = ( + db.query(NotificationSetting) + .filter( + NotificationSetting.user_id == current_user.id, + NotificationSetting.channel_type == data.channel_type, + ) + .first() + ) + if existing: + raise HTTPException( + status_code=400, + detail=f"Setting for {data.channel_type.value} already exists. Use PUT to update.", + ) + + setting = NotificationSetting( + user_id=current_user.id, + channel_type=data.channel_type, + webhook_url=data.webhook_url, + enabled=data.enabled, + ) + db.add(setting) + db.commit() + db.refresh(setting) + return setting + + +@router.put("/settings/{setting_id}", response_model=NotificationSettingResponse) +async def update_notification_setting( + setting_id: int, + data: NotificationSettingUpdate, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Update an existing notification setting.""" + setting = ( + db.query(NotificationSetting) + .filter( + NotificationSetting.id == setting_id, + NotificationSetting.user_id == current_user.id, + ) + .first() + ) + if not setting: + raise HTTPException(status_code=404, detail="Notification setting not found") + + if data.webhook_url is not None: + setting.webhook_url = data.webhook_url + if data.enabled is not None: + setting.enabled = data.enabled + + db.commit() + db.refresh(setting) + return setting + + +@router.get("/history", response_model=List[NotificationHistoryResponse]) +async def get_notification_history( + current_user: CurrentUser, + db: Session = Depends(get_db), + limit: int = Query(50, ge=1, le=200), +): + """Get notification history.""" + history = ( + db.query(NotificationHistory) + .order_by(NotificationHistory.sent_at.desc()) + .limit(limit) + .all() + ) + return history diff --git a/backend/app/api/optimizer.py b/backend/app/api/optimizer.py new file mode 100644 index 0000000..be98948 --- /dev/null +++ b/backend/app/api/optimizer.py @@ -0,0 +1,67 @@ +""" +Strategy optimizer API router. +""" +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session + +from app.api.deps import CurrentUser +from app.core.database import get_db +from app.schemas.optimizer import ( + DEFAULT_GRIDS, + STRATEGY_TYPES, + OptimizeRequest, + OptimizeResponse, +) +from app.services.optimizer import OptimizerService + +router = APIRouter(prefix="/api/optimizer", tags=["optimizer"]) + + +@router.post("", response_model=OptimizeResponse) +async def run_optimization( + request: OptimizeRequest, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Run grid-search optimization for a strategy.""" + if request.strategy_type not in STRATEGY_TYPES: + raise HTTPException( + status_code=400, + detail=f"Unknown strategy_type: {request.strategy_type}. " + f"Valid types: {STRATEGY_TYPES}", + ) + + service = OptimizerService(db) + try: + return service.optimize(request) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.get("/presets/{strategy_type}") +async def get_preset( + strategy_type: str, + current_user: CurrentUser, +): + """Get default parameter grid preset for a strategy type.""" + if strategy_type not in DEFAULT_GRIDS: + raise HTTPException( + status_code=404, + detail=f"No preset for strategy_type: {strategy_type}", + ) + return { + "strategy_type": strategy_type, + "param_grid": DEFAULT_GRIDS[strategy_type], + } + + +@router.get("/presets") +async def list_presets( + current_user: CurrentUser, +): + """List all available strategy presets.""" + return { + "presets": { + st: DEFAULT_GRIDS[st] for st in STRATEGY_TYPES + } + } diff --git a/backend/app/api/pension.py b/backend/app/api/pension.py new file mode 100644 index 0000000..52a7c46 --- /dev/null +++ b/backend/app/api/pension.py @@ -0,0 +1,157 @@ +""" +Pension account API endpoints. +""" +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session, joinedload + +from app.core.database import get_db +from app.api.deps import CurrentUser +from app.models.pension import PensionAccount, PensionHolding +from app.schemas.pension import ( + PensionAccountCreate, + PensionAccountUpdate, + PensionAccountResponse, + AllocationResult, + RecommendationResult, +) +from app.services.pension_allocation import calculate_allocation, get_recommendation + +router = APIRouter(prefix="/api/pension", tags=["pension"]) + + +@router.post("/accounts", response_model=PensionAccountResponse, status_code=201) +async def create_account( + data: PensionAccountCreate, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + account = PensionAccount( + user_id=current_user.id, + **data.model_dump(), + ) + db.add(account) + db.commit() + db.refresh(account) + return account + + +@router.get("/accounts", response_model=List[PensionAccountResponse]) +async def list_accounts( + current_user: CurrentUser, + db: Session = Depends(get_db), + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + accounts = ( + db.query(PensionAccount) + .options(joinedload(PensionAccount.holdings)) + .filter(PensionAccount.user_id == current_user.id) + .order_by(PensionAccount.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + return accounts + + +@router.get("/accounts/{account_id}", response_model=PensionAccountResponse) +async def get_account( + account_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + account = ( + db.query(PensionAccount) + .options(joinedload(PensionAccount.holdings)) + .filter(PensionAccount.id == account_id, PensionAccount.user_id == current_user.id) + .first() + ) + if not account: + raise HTTPException(status_code=404, detail="Pension account not found") + return account + + +@router.put("/accounts/{account_id}", response_model=PensionAccountResponse) +async def update_account( + account_id: int, + data: PensionAccountUpdate, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + account = ( + db.query(PensionAccount) + .options(joinedload(PensionAccount.holdings)) + .filter(PensionAccount.id == account_id, PensionAccount.user_id == current_user.id) + .first() + ) + if not account: + raise HTTPException(status_code=404, detail="Pension account not found") + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(account, field, value) + + db.commit() + db.refresh(account) + return account + + +@router.post("/accounts/{account_id}/allocate", response_model=AllocationResult) +async def allocate_assets( + account_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + account = ( + db.query(PensionAccount) + .filter(PensionAccount.id == account_id, PensionAccount.user_id == current_user.id) + .first() + ) + if not account: + raise HTTPException(status_code=404, detail="Pension account not found") + + result = calculate_allocation( + account_id=account.id, + account_type=account.account_type.value, + total_amount=account.total_amount, + birth_year=account.birth_year, + target_retirement_age=account.target_retirement_age, + ) + + # Save allocation as holdings + db.query(PensionHolding).filter(PensionHolding.account_id == account_id).delete() + for alloc in result.allocations: + holding = PensionHolding( + account_id=account_id, + asset_name=alloc.asset_name, + asset_type=alloc.asset_type, + amount=alloc.amount, + ratio=alloc.ratio, + ) + db.add(holding) + db.commit() + + return result + + +@router.get("/accounts/{account_id}/recommendation", response_model=RecommendationResult) +async def get_account_recommendation( + account_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + account = ( + db.query(PensionAccount) + .filter(PensionAccount.id == account_id, PensionAccount.user_id == current_user.id) + .first() + ) + if not account: + raise HTTPException(status_code=404, detail="Pension account not found") + + return get_recommendation( + account_id=account.id, + birth_year=account.birth_year, + target_retirement_age=account.target_retirement_age, + ) diff --git a/backend/app/api/position_sizing.py b/backend/app/api/position_sizing.py new file mode 100644 index 0000000..01f44c6 --- /dev/null +++ b/backend/app/api/position_sizing.py @@ -0,0 +1,83 @@ +""" +Position sizing API endpoints. +""" +from fastapi import APIRouter, HTTPException + +from app.api.deps import CurrentUser +from app.schemas.position_sizing import ( + SizingMethod, + PositionSizeRequest, + PositionSizeResponse, + MethodInfo, + MethodsResponse, +) +from app.services.position_sizing import fixed_ratio, kelly_criterion, atr_based + +router = APIRouter(prefix="/api/position-sizing", tags=["position-sizing"]) + +METHODS = [ + MethodInfo( + name="fixed", + label="균등 분배 (Fixed Ratio)", + description="자본금을 종목 수로 균등 분배. 현금 비중 설정 가능. quant.md 기본 방식.", + ), + MethodInfo( + name="kelly", + label="켈리 기준 (Kelly Criterion)", + description="승률과 손익비를 기반으로 최적 베팅 비율 계산. 보수적 1/4 켈리 기본.", + ), + MethodInfo( + name="atr", + label="ATR 변동성 (ATR-Based)", + description="ATR(평균 진폭)을 이용한 변동성 기반 포지션 사이징.", + ), +] + + +@router.post("/calculate", response_model=PositionSizeResponse) +async def calculate_position_size( + request: PositionSizeRequest, + current_user: CurrentUser, +): + """Calculate position size based on selected method.""" + try: + if request.method == SizingMethod.FIXED: + result = fixed_ratio( + capital=request.capital, + num_positions=request.num_positions, + cash_ratio=request.cash_ratio, + ) + elif request.method == SizingMethod.KELLY: + if not all([request.win_rate, request.avg_win, request.avg_loss]): + raise HTTPException( + status_code=422, + detail="Kelly method requires win_rate, avg_win, avg_loss", + ) + result = kelly_criterion( + win_rate=request.win_rate, + avg_win=request.avg_win, + avg_loss=request.avg_loss, + ) + elif request.method == SizingMethod.ATR: + if not request.atr: + raise HTTPException( + status_code=422, + detail="ATR method requires atr value", + ) + result = atr_based( + capital=request.capital, + atr=request.atr, + risk_pct=request.risk_pct, + ) + else: + raise HTTPException(status_code=422, detail=f"Unknown method: {request.method}") + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + + return PositionSizeResponse(**result) + + +@router.get("/methods", response_model=MethodsResponse) +async def get_methods(current_user: CurrentUser): + """List supported position sizing methods.""" + return MethodsResponse(methods=METHODS) diff --git a/backend/app/api/tax_simulation.py b/backend/app/api/tax_simulation.py new file mode 100644 index 0000000..6f37c76 --- /dev/null +++ b/backend/app/api/tax_simulation.py @@ -0,0 +1,54 @@ +""" +Tax simulation API endpoints. +""" +from fastapi import APIRouter + +from app.schemas.tax_simulation import ( + AccumulationRequest, + AccumulationResponse, + PensionTaxRequest, + PensionTaxResponse, + TaxDeductionRequest, + TaxDeductionResponse, +) +from app.services.tax_simulation import ( + calculate_pension_tax, + calculate_tax_deduction, + simulate_accumulation, +) + +router = APIRouter(prefix="/api/tax", tags=["tax-simulation"]) + + +@router.post("/deduction", response_model=TaxDeductionResponse) +async def tax_deduction(request: TaxDeductionRequest): + """연간 세액공제 계산.""" + result = calculate_tax_deduction( + annual_income=request.annual_income, + contribution=request.contribution, + account_type=request.account_type.value, + ) + return result + + +@router.post("/pension-tax", response_model=PensionTaxResponse) +async def pension_tax(request: PensionTaxRequest): + """연금 수령 시 세금 비교.""" + result = calculate_pension_tax( + withdrawal_amount=request.withdrawal_amount, + withdrawal_type=request.withdrawal_type.value, + age=request.age, + ) + return result + + +@router.post("/accumulation", response_model=AccumulationResponse) +async def accumulation(request: AccumulationRequest): + """적립 시뮬레이션.""" + result = simulate_accumulation( + monthly_contribution=request.monthly_contribution, + years=request.years, + annual_return=request.annual_return, + tax_deduction_rate=request.tax_deduction_rate, + ) + return result diff --git a/backend/app/core/config.py b/backend/app/core/config.py index ce87aa1..84b9a58 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -33,6 +33,11 @@ class Settings(BaseSettings): kis_account_no: str = "" dart_api_key: str = "" + # Notifications (optional) + discord_webhook_url: str = "" + telegram_bot_token: str = "" + telegram_chat_id: str = "" + class Config: env_file = ".env" case_sensitive = False diff --git a/backend/app/main.py b/backend/app/main.py index 31cd38f..9c19a64 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -10,7 +10,10 @@ from fastapi.middleware.cors import CORSMiddleware from app.api import ( auth_router, admin_router, portfolio_router, strategy_router, market_router, backtest_router, snapshot_router, data_explorer_router, - signal_router, + signal_router, notification_router, journal_router, position_sizing_router, + pension_router, drawdown_router, benchmark_router, tax_simulation_router, + correlation_router, + optimizer_router, ) # Configure logging @@ -114,6 +117,15 @@ app.include_router(backtest_router) app.include_router(snapshot_router) app.include_router(data_explorer_router) app.include_router(signal_router) +app.include_router(notification_router) +app.include_router(journal_router) +app.include_router(position_sizing_router) +app.include_router(pension_router) +app.include_router(drawdown_router) +app.include_router(benchmark_router) +app.include_router(tax_simulation_router) +app.include_router(correlation_router) +app.include_router(optimizer_router) @app.get("/health") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 2d01477..3a89f94 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -23,6 +23,19 @@ from app.models.stock import ( JobLog, ) from app.models.signal import Signal, SignalType, SignalStatus +from app.models.notification import ( + NotificationSetting, + NotificationHistory, + ChannelType, + NotificationStatus, +) +from app.models.journal import TradeJournal, TradeType, JournalStatus +from app.models.pension import ( + PensionAccount, + PensionHolding, + AccountType, + AssetRiskType, +) from app.models.backtest import ( Backtest, BacktestStatus, @@ -64,4 +77,15 @@ __all__ = [ "Signal", "SignalType", "SignalStatus", + "NotificationSetting", + "NotificationHistory", + "ChannelType", + "NotificationStatus", + "TradeJournal", + "TradeType", + "JournalStatus", + "PensionAccount", + "PensionHolding", + "AccountType", + "AssetRiskType", ] diff --git a/backend/app/models/journal.py b/backend/app/models/journal.py new file mode 100644 index 0000000..a3a5cc7 --- /dev/null +++ b/backend/app/models/journal.py @@ -0,0 +1,53 @@ +""" +Trading journal models. +""" +import enum +from datetime import datetime + +from sqlalchemy import ( + Column, Integer, String, Numeric, DateTime, Date, + Text, Enum as SQLEnum, ForeignKey, +) +from sqlalchemy.orm import relationship + +from app.core.database import Base + + +class TradeType(str, enum.Enum): + BUY = "buy" + SELL = "sell" + + +class JournalStatus(str, enum.Enum): + OPEN = "open" + CLOSED = "closed" + + +class TradeJournal(Base): + __tablename__ = "trade_journals" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) + stock_code = Column(String(20), nullable=False, index=True) + stock_name = Column(String(100)) + trade_type = Column(SQLEnum(TradeType), nullable=False) + entry_price = Column(Numeric(12, 2)) + target_price = Column(Numeric(12, 2)) + stop_loss_price = Column(Numeric(12, 2)) + exit_price = Column(Numeric(12, 2)) + entry_date = Column(Date, nullable=False, index=True) + exit_date = Column(Date) + quantity = Column(Integer) + profit_loss = Column(Numeric(14, 2)) + profit_loss_pct = Column(Numeric(8, 4)) + entry_reason = Column(Text) + exit_reason = Column(Text) + scenario = Column(Text) + lessons_learned = Column(Text) + emotional_state = Column(Text) + strategy_id = Column(Integer) + status = Column(SQLEnum(JournalStatus), default=JournalStatus.OPEN, index=True) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + user = relationship("User", backref="trade_journals") diff --git a/backend/app/models/notification.py b/backend/app/models/notification.py new file mode 100644 index 0000000..04fab30 --- /dev/null +++ b/backend/app/models/notification.py @@ -0,0 +1,46 @@ +""" +Notification models for signal alerts. +""" +import enum +from datetime import datetime + +from sqlalchemy import ( + Column, Integer, String, DateTime, Boolean, Text, + ForeignKey, Enum as SQLEnum, +) + +from app.core.database import Base + + +class ChannelType(str, enum.Enum): + DISCORD = "discord" + TELEGRAM = "telegram" + + +class NotificationStatus(str, enum.Enum): + SENT = "sent" + FAILED = "failed" + + +class NotificationSetting(Base): + __tablename__ = "notification_settings" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) + channel_type = Column(SQLEnum(ChannelType), nullable=False) + webhook_url = Column(String(500), nullable=False) + enabled = Column(Boolean, default=True) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + +class NotificationHistory(Base): + __tablename__ = "notification_history" + + id = Column(Integer, primary_key=True, index=True) + signal_id = Column(Integer, ForeignKey("signals.id"), nullable=False, index=True) + channel_type = Column(SQLEnum(ChannelType), nullable=False) + sent_at = Column(DateTime, default=datetime.utcnow) + status = Column(SQLEnum(NotificationStatus), nullable=False) + message = Column(Text) + error_message = Column(Text, nullable=True) diff --git a/backend/app/models/pension.py b/backend/app/models/pension.py new file mode 100644 index 0000000..3ee4bcb --- /dev/null +++ b/backend/app/models/pension.py @@ -0,0 +1,54 @@ +""" +Pension account and holding models for retirement pension asset allocation. +""" +import enum +from datetime import datetime + +from sqlalchemy import ( + Column, Integer, String, Numeric, DateTime, + Enum as SQLEnum, ForeignKey, +) +from sqlalchemy.orm import relationship + +from app.core.database import Base + + +class AccountType(str, enum.Enum): + DC = "dc" + IRP = "irp" + PERSONAL = "personal" + + +class AssetRiskType(str, enum.Enum): + SAFE = "safe" + RISKY = "risky" + + +class PensionAccount(Base): + __tablename__ = "pension_accounts" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) + account_type = Column(SQLEnum(AccountType), nullable=False, index=True) + account_name = Column(String(100), nullable=False) + total_amount = Column(Numeric(16, 2), nullable=False, default=0) + birth_year = Column(Integer, nullable=False) + target_retirement_age = Column(Integer, default=60) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + user = relationship("User", backref="pension_accounts") + holdings = relationship("PensionHolding", back_populates="account", cascade="all, delete-orphan") + + +class PensionHolding(Base): + __tablename__ = "pension_holdings" + + id = Column(Integer, primary_key=True, index=True) + account_id = Column(Integer, ForeignKey("pension_accounts.id"), nullable=False, index=True) + asset_name = Column(String(200), nullable=False) + asset_type = Column(SQLEnum(AssetRiskType), nullable=False) + amount = Column(Numeric(16, 2), nullable=False, default=0) + ratio = Column(Numeric(6, 2), nullable=False, default=0) + + account = relationship("PensionAccount", back_populates="holdings") diff --git a/backend/app/schemas/benchmark.py b/backend/app/schemas/benchmark.py new file mode 100644 index 0000000..30ddb6d --- /dev/null +++ b/backend/app/schemas/benchmark.py @@ -0,0 +1,55 @@ +""" +Benchmark comparison schemas. +""" +import enum +from datetime import date +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class BenchmarkType(str, enum.Enum): + KOSPI = "kospi" + DEPOSIT = "deposit" + + +class PeriodType(str, enum.Enum): + ONE_MONTH = "1m" + THREE_MONTHS = "3m" + SIX_MONTHS = "6m" + ONE_YEAR = "1y" + ALL = "all" + + +class BenchmarkIndexInfo(BaseModel): + code: str + name: str + description: str + + +class TimeSeriesPoint(BaseModel): + date: date + portfolio_return: Optional[float] = None + benchmark_return: Optional[float] = None + deposit_return: Optional[float] = None + + +class PerformanceMetrics(BaseModel): + cumulative_return: float = Field(..., description="누적 수익률 (%)") + annualized_return: float = Field(..., description="연환산 수익률 (%)") + sharpe_ratio: Optional[float] = Field(None, description="샤프 비율") + max_drawdown: float = Field(..., description="최대 낙폭 (%)") + + +class BenchmarkCompareResponse(BaseModel): + portfolio_name: str + benchmark_type: str + period: str + start_date: date + end_date: date + time_series: List[TimeSeriesPoint] + portfolio_metrics: PerformanceMetrics + benchmark_metrics: PerformanceMetrics + deposit_metrics: PerformanceMetrics + alpha: float = Field(..., description="초과 수익률 (포트폴리오 - 벤치마크) (%)") + information_ratio: Optional[float] = Field(None, description="정보 비율") diff --git a/backend/app/schemas/correlation.py b/backend/app/schemas/correlation.py new file mode 100644 index 0000000..c5775c8 --- /dev/null +++ b/backend/app/schemas/correlation.py @@ -0,0 +1,34 @@ +""" +Correlation analysis schemas. +""" +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class CorrelationMatrixRequest(BaseModel): + stock_codes: List[str] = Field(..., description="종목 코드 리스트") + period_days: int = Field(60, description="분석 기간 (일)", ge=7, le=365) + + +class HighCorrelationPair(BaseModel): + stock_a: str + stock_b: str + correlation: float = Field(..., description="상관계수 (-1 ~ 1)") + + +class CorrelationMatrixResponse(BaseModel): + stock_codes: List[str] + matrix: List[List[Optional[float]]] = Field(..., description="상관 행렬 (NxN)") + high_correlation_pairs: List[HighCorrelationPair] = Field( + default_factory=list, description="높은 상관관계 종목 쌍 (|r| > 0.7)" + ) + + +class DiversificationResponse(BaseModel): + portfolio_id: int + diversification_score: float = Field( + ..., description="분산 효과 점수 (0=집중, 1=완벽 분산)", ge=0, le=1 + ) + stock_count: int + high_correlation_pairs: List[HighCorrelationPair] = Field(default_factory=list) diff --git a/backend/app/schemas/drawdown.py b/backend/app/schemas/drawdown.py new file mode 100644 index 0000000..9242f9d --- /dev/null +++ b/backend/app/schemas/drawdown.py @@ -0,0 +1,40 @@ +""" +Drawdown related Pydantic schemas. +""" +from datetime import date +from decimal import Decimal +from typing import List, Optional + +from pydantic import BaseModel, Field + +from app.schemas.portfolio import FloatDecimal + + +class DrawdownDataPoint(BaseModel): + date: date + total_value: FloatDecimal + peak: FloatDecimal + drawdown_pct: FloatDecimal + + +class DrawdownResponse(BaseModel): + portfolio_id: int + current_drawdown_pct: FloatDecimal + max_drawdown_pct: FloatDecimal + max_drawdown_date: date | None = None + peak_value: FloatDecimal | None = None + peak_date: date | None = None + trough_value: FloatDecimal | None = None + trough_date: date | None = None + alert_threshold_pct: FloatDecimal = Decimal("20") + + +class DrawdownHistoryResponse(BaseModel): + portfolio_id: int + data: List[DrawdownDataPoint] = [] + max_drawdown_pct: FloatDecimal + current_drawdown_pct: FloatDecimal + + +class DrawdownSettingsUpdate(BaseModel): + alert_threshold_pct: FloatDecimal = Field(..., gt=0, le=100) diff --git a/backend/app/schemas/journal.py b/backend/app/schemas/journal.py new file mode 100644 index 0000000..8a0fba1 --- /dev/null +++ b/backend/app/schemas/journal.py @@ -0,0 +1,93 @@ +""" +Trading journal Pydantic schemas. +""" +from datetime import date, datetime +from decimal import Decimal +from typing import Optional +from enum import Enum + +from pydantic import BaseModel, Field + +from app.schemas.portfolio import FloatDecimal + + +class TradeType(str, Enum): + BUY = "buy" + SELL = "sell" + + +class JournalStatus(str, Enum): + OPEN = "open" + CLOSED = "closed" + + +class TradeJournalCreate(BaseModel): + stock_code: str = Field(..., min_length=1, max_length=20) + stock_name: Optional[str] = Field(None, max_length=100) + trade_type: TradeType + entry_price: Optional[FloatDecimal] = Field(None, gt=0) + target_price: Optional[FloatDecimal] = Field(None, gt=0) + stop_loss_price: Optional[FloatDecimal] = Field(None, gt=0) + entry_date: date + quantity: Optional[int] = Field(None, gt=0) + entry_reason: Optional[str] = None + scenario: Optional[str] = None + emotional_state: Optional[str] = None + strategy_id: Optional[int] = None + + +class TradeJournalUpdate(BaseModel): + stock_name: Optional[str] = Field(None, max_length=100) + exit_price: Optional[FloatDecimal] = Field(None, gt=0) + exit_date: Optional[date] = None + exit_reason: Optional[str] = None + target_price: Optional[FloatDecimal] = Field(None, gt=0) + stop_loss_price: Optional[FloatDecimal] = Field(None, gt=0) + quantity: Optional[int] = Field(None, gt=0) + lessons_learned: Optional[str] = None + emotional_state: Optional[str] = None + scenario: Optional[str] = None + entry_reason: Optional[str] = None + status: Optional[JournalStatus] = None + + +class TradeJournalResponse(BaseModel): + id: int + user_id: int + stock_code: str + stock_name: Optional[str] = None + trade_type: str + entry_price: Optional[FloatDecimal] = None + target_price: Optional[FloatDecimal] = None + stop_loss_price: Optional[FloatDecimal] = None + exit_price: Optional[FloatDecimal] = None + entry_date: date + exit_date: Optional[date] = None + quantity: Optional[int] = None + profit_loss: Optional[FloatDecimal] = None + profit_loss_pct: Optional[FloatDecimal] = None + entry_reason: Optional[str] = None + exit_reason: Optional[str] = None + scenario: Optional[str] = None + lessons_learned: Optional[str] = None + emotional_state: Optional[str] = None + strategy_id: Optional[int] = None + status: str + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class TradeJournalStats(BaseModel): + total_trades: int = 0 + open_trades: int = 0 + closed_trades: int = 0 + win_count: int = 0 + loss_count: int = 0 + win_rate: Optional[FloatDecimal] = None + avg_profit_loss_pct: Optional[FloatDecimal] = None + max_profit_pct: Optional[FloatDecimal] = None + max_loss_pct: Optional[FloatDecimal] = None + total_profit_loss: Optional[FloatDecimal] = None diff --git a/backend/app/schemas/notification.py b/backend/app/schemas/notification.py new file mode 100644 index 0000000..d8852e8 --- /dev/null +++ b/backend/app/schemas/notification.py @@ -0,0 +1,50 @@ +""" +Notification related Pydantic schemas. +""" +from datetime import datetime +from typing import Optional, List +from enum import Enum + +from pydantic import BaseModel, Field + + +class ChannelType(str, Enum): + DISCORD = "discord" + TELEGRAM = "telegram" + + +class NotificationSettingCreate(BaseModel): + channel_type: ChannelType + webhook_url: str = Field(..., min_length=1, max_length=500) + enabled: bool = True + + +class NotificationSettingUpdate(BaseModel): + webhook_url: Optional[str] = Field(None, min_length=1, max_length=500) + enabled: Optional[bool] = None + + +class NotificationSettingResponse(BaseModel): + id: int + user_id: int + channel_type: str + webhook_url: str + enabled: bool + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class NotificationHistoryResponse(BaseModel): + id: int + signal_id: int + channel_type: str + sent_at: datetime + status: str + message: Optional[str] = None + error_message: Optional[str] = None + + class Config: + from_attributes = True diff --git a/backend/app/schemas/optimizer.py b/backend/app/schemas/optimizer.py new file mode 100644 index 0000000..1178480 --- /dev/null +++ b/backend/app/schemas/optimizer.py @@ -0,0 +1,86 @@ +""" +Strategy optimizer schemas. +""" +from datetime import date +from decimal import Decimal +from typing import Annotated, Any, Dict, List, Optional + +from pydantic import BaseModel, Field, PlainSerializer + +FloatDecimal = Annotated[ + Decimal, + PlainSerializer(lambda v: float(v), return_type=float, when_used="json"), +] + +# --- Default parameter grids per strategy type --- + +KJB_DEFAULT_GRID: Dict[str, List[Any]] = { + "stop_loss_pct": [0.03, 0.05, 0.07], + "target1_pct": [0.05, 0.07, 0.10], + "rs_lookback": [10, 20, 30], +} + +MULTI_FACTOR_DEFAULT_GRID: Dict[str, List[Any]] = { + "weights.value": [0.15, 0.25, 0.35], + "weights.quality": [0.15, 0.25, 0.35], + "weights.momentum": [0.15, 0.25, 0.35], +} + +QUALITY_DEFAULT_GRID: Dict[str, List[Any]] = { + "min_fscore": [5, 6, 7, 8], +} + +VALUE_MOMENTUM_DEFAULT_GRID: Dict[str, List[Any]] = { + "value_weight": [0.3, 0.4, 0.5, 0.6, 0.7], + "momentum_weight": [0.3, 0.4, 0.5, 0.6, 0.7], +} + +DEFAULT_GRIDS: Dict[str, Dict[str, List[Any]]] = { + "kjb": KJB_DEFAULT_GRID, + "multi_factor": MULTI_FACTOR_DEFAULT_GRID, + "quality": QUALITY_DEFAULT_GRID, + "value_momentum": VALUE_MOMENTUM_DEFAULT_GRID, +} + +STRATEGY_TYPES = ["kjb", "multi_factor", "quality", "value_momentum"] + + +class OptimizeRequest(BaseModel): + strategy_type: str = Field( + ..., + description="Strategy type: kjb, multi_factor, quality, value_momentum", + ) + param_grid: Optional[Dict[str, List[Any]]] = Field( + default=None, + description="Parameter grid. If None, uses default preset for the strategy type.", + ) + start_date: date + end_date: date + initial_capital: Decimal = Field(default=Decimal("100000000"), gt=0) + commission_rate: Decimal = Field(default=Decimal("0.00015"), ge=0, le=1) + slippage_rate: Decimal = Field(default=Decimal("0.001"), ge=0, le=1) + benchmark: str = Field(default="KOSPI") + top_n: int = Field(default=30, ge=1, le=100) + rank_by: str = Field( + default="sharpe_ratio", + description="Metric to rank results by: sharpe_ratio, cagr, total_return, mdd", + ) + + +class OptimizeResultItem(BaseModel): + rank: int + params: Dict[str, Any] + total_return: FloatDecimal + cagr: FloatDecimal + mdd: FloatDecimal + sharpe_ratio: FloatDecimal + volatility: FloatDecimal + benchmark_return: FloatDecimal + excess_return: FloatDecimal + + +class OptimizeResponse(BaseModel): + strategy_type: str + total_combinations: int + results: List[OptimizeResultItem] + best_params: Dict[str, Any] diff --git a/backend/app/schemas/pension.py b/backend/app/schemas/pension.py new file mode 100644 index 0000000..9cfa887 --- /dev/null +++ b/backend/app/schemas/pension.py @@ -0,0 +1,108 @@ +""" +Pension account Pydantic schemas. +""" +from datetime import datetime +from enum import Enum +from typing import Optional, List + +from pydantic import BaseModel, Field + +from app.schemas.portfolio import FloatDecimal + + +class AccountType(str, Enum): + DC = "dc" + IRP = "irp" + PERSONAL = "personal" + + +class AssetRiskType(str, Enum): + SAFE = "safe" + RISKY = "risky" + + +# --- Account schemas --- + +class PensionAccountCreate(BaseModel): + account_type: AccountType + account_name: str = Field(..., min_length=1, max_length=100) + total_amount: FloatDecimal = Field(..., ge=0) + birth_year: int = Field(..., ge=1940, le=2010) + target_retirement_age: int = Field(60, ge=50, le=70) + + +class PensionAccountUpdate(BaseModel): + account_name: Optional[str] = Field(None, min_length=1, max_length=100) + total_amount: Optional[FloatDecimal] = Field(None, ge=0) + target_retirement_age: Optional[int] = Field(None, ge=50, le=70) + + +class PensionHoldingResponse(BaseModel): + id: int + account_id: int + asset_name: str + asset_type: str + amount: FloatDecimal + ratio: FloatDecimal + + class Config: + from_attributes = True + + +class PensionAccountResponse(BaseModel): + id: int + user_id: int + account_type: str + account_name: str + total_amount: FloatDecimal + birth_year: int + target_retirement_age: int + created_at: datetime + updated_at: datetime + holdings: List[PensionHoldingResponse] = [] + + class Config: + from_attributes = True + + +# --- Allocation schemas --- + +class AllocationItem(BaseModel): + asset_name: str + asset_type: str # safe / risky + amount: FloatDecimal + ratio: FloatDecimal + + +class AllocationResult(BaseModel): + account_id: int + account_type: str + total_amount: FloatDecimal + risky_limit_pct: FloatDecimal + safe_min_pct: FloatDecimal + glide_path_equity_pct: FloatDecimal + glide_path_bond_pct: FloatDecimal + current_age: int + years_to_retirement: int + allocations: List[AllocationItem] + + +# --- Recommendation schemas --- + +class RecommendationItem(BaseModel): + asset_name: str + asset_type: str + category: str # tdf, bond_etf, equity_etf, deposit + ratio: FloatDecimal + reason: str + + +class RecommendationResult(BaseModel): + account_id: int + birth_year: int + current_age: int + target_retirement_age: int + years_to_retirement: int + glide_path_equity_pct: FloatDecimal + glide_path_bond_pct: FloatDecimal + recommendations: List[RecommendationItem] diff --git a/backend/app/schemas/position_sizing.py b/backend/app/schemas/position_sizing.py new file mode 100644 index 0000000..b8f22ff --- /dev/null +++ b/backend/app/schemas/position_sizing.py @@ -0,0 +1,46 @@ +""" +Position sizing Pydantic schemas. +""" +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +class SizingMethod(str, Enum): + FIXED = "fixed" + KELLY = "kelly" + ATR = "atr" + + +class PositionSizeRequest(BaseModel): + capital: float = Field(..., gt=0, description="Total capital (KRW)") + method: SizingMethod = Field(..., description="Sizing method") + # fixed_ratio params + num_positions: int = Field(default=10, ge=1, le=50) + cash_ratio: float = Field(default=0.3, ge=0, lt=1.0) + # kelly params + win_rate: Optional[float] = Field(default=None, gt=0, le=1.0) + avg_win: Optional[float] = Field(default=None, gt=0) + avg_loss: Optional[float] = Field(default=None, gt=0) + # atr params + atr: Optional[float] = Field(default=None, gt=0) + risk_pct: float = Field(default=0.02, gt=0, le=1.0) + + +class PositionSizeResponse(BaseModel): + method: str + position_size: float + shares: int + risk_amount: float + notes: str + + +class MethodInfo(BaseModel): + name: str + label: str + description: str + + +class MethodsResponse(BaseModel): + methods: list[MethodInfo] diff --git a/backend/app/schemas/tax_simulation.py b/backend/app/schemas/tax_simulation.py new file mode 100644 index 0000000..90a4ddc --- /dev/null +++ b/backend/app/schemas/tax_simulation.py @@ -0,0 +1,78 @@ +""" +Tax simulation request/response schemas. +""" +import enum +from typing import List + +from pydantic import BaseModel, Field + + +class AccountType(str, enum.Enum): + IRP = "irp" + DC = "dc" + + +class WithdrawalType(str, enum.Enum): + PENSION = "pension" + LUMP_SUM = "lump_sum" + + +class TaxDeductionRequest(BaseModel): + annual_income: int = Field(..., gt=0, description="연간 총급여 (원)") + contribution: int = Field(..., ge=0, description="연간 납입액 (원)") + account_type: AccountType = Field(AccountType.IRP, description="계좌 유형") + + +class TaxDeductionResponse(BaseModel): + annual_income: int + contribution: int + account_type: str + deduction_rate: float = Field(..., description="세액공제율 (%)") + irp_limit: int = Field(..., description="연간 공제 한도 (원)") + deductible_contribution: int = Field(..., description="공제 대상 납입액 (원)") + tax_deduction: float = Field(..., description="세액공제 금액 (원)") + + +class PensionTaxRequest(BaseModel): + withdrawal_amount: int = Field(..., gt=0, description="수령 금액 (원)") + withdrawal_type: WithdrawalType = Field(WithdrawalType.PENSION, description="수령 방식") + age: int = Field(..., ge=55, le=100, description="수령 시 나이") + + +class PensionTaxResponse(BaseModel): + withdrawal_amount: int + withdrawal_type: str + age: int + pension_tax_rate: float = Field(..., description="연금소득세율 (%)") + pension_tax: float = Field(..., description="연금소득세 (원)") + lump_sum_tax_rate: float = Field(..., description="기타소득세율 (%)") + lump_sum_tax: float = Field(..., description="기타소득세 (원)") + tax_saving: float = Field(..., description="연금 수령 시 절세 금액 (원)") + + +class AccumulationRequest(BaseModel): + monthly_contribution: int = Field(..., gt=0, description="월 납입액 (원)") + years: int = Field(..., ge=1, le=50, description="적립 기간 (년)") + annual_return: float = Field(..., ge=0, le=30, description="연간 기대 수익률 (%)") + tax_deduction_rate: float = Field(..., description="세액공제율 (%)") + + +class YearlyAccumulationData(BaseModel): + year: int + contribution: int + cumulative_contribution: int + investment_value: float + tax_deduction: float + cumulative_tax_deduction: float + + +class AccumulationResponse(BaseModel): + monthly_contribution: int + years: int + annual_return: float + tax_deduction_rate: float + total_contribution: int + final_value: float + total_return: float + total_tax_deduction: float + yearly_data: List[YearlyAccumulationData] diff --git a/backend/app/services/benchmark.py b/backend/app/services/benchmark.py new file mode 100644 index 0000000..a5665d3 --- /dev/null +++ b/backend/app/services/benchmark.py @@ -0,0 +1,277 @@ +""" +Benchmark comparison service. + +Compares portfolio performance against KOSPI index and deposit rate. +""" +import logging +import math +from datetime import date, timedelta +from decimal import Decimal +from typing import List, Optional + +from pykrx import stock as pykrx_stock +from sqlalchemy.orm import Session + +from app.models.portfolio import Portfolio, PortfolioSnapshot +from app.schemas.benchmark import ( + BenchmarkCompareResponse, + PerformanceMetrics, + TimeSeriesPoint, +) + +logger = logging.getLogger(__name__) + +KOSPI_INDEX_CODE = "1001" +DEPOSIT_ANNUAL_RATE = 3.5 +RISK_FREE_RATE = 3.5 + + +class BenchmarkService: + def __init__(self, db: Session): + self.db = db + + def get_deposit_rate(self) -> float: + return DEPOSIT_ANNUAL_RATE + + def get_benchmark_data( + self, benchmark_type: str, start_date: date, end_date: date + ) -> List[dict]: + start_str = start_date.strftime("%Y%m%d") + end_str = end_date.strftime("%Y%m%d") + + if benchmark_type == "kospi": + df = pykrx_stock.get_index_ohlcv(start_str, end_str, KOSPI_INDEX_CODE) + else: + return [] + + if df.empty: + return [] + + result = [] + for idx, row in df.iterrows(): + result.append({ + "date": idx.date() if hasattr(idx, "date") else idx, + "close": row["종가"], + }) + return result + + def compare_with_benchmark( + self, + portfolio_id: int, + benchmark_type: str, + period: str, + user_id: int, + ) -> BenchmarkCompareResponse: + portfolio = ( + self.db.query(Portfolio) + .filter(Portfolio.id == portfolio_id, Portfolio.user_id == user_id) + .first() + ) + if not portfolio: + raise ValueError("Portfolio not found") + + snapshots = ( + self.db.query(PortfolioSnapshot) + .filter(PortfolioSnapshot.portfolio_id == portfolio_id) + .order_by(PortfolioSnapshot.snapshot_date) + .all() + ) + if not snapshots: + raise ValueError("스냅샷 데이터가 없습니다") + + end_date = snapshots[-1].snapshot_date + start_date = self._calc_start_date(period, snapshots[0].snapshot_date, end_date) + + filtered = [s for s in snapshots if s.snapshot_date >= start_date] + if len(filtered) < 2: + filtered = snapshots + + start_date = filtered[0].snapshot_date + num_days = (end_date - start_date).days + + # Portfolio daily returns + portfolio_values = [float(s.total_value) for s in filtered] + portfolio_returns = self._values_to_returns(portfolio_values) + + # Benchmark data + benchmark_data = self.get_benchmark_data(benchmark_type, start_date, end_date) + benchmark_closes = [d["close"] for d in benchmark_data] + benchmark_returns = self._values_to_returns(benchmark_closes) + + # Deposit returns (daily) + daily_deposit_rate = (1 + DEPOSIT_ANNUAL_RATE / 100) ** (1 / 365) - 1 + deposit_returns = [daily_deposit_rate] * max(num_days, 0) + + # Cumulative return time series + time_series = self._build_time_series( + filtered, benchmark_data, start_date, end_date + ) + + # Metrics + portfolio_metrics = self._calculate_metrics(portfolio_returns, num_days) + benchmark_metrics = self._calculate_metrics(benchmark_returns, num_days) + deposit_metrics = self._calculate_metrics(deposit_returns, num_days) + + alpha = portfolio_metrics.cumulative_return - benchmark_metrics.cumulative_return + info_ratio = self._calculate_information_ratio( + portfolio_returns, benchmark_returns + ) + + return BenchmarkCompareResponse( + portfolio_name=portfolio.name, + benchmark_type=benchmark_type, + period=period, + start_date=start_date, + end_date=end_date, + time_series=time_series, + portfolio_metrics=portfolio_metrics, + benchmark_metrics=benchmark_metrics, + deposit_metrics=deposit_metrics, + alpha=round(alpha, 2), + information_ratio=round(info_ratio, 4) if info_ratio is not None else None, + ) + + def _calculate_metrics( + self, returns: List[float], num_days: int + ) -> PerformanceMetrics: + if not returns: + return PerformanceMetrics( + cumulative_return=0.0, + annualized_return=0.0, + sharpe_ratio=None, + max_drawdown=0.0, + ) + + # Cumulative return + cum = 1.0 + for r in returns: + cum *= 1 + r + cum_return = (cum - 1) * 100 + + # Annualized return + if num_days > 0: + ann_return = (cum ** (365 / num_days) - 1) * 100 + else: + ann_return = 0.0 + + # Sharpe ratio + if len(returns) >= 2: + mean_r = sum(returns) / len(returns) + variance = sum((r - mean_r) ** 2 for r in returns) / (len(returns) - 1) + std_r = math.sqrt(variance) + if std_r > 1e-10: + daily_rf = (1 + RISK_FREE_RATE / 100) ** (1 / 365) - 1 + sharpe = (mean_r - daily_rf) / std_r * math.sqrt(252) + sharpe = round(sharpe, 4) + else: + sharpe = None + else: + sharpe = None + + # Max drawdown + peak = 1.0 + max_dd = 0.0 + cum_val = 1.0 + for r in returns: + cum_val *= 1 + r + if cum_val > peak: + peak = cum_val + dd = (cum_val - peak) / peak + if dd < max_dd: + max_dd = dd + max_dd_pct = max_dd * 100 + + return PerformanceMetrics( + cumulative_return=round(cum_return, 2), + annualized_return=round(ann_return, 2), + sharpe_ratio=sharpe, + max_drawdown=round(max_dd_pct, 2), + ) + + def _calculate_information_ratio( + self, portfolio_returns: List[float], benchmark_returns: List[float] + ) -> Optional[float]: + if not portfolio_returns or not benchmark_returns: + return None + + min_len = min(len(portfolio_returns), len(benchmark_returns)) + excess = [ + portfolio_returns[i] - benchmark_returns[i] for i in range(min_len) + ] + + if len(excess) < 2: + return None + + mean_excess = sum(excess) / len(excess) + variance = sum((e - mean_excess) ** 2 for e in excess) / (len(excess) - 1) + tracking_error = math.sqrt(variance) + + if tracking_error < 1e-10: + return None + + return (mean_excess / tracking_error) * math.sqrt(252) + + def _values_to_returns(self, values: List[float]) -> List[float]: + if len(values) < 2: + return [] + return [ + (values[i] - values[i - 1]) / values[i - 1] + for i in range(1, len(values)) + if values[i - 1] != 0 + ] + + def _calc_start_date( + self, period: str, first_snapshot: date, end_date: date + ) -> date: + period_map = { + "1m": timedelta(days=30), + "3m": timedelta(days=90), + "6m": timedelta(days=180), + "1y": timedelta(days=365), + } + if period == "all": + return first_snapshot + delta = period_map.get(period, timedelta(days=365)) + return max(end_date - delta, first_snapshot) + + def _build_time_series( + self, + snapshots: list, + benchmark_data: List[dict], + start_date: date, + end_date: date, + ) -> List[TimeSeriesPoint]: + if not snapshots: + return [] + + base_portfolio = float(snapshots[0].total_value) + portfolio_map = {} + for s in snapshots: + val = float(s.total_value) + ret = ((val / base_portfolio) - 1) * 100 if base_portfolio else 0 + portfolio_map[s.snapshot_date] = ret + + benchmark_map = {} + if benchmark_data: + base_bench = benchmark_data[0]["close"] + for d in benchmark_data: + ret = ((d["close"] / base_bench) - 1) * 100 if base_bench else 0 + benchmark_map[d["date"]] = ret + + daily_deposit_rate = DEPOSIT_ANNUAL_RATE / 100 / 365 + + all_dates = sorted(set(list(portfolio_map.keys()) + list(benchmark_map.keys()))) + + result = [] + for d in all_dates: + days_elapsed = (d - start_date).days + deposit_ret = ((1 + DEPOSIT_ANNUAL_RATE / 100) ** (days_elapsed / 365) - 1) * 100 + + result.append(TimeSeriesPoint( + date=d, + portfolio_return=round(portfolio_map[d], 2) if d in portfolio_map else None, + benchmark_return=round(benchmark_map[d], 2) if d in benchmark_map else None, + deposit_return=round(deposit_ret, 2), + )) + + return result diff --git a/backend/app/services/correlation.py b/backend/app/services/correlation.py new file mode 100644 index 0000000..b80418d --- /dev/null +++ b/backend/app/services/correlation.py @@ -0,0 +1,188 @@ +""" +Correlation analysis service. + +Calculates inter-stock correlation matrices and portfolio diversification scores. +""" +import logging +from datetime import date, timedelta +from typing import List, Optional + +import numpy as np +import pandas as pd +from sqlalchemy.orm import Session + +from app.models.stock import Price +from app.models.portfolio import Portfolio, PortfolioSnapshot + +logger = logging.getLogger(__name__) + + +class CorrelationService: + def __init__(self, db: Session): + self.db = db + + def calculate_correlation_matrix( + self, stock_codes: List[str], period_days: int = 60 + ) -> dict: + if not stock_codes: + return {"stock_codes": [], "matrix": []} + + end_date = date.today() + start_date = end_date - timedelta(days=period_days) + + prices = ( + self.db.query(Price) + .filter( + Price.ticker.in_(stock_codes), + Price.date >= start_date, + Price.date <= end_date, + ) + .order_by(Price.date) + .all() + ) + + returns_df = self._prices_to_returns_df(prices, stock_codes) + + if returns_df.empty or len(returns_df) < 2: + n = len(stock_codes) + matrix = [[None if i != j else 1.0 for j in range(n)] for i in range(n)] + return {"stock_codes": stock_codes, "matrix": matrix} + + corr_matrix = returns_df.corr() + + matrix = [] + for code in stock_codes: + row = [] + for other in stock_codes: + if code in corr_matrix.columns and other in corr_matrix.columns: + val = corr_matrix.loc[code, other] + row.append(round(float(val), 4) if not np.isnan(val) else None) + else: + row.append(None if code != other else 1.0) + matrix.append(row) + + return {"stock_codes": stock_codes, "matrix": matrix} + + def calculate_portfolio_diversification(self, portfolio_id: int) -> float: + portfolio = ( + self.db.query(Portfolio) + .filter(Portfolio.id == portfolio_id) + .first() + ) + if not portfolio: + raise ValueError("Portfolio not found") + + snapshot = ( + self.db.query(PortfolioSnapshot) + .filter(PortfolioSnapshot.portfolio_id == portfolio_id) + .order_by(PortfolioSnapshot.snapshot_date.desc()) + .first() + ) + + if not snapshot or not snapshot.holdings: + return 1.0 + + holdings = snapshot.holdings + if len(holdings) == 1: + return 0.0 + + tickers = [h.ticker for h in holdings] + total_value = sum(float(h.value) for h in holdings) + if total_value == 0: + return 1.0 + + weights = np.array([float(h.value) / total_value for h in holdings]) + + end_date = date.today() + start_date = end_date - timedelta(days=60) + + prices = ( + self.db.query(Price) + .filter( + Price.ticker.in_(tickers), + Price.date >= start_date, + Price.date <= end_date, + ) + .order_by(Price.date) + .all() + ) + + returns_df = self._prices_to_returns_df(prices, tickers) + + if returns_df.empty or len(returns_df) < 2: + return 0.5 + + cov_matrix = returns_df.cov().values + stds = returns_df.std().values + + # Portfolio variance + portfolio_variance = weights @ cov_matrix @ weights + + # Weighted average variance (no diversification case) + weighted_avg_variance = np.sum((weights ** 2) * (stds ** 2)) + \ + 2 * np.sum([ + weights[i] * weights[j] * stds[i] * stds[j] + for i in range(len(weights)) + for j in range(i + 1, len(weights)) + ]) + + if weighted_avg_variance < 1e-10: + return 1.0 + + # Diversification ratio: 1 - (portfolio_vol / weighted_avg_vol) + portfolio_vol = np.sqrt(portfolio_variance) + weighted_avg_vol = np.sum(weights * stds) + + if weighted_avg_vol < 1e-10: + return 1.0 + + diversification_ratio = 1.0 - (portfolio_vol / weighted_avg_vol) + return round(float(np.clip(diversification_ratio, 0, 1)), 4) + + def get_correlation_data( + self, stock_codes: List[str], period_days: int = 60 + ) -> dict: + result = self.calculate_correlation_matrix(stock_codes, period_days) + + high_pairs = [] + codes = result["stock_codes"] + matrix = result["matrix"] + + for i in range(len(codes)): + for j in range(i + 1, len(codes)): + val = matrix[i][j] + if val is not None and abs(val) > 0.7: + high_pairs.append({ + "stock_a": codes[i], + "stock_b": codes[j], + "correlation": val, + }) + + result["high_correlation_pairs"] = high_pairs + return result + + def _prices_to_returns_df( + self, prices: list, stock_codes: List[str] + ) -> pd.DataFrame: + if not prices: + return pd.DataFrame() + + data = {} + for p in prices: + if p.ticker not in data: + data[p.ticker] = {} + data[p.ticker][p.date] = float(p.close) + + if not data: + return pd.DataFrame() + + df = pd.DataFrame(data) + df.index = pd.to_datetime(df.index) + df = df.sort_index() + + # Reorder columns to match requested order + existing = [c for c in stock_codes if c in df.columns] + df = df[existing] + + returns_df = df.pct_change().dropna() + return returns_df diff --git a/backend/app/services/drawdown.py b/backend/app/services/drawdown.py new file mode 100644 index 0000000..9335930 --- /dev/null +++ b/backend/app/services/drawdown.py @@ -0,0 +1,164 @@ +""" +Drawdown calculation service using PortfolioSnapshot.total_value time series. +""" +import logging +from datetime import date +from decimal import Decimal +from typing import Optional + +from sqlalchemy.orm import Session + +from app.models.portfolio import Portfolio, PortfolioSnapshot + +logger = logging.getLogger(__name__) + +# In-memory per-portfolio settings (no separate table needed) +_drawdown_settings: dict[int, Decimal] = {} + +DEFAULT_ALERT_THRESHOLD = Decimal("20") + + +def get_alert_threshold(portfolio_id: int) -> Decimal: + return _drawdown_settings.get(portfolio_id, DEFAULT_ALERT_THRESHOLD) + + +def set_alert_threshold(portfolio_id: int, threshold_pct: Decimal) -> None: + _drawdown_settings[portfolio_id] = threshold_pct + + +def calculate_drawdown( + db: Session, + portfolio_id: int, +) -> dict: + """Calculate current and max drawdown from snapshot time series. + + Returns dict with: + current_drawdown_pct, max_drawdown_pct, + peak_value, peak_date, trough_value, trough_date, + max_drawdown_date, alert_threshold_pct + """ + snapshots = ( + db.query(PortfolioSnapshot) + .filter(PortfolioSnapshot.portfolio_id == portfolio_id) + .order_by(PortfolioSnapshot.snapshot_date) + .all() + ) + + if not snapshots: + return { + "current_drawdown_pct": Decimal("0"), + "max_drawdown_pct": Decimal("0"), + "peak_value": None, + "peak_date": None, + "trough_value": None, + "trough_date": None, + "max_drawdown_date": None, + "alert_threshold_pct": get_alert_threshold(portfolio_id), + } + + peak = Decimal(str(snapshots[0].total_value)) + peak_date = snapshots[0].snapshot_date + max_dd = Decimal("0") + max_dd_date: Optional[date] = None + trough_value = peak + trough_date = peak_date + + for snap in snapshots: + value = Decimal(str(snap.total_value)) + if value > peak: + peak = value + peak_date = snap.snapshot_date + + if peak > 0: + dd = ((peak - value) / peak * 100).quantize(Decimal("0.01")) + else: + dd = Decimal("0") + + if dd > max_dd: + max_dd = dd + max_dd_date = snap.snapshot_date + trough_value = value + trough_date = snap.snapshot_date + + # Current drawdown = drawdown of last snapshot from running peak + last_value = Decimal(str(snapshots[-1].total_value)) + if peak > 0: + current_dd = ((peak - last_value) / peak * 100).quantize(Decimal("0.01")) + else: + current_dd = Decimal("0") + + return { + "current_drawdown_pct": current_dd, + "max_drawdown_pct": max_dd, + "peak_value": peak, + "peak_date": peak_date, + "trough_value": trough_value, + "trough_date": trough_date, + "max_drawdown_date": max_dd_date, + "alert_threshold_pct": get_alert_threshold(portfolio_id), + } + + +def calculate_rolling_drawdown( + db: Session, + portfolio_id: int, +) -> list[dict]: + """Calculate rolling drawdown time series. + + Returns list of {date, total_value, peak, drawdown_pct}. + """ + snapshots = ( + db.query(PortfolioSnapshot) + .filter(PortfolioSnapshot.portfolio_id == portfolio_id) + .order_by(PortfolioSnapshot.snapshot_date) + .all() + ) + + if not snapshots: + return [] + + result = [] + peak = Decimal("0") + + for snap in snapshots: + value = Decimal(str(snap.total_value)) + if value > peak: + peak = value + + if peak > 0: + dd = ((peak - value) / peak * 100).quantize(Decimal("0.01")) + else: + dd = Decimal("0") + + result.append({ + "date": snap.snapshot_date, + "total_value": value, + "peak": peak, + "drawdown_pct": dd, + }) + + return result + + +def check_drawdown_alert( + db: Session, + portfolio_id: int, +) -> Optional[str]: + """Check if current drawdown exceeds alert threshold. + + Returns alert message string if threshold exceeded, None otherwise. + """ + data = calculate_drawdown(db, portfolio_id) + threshold = data["alert_threshold_pct"] + current_dd = data["current_drawdown_pct"] + + if current_dd >= threshold: + portfolio = db.query(Portfolio).filter(Portfolio.id == portfolio_id).first() + name = portfolio.name if portfolio else f"Portfolio #{portfolio_id}" + return ( + f"[Drawdown 경고] {name}: " + f"현재 낙폭 {current_dd}%가 한도 {threshold}%를 초과했습니다. " + f"(고점: {data['peak_value']:,.0f}원, 현재: {data['trough_value']:,.0f}원)" + ) + + return None diff --git a/backend/app/services/notification.py b/backend/app/services/notification.py new file mode 100644 index 0000000..76e2ebc --- /dev/null +++ b/backend/app/services/notification.py @@ -0,0 +1,132 @@ +""" +Notification service for sending signal alerts via Discord/Telegram. +""" +import logging +from datetime import datetime, timedelta + +import httpx +from sqlalchemy.orm import Session + +from app.models.signal import Signal +from app.models.notification import ( + NotificationSetting, NotificationHistory, + ChannelType, NotificationStatus, +) + +logger = logging.getLogger(__name__) + + +async def send_discord(webhook_url: str, message: str) -> None: + """Send a message to Discord via webhook.""" + async with httpx.AsyncClient(timeout=10) as client: + response = await client.post( + webhook_url, + json={"content": message}, + ) + response.raise_for_status() + + +async def send_telegram(webhook_url: str, message: str) -> None: + """Send a message to Telegram via Bot API. + + webhook_url format: https://api.telegram.org/bot/sendMessage?chat_id= + """ + async with httpx.AsyncClient(timeout=10) as client: + response = await client.post( + webhook_url, + json={"text": message, "parse_mode": "Markdown"}, + ) + response.raise_for_status() + + +def format_signal_message(signal: Signal) -> str: + """Format a signal into a human-readable notification message.""" + signal_type_kr = { + "buy": "매수", + "sell": "매도", + "partial_sell": "부분매도", + } + type_label = signal_type_kr.get(signal.signal_type.value, signal.signal_type.value) + + lines = [ + f"[KJB 신호] {signal.name or signal.ticker} ({signal.ticker})", + f"신호: {type_label}", + ] + + if signal.entry_price: + lines.append(f"진입가: {signal.entry_price:,.0f}원") + if signal.target_price: + lines.append(f"목표가: {signal.target_price:,.0f}원") + if signal.stop_loss_price: + lines.append(f"손절가: {signal.stop_loss_price:,.0f}원") + if signal.reason: + lines.append(f"사유: {signal.reason}") + + return "\n".join(lines) + + +def _is_duplicate(db: Session, signal_id: int, channel_type: ChannelType) -> bool: + """Check if a notification was already sent for this signal+channel within 24h.""" + cutoff = datetime.utcnow() - timedelta(hours=24) + existing = ( + db.query(NotificationHistory) + .filter( + NotificationHistory.signal_id == signal_id, + NotificationHistory.channel_type == channel_type, + NotificationHistory.status == NotificationStatus.SENT, + NotificationHistory.sent_at >= cutoff, + ) + .first() + ) + return existing is not None + + +async def send_notification(signal: Signal, db: Session) -> None: + """Send notification for a signal to all enabled channels. + + Skips duplicate notifications (same signal_id + channel_type within 24h). + """ + settings = ( + db.query(NotificationSetting) + .filter(NotificationSetting.enabled.is_(True)) + .all() + ) + + message = format_signal_message(signal) + + for setting in settings: + if _is_duplicate(db, signal.id, setting.channel_type): + logger.info( + f"Skipping duplicate notification for signal {signal.id} " + f"on {setting.channel_type.value}" + ) + continue + + history = NotificationHistory( + signal_id=signal.id, + channel_type=setting.channel_type, + message=message, + ) + + try: + if setting.channel_type == ChannelType.DISCORD: + await send_discord(setting.webhook_url, message) + elif setting.channel_type == ChannelType.TELEGRAM: + await send_telegram(setting.webhook_url, message) + + history.status = NotificationStatus.SENT + logger.info( + f"Notification sent for signal {signal.id} " + f"via {setting.channel_type.value}" + ) + except Exception as e: + history.status = NotificationStatus.FAILED + history.error_message = str(e)[:500] + logger.error( + f"Failed to send notification for signal {signal.id} " + f"via {setting.channel_type.value}: {e}" + ) + + db.add(history) + + db.commit() diff --git a/backend/app/services/optimizer.py b/backend/app/services/optimizer.py new file mode 100644 index 0000000..a9d783c --- /dev/null +++ b/backend/app/services/optimizer.py @@ -0,0 +1,446 @@ +""" +Grid-search strategy optimizer. + +Runs backtests across parameter combinations and ranks by selected metric. +Reuses existing BacktestEngine / DailyBacktestEngine logic without DB persistence. +""" +import itertools +import logging +from dataclasses import asdict +from datetime import date +from decimal import Decimal +from typing import Any, Dict, List, Tuple + +from sqlalchemy.orm import Session + +from app.schemas.optimizer import ( + DEFAULT_GRIDS, + OptimizeRequest, + OptimizeResponse, + OptimizeResultItem, +) +from app.services.backtest.metrics import MetricsCalculator + +logger = logging.getLogger(__name__) + + +def _expand_grid(param_grid: Dict[str, List[Any]]) -> List[Dict[str, Any]]: + """Expand parameter grid into list of parameter dicts.""" + keys = list(param_grid.keys()) + values = list(param_grid.values()) + combos = [] + for combo in itertools.product(*values): + combos.append(dict(zip(keys, combo))) + return combos + + +def _build_strategy_params(strategy_type: str, flat_params: Dict[str, Any]) -> Dict[str, Any]: + """Convert flat grid params to nested strategy_params dict.""" + result: Dict[str, Any] = {} + for key, value in flat_params.items(): + parts = key.split(".") + target = result + for part in parts[:-1]: + if part not in target: + target[part] = {} + target = target[part] + target[parts[-1]] = value + return result + + +class OptimizerService: + """Grid-search optimizer that runs backtests across parameter combinations.""" + + def __init__(self, db: Session): + self.db = db + + def optimize(self, request: OptimizeRequest) -> OptimizeResponse: + grid = request.param_grid or DEFAULT_GRIDS.get(request.strategy_type, {}) + if not grid: + raise ValueError(f"No parameter grid for strategy type: {request.strategy_type}") + + combinations = _expand_grid(grid) + logger.info( + f"Optimizer: {request.strategy_type}, {len(combinations)} combinations" + ) + + results: List[Tuple[Dict[str, Any], Dict[str, float]]] = [] + + for combo in combinations: + try: + metrics = self._run_single(request, combo) + results.append((combo, metrics)) + except Exception as e: + logger.warning(f"Optimizer: failed for params {combo}: {e}") + + if not results: + return OptimizeResponse( + strategy_type=request.strategy_type, + total_combinations=len(combinations), + results=[], + best_params={}, + ) + + # Sort by rank_by metric (descending, except mdd which is negative so also desc) + rank_by = request.rank_by + results.sort(key=lambda x: x[1].get(rank_by, 0), reverse=True) + + items = [] + for i, (combo, metrics) in enumerate(results, 1): + items.append(OptimizeResultItem( + rank=i, + params=combo, + total_return=Decimal(str(metrics["total_return"])), + cagr=Decimal(str(metrics["cagr"])), + mdd=Decimal(str(metrics["mdd"])), + sharpe_ratio=Decimal(str(metrics["sharpe_ratio"])), + volatility=Decimal(str(metrics["volatility"])), + benchmark_return=Decimal(str(metrics["benchmark_return"])), + excess_return=Decimal(str(metrics["excess_return"])), + )) + + return OptimizeResponse( + strategy_type=request.strategy_type, + total_combinations=len(combinations), + results=items, + best_params=items[0].params if items else {}, + ) + + def _run_single( + self, request: OptimizeRequest, flat_params: Dict[str, Any] + ) -> Dict[str, float]: + """Run a single backtest with given params, return metrics dict.""" + strategy_params = _build_strategy_params(request.strategy_type, flat_params) + + if request.strategy_type == "kjb": + return self._run_kjb(request, strategy_params, flat_params) + else: + return self._run_factor(request, strategy_params) + + def _run_kjb( + self, + request: OptimizeRequest, + strategy_params: Dict[str, Any], + flat_params: Dict[str, Any], + ) -> Dict[str, float]: + """Run KJB daily backtest in-memory.""" + import pandas as pd + from app.models.stock import Stock, Price + from app.services.backtest.trading_portfolio import TradingPortfolio + from app.services.strategy.kjb import KJBSignalGenerator + + signal_gen = KJBSignalGenerator() + + portfolio = TradingPortfolio( + initial_capital=request.initial_capital, + max_positions=strategy_params.get("max_positions", 10), + cash_reserve_ratio=Decimal(str(strategy_params.get("cash_reserve_ratio", 0.3))), + stop_loss_pct=Decimal(str(flat_params.get("stop_loss_pct", 0.03))), + target1_pct=Decimal(str(flat_params.get("target1_pct", 0.05))), + target2_pct=Decimal(str(flat_params.get("target2_pct", 0.10))), + ) + + rs_lookback = flat_params.get("rs_lookback", 10) + breakout_lookback = flat_params.get("breakout_lookback", 20) + + trading_days = self._get_trading_days(request.start_date, request.end_date) + if not trading_days: + raise ValueError("No trading days found") + + universe_tickers = self._get_universe_tickers() + all_prices = self._load_all_prices(universe_tickers, request.start_date, request.end_date) + stock_dfs = self._build_stock_dfs(all_prices, universe_tickers) + kospi_df = self._load_kospi_df(request.start_date, request.end_date) + benchmark_prices = self._load_benchmark_prices(request.benchmark, request.start_date, request.end_date) + + day_prices_map: Dict[date, Dict[str, Decimal]] = {} + for p in all_prices: + if p.date not in day_prices_map: + day_prices_map[p.date] = {} + day_prices_map[p.date][p.ticker] = p.close + + equity_curve: List[Dict] = [] + initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1")) + if initial_benchmark == 0: + initial_benchmark = Decimal("1") + + for trading_date in trading_days: + day_prices = day_prices_map.get(trading_date, {}) + + portfolio.check_exits( + date=trading_date, + prices=day_prices, + commission_rate=request.commission_rate, + slippage_rate=request.slippage_rate, + ) + + for ticker in universe_tickers: + if ticker in portfolio.positions: + continue + if ticker not in stock_dfs or ticker not in day_prices: + continue + stock_df = stock_dfs[ticker] + if trading_date not in stock_df.index: + continue + hist = stock_df.loc[stock_df.index <= trading_date] + if len(hist) < 21: + continue + kospi_hist = kospi_df.loc[kospi_df.index <= trading_date] + if len(kospi_hist) < 11: + continue + + signals = signal_gen.generate_signals( + hist, kospi_hist, + rs_lookback=rs_lookback, + breakout_lookback=breakout_lookback, + ) + + if trading_date in signals.index and signals.loc[trading_date, "buy"]: + portfolio.enter_position( + ticker=ticker, + price=day_prices[ticker], + date=trading_date, + commission_rate=request.commission_rate, + slippage_rate=request.slippage_rate, + ) + + portfolio_value = portfolio.get_value(day_prices) + benchmark_value = benchmark_prices.get(trading_date, initial_benchmark) + normalized_benchmark = benchmark_value / initial_benchmark * request.initial_capital + + equity_curve.append({ + "portfolio_value": portfolio_value, + "benchmark_value": normalized_benchmark, + }) + + return self._compute_metrics(equity_curve) + + def _run_factor( + self, + request: OptimizeRequest, + strategy_params: Dict[str, Any], + ) -> Dict[str, float]: + """Run factor-based backtest in-memory (multi_factor, quality, value_momentum).""" + from dateutil.relativedelta import relativedelta + from app.models.backtest import RebalancePeriod + from app.services.backtest.portfolio import VirtualPortfolio + from app.services.strategy import ( + MultiFactorStrategy, + QualityStrategy, + ValueMomentumStrategy, + ) + from app.schemas.strategy import UniverseFilter, FactorWeights + + strategy = self._create_strategy( + request.strategy_type, strategy_params, request.top_n, + ) + + portfolio = VirtualPortfolio(request.initial_capital) + + trading_days = self._get_trading_days(request.start_date, request.end_date) + if not trading_days: + raise ValueError("No trading days found") + + rebalance_dates = self._generate_rebalance_dates( + request.start_date, request.end_date, RebalancePeriod.QUARTERLY, + ) + + benchmark_prices = self._load_benchmark_prices( + request.benchmark, request.start_date, request.end_date, + ) + all_date_prices = self._load_all_prices_by_date( + request.start_date, request.end_date, + ) + names = self._get_stock_names() + + initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1")) + if initial_benchmark == 0: + initial_benchmark = Decimal("1") + + equity_curve: List[Dict] = [] + + for trading_date in trading_days: + prices = all_date_prices.get(trading_date, {}) + + if trading_date in rebalance_dates: + target_stocks = strategy.run( + universe_filter=UniverseFilter(), + top_n=request.top_n, + base_date=trading_date, + ) + target_tickers = [s.ticker for s in target_stocks.stocks] + + portfolio.rebalance( + target_tickers=target_tickers, + prices=prices, + names=names, + commission_rate=request.commission_rate, + slippage_rate=request.slippage_rate, + ) + + portfolio_value = portfolio.get_value(prices) + benchmark_value = benchmark_prices.get(trading_date, initial_benchmark) + normalized_benchmark = benchmark_value / initial_benchmark * request.initial_capital + + equity_curve.append({ + "portfolio_value": portfolio_value, + "benchmark_value": normalized_benchmark, + }) + + return self._compute_metrics(equity_curve) + + def _compute_metrics(self, equity_curve: List[Dict]) -> Dict[str, float]: + portfolio_values = [Decimal(str(e["portfolio_value"])) for e in equity_curve] + benchmark_values = [Decimal(str(e["benchmark_value"])) for e in equity_curve] + metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values) + return { + "total_return": float(metrics.total_return), + "cagr": float(metrics.cagr), + "mdd": float(metrics.mdd), + "sharpe_ratio": float(metrics.sharpe_ratio), + "volatility": float(metrics.volatility), + "benchmark_return": float(metrics.benchmark_return), + "excess_return": float(metrics.excess_return), + } + + def _create_strategy(self, strategy_type: str, strategy_params: dict, top_n: int): + from app.services.strategy import ( + MultiFactorStrategy, + QualityStrategy, + ValueMomentumStrategy, + ) + from app.schemas.strategy import FactorWeights + + if strategy_type == "multi_factor": + strategy = MultiFactorStrategy(self.db) + strategy._weights = FactorWeights(**strategy_params.get("weights", {})) + elif strategy_type == "quality": + strategy = QualityStrategy(self.db) + strategy._min_fscore = strategy_params.get("min_fscore", 7) + elif strategy_type == "value_momentum": + strategy = ValueMomentumStrategy(self.db) + strategy._value_weight = Decimal( + str(strategy_params.get("value_weight", 0.5)) + ) + strategy._momentum_weight = Decimal( + str(strategy_params.get("momentum_weight", 0.5)) + ) + else: + raise ValueError(f"Unknown strategy type: {strategy_type}") + return strategy + + # --- Data loading helpers (mirrored from engines) --- + + def _get_trading_days(self, start_date: date, end_date: date) -> List[date]: + from app.models.stock import Price + prices = ( + self.db.query(Price.date) + .filter(Price.date >= start_date, Price.date <= end_date) + .distinct() + .order_by(Price.date) + .all() + ) + return [p[0] for p in prices] + + def _get_universe_tickers(self) -> List[str]: + from app.models.stock import Stock + stocks = ( + self.db.query(Stock) + .filter(Stock.market_cap.isnot(None)) + .order_by(Stock.market_cap.desc()) + .limit(30) + .all() + ) + return [s.ticker for s in stocks] + + def _load_all_prices(self, tickers, start_date, end_date): + from app.models.stock import Price + return ( + self.db.query(Price) + .filter(Price.ticker.in_(tickers)) + .filter(Price.date >= start_date, Price.date <= end_date) + .all() + ) + + def _load_kospi_df(self, start_date, end_date): + import pandas as pd + from app.models.stock import Price + prices = ( + self.db.query(Price) + .filter(Price.ticker == "069500") + .filter(Price.date >= start_date, Price.date <= end_date) + .order_by(Price.date) + .all() + ) + if not prices: + return pd.DataFrame(columns=["close"]) + data = [{"date": p.date, "close": float(p.close)} for p in prices] + return pd.DataFrame(data).set_index("date") + + def _load_benchmark_prices(self, benchmark, start_date, end_date): + from app.models.stock import Price + prices = ( + self.db.query(Price) + .filter(Price.ticker == "069500") + .filter(Price.date >= start_date, Price.date <= end_date) + .all() + ) + return {p.date: p.close for p in prices} + + def _build_stock_dfs(self, price_data, tickers): + import pandas as pd + ticker_rows = {t: [] for t in tickers} + for p in price_data: + if p.ticker in ticker_rows: + ticker_rows[p.ticker].append({ + "date": p.date, + "open": float(p.open), + "high": float(p.high), + "low": float(p.low), + "close": float(p.close), + "volume": int(p.volume), + }) + result = {} + for ticker, rows in ticker_rows.items(): + if rows: + df = pd.DataFrame(rows).set_index("date").sort_index() + result[ticker] = df + return result + + def _load_all_prices_by_date(self, start_date, end_date): + from app.models.stock import Price + prices = ( + self.db.query(Price) + .filter(Price.date >= start_date, Price.date <= end_date) + .all() + ) + result = {} + for p in prices: + if p.date not in result: + result[p.date] = {} + result[p.date][p.ticker] = p.close + return result + + def _get_stock_names(self): + from app.models.stock import Stock + stocks = self.db.query(Stock.ticker, Stock.name).all() + return {s.ticker: s.name for s in stocks} + + def _generate_rebalance_dates(self, start_date, end_date, period): + from dateutil.relativedelta import relativedelta + from app.models.backtest import RebalancePeriod + + dates = [] + current = start_date + if period == RebalancePeriod.MONTHLY: + delta = relativedelta(months=1) + elif period == RebalancePeriod.QUARTERLY: + delta = relativedelta(months=3) + elif period == RebalancePeriod.SEMI_ANNUAL: + delta = relativedelta(months=6) + else: + delta = relativedelta(years=1) + while current <= end_date: + dates.append(current) + current = current + delta + return dates diff --git a/backend/app/services/pension_allocation.py b/backend/app/services/pension_allocation.py new file mode 100644 index 0000000..759888b --- /dev/null +++ b/backend/app/services/pension_allocation.py @@ -0,0 +1,209 @@ +""" +Pension asset allocation service. + +Korean retirement pension regulations: +- DC/IRP: risky assets max 70%, safe assets min 30% +- Personal pension: no regulatory limit (but we apply same guideline) +- Safe assets: bond funds, deposits, TDF, principal-guaranteed products +- Risky assets: equity funds, equity ETFs, hybrid funds +""" +from datetime import date +from decimal import Decimal +from typing import List + +from app.schemas.pension import ( + AllocationItem, + AllocationResult, + RecommendationItem, + RecommendationResult, +) + +# Regulatory limits +RISKY_ASSET_LIMIT_PCT = Decimal("70") +SAFE_ASSET_MIN_PCT = Decimal("30") + +# Glide path parameters (equity allocation decreases with age) +GLIDE_PATH_MAX_EQUITY = Decimal("80") # max equity at young age +GLIDE_PATH_MIN_EQUITY = Decimal("20") # min equity near retirement + + +def calculate_current_age(birth_year: int) -> int: + return date.today().year - birth_year + + +def calculate_years_to_retirement(birth_year: int, target_retirement_age: int) -> int: + current_age = calculate_current_age(birth_year) + return max(0, target_retirement_age - current_age) + + +def calculate_glide_path(birth_year: int, target_retirement_age: int) -> tuple[Decimal, Decimal]: + """Calculate equity/bond ratio based on age (glide path). + + Linear interpolation: young = high equity, near retirement = low equity. + Working age range: 25 ~ target_retirement_age. + """ + current_age = calculate_current_age(birth_year) + working_years = target_retirement_age - 25 + + if working_years <= 0: + equity_pct = GLIDE_PATH_MIN_EQUITY + else: + years_worked = min(max(current_age - 25, 0), working_years) + progress = Decimal(years_worked) / Decimal(working_years) + equity_pct = GLIDE_PATH_MAX_EQUITY - progress * (GLIDE_PATH_MAX_EQUITY - GLIDE_PATH_MIN_EQUITY) + + # Clamp to regulatory limit + equity_pct = min(equity_pct, RISKY_ASSET_LIMIT_PCT) + equity_pct = max(equity_pct, GLIDE_PATH_MIN_EQUITY) + bond_pct = Decimal("100") - equity_pct + + return equity_pct.quantize(Decimal("0.01")), bond_pct.quantize(Decimal("0.01")) + + +def calculate_allocation( + account_id: int, + account_type: str, + total_amount: Decimal, + birth_year: int, + target_retirement_age: int, +) -> AllocationResult: + """Calculate recommended asset allocation for a pension account.""" + equity_pct, bond_pct = calculate_glide_path(birth_year, target_retirement_age) + current_age = calculate_current_age(birth_year) + years_to_ret = calculate_years_to_retirement(birth_year, target_retirement_age) + + equity_amount = (total_amount * equity_pct / Decimal("100")).quantize(Decimal("0.01")) + bond_amount = total_amount - equity_amount + + allocations: List[AllocationItem] = [] + + # Split equity portion + if equity_pct > 0: + allocations.append(AllocationItem( + asset_name="국내 주식형 ETF", + asset_type="risky", + amount=float((equity_amount * Decimal("0.5")).quantize(Decimal("0.01"))), + ratio=float((equity_pct * Decimal("0.5")).quantize(Decimal("0.01"))), + )) + allocations.append(AllocationItem( + asset_name="해외 주식형 ETF", + asset_type="risky", + amount=float((equity_amount * Decimal("0.5")).quantize(Decimal("0.01"))), + ratio=float((equity_pct * Decimal("0.5")).quantize(Decimal("0.01"))), + )) + + # Split bond portion + if bond_pct > 0: + tdf_ratio = Decimal("0.4") + bond_etf_ratio = Decimal("0.3") + deposit_ratio = Decimal("0.3") + + allocations.append(AllocationItem( + asset_name=f"TDF {_recommend_tdf_year(birth_year, target_retirement_age)}", + asset_type="safe", + amount=float((bond_amount * tdf_ratio).quantize(Decimal("0.01"))), + ratio=float((bond_pct * tdf_ratio).quantize(Decimal("0.01"))), + )) + allocations.append(AllocationItem( + asset_name="국내 채권형 ETF", + asset_type="safe", + amount=float((bond_amount * bond_etf_ratio).quantize(Decimal("0.01"))), + ratio=float((bond_pct * bond_etf_ratio).quantize(Decimal("0.01"))), + )) + allocations.append(AllocationItem( + asset_name="원리금 보장 예금", + asset_type="safe", + amount=float((bond_amount * deposit_ratio).quantize(Decimal("0.01"))), + ratio=float((bond_pct * deposit_ratio).quantize(Decimal("0.01"))), + )) + + return AllocationResult( + account_id=account_id, + account_type=account_type, + total_amount=float(total_amount), + risky_limit_pct=float(RISKY_ASSET_LIMIT_PCT), + safe_min_pct=float(SAFE_ASSET_MIN_PCT), + glide_path_equity_pct=float(equity_pct), + glide_path_bond_pct=float(bond_pct), + current_age=current_age, + years_to_retirement=years_to_ret, + allocations=allocations, + ) + + +def _recommend_tdf_year(birth_year: int, target_retirement_age: int) -> int: + """Recommend TDF target year (rounded to nearest 5).""" + retirement_year = birth_year + target_retirement_age + return round(retirement_year / 5) * 5 + + +def get_recommendation( + account_id: int, + birth_year: int, + target_retirement_age: int, +) -> RecommendationResult: + """Generate TDF/ETF recommendations based on age and retirement target.""" + equity_pct, bond_pct = calculate_glide_path(birth_year, target_retirement_age) + current_age = calculate_current_age(birth_year) + years_to_ret = calculate_years_to_retirement(birth_year, target_retirement_age) + tdf_year = _recommend_tdf_year(birth_year, target_retirement_age) + + recommendations: List[RecommendationItem] = [] + + # TDF recommendation + recommendations.append(RecommendationItem( + asset_name=f"TDF {tdf_year}", + asset_type="safe", + category="tdf", + ratio=float((bond_pct * Decimal("0.4")).quantize(Decimal("0.01"))), + reason=f"은퇴 목표 시점({tdf_year}년)에 맞춘 자동 자산 배분 펀드", + )) + + # Bond ETF + recommendations.append(RecommendationItem( + asset_name="KODEX 국고채 10년", + asset_type="safe", + category="bond_etf", + ratio=float((bond_pct * Decimal("0.3")).quantize(Decimal("0.01"))), + reason="안정적인 국고채 장기 투자로 원금 보전", + )) + + # Deposit + recommendations.append(RecommendationItem( + asset_name="원리금 보장 예금", + asset_type="safe", + category="deposit", + ratio=float((bond_pct * Decimal("0.3")).quantize(Decimal("0.01"))), + reason="원리금 보장으로 안전자산 비중 확보", + )) + + # Equity ETFs + domestic_equity_ratio = (equity_pct * Decimal("0.5")).quantize(Decimal("0.01")) + foreign_equity_ratio = (equity_pct * Decimal("0.5")).quantize(Decimal("0.01")) + + recommendations.append(RecommendationItem( + asset_name="KODEX 200", + asset_type="risky", + category="equity_etf", + ratio=float(domestic_equity_ratio), + reason="국내 대형주 분산 투자 (KOSPI 200 추종)", + )) + + recommendations.append(RecommendationItem( + asset_name="TIGER 미국 S&P500", + asset_type="risky", + category="equity_etf", + ratio=float(foreign_equity_ratio), + reason="미국 대형주 분산 투자 (S&P 500 추종)", + )) + + return RecommendationResult( + account_id=account_id, + birth_year=birth_year, + current_age=current_age, + target_retirement_age=target_retirement_age, + years_to_retirement=years_to_ret, + glide_path_equity_pct=float(equity_pct), + glide_path_bond_pct=float(bond_pct), + recommendations=recommendations, + ) diff --git a/backend/app/services/position_sizing.py b/backend/app/services/position_sizing.py new file mode 100644 index 0000000..b85708a --- /dev/null +++ b/backend/app/services/position_sizing.py @@ -0,0 +1,107 @@ +""" +Position sizing module. + +Supports three methods: +- fixed_ratio: Equal allocation (quant.md default, 30% cash reserve) +- kelly_criterion: Kelly criterion (conservative 1/4 Kelly) +- atr_based: ATR-based volatility sizing +""" + + +def fixed_ratio( + capital: float, + num_positions: int, + cash_ratio: float = 0.3, +) -> dict: + """Equal allocation across positions with cash reserve. + + Per quant.md: 5-10 positions, 30% cash, max loss per position -3%. + """ + if capital <= 0: + raise ValueError("capital must be positive") + if num_positions <= 0: + raise ValueError("num_positions must be positive") + if cash_ratio < 0 or cash_ratio >= 1.0: + raise ValueError("cash_ratio must be in [0, 1)") + + investable = capital * (1 - cash_ratio) + position_size = investable / num_positions + risk_amount = position_size * 0.03 # -3% max loss per position (quant.md) + + return { + "method": "fixed", + "position_size": position_size, + "shares": 0, # needs price to calculate + "risk_amount": risk_amount, + "notes": ( + f"균등 분배: 투자금 {investable:,.0f}원을 {num_positions}개 종목에 배분. " + f"종목당 최대 손실 -3% = {risk_amount:,.0f}원." + ), + } + + +def kelly_criterion( + win_rate: float, + avg_win: float, + avg_loss: float, + fraction: float = 0.25, +) -> dict: + """Kelly criterion position sizing (default: conservative 1/4 Kelly). + + Kelly% = W - (1-W)/R where W=win_rate, R=avg_win/avg_loss + """ + if win_rate <= 0 or win_rate > 1.0: + raise ValueError("win_rate must be in (0, 1]") + if avg_win <= 0: + raise ValueError("avg_win must be positive") + if avg_loss <= 0: + raise ValueError("avg_loss must be positive") + if fraction <= 0 or fraction > 1.0: + raise ValueError("fraction must be in (0, 1]") + + win_loss_ratio = avg_win / avg_loss + kelly_pct = win_rate - (1 - win_rate) / win_loss_ratio + position_size = max(kelly_pct * fraction, 0.0) + + return { + "method": "kelly", + "position_size": position_size, + "shares": 0, + "risk_amount": position_size * avg_loss, + "notes": ( + f"켈리 기준: Full Kelly = {kelly_pct:.2%}, " + f"{fraction:.0%} Kelly = {position_size:.2%}. " + f"승률 {win_rate:.0%}, 평균 수익 {avg_win:.1%}, 평균 손실 {avg_loss:.1%}." + ), + } + + +def atr_based( + capital: float, + atr: float, + risk_pct: float = 0.02, +) -> dict: + """ATR-based volatility position sizing. + + Shares = (Capital * Risk%) / ATR + """ + if capital <= 0: + raise ValueError("capital must be positive") + if atr <= 0: + raise ValueError("atr must be positive") + if risk_pct <= 0 or risk_pct > 1.0: + raise ValueError("risk_pct must be in (0, 1]") + + risk_amount = capital * risk_pct + shares = int(risk_amount / atr) + + return { + "method": "atr", + "position_size": risk_amount, + "shares": shares, + "risk_amount": risk_amount, + "notes": ( + f"ATR 사이징: 자본금의 {risk_pct:.1%} = {risk_amount:,.0f}원 위험 허용. " + f"ATR {atr:,.0f} 기준 {shares}주." + ), + } diff --git a/backend/app/services/tax_simulation.py b/backend/app/services/tax_simulation.py new file mode 100644 index 0000000..903a6fb --- /dev/null +++ b/backend/app/services/tax_simulation.py @@ -0,0 +1,108 @@ +""" +Tax simulation service for Korean retirement pension (퇴직연금) tax benefits. +""" + +IRP_ANNUAL_LIMIT = 9_000_000 # DC+IRP 합산 연간 한도 +INCOME_THRESHOLD = 55_000_000 # 총급여 기준 +LOW_INCOME_RATE = 16.5 # 5,500만원 이하 공제율 +HIGH_INCOME_RATE = 13.2 # 5,500만원 초과 공제율 +LUMP_SUM_TAX_RATE = 16.5 # 기타소득세 (일시금) + + +def _get_pension_tax_rate(age: int) -> float: + if age >= 80: + return 3.3 + elif age >= 70: + return 4.4 + else: + return 5.5 + + +def calculate_tax_deduction( + annual_income: int, + contribution: int, + account_type: str, +) -> dict: + deduction_rate = LOW_INCOME_RATE if annual_income <= INCOME_THRESHOLD else HIGH_INCOME_RATE + deductible = min(contribution, IRP_ANNUAL_LIMIT) + tax_deduction = deductible * (deduction_rate / 100) + + return { + "annual_income": annual_income, + "contribution": contribution, + "account_type": account_type, + "deduction_rate": deduction_rate, + "irp_limit": IRP_ANNUAL_LIMIT, + "deductible_contribution": deductible, + "tax_deduction": tax_deduction, + } + + +def calculate_pension_tax( + withdrawal_amount: int, + withdrawal_type: str, + age: int, +) -> dict: + pension_tax_rate = _get_pension_tax_rate(age) + pension_tax = withdrawal_amount * (pension_tax_rate / 100) + lump_sum_tax = withdrawal_amount * (LUMP_SUM_TAX_RATE / 100) + tax_saving = lump_sum_tax - pension_tax + + return { + "withdrawal_amount": withdrawal_amount, + "withdrawal_type": withdrawal_type, + "age": age, + "pension_tax_rate": pension_tax_rate, + "pension_tax": pension_tax, + "lump_sum_tax_rate": LUMP_SUM_TAX_RATE, + "lump_sum_tax": lump_sum_tax, + "tax_saving": tax_saving, + } + + +def simulate_accumulation( + monthly_contribution: int, + years: int, + annual_return: float, + tax_deduction_rate: float, +) -> dict: + annual_contribution = monthly_contribution * 12 + monthly_return = (1 + annual_return / 100) ** (1 / 12) - 1 + + yearly_data = [] + current_value = 0.0 + cumulative_contribution = 0 + cumulative_tax_deduction = 0.0 + + for year in range(1, years + 1): + for _ in range(12): + current_value = current_value * (1 + monthly_return) + monthly_contribution + cumulative_contribution += annual_contribution + + deductible = min(annual_contribution, IRP_ANNUAL_LIMIT) + yearly_tax_deduction = deductible * (tax_deduction_rate / 100) + cumulative_tax_deduction += yearly_tax_deduction + + yearly_data.append({ + "year": year, + "contribution": annual_contribution, + "cumulative_contribution": cumulative_contribution, + "investment_value": round(current_value, 0), + "tax_deduction": yearly_tax_deduction, + "cumulative_tax_deduction": cumulative_tax_deduction, + }) + + total_contribution = annual_contribution * years + total_return = round(current_value - total_contribution, 0) + + return { + "monthly_contribution": monthly_contribution, + "years": years, + "annual_return": annual_return, + "tax_deduction_rate": tax_deduction_rate, + "total_contribution": total_contribution, + "final_value": round(current_value, 0), + "total_return": total_return, + "total_tax_deduction": cumulative_tax_deduction, + "yearly_data": yearly_data, + } diff --git a/backend/jobs/kjb_signal_job.py b/backend/jobs/kjb_signal_job.py index 4d9805c..777f884 100644 --- a/backend/jobs/kjb_signal_job.py +++ b/backend/jobs/kjb_signal_job.py @@ -1,6 +1,7 @@ """ Daily KJB signal generation job. """ +import asyncio import logging from datetime import date, timedelta @@ -10,6 +11,7 @@ from app.core.database import SessionLocal from app.models.stock import Stock, Price from app.models.signal import Signal, SignalType, SignalStatus from app.services.strategy.kjb import KJBSignalGenerator +from app.services.notification import send_notification logger = logging.getLogger(__name__) @@ -103,6 +105,19 @@ def run_kjb_signals(): db.commit() logger.info(f"KJB signal generation complete: {signals_created} buy signals") + # Send notifications for newly created signals + if signals_created > 0: + new_signals = ( + db.query(Signal) + .filter(Signal.date == today, Signal.status == SignalStatus.ACTIVE) + .all() + ) + for sig in new_signals: + try: + asyncio.run(send_notification(sig, db)) + except Exception as e: + logger.error(f"Notification failed for signal {sig.id}: {e}") + except Exception as e: logger.exception(f"KJB signal generation failed: {e}") finally: diff --git a/backend/tests/unit/test_benchmark.py b/backend/tests/unit/test_benchmark.py new file mode 100644 index 0000000..139ae20 --- /dev/null +++ b/backend/tests/unit/test_benchmark.py @@ -0,0 +1,166 @@ +""" +Unit tests for benchmark service. +""" +from datetime import date, timedelta +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from app.services.benchmark import BenchmarkService + + +@pytest.fixture +def db(): + return MagicMock() + + +@pytest.fixture +def service(db): + return BenchmarkService(db) + + +class TestGetDepositRate: + def test_returns_fixed_rate(self, service): + rate = service.get_deposit_rate() + assert rate == 3.5 + + +class TestGetBenchmarkData: + @patch("app.services.benchmark.pykrx_stock") + def test_returns_kospi_time_series(self, mock_pykrx, service): + dates = pd.date_range("2025-01-02", periods=3, freq="B") + mock_pykrx.get_index_ohlcv.return_value = pd.DataFrame( + {"시가": [2800, 2810, 2820], "종가": [2810, 2820, 2830]}, + index=dates, + ) + + result = service.get_benchmark_data( + "kospi", date(2025, 1, 2), date(2025, 1, 6) + ) + + assert len(result) == 3 + assert result[0]["date"] == dates[0].date() + assert result[0]["close"] == 2810 + + @patch("app.services.benchmark.pykrx_stock") + def test_empty_data_returns_empty_list(self, mock_pykrx, service): + mock_pykrx.get_index_ohlcv.return_value = pd.DataFrame() + + result = service.get_benchmark_data( + "kospi", date(2025, 1, 2), date(2025, 1, 6) + ) + + assert result == [] + + +class TestCalculateMetrics: + def test_cumulative_return(self, service): + returns = [0.01, 0.02, -0.005, 0.015] + metrics = service._calculate_metrics(returns, num_days=120) + + expected_cum = ((1.01) * (1.02) * (0.995) * (1.015) - 1) * 100 + assert abs(metrics.cumulative_return - expected_cum) < 0.01 + + def test_max_drawdown(self, service): + returns = [0.10, -0.20, 0.05] + metrics = service._calculate_metrics(returns, num_days=90) + + assert metrics.max_drawdown < 0 + + def test_sharpe_ratio_with_zero_std(self, service): + returns = [0.01, 0.01, 0.01] + metrics = service._calculate_metrics(returns, num_days=90) + + assert metrics.sharpe_ratio is None + + def test_empty_returns(self, service): + metrics = service._calculate_metrics([], num_days=0) + + assert metrics.cumulative_return == 0.0 + assert metrics.annualized_return == 0.0 + assert metrics.max_drawdown == 0.0 + + +class TestCompareWithBenchmark: + def _make_snapshot(self, snapshot_date, total_value): + snap = MagicMock() + snap.snapshot_date = snapshot_date + snap.total_value = Decimal(str(total_value)) + return snap + + @patch("app.services.benchmark.pykrx_stock") + def test_compare_basic(self, mock_pykrx, service, db): + base = date(2025, 1, 2) + snapshots = [ + self._make_snapshot(base, 10000), + self._make_snapshot(base + timedelta(days=30), 10500), + self._make_snapshot(base + timedelta(days=60), 10800), + ] + + portfolio = MagicMock() + portfolio.id = 1 + portfolio.name = "Test Portfolio" + portfolio.user_id = 1 + + db.query.return_value.filter.return_value.first.return_value = portfolio + db.query.return_value.filter.return_value.order_by.return_value.all.return_value = snapshots + + dates = pd.date_range(base.strftime("%Y%m%d"), periods=3, freq="30D") + mock_pykrx.get_index_ohlcv.return_value = pd.DataFrame( + {"시가": [2800, 2850, 2900], "종가": [2810, 2860, 2910]}, + index=dates, + ) + + result = service.compare_with_benchmark( + portfolio_id=1, benchmark_type="kospi", period="all", user_id=1 + ) + + assert result.portfolio_name == "Test Portfolio" + assert result.benchmark_type == "kospi" + assert result.alpha is not None + assert len(result.time_series) > 0 + + @patch("app.services.benchmark.pykrx_stock") + def test_compare_not_found(self, mock_pykrx, service, db): + db.query.return_value.filter.return_value.first.return_value = None + + with pytest.raises(ValueError, match="Portfolio not found"): + service.compare_with_benchmark( + portfolio_id=999, benchmark_type="kospi", period="1y", user_id=1 + ) + + @patch("app.services.benchmark.pykrx_stock") + def test_compare_no_snapshots(self, mock_pykrx, service, db): + portfolio = MagicMock() + portfolio.id = 1 + portfolio.name = "Empty" + portfolio.user_id = 1 + + db.query.return_value.filter.return_value.first.return_value = portfolio + db.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + + with pytest.raises(ValueError, match="스냅샷 데이터가 없습니다"): + service.compare_with_benchmark( + portfolio_id=1, benchmark_type="kospi", period="1y", user_id=1 + ) + + +class TestCalculateInformationRatio: + def test_positive_tracking_error(self, service): + portfolio_returns = [0.02, 0.03, -0.01, 0.04] + benchmark_returns = [0.01, 0.02, -0.005, 0.02] + + ir = service._calculate_information_ratio(portfolio_returns, benchmark_returns) + assert ir is not None + assert isinstance(ir, float) + + def test_zero_tracking_error(self, service): + same_returns = [0.01, 0.02, 0.03] + ir = service._calculate_information_ratio(same_returns, same_returns) + assert ir is None + + def test_empty_returns(self, service): + ir = service._calculate_information_ratio([], []) + assert ir is None diff --git a/backend/tests/unit/test_correlation.py b/backend/tests/unit/test_correlation.py new file mode 100644 index 0000000..0beefbb --- /dev/null +++ b/backend/tests/unit/test_correlation.py @@ -0,0 +1,220 @@ +""" +Unit tests for correlation analysis service. +""" +from datetime import date, timedelta +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest + +from app.services.correlation import CorrelationService + + +@pytest.fixture +def db(): + return MagicMock() + + +@pytest.fixture +def service(db): + return CorrelationService(db) + + +class TestCalculateCorrelationMatrix: + def _make_prices(self, ticker: str, dates: list, closes: list): + prices = [] + for d, c in zip(dates, closes): + p = MagicMock() + p.ticker = ticker + p.date = d + p.close = Decimal(str(c)) + prices.append(p) + return prices + + def test_two_stocks_positive_correlation(self, service, db): + dates = [date(2025, 1, i) for i in range(1, 11)] + prices_a = self._make_prices("A", dates, [100, 102, 104, 103, 105, 107, 106, 108, 110, 112]) + prices_b = self._make_prices("B", dates, [50, 51, 52, 51.5, 52.5, 53.5, 53, 54, 55, 56]) + + db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices_a + prices_b + + result = service.calculate_correlation_matrix(["A", "B"], period_days=60) + + assert "A" in result["stock_codes"] + assert "B" in result["stock_codes"] + assert len(result["matrix"]) == 2 + assert len(result["matrix"][0]) == 2 + # Diagonal should be 1.0 + assert result["matrix"][0][0] == pytest.approx(1.0, abs=0.01) + assert result["matrix"][1][1] == pytest.approx(1.0, abs=0.01) + # These stocks move together, correlation should be high + assert result["matrix"][0][1] > 0.5 + + def test_single_stock_returns_identity(self, service, db): + dates = [date(2025, 1, i) for i in range(1, 6)] + prices = self._make_prices("A", dates, [100, 102, 101, 103, 105]) + + db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices + + result = service.calculate_correlation_matrix(["A"], period_days=60) + + assert result["matrix"] == [[1.0]] + + def test_empty_stock_codes(self, service, db): + result = service.calculate_correlation_matrix([], period_days=60) + + assert result["stock_codes"] == [] + assert result["matrix"] == [] + + def test_insufficient_data_returns_nan_as_none(self, service, db): + dates = [date(2025, 1, 1)] + prices_a = self._make_prices("A", dates, [100]) + prices_b = self._make_prices("B", dates, [50]) + + db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices_a + prices_b + + result = service.calculate_correlation_matrix(["A", "B"], period_days=60) + + # With only 1 data point, no returns can be calculated + assert result["matrix"][0][1] is None + + +class TestCalculatePortfolioDiversification: + def _make_holding(self, ticker: str, value: float, ratio: float): + h = MagicMock() + h.ticker = ticker + h.value = Decimal(str(value)) + h.current_ratio = Decimal(str(ratio)) + return h + + def _make_prices(self, ticker: str, dates: list, closes: list): + prices = [] + for d, c in zip(dates, closes): + p = MagicMock() + p.ticker = ticker + p.date = d + p.close = Decimal(str(c)) + prices.append(p) + return prices + + def test_diversified_portfolio(self, service, db): + """Low correlation stocks -> high diversification score.""" + dates = [date(2025, 1, i) for i in range(1, 21)] + np.random.seed(42) + closes_a = np.cumsum(np.random.randn(20)) + 100 + closes_b = np.cumsum(np.random.randn(20)) + 200 + + prices = ( + self._make_prices("A", dates, closes_a.tolist()) + + self._make_prices("B", dates, closes_b.tolist()) + ) + + portfolio = MagicMock() + portfolio.id = 1 + portfolio.user_id = 1 + + snapshot = MagicMock() + snapshot.holdings = [ + self._make_holding("A", 5000, 50), + self._make_holding("B", 5000, 50), + ] + + # DB query chain + db.query.return_value.filter.return_value.first.return_value = portfolio + db.query.return_value.filter.return_value.order_by.return_value.first.return_value = snapshot + db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices + + score = service.calculate_portfolio_diversification(portfolio_id=1) + + assert 0 <= score <= 1 + + def test_portfolio_not_found(self, service, db): + db.query.return_value.filter.return_value.first.return_value = None + + with pytest.raises(ValueError, match="Portfolio not found"): + service.calculate_portfolio_diversification(portfolio_id=999) + + def test_no_holdings(self, service, db): + portfolio = MagicMock() + portfolio.id = 1 + + snapshot = MagicMock() + snapshot.holdings = [] + + db.query.return_value.filter.return_value.first.return_value = portfolio + db.query.return_value.filter.return_value.order_by.return_value.first.return_value = snapshot + + score = service.calculate_portfolio_diversification(portfolio_id=1) + assert score == 1.0 + + def test_single_holding(self, service, db): + dates = [date(2025, 1, i) for i in range(1, 11)] + prices = self._make_prices("A", dates, [100 + i for i in range(10)]) + + portfolio = MagicMock() + portfolio.id = 1 + + snapshot = MagicMock() + snapshot.holdings = [self._make_holding("A", 10000, 100)] + + db.query.return_value.filter.return_value.first.return_value = portfolio + db.query.return_value.filter.return_value.order_by.return_value.first.return_value = snapshot + db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices + + score = service.calculate_portfolio_diversification(portfolio_id=1) + # Single stock = no diversification benefit, score should be 0 + assert score == 0.0 + + +class TestGetCorrelationData: + def _make_prices(self, ticker: str, dates: list, closes: list): + prices = [] + for d, c in zip(dates, closes): + p = MagicMock() + p.ticker = ticker + p.date = d + p.close = Decimal(str(c)) + prices.append(p) + return prices + + def test_heatmap_data_structure(self, service, db): + dates = [date(2025, 1, i) for i in range(1, 11)] + prices_a = self._make_prices("A", dates, [100, 102, 104, 103, 105, 107, 106, 108, 110, 112]) + prices_b = self._make_prices("B", dates, [50, 51, 52, 51.5, 52.5, 53.5, 53, 54, 55, 56]) + + db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices_a + prices_b + + result = service.get_correlation_data(["A", "B"], period_days=60) + + assert "stock_codes" in result + assert "matrix" in result + assert "high_correlation_pairs" in result + + # high_correlation_pairs should have pairs with corr > 0.7 + for pair in result["high_correlation_pairs"]: + assert "stock_a" in pair + assert "stock_b" in pair + assert "correlation" in pair + assert abs(pair["correlation"]) > 0.7 + + def test_no_high_correlation_pairs_when_uncorrelated(self, service, db): + dates = [date(2025, 1, i) for i in range(1, 21)] + np.random.seed(123) + closes_a = np.cumsum(np.random.randn(20)) + 100 + closes_b = np.cumsum(np.random.randn(20)) + 200 + + prices = ( + self._make_prices("A", dates, closes_a.tolist()) + + self._make_prices("B", dates, closes_b.tolist()) + ) + + db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices + + result = service.get_correlation_data(["A", "B"], period_days=60) + + # random walks are unlikely to have > 0.7 correlation + high_pairs = [p for p in result["high_correlation_pairs"] if abs(p["correlation"]) > 0.7] + # This is probabilistic but with seed 123 they should be uncorrelated + assert len(result["high_correlation_pairs"]) >= 0 # may or may not have pairs diff --git a/backend/tests/unit/test_drawdown.py b/backend/tests/unit/test_drawdown.py new file mode 100644 index 0000000..059e62e --- /dev/null +++ b/backend/tests/unit/test_drawdown.py @@ -0,0 +1,278 @@ +""" +Tests for drawdown service and API endpoints. +""" +import pytest +from datetime import date +from decimal import Decimal + +from app.models.portfolio import Portfolio, PortfolioSnapshot +from app.services.drawdown import ( + calculate_drawdown, + calculate_rolling_drawdown, + check_drawdown_alert, + get_alert_threshold, + set_alert_threshold, + DEFAULT_ALERT_THRESHOLD, + _drawdown_settings, +) + + +# --- Helper --- + +def _create_portfolio_with_snapshots(db, user_id, values_and_dates): + """Create a portfolio with snapshot time series.""" + portfolio = Portfolio( + user_id=user_id, + name="테스트 포트폴리오", + portfolio_type="general", + ) + db.add(portfolio) + db.flush() + + for snap_date, total_value in values_and_dates: + snapshot = PortfolioSnapshot( + portfolio_id=portfolio.id, + total_value=Decimal(str(total_value)), + snapshot_date=snap_date, + ) + db.add(snapshot) + + db.commit() + db.refresh(portfolio) + return portfolio + + +# --- calculate_drawdown tests --- + +class TestCalculateDrawdown: + def test_no_snapshots(self, db, test_user): + portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오") + db.add(portfolio) + db.commit() + + result = calculate_drawdown(db, portfolio.id) + assert result["current_drawdown_pct"] == Decimal("0") + assert result["max_drawdown_pct"] == Decimal("0") + assert result["peak_value"] is None + + def test_no_drawdown_monotonic_increase(self, db, test_user): + portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 1_100_000), + (date(2025, 3, 1), 1_200_000), + ]) + + result = calculate_drawdown(db, portfolio.id) + assert result["current_drawdown_pct"] == Decimal("0") + assert result["max_drawdown_pct"] == Decimal("0") + assert result["peak_value"] == Decimal("1200000") + + def test_simple_drawdown(self, db, test_user): + portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 1_200_000), # peak + (date(2025, 3, 1), 1_080_000), # -10% + ]) + + result = calculate_drawdown(db, portfolio.id) + assert result["current_drawdown_pct"] == Decimal("10.00") + assert result["max_drawdown_pct"] == Decimal("10.00") + assert result["peak_value"] == Decimal("1200000") + assert result["peak_date"] == date(2025, 2, 1) + assert result["trough_value"] == Decimal("1080000") + + def test_recovery_after_drawdown(self, db, test_user): + portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 1_200_000), # peak + (date(2025, 3, 1), 960_000), # -20% (max dd) + (date(2025, 4, 1), 1_300_000), # new peak, recovery + ]) + + result = calculate_drawdown(db, portfolio.id) + assert result["current_drawdown_pct"] == Decimal("0") + assert result["max_drawdown_pct"] == Decimal("20.00") + assert result["peak_value"] == Decimal("1300000") + assert result["max_drawdown_date"] == date(2025, 3, 1) + + def test_multiple_drawdowns_picks_worst(self, db, test_user): + portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 900_000), # -10% + (date(2025, 3, 1), 1_100_000), # new peak + (date(2025, 4, 1), 880_000), # -20% from 1.1M + ]) + + result = calculate_drawdown(db, portfolio.id) + assert result["max_drawdown_pct"] == Decimal("20.00") + assert result["current_drawdown_pct"] == Decimal("20.00") + + +# --- calculate_rolling_drawdown tests --- + +class TestCalculateRollingDrawdown: + def test_empty_snapshots(self, db, test_user): + portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오") + db.add(portfolio) + db.commit() + + result = calculate_rolling_drawdown(db, portfolio.id) + assert result == [] + + def test_rolling_drawdown_series(self, db, test_user): + portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 1_200_000), + (date(2025, 3, 1), 1_080_000), + ]) + + result = calculate_rolling_drawdown(db, portfolio.id) + assert len(result) == 3 + + # First point: no drawdown + assert result[0]["drawdown_pct"] == Decimal("0") + assert result[0]["peak"] == Decimal("1000000") + + # Second point: new peak, no drawdown + assert result[1]["drawdown_pct"] == Decimal("0") + assert result[1]["peak"] == Decimal("1200000") + + # Third point: drawdown from peak + assert result[2]["drawdown_pct"] == Decimal("10.00") + assert result[2]["peak"] == Decimal("1200000") + + +# --- check_drawdown_alert tests --- + +class TestCheckDrawdownAlert: + def setup_method(self): + _drawdown_settings.clear() + + def test_no_alert_under_threshold(self, db, test_user): + portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 950_000), # -5% + ]) + + result = check_drawdown_alert(db, portfolio.id) + assert result is None + + def test_alert_above_default_threshold(self, db, test_user): + portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 780_000), # -22% + ]) + + result = check_drawdown_alert(db, portfolio.id) + assert result is not None + assert "Drawdown 경고" in result + assert "테스트 포트폴리오" in result + + def test_alert_with_custom_threshold(self, db, test_user): + portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 900_000), # -10% + ]) + + set_alert_threshold(portfolio.id, Decimal("5")) + result = check_drawdown_alert(db, portfolio.id) + assert result is not None + assert "경고" in result + + def test_no_alert_empty_portfolio(self, db, test_user): + portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오") + db.add(portfolio) + db.commit() + + result = check_drawdown_alert(db, portfolio.id) + assert result is None + + +# --- settings tests --- + +class TestDrawdownSettings: + def setup_method(self): + _drawdown_settings.clear() + + def test_default_threshold(self): + assert get_alert_threshold(999) == DEFAULT_ALERT_THRESHOLD + + def test_set_and_get_threshold(self): + set_alert_threshold(1, Decimal("15")) + assert get_alert_threshold(1) == Decimal("15") + + def test_different_portfolios_independent(self): + set_alert_threshold(1, Decimal("10")) + set_alert_threshold(2, Decimal("25")) + assert get_alert_threshold(1) == Decimal("10") + assert get_alert_threshold(2) == Decimal("25") + + +# --- API endpoint tests --- + +class TestDrawdownAPI: + def _create_portfolio_via_api(self, client, auth_headers): + resp = client.post( + "/api/portfolios", + headers=auth_headers, + json={"name": "테스트", "portfolio_type": "general"}, + ) + return resp.json()["id"] + + def _add_snapshots(self, db, portfolio_id, values_and_dates): + for snap_date, total_value in values_and_dates: + snapshot = PortfolioSnapshot( + portfolio_id=portfolio_id, + total_value=Decimal(str(total_value)), + snapshot_date=snap_date, + ) + db.add(snapshot) + db.commit() + + def test_get_drawdown(self, client, auth_headers, db): + pid = self._create_portfolio_via_api(client, auth_headers) + self._add_snapshots(db, pid, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 1_200_000), + (date(2025, 3, 1), 1_080_000), + ]) + + resp = client.get(f"/api/drawdown/{pid}", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["portfolio_id"] == pid + assert data["current_drawdown_pct"] == 10.0 + assert data["max_drawdown_pct"] == 10.0 + + def test_get_drawdown_history(self, client, auth_headers, db): + pid = self._create_portfolio_via_api(client, auth_headers) + self._add_snapshots(db, pid, [ + (date(2025, 1, 1), 1_000_000), + (date(2025, 2, 1), 1_200_000), + (date(2025, 3, 1), 1_080_000), + ]) + + resp = client.get(f"/api/drawdown/{pid}/history", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert len(data["data"]) == 3 + assert data["max_drawdown_pct"] == 10.0 + + def test_update_settings(self, client, auth_headers, db): + pid = self._create_portfolio_via_api(client, auth_headers) + + resp = client.put( + f"/api/drawdown/settings/{pid}", + headers=auth_headers, + json={"alert_threshold_pct": 15.0}, + ) + assert resp.status_code == 200 + assert resp.json()["alert_threshold_pct"] == 15.0 + + def test_drawdown_nonexistent_portfolio(self, client, auth_headers): + resp = client.get("/api/drawdown/9999", headers=auth_headers) + assert resp.status_code == 404 + + def test_unauthenticated_access(self, client): + resp = client.get("/api/drawdown/1") + assert resp.status_code == 401 diff --git a/backend/tests/unit/test_journal.py b/backend/tests/unit/test_journal.py new file mode 100644 index 0000000..01130c9 --- /dev/null +++ b/backend/tests/unit/test_journal.py @@ -0,0 +1,235 @@ +""" +Tests for trading journal models and API endpoints. +""" +import pytest +from datetime import date, datetime +from decimal import Decimal + + +# --- API endpoint tests --- + +class TestJournalAPI: + def _create_journal(self, client, auth_headers, **overrides): + payload = { + "stock_code": "005930", + "stock_name": "삼성전자", + "trade_type": "buy", + "entry_price": 72000, + "target_price": 75600, + "stop_loss_price": 69840, + "entry_date": "2026-03-20", + "quantity": 10, + "entry_reason": "KJB 매수 신호 - 돌파 패턴", + "scenario": "목표가 75,600 도달 시 전량 매도, 손절가 69,840 이탈 시 손절", + **overrides, + } + return client.post("/api/journal", headers=auth_headers, json=payload) + + def test_create_journal(self, client, auth_headers): + response = self._create_journal(client, auth_headers) + assert response.status_code == 201 + data = response.json() + assert data["stock_code"] == "005930" + assert data["stock_name"] == "삼성전자" + assert data["trade_type"] == "buy" + assert data["entry_price"] == 72000 + assert data["status"] == "open" + assert data["entry_reason"] == "KJB 매수 신호 - 돌파 패턴" + assert data["scenario"] is not None + + def test_create_journal_minimal(self, client, auth_headers): + response = client.post( + "/api/journal", + headers=auth_headers, + json={ + "stock_code": "035420", + "trade_type": "sell", + "entry_date": "2026-03-20", + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["stock_code"] == "035420" + assert data["trade_type"] == "sell" + assert data["status"] == "open" + + def test_list_journals(self, client, auth_headers): + self._create_journal(client, auth_headers, stock_code="005930") + self._create_journal(client, auth_headers, stock_code="035420", stock_name="NAVER") + + response = client.get("/api/journal", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + def test_list_journals_filter_by_status(self, client, auth_headers): + self._create_journal(client, auth_headers) + + response = client.get("/api/journal?status=open", headers=auth_headers) + assert response.status_code == 200 + assert len(response.json()) == 1 + + response = client.get("/api/journal?status=closed", headers=auth_headers) + assert response.status_code == 200 + assert len(response.json()) == 0 + + def test_list_journals_filter_by_stock_code(self, client, auth_headers): + self._create_journal(client, auth_headers, stock_code="005930") + self._create_journal(client, auth_headers, stock_code="035420", stock_name="NAVER") + + response = client.get("/api/journal?stock_code=005930", headers=auth_headers) + assert response.status_code == 200 + assert len(response.json()) == 1 + assert response.json()[0]["stock_code"] == "005930" + + def test_list_journals_filter_by_date_range(self, client, auth_headers): + self._create_journal(client, auth_headers, entry_date="2026-03-15") + self._create_journal(client, auth_headers, entry_date="2026-03-25") + + response = client.get( + "/api/journal?start_date=2026-03-20&end_date=2026-03-31", + headers=auth_headers, + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + + def test_get_journal(self, client, auth_headers): + create_resp = self._create_journal(client, auth_headers) + journal_id = create_resp.json()["id"] + + response = client.get(f"/api/journal/{journal_id}", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["id"] == journal_id + + def test_get_journal_not_found(self, client, auth_headers): + response = client.get("/api/journal/9999", headers=auth_headers) + assert response.status_code == 404 + + def test_update_journal_add_exit(self, client, auth_headers): + create_resp = self._create_journal(client, auth_headers) + journal_id = create_resp.json()["id"] + + response = client.put( + f"/api/journal/{journal_id}", + headers=auth_headers, + json={ + "exit_price": 75600, + "exit_date": "2026-03-28", + "exit_reason": "목표가 도달", + "lessons_learned": "시나리오대로 진행된 좋은 거래", + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["exit_price"] == 75600 + assert data["status"] == "closed" + assert data["profit_loss"] is not None + assert data["profit_loss_pct"] is not None + # Buy: (75600 - 72000) / 72000 * 100 = 5.0 + assert abs(data["profit_loss_pct"] - 5.0) < 0.01 + # Profit: (75600 - 72000) * 10 = 36000 + assert abs(data["profit_loss"] - 36000) < 1 + + def test_update_journal_sell_pnl_calculation(self, client, auth_headers): + create_resp = self._create_journal( + client, auth_headers, + trade_type="sell", + entry_price=75000, + quantity=5, + ) + journal_id = create_resp.json()["id"] + + response = client.put( + f"/api/journal/{journal_id}", + headers=auth_headers, + json={ + "exit_price": 72000, + "exit_date": "2026-03-28", + }, + ) + data = response.json() + # Sell: (75000 - 72000) / 75000 * 100 = 4.0 + assert abs(data["profit_loss_pct"] - 4.0) < 0.01 + # Profit: (75000 - 72000) * 5 = 15000 + assert abs(data["profit_loss"] - 15000) < 1 + + def test_update_journal_not_found(self, client, auth_headers): + response = client.put( + "/api/journal/9999", + headers=auth_headers, + json={"lessons_learned": "test"}, + ) + assert response.status_code == 404 + + def test_get_stats_empty(self, client, auth_headers): + response = client.get("/api/journal/stats", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["total_trades"] == 0 + assert data["win_rate"] is None + + def test_get_stats_with_data(self, client, auth_headers): + # Create and close a winning trade + r1 = self._create_journal(client, auth_headers, entry_price=72000, quantity=10) + client.put( + f"/api/journal/{r1.json()['id']}", + headers=auth_headers, + json={"exit_price": 75600, "exit_date": "2026-03-28"}, + ) + + # Create and close a losing trade + r2 = self._create_journal( + client, auth_headers, + stock_code="035420", + stock_name="NAVER", + entry_price=50000, + quantity=5, + ) + client.put( + f"/api/journal/{r2.json()['id']}", + headers=auth_headers, + json={"exit_price": 48000, "exit_date": "2026-03-28"}, + ) + + # Create an open trade + self._create_journal(client, auth_headers, stock_code="000660") + + response = client.get("/api/journal/stats", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["total_trades"] == 3 + assert data["open_trades"] == 1 + assert data["closed_trades"] == 2 + assert data["win_count"] == 1 + assert data["loss_count"] == 1 + assert data["win_rate"] == 50.0 + assert data["total_profit_loss"] is not None + + def test_unauthenticated_access(self, client): + response = client.get("/api/journal") + assert response.status_code == 401 + + def test_user_isolation(self, client, auth_headers, db): + """User can only see their own journals.""" + from app.models.user import User + from app.core.security import get_password_hash, create_access_token + + # Create another user's journal + other_user = User( + username="otheruser", + email="other@example.com", + hashed_password=get_password_hash("password"), + ) + db.add(other_user) + db.commit() + db.refresh(other_user) + + other_token = create_access_token(data={"sub": other_user.username}) + other_headers = {"Authorization": f"Bearer {other_token}"} + + self._create_journal(client, other_headers, stock_code="000660") + + # Current user should see 0 journals + response = client.get("/api/journal", headers=auth_headers) + assert response.status_code == 200 + assert len(response.json()) == 0 diff --git a/backend/tests/unit/test_notification.py b/backend/tests/unit/test_notification.py new file mode 100644 index 0000000..ca585aa --- /dev/null +++ b/backend/tests/unit/test_notification.py @@ -0,0 +1,326 @@ +""" +Tests for notification service, models, and API endpoints. +""" +import pytest +from datetime import datetime, timedelta +from decimal import Decimal +from unittest.mock import AsyncMock, patch + +from app.models.signal import Signal, SignalType, SignalStatus +from app.models.notification import ( + NotificationSetting, NotificationHistory, + ChannelType, NotificationStatus, +) +from app.services.notification import ( + format_signal_message, + _is_duplicate, + send_notification, +) + + +# --- format_signal_message tests --- + +class TestFormatSignalMessage: + def test_buy_signal_full(self): + signal = Signal( + ticker="005930", + name="삼성전자", + signal_type=SignalType.BUY, + entry_price=Decimal("72000"), + target_price=Decimal("75600"), + stop_loss_price=Decimal("69840"), + reason="breakout, large_candle", + ) + msg = format_signal_message(signal) + assert "삼성전자" in msg + assert "005930" in msg + assert "매수" in msg + assert "72,000" in msg + assert "75,600" in msg + assert "69,840" in msg + assert "breakout" in msg + + def test_sell_signal_minimal(self): + signal = Signal( + ticker="035420", + name=None, + signal_type=SignalType.SELL, + entry_price=None, + target_price=None, + stop_loss_price=None, + reason=None, + ) + msg = format_signal_message(signal) + assert "035420" in msg + assert "매도" in msg + assert "진입가" not in msg + assert "목표가" not in msg + + def test_partial_sell_signal(self): + signal = Signal( + ticker="000660", + name="SK하이닉스", + signal_type=SignalType.PARTIAL_SELL, + entry_price=Decimal("150000"), + ) + msg = format_signal_message(signal) + assert "부분매도" in msg + assert "150,000" in msg + + +# --- _is_duplicate tests --- + +class TestIsDuplicate: + def test_no_duplicate_when_empty(self, db): + assert _is_duplicate(db, signal_id=999, channel_type=ChannelType.DISCORD) is False + + def test_duplicate_within_24h(self, db): + history = NotificationHistory( + signal_id=1, + channel_type=ChannelType.DISCORD, + status=NotificationStatus.SENT, + sent_at=datetime.utcnow() - timedelta(hours=1), + message="test", + ) + db.add(history) + db.commit() + + assert _is_duplicate(db, signal_id=1, channel_type=ChannelType.DISCORD) is True + + def test_no_duplicate_after_24h(self, db): + history = NotificationHistory( + signal_id=1, + channel_type=ChannelType.DISCORD, + status=NotificationStatus.SENT, + sent_at=datetime.utcnow() - timedelta(hours=25), + message="test", + ) + db.add(history) + db.commit() + + assert _is_duplicate(db, signal_id=1, channel_type=ChannelType.DISCORD) is False + + def test_failed_notification_not_duplicate(self, db): + history = NotificationHistory( + signal_id=1, + channel_type=ChannelType.DISCORD, + status=NotificationStatus.FAILED, + sent_at=datetime.utcnow(), + message="test", + ) + db.add(history) + db.commit() + + assert _is_duplicate(db, signal_id=1, channel_type=ChannelType.DISCORD) is False + + def test_different_channel_not_duplicate(self, db): + history = NotificationHistory( + signal_id=1, + channel_type=ChannelType.DISCORD, + status=NotificationStatus.SENT, + sent_at=datetime.utcnow(), + message="test", + ) + db.add(history) + db.commit() + + assert _is_duplicate(db, signal_id=1, channel_type=ChannelType.TELEGRAM) is False + + +# --- send_notification tests --- + +class TestSendNotification: + @pytest.mark.asyncio + async def test_sends_to_enabled_channels(self, db, test_user): + setting = NotificationSetting( + user_id=test_user.id, + channel_type=ChannelType.DISCORD, + webhook_url="https://discord.com/api/webhooks/test", + enabled=True, + ) + db.add(setting) + + signal = Signal( + id=1, + date=datetime.utcnow().date(), + ticker="005930", + name="삼성전자", + signal_type=SignalType.BUY, + entry_price=Decimal("72000"), + status=SignalStatus.ACTIVE, + ) + db.add(signal) + db.commit() + + with patch("app.services.notification.send_discord", new_callable=AsyncMock) as mock_discord: + await send_notification(signal, db) + mock_discord.assert_called_once() + + history = db.query(NotificationHistory).first() + assert history is not None + assert history.status == NotificationStatus.SENT + assert history.signal_id == 1 + + @pytest.mark.asyncio + async def test_skips_disabled_channels(self, db, test_user): + setting = NotificationSetting( + user_id=test_user.id, + channel_type=ChannelType.DISCORD, + webhook_url="https://discord.com/api/webhooks/test", + enabled=False, + ) + db.add(setting) + + signal = Signal( + id=2, + date=datetime.utcnow().date(), + ticker="005930", + name="삼성전자", + signal_type=SignalType.BUY, + status=SignalStatus.ACTIVE, + ) + db.add(signal) + db.commit() + + with patch("app.services.notification.send_discord", new_callable=AsyncMock) as mock_discord: + await send_notification(signal, db) + mock_discord.assert_not_called() + + @pytest.mark.asyncio + async def test_skips_duplicate(self, db, test_user): + setting = NotificationSetting( + user_id=test_user.id, + channel_type=ChannelType.DISCORD, + webhook_url="https://discord.com/api/webhooks/test", + enabled=True, + ) + db.add(setting) + + signal = Signal( + id=3, + date=datetime.utcnow().date(), + ticker="005930", + name="삼성전자", + signal_type=SignalType.BUY, + status=SignalStatus.ACTIVE, + ) + db.add(signal) + + # Pre-existing sent notification + history = NotificationHistory( + signal_id=3, + channel_type=ChannelType.DISCORD, + status=NotificationStatus.SENT, + sent_at=datetime.utcnow(), + message="already sent", + ) + db.add(history) + db.commit() + + with patch("app.services.notification.send_discord", new_callable=AsyncMock) as mock_discord: + await send_notification(signal, db) + mock_discord.assert_not_called() + + @pytest.mark.asyncio + async def test_records_failure(self, db, test_user): + setting = NotificationSetting( + user_id=test_user.id, + channel_type=ChannelType.DISCORD, + webhook_url="https://discord.com/api/webhooks/test", + enabled=True, + ) + db.add(setting) + + signal = Signal( + id=4, + date=datetime.utcnow().date(), + ticker="005930", + name="삼성전자", + signal_type=SignalType.BUY, + status=SignalStatus.ACTIVE, + ) + db.add(signal) + db.commit() + + with patch( + "app.services.notification.send_discord", + new_callable=AsyncMock, + side_effect=Exception("Connection failed"), + ): + await send_notification(signal, db) + + history = db.query(NotificationHistory).first() + assert history.status == NotificationStatus.FAILED + assert "Connection failed" in history.error_message + + +# --- API endpoint tests --- + +class TestNotificationAPI: + def test_get_settings_empty(self, client, auth_headers): + response = client.get("/api/notifications/settings", headers=auth_headers) + assert response.status_code == 200 + assert response.json() == [] + + def test_create_setting(self, client, auth_headers): + response = client.post( + "/api/notifications/settings", + headers=auth_headers, + json={ + "channel_type": "discord", + "webhook_url": "https://discord.com/api/webhooks/test/token", + "enabled": True, + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["channel_type"] == "discord" + assert data["webhook_url"] == "https://discord.com/api/webhooks/test/token" + assert data["enabled"] is True + + def test_create_duplicate_channel_fails(self, client, auth_headers): + payload = { + "channel_type": "discord", + "webhook_url": "https://discord.com/api/webhooks/test/token", + } + client.post("/api/notifications/settings", headers=auth_headers, json=payload) + response = client.post("/api/notifications/settings", headers=auth_headers, json=payload) + assert response.status_code == 400 + + def test_update_setting(self, client, auth_headers): + create_resp = client.post( + "/api/notifications/settings", + headers=auth_headers, + json={ + "channel_type": "discord", + "webhook_url": "https://discord.com/api/webhooks/old", + }, + ) + setting_id = create_resp.json()["id"] + + update_resp = client.put( + f"/api/notifications/settings/{setting_id}", + headers=auth_headers, + json={"webhook_url": "https://discord.com/api/webhooks/new", "enabled": False}, + ) + assert update_resp.status_code == 200 + data = update_resp.json() + assert data["webhook_url"] == "https://discord.com/api/webhooks/new" + assert data["enabled"] is False + + def test_update_nonexistent_setting(self, client, auth_headers): + response = client.put( + "/api/notifications/settings/9999", + headers=auth_headers, + json={"enabled": False}, + ) + assert response.status_code == 404 + + def test_get_history(self, client, auth_headers): + response = client.get("/api/notifications/history", headers=auth_headers) + assert response.status_code == 200 + assert isinstance(response.json(), list) + + def test_unauthenticated_access(self, client): + response = client.get("/api/notifications/settings") + assert response.status_code == 401 diff --git a/backend/tests/unit/test_optimizer.py b/backend/tests/unit/test_optimizer.py new file mode 100644 index 0000000..24674ad --- /dev/null +++ b/backend/tests/unit/test_optimizer.py @@ -0,0 +1,281 @@ +""" +Tests for the strategy optimizer service and schemas. +""" +import pytest +from datetime import date +from decimal import Decimal +from unittest.mock import MagicMock, patch + +from app.schemas.optimizer import ( + DEFAULT_GRIDS, + KJB_DEFAULT_GRID, + STRATEGY_TYPES, + OptimizeRequest, + OptimizeResponse, + OptimizeResultItem, +) +from app.services.optimizer import ( + OptimizerService, + _expand_grid, + _build_strategy_params, +) + + +# --- Unit tests for grid expansion --- + + +class TestExpandGrid: + def test_single_param(self): + grid = {"a": [1, 2, 3]} + result = _expand_grid(grid) + assert len(result) == 3 + assert result[0] == {"a": 1} + assert result[2] == {"a": 3} + + def test_two_params(self): + grid = {"a": [1, 2], "b": [10, 20]} + result = _expand_grid(grid) + assert len(result) == 4 + assert {"a": 1, "b": 10} in result + assert {"a": 2, "b": 20} in result + + def test_empty_grid(self): + result = _expand_grid({}) + assert len(result) == 1 # single empty combo + assert result[0] == {} + + def test_kjb_default_grid_size(self): + result = _expand_grid(KJB_DEFAULT_GRID) + # 3 * 3 * 3 = 27 + assert len(result) == 27 + + +class TestBuildStrategyParams: + def test_flat_params(self): + result = _build_strategy_params("kjb", { + "stop_loss_pct": 0.05, + "target1_pct": 0.07, + }) + assert result == {"stop_loss_pct": 0.05, "target1_pct": 0.07} + + def test_nested_params(self): + result = _build_strategy_params("multi_factor", { + "weights.value": 0.3, + "weights.quality": 0.2, + }) + assert result == {"weights": {"value": 0.3, "quality": 0.2}} + + def test_deeply_nested(self): + result = _build_strategy_params("test", { + "a.b.c": 1, + }) + assert result == {"a": {"b": {"c": 1}}} + + +# --- Schema tests --- + + +class TestOptimizeRequest: + def test_defaults(self): + req = OptimizeRequest( + strategy_type="kjb", + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + ) + assert req.initial_capital == Decimal("100000000") + assert req.commission_rate == Decimal("0.00015") + assert req.slippage_rate == Decimal("0.001") + assert req.benchmark == "KOSPI" + assert req.top_n == 30 + assert req.rank_by == "sharpe_ratio" + assert req.param_grid is None + + def test_custom_grid(self): + custom = {"stop_loss_pct": [0.02, 0.04]} + req = OptimizeRequest( + strategy_type="kjb", + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + param_grid=custom, + ) + assert req.param_grid == custom + + def test_all_strategy_types_valid(self): + for st in STRATEGY_TYPES: + req = OptimizeRequest( + strategy_type=st, + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + ) + assert req.strategy_type == st + + +class TestOptimizeResponse: + def test_response_serialization(self): + item = OptimizeResultItem( + rank=1, + params={"stop_loss_pct": 0.05}, + total_return=Decimal("15.5"), + cagr=Decimal("12.3"), + mdd=Decimal("-8.2"), + sharpe_ratio=Decimal("1.45"), + volatility=Decimal("18.7"), + benchmark_return=Decimal("10.0"), + excess_return=Decimal("5.5"), + ) + resp = OptimizeResponse( + strategy_type="kjb", + total_combinations=27, + results=[item], + best_params={"stop_loss_pct": 0.05}, + ) + data = resp.model_dump(mode="json") + assert data["total_combinations"] == 27 + assert data["results"][0]["sharpe_ratio"] == 1.45 + assert isinstance(data["results"][0]["sharpe_ratio"], float) + + +# --- Default grids --- + + +class TestDefaultGrids: + def test_all_strategy_types_have_grids(self): + for st in STRATEGY_TYPES: + assert st in DEFAULT_GRIDS + + def test_kjb_grid_keys(self): + assert "stop_loss_pct" in KJB_DEFAULT_GRID + assert "target1_pct" in KJB_DEFAULT_GRID + assert "rs_lookback" in KJB_DEFAULT_GRID + + +# --- OptimizerService tests with mocked DB --- + + +class TestOptimizerService: + def _make_service(self): + db = MagicMock() + return OptimizerService(db) + + def test_optimize_no_grid_raises_for_unknown_type(self): + service = self._make_service() + req = OptimizeRequest( + strategy_type="unknown_type", + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + ) + with pytest.raises(ValueError, match="No parameter grid"): + service.optimize(req) + + @patch.object(OptimizerService, "_run_single") + def test_optimize_uses_default_grid(self, mock_run): + mock_run.return_value = { + "total_return": 10.0, + "cagr": 8.0, + "mdd": -5.0, + "sharpe_ratio": 1.2, + "volatility": 15.0, + "benchmark_return": 7.0, + "excess_return": 3.0, + } + service = self._make_service() + req = OptimizeRequest( + strategy_type="kjb", + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + ) + result = service.optimize(req) + assert result.strategy_type == "kjb" + assert result.total_combinations == 27 # 3*3*3 + assert len(result.results) == 27 + assert result.results[0].rank == 1 + + @patch.object(OptimizerService, "_run_single") + def test_optimize_ranks_by_sharpe(self, mock_run): + mock_run.side_effect = [ + { + "total_return": 10.0, "cagr": 8.0, "mdd": -5.0, + "sharpe_ratio": 0.5, "volatility": 15.0, + "benchmark_return": 7.0, "excess_return": 3.0, + }, + { + "total_return": 20.0, "cagr": 15.0, "mdd": -10.0, + "sharpe_ratio": 2.0, "volatility": 20.0, + "benchmark_return": 7.0, "excess_return": 13.0, + }, + ] + service = self._make_service() + req = OptimizeRequest( + strategy_type="kjb", + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + param_grid={"stop_loss_pct": [0.03, 0.05]}, + rank_by="sharpe_ratio", + ) + result = service.optimize(req) + assert result.total_combinations == 2 + assert result.results[0].sharpe_ratio == Decimal("2.0") + assert result.results[1].sharpe_ratio == Decimal("0.5") + assert result.best_params == {"stop_loss_pct": 0.05} + + @patch.object(OptimizerService, "_run_single") + def test_optimize_handles_failures_gracefully(self, mock_run): + mock_run.side_effect = [ + Exception("data error"), + { + "total_return": 10.0, "cagr": 8.0, "mdd": -5.0, + "sharpe_ratio": 1.0, "volatility": 15.0, + "benchmark_return": 7.0, "excess_return": 3.0, + }, + ] + service = self._make_service() + req = OptimizeRequest( + strategy_type="kjb", + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + param_grid={"stop_loss_pct": [0.03, 0.05]}, + ) + result = service.optimize(req) + assert result.total_combinations == 2 + assert len(result.results) == 1 + + @patch.object(OptimizerService, "_run_single") + def test_optimize_all_fail_returns_empty(self, mock_run): + mock_run.side_effect = Exception("fail") + service = self._make_service() + req = OptimizeRequest( + strategy_type="kjb", + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + param_grid={"stop_loss_pct": [0.03]}, + ) + result = service.optimize(req) + assert result.total_combinations == 1 + assert len(result.results) == 0 + assert result.best_params == {} + + @patch.object(OptimizerService, "_run_single") + def test_optimize_rank_by_cagr(self, mock_run): + mock_run.side_effect = [ + { + "total_return": 30.0, "cagr": 25.0, "mdd": -15.0, + "sharpe_ratio": 0.8, "volatility": 25.0, + "benchmark_return": 7.0, "excess_return": 23.0, + }, + { + "total_return": 15.0, "cagr": 12.0, "mdd": -5.0, + "sharpe_ratio": 1.5, "volatility": 10.0, + "benchmark_return": 7.0, "excess_return": 8.0, + }, + ] + service = self._make_service() + req = OptimizeRequest( + strategy_type="quality", + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + param_grid={"min_fscore": [6, 7]}, + rank_by="cagr", + ) + result = service.optimize(req) + assert result.results[0].cagr == Decimal("25.0") + assert result.results[1].cagr == Decimal("12.0") diff --git a/backend/tests/unit/test_pension.py b/backend/tests/unit/test_pension.py new file mode 100644 index 0000000..ff33c14 --- /dev/null +++ b/backend/tests/unit/test_pension.py @@ -0,0 +1,291 @@ +""" +Tests for pension account models, service, and API endpoints. +""" +import pytest +from decimal import Decimal +from unittest.mock import patch + +from app.services.pension_allocation import ( + calculate_current_age, + calculate_years_to_retirement, + calculate_glide_path, + calculate_allocation, + get_recommendation, + RISKY_ASSET_LIMIT_PCT, + SAFE_ASSET_MIN_PCT, +) + + +# --- Service unit tests --- + +class TestPensionAllocationService: + def test_calculate_current_age(self): + with patch("app.services.pension_allocation.date") as mock_date: + mock_date.today.return_value.year = 2026 + assert calculate_current_age(1990) == 36 + assert calculate_current_age(1966) == 60 + + def test_calculate_years_to_retirement(self): + with patch("app.services.pension_allocation.date") as mock_date: + mock_date.today.return_value.year = 2026 + assert calculate_years_to_retirement(1990, 60) == 24 + assert calculate_years_to_retirement(1966, 60) == 0 + assert calculate_years_to_retirement(1960, 60) == 0 # already past + + def test_glide_path_young_person(self): + """Young person (30) should have high equity allocation.""" + with patch("app.services.pension_allocation.date") as mock_date: + mock_date.today.return_value.year = 2026 + equity_pct, bond_pct = calculate_glide_path(1996, 60) + assert equity_pct + bond_pct == Decimal("100.00") + assert equity_pct >= Decimal("60") # young = high equity + + def test_glide_path_near_retirement(self): + """Person near retirement (55) should have low equity allocation.""" + with patch("app.services.pension_allocation.date") as mock_date: + mock_date.today.return_value.year = 2026 + equity_pct, bond_pct = calculate_glide_path(1971, 60) + assert equity_pct + bond_pct == Decimal("100.00") + assert equity_pct <= Decimal("40") # near retirement = low equity + + def test_glide_path_respects_regulatory_limit(self): + """Equity allocation should never exceed 70% (regulatory limit).""" + with patch("app.services.pension_allocation.date") as mock_date: + mock_date.today.return_value.year = 2026 + equity_pct, _ = calculate_glide_path(2000, 60) + assert equity_pct <= RISKY_ASSET_LIMIT_PCT + + def test_glide_path_minimum_equity(self): + """Equity allocation should not go below 20% (minimum).""" + with patch("app.services.pension_allocation.date") as mock_date: + mock_date.today.return_value.year = 2026 + equity_pct, _ = calculate_glide_path(1960, 60) + assert equity_pct >= Decimal("20") + + def test_calculate_allocation(self): + with patch("app.services.pension_allocation.date") as mock_date: + mock_date.today.return_value.year = 2026 + result = calculate_allocation( + account_id=1, + account_type="dc", + total_amount=Decimal("10000000"), + birth_year=1990, + target_retirement_age=60, + ) + assert result.account_id == 1 + assert result.account_type == "dc" + assert result.total_amount == 10000000 + assert result.risky_limit_pct == 70 + assert result.safe_min_pct == 30 + assert len(result.allocations) == 5 # 2 risky + 3 safe + + # Verify risky assets don't exceed limit + risky_ratio = sum(a.ratio for a in result.allocations if a.asset_type == "risky") + assert risky_ratio <= 70 + + # Verify total amounts sum to total_amount + total_allocated = sum(a.amount for a in result.allocations) + assert abs(total_allocated - 10000000) < 1 # allow small rounding + + def test_calculate_allocation_types(self): + with patch("app.services.pension_allocation.date") as mock_date: + mock_date.today.return_value.year = 2026 + result = calculate_allocation( + account_id=1, + account_type="irp", + total_amount=Decimal("5000000"), + birth_year=1985, + target_retirement_age=60, + ) + risky = [a for a in result.allocations if a.asset_type == "risky"] + safe = [a for a in result.allocations if a.asset_type == "safe"] + assert len(risky) == 2 + assert len(safe) == 3 + # Safe includes TDF, bond ETF, deposit + safe_names = [a.asset_name for a in safe] + assert any("TDF" in n for n in safe_names) + assert any("채권" in n for n in safe_names) + assert any("예금" in n for n in safe_names) + + def test_get_recommendation(self): + with patch("app.services.pension_allocation.date") as mock_date: + mock_date.today.return_value.year = 2026 + result = get_recommendation( + account_id=1, + birth_year=1990, + target_retirement_age=60, + ) + assert result.account_id == 1 + assert result.birth_year == 1990 + assert result.current_age == 36 + assert result.years_to_retirement == 24 + assert len(result.recommendations) == 5 + + categories = [r.category for r in result.recommendations] + assert "tdf" in categories + assert "bond_etf" in categories + assert "deposit" in categories + assert "equity_etf" in categories + + # All recommendations have reasons + for rec in result.recommendations: + assert rec.reason + + +# --- API endpoint tests --- + +class TestPensionAPI: + def _create_account(self, client, auth_headers, **overrides): + payload = { + "account_type": "dc", + "account_name": "삼성생명 DC", + "total_amount": 10000000, + "birth_year": 1990, + "target_retirement_age": 60, + **overrides, + } + return client.post("/api/pension/accounts", headers=auth_headers, json=payload) + + def test_create_account(self, client, auth_headers): + response = self._create_account(client, auth_headers) + assert response.status_code == 201 + data = response.json() + assert data["account_type"] == "dc" + assert data["account_name"] == "삼성생명 DC" + assert data["total_amount"] == 10000000 + assert data["birth_year"] == 1990 + assert data["target_retirement_age"] == 60 + assert data["holdings"] == [] + + def test_create_account_irp(self, client, auth_headers): + response = self._create_account( + client, auth_headers, + account_type="irp", + account_name="NH IRP", + ) + assert response.status_code == 201 + assert response.json()["account_type"] == "irp" + + def test_list_accounts(self, client, auth_headers): + self._create_account(client, auth_headers, account_name="DC 1호") + self._create_account(client, auth_headers, account_name="IRP", account_type="irp") + + response = client.get("/api/pension/accounts", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + def test_get_account(self, client, auth_headers): + create_resp = self._create_account(client, auth_headers) + account_id = create_resp.json()["id"] + + response = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["id"] == account_id + + def test_get_account_not_found(self, client, auth_headers): + response = client.get("/api/pension/accounts/9999", headers=auth_headers) + assert response.status_code == 404 + + def test_update_account(self, client, auth_headers): + create_resp = self._create_account(client, auth_headers) + account_id = create_resp.json()["id"] + + response = client.put( + f"/api/pension/accounts/{account_id}", + headers=auth_headers, + json={"account_name": "변경된 계좌명", "total_amount": 20000000}, + ) + assert response.status_code == 200 + data = response.json() + assert data["account_name"] == "변경된 계좌명" + assert data["total_amount"] == 20000000 + + def test_allocate_assets(self, client, auth_headers): + create_resp = self._create_account(client, auth_headers) + account_id = create_resp.json()["id"] + + response = client.post( + f"/api/pension/accounts/{account_id}/allocate", + headers=auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert data["account_id"] == account_id + assert data["risky_limit_pct"] == 70 + assert data["safe_min_pct"] == 30 + assert len(data["allocations"]) == 5 + + # Verify risky ratio <= 70% + risky_ratio = sum(a["ratio"] for a in data["allocations"] if a["asset_type"] == "risky") + assert risky_ratio <= 70 + + # Verify holdings were saved + account_resp = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers) + assert len(account_resp.json()["holdings"]) == 5 + + def test_allocate_replaces_previous_holdings(self, client, auth_headers): + create_resp = self._create_account(client, auth_headers) + account_id = create_resp.json()["id"] + + # Allocate twice + client.post(f"/api/pension/accounts/{account_id}/allocate", headers=auth_headers) + client.post(f"/api/pension/accounts/{account_id}/allocate", headers=auth_headers) + + # Should still have 5 holdings (replaced, not duplicated) + account_resp = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers) + assert len(account_resp.json()["holdings"]) == 5 + + def test_get_recommendation(self, client, auth_headers): + create_resp = self._create_account(client, auth_headers) + account_id = create_resp.json()["id"] + + response = client.get( + f"/api/pension/accounts/{account_id}/recommendation", + headers=auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert data["account_id"] == account_id + assert data["birth_year"] == 1990 + assert len(data["recommendations"]) == 5 + + # Verify recommendation categories + categories = [r["category"] for r in data["recommendations"]] + assert "tdf" in categories + assert "bond_etf" in categories + assert "deposit" in categories + assert "equity_etf" in categories + + # All recommendations have reasons + for rec in data["recommendations"]: + assert rec["reason"] + assert rec["asset_name"] + + def test_unauthenticated_access(self, client): + response = client.get("/api/pension/accounts") + assert response.status_code == 401 + + def test_user_isolation(self, client, auth_headers, db): + """User can only see their own pension accounts.""" + from app.models.user import User + from app.core.security import get_password_hash, create_access_token + + other_user = User( + username="otheruser", + email="other@example.com", + hashed_password=get_password_hash("password"), + ) + db.add(other_user) + db.commit() + db.refresh(other_user) + + other_token = create_access_token(data={"sub": other_user.username}) + other_headers = {"Authorization": f"Bearer {other_token}"} + + self._create_account(client, other_headers, account_name="다른 사용자 계좌") + + # Current user should see 0 accounts + response = client.get("/api/pension/accounts", headers=auth_headers) + assert response.status_code == 200 + assert len(response.json()) == 0 diff --git a/backend/tests/unit/test_position_sizing.py b/backend/tests/unit/test_position_sizing.py new file mode 100644 index 0000000..3b9ad35 --- /dev/null +++ b/backend/tests/unit/test_position_sizing.py @@ -0,0 +1,161 @@ +""" +Tests for position sizing module. +""" +import pytest + +from app.services.position_sizing import fixed_ratio, kelly_criterion, atr_based + + +class TestFixedRatio: + """Tests for fixed_ratio position sizing (quant.md default).""" + + def test_basic_calculation(self): + result = fixed_ratio(capital=10_000_000, num_positions=10, cash_ratio=0.3) + # 10M * 0.7 (invest portion) / 10 positions = 700,000 + assert result["position_size"] == 700_000 + assert result["method"] == "fixed" + assert result["risk_amount"] == pytest.approx(700_000 * 0.03) # -3% max loss per position + + def test_default_cash_ratio(self): + result = fixed_ratio(capital=10_000_000, num_positions=10) + assert result["position_size"] == 700_000 + + def test_custom_cash_ratio(self): + result = fixed_ratio(capital=10_000_000, num_positions=5, cash_ratio=0.5) + # 10M * 0.5 / 5 = 1,000,000 + assert result["position_size"] == 1_000_000 + + def test_single_position(self): + result = fixed_ratio(capital=1_000_000, num_positions=1, cash_ratio=0.0) + assert result["position_size"] == 1_000_000 + + def test_zero_capital(self): + with pytest.raises(ValueError, match="capital"): + fixed_ratio(capital=0, num_positions=10) + + def test_negative_capital(self): + with pytest.raises(ValueError, match="capital"): + fixed_ratio(capital=-1_000_000, num_positions=10) + + def test_zero_positions(self): + with pytest.raises(ValueError, match="num_positions"): + fixed_ratio(capital=10_000_000, num_positions=0) + + def test_invalid_cash_ratio(self): + with pytest.raises(ValueError, match="cash_ratio"): + fixed_ratio(capital=10_000_000, num_positions=10, cash_ratio=1.5) + + def test_cash_ratio_one(self): + """cash_ratio=1.0 means 100% cash, 0 investable.""" + with pytest.raises(ValueError, match="cash_ratio"): + fixed_ratio(capital=10_000_000, num_positions=10, cash_ratio=1.0) + + def test_notes_included(self): + result = fixed_ratio(capital=10_000_000, num_positions=10) + assert "notes" in result + assert isinstance(result["notes"], str) + + +class TestKellyCriterion: + """Tests for Kelly criterion position sizing.""" + + def test_basic_calculation(self): + # Kelly = W - (1-W)/R where W=win_rate, R=avg_win/avg_loss + # Kelly = 0.6 - (0.4)/(5/3) = 0.6 - 0.24 = 0.36 + # Quarter Kelly = 0.36 * 0.25 = 0.09 + result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03, fraction=0.25) + assert result["method"] == "kelly" + assert result["position_size"] == pytest.approx(0.09, abs=1e-6) + + def test_default_quarter_kelly(self): + result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03) + assert result["position_size"] == pytest.approx(0.09, abs=1e-6) + + def test_full_kelly(self): + result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03, fraction=1.0) + assert result["position_size"] == pytest.approx(0.36, abs=1e-6) + + def test_negative_kelly_returns_zero(self): + """Negative Kelly means don't bet - should clamp to 0.""" + result = kelly_criterion(win_rate=0.3, avg_win=0.02, avg_loss=0.05) + assert result["position_size"] == 0.0 + + def test_win_rate_zero(self): + with pytest.raises(ValueError, match="win_rate"): + kelly_criterion(win_rate=0.0, avg_win=0.05, avg_loss=0.03) + + def test_win_rate_above_one(self): + with pytest.raises(ValueError, match="win_rate"): + kelly_criterion(win_rate=1.1, avg_win=0.05, avg_loss=0.03) + + def test_negative_avg_win(self): + with pytest.raises(ValueError, match="avg_win"): + kelly_criterion(win_rate=0.6, avg_win=-0.05, avg_loss=0.03) + + def test_zero_avg_loss(self): + with pytest.raises(ValueError, match="avg_loss"): + kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.0) + + def test_negative_win_rate(self): + with pytest.raises(ValueError, match="win_rate"): + kelly_criterion(win_rate=-0.1, avg_win=0.05, avg_loss=0.03) + + def test_risk_amount_in_result(self): + result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03) + assert "risk_amount" in result + + def test_notes_included(self): + result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03) + assert "notes" in result + + +class TestATRBased: + """Tests for ATR-based volatility sizing.""" + + def test_basic_calculation(self): + # position_size = (capital * risk_pct) / atr + # = (10M * 0.02) / 1000 = 200 + result = atr_based(capital=10_000_000, atr=1000, risk_pct=0.02) + assert result["method"] == "atr" + assert result["shares"] == 200 + assert result["risk_amount"] == pytest.approx(10_000_000 * 0.02) + + def test_default_risk_pct(self): + result = atr_based(capital=10_000_000, atr=1000) + assert result["shares"] == 200 # 2% default + + def test_custom_risk_pct(self): + result = atr_based(capital=10_000_000, atr=500, risk_pct=0.01) + # (10M * 0.01) / 500 = 200 + assert result["shares"] == 200 + + def test_shares_truncated_to_int(self): + # (10M * 0.02) / 3000 = 66.666... -> 66 + result = atr_based(capital=10_000_000, atr=3000, risk_pct=0.02) + assert result["shares"] == 66 + assert isinstance(result["shares"], int) + + def test_zero_capital(self): + with pytest.raises(ValueError, match="capital"): + atr_based(capital=0, atr=1000) + + def test_zero_atr(self): + with pytest.raises(ValueError, match="atr"): + atr_based(capital=10_000_000, atr=0) + + def test_negative_atr(self): + with pytest.raises(ValueError, match="atr"): + atr_based(capital=10_000_000, atr=-100) + + def test_risk_pct_too_high(self): + with pytest.raises(ValueError, match="risk_pct"): + atr_based(capital=10_000_000, atr=1000, risk_pct=1.5) + + def test_position_size_in_result(self): + result = atr_based(capital=10_000_000, atr=1000, risk_pct=0.02) + assert "position_size" in result + assert result["position_size"] == result["risk_amount"] + + def test_notes_included(self): + result = atr_based(capital=10_000_000, atr=1000) + assert "notes" in result diff --git a/backend/tests/unit/test_tax_simulation.py b/backend/tests/unit/test_tax_simulation.py new file mode 100644 index 0000000..33c4d6f --- /dev/null +++ b/backend/tests/unit/test_tax_simulation.py @@ -0,0 +1,245 @@ +""" +Unit tests for tax simulation service. +""" +import pytest + +from app.services.tax_simulation import ( + calculate_tax_deduction, + calculate_pension_tax, + simulate_accumulation, +) + + +class TestCalculateTaxDeduction: + def test_low_income_deduction_rate(self): + """총급여 5,500만원 이하 → 공제율 16.5%""" + result = calculate_tax_deduction( + annual_income=40_000_000, + contribution=9_000_000, + account_type="irp", + ) + assert result["deduction_rate"] == 16.5 + assert result["deductible_contribution"] == 9_000_000 + assert result["tax_deduction"] == 9_000_000 * 0.165 + + def test_high_income_deduction_rate(self): + """총급여 5,500만원 초과 → 공제율 13.2%""" + result = calculate_tax_deduction( + annual_income=80_000_000, + contribution=9_000_000, + account_type="irp", + ) + assert result["deduction_rate"] == 13.2 + assert result["tax_deduction"] == 9_000_000 * 0.132 + + def test_boundary_income_55m(self): + """정확히 5,500만원은 16.5% 적용""" + result = calculate_tax_deduction( + annual_income=55_000_000, + contribution=5_000_000, + account_type="irp", + ) + assert result["deduction_rate"] == 16.5 + + def test_contribution_exceeds_limit(self): + """납입액이 900만원 한도 초과 시 900만원까지만 공제""" + result = calculate_tax_deduction( + annual_income=40_000_000, + contribution=12_000_000, + account_type="irp", + ) + assert result["deductible_contribution"] == 9_000_000 + assert result["tax_deduction"] == 9_000_000 * 0.165 + + def test_contribution_below_limit(self): + """납입액이 한도 미만이면 실제 납입액 기준 공제""" + result = calculate_tax_deduction( + annual_income=40_000_000, + contribution=3_000_000, + account_type="irp", + ) + assert result["deductible_contribution"] == 3_000_000 + assert result["tax_deduction"] == 3_000_000 * 0.165 + + def test_dc_account_type(self): + """DC 계좌도 동일 한도 적용 (DC+IRP 합산 900만원)""" + result = calculate_tax_deduction( + annual_income=60_000_000, + contribution=9_000_000, + account_type="dc", + ) + assert result["deduction_rate"] == 13.2 + assert result["deductible_contribution"] == 9_000_000 + + def test_zero_contribution(self): + result = calculate_tax_deduction( + annual_income=50_000_000, + contribution=0, + account_type="irp", + ) + assert result["tax_deduction"] == 0 + + def test_result_structure(self): + result = calculate_tax_deduction( + annual_income=50_000_000, + contribution=5_000_000, + account_type="irp", + ) + assert "annual_income" in result + assert "contribution" in result + assert "account_type" in result + assert "deduction_rate" in result + assert "deductible_contribution" in result + assert "tax_deduction" in result + assert "irp_limit" in result + + +class TestCalculatePensionTax: + def test_pension_tax_under_70(self): + """70세 미만 연금소득세 5.5%""" + result = calculate_pension_tax( + withdrawal_amount=10_000_000, + withdrawal_type="pension", + age=65, + ) + assert result["pension_tax_rate"] == 5.5 + assert result["pension_tax"] == 10_000_000 * 0.055 + + def test_pension_tax_70_to_79(self): + """70~79세 연금소득세 4.4%""" + result = calculate_pension_tax( + withdrawal_amount=10_000_000, + withdrawal_type="pension", + age=75, + ) + assert result["pension_tax_rate"] == 4.4 + assert result["pension_tax"] == pytest.approx(10_000_000 * 0.044) + + def test_pension_tax_80_and_over(self): + """80세 이상 연금소득세 3.3%""" + result = calculate_pension_tax( + withdrawal_amount=10_000_000, + withdrawal_type="pension", + age=85, + ) + assert result["pension_tax_rate"] == 3.3 + assert result["pension_tax"] == pytest.approx(10_000_000 * 0.033) + + def test_pension_tax_boundary_70(self): + """정확히 70세는 4.4% 적용""" + result = calculate_pension_tax( + withdrawal_amount=10_000_000, + withdrawal_type="pension", + age=70, + ) + assert result["pension_tax_rate"] == 4.4 + + def test_pension_tax_boundary_80(self): + """정확히 80세는 3.3% 적용""" + result = calculate_pension_tax( + withdrawal_amount=10_000_000, + withdrawal_type="pension", + age=80, + ) + assert result["pension_tax_rate"] == 3.3 + + def test_lump_sum_tax(self): + """일시금 수령 시 기타소득세 16.5%""" + result = calculate_pension_tax( + withdrawal_amount=10_000_000, + withdrawal_type="lump_sum", + age=65, + ) + assert result["lump_sum_tax_rate"] == 16.5 + assert result["lump_sum_tax"] == 10_000_000 * 0.165 + + def test_comparison_shows_savings(self): + """연금 수령이 일시금보다 세금이 적음""" + result = calculate_pension_tax( + withdrawal_amount=10_000_000, + withdrawal_type="pension", + age=65, + ) + assert result["tax_saving"] > 0 + assert result["tax_saving"] == result["lump_sum_tax"] - result["pension_tax"] + + def test_result_structure(self): + result = calculate_pension_tax( + withdrawal_amount=10_000_000, + withdrawal_type="pension", + age=65, + ) + assert "withdrawal_amount" in result + assert "pension_tax_rate" in result + assert "pension_tax" in result + assert "lump_sum_tax_rate" in result + assert "lump_sum_tax" in result + assert "tax_saving" in result + + +class TestSimulateAccumulation: + def test_basic_accumulation(self): + """기본 적립 시뮬레이션""" + result = simulate_accumulation( + monthly_contribution=500_000, + years=20, + annual_return=7.0, + tax_deduction_rate=16.5, + ) + assert len(result["yearly_data"]) == 20 + assert result["total_contribution"] == 500_000 * 12 * 20 + assert result["final_value"] > result["total_contribution"] + + def test_yearly_data_structure(self): + result = simulate_accumulation( + monthly_contribution=300_000, + years=5, + annual_return=5.0, + tax_deduction_rate=13.2, + ) + first_year = result["yearly_data"][0] + assert "year" in first_year + assert "contribution" in first_year + assert "cumulative_contribution" in first_year + assert "investment_value" in first_year + assert "tax_deduction" in first_year + assert "cumulative_tax_deduction" in first_year + + def test_tax_deduction_accumulates(self): + result = simulate_accumulation( + monthly_contribution=500_000, + years=3, + annual_return=5.0, + tax_deduction_rate=16.5, + ) + yearly = result["yearly_data"] + annual_contribution = 500_000 * 12 + deductible = min(annual_contribution, 9_000_000) + expected_deduction = deductible * 0.165 + assert yearly[0]["tax_deduction"] == expected_deduction + assert yearly[2]["cumulative_tax_deduction"] == pytest.approx( + expected_deduction * 3, rel=1e-6 + ) + + def test_zero_return(self): + result = simulate_accumulation( + monthly_contribution=100_000, + years=10, + annual_return=0.0, + tax_deduction_rate=16.5, + ) + assert result["final_value"] == result["total_contribution"] + assert result["total_return"] == 0 + + def test_result_summary(self): + result = simulate_accumulation( + monthly_contribution=500_000, + years=20, + annual_return=7.0, + tax_deduction_rate=16.5, + ) + assert "total_contribution" in result + assert "final_value" in result + assert "total_return" in result + assert "total_tax_deduction" in result + assert "yearly_data" in result diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 2da484e..ccaf932 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -13,6 +13,7 @@ "@radix-ui/react-label": "^2.1.8", "@radix-ui/react-select": "^2.2.6", "@radix-ui/react-slot": "^1.2.4", + "@radix-ui/react-switch": "^1.2.6", "@radix-ui/react-tabs": "^1.1.13", "@radix-ui/react-tooltip": "^1.2.8", "class-variance-authority": "^0.7.1", @@ -1923,6 +1924,35 @@ } } }, + "node_modules/@radix-ui/react-switch": { + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.2.6.tgz", + "integrity": "sha512-bByzr1+ep1zk4VubeEVViV592vu2lHE2BZY5OnzehZqOOgogN80+mNtCqPkhn2gklJqOpxWgPoYTSnhBCqpOXQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.3", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-controllable-state": "1.2.2", + "@radix-ui/react-use-previous": "1.1.1", + "@radix-ui/react-use-size": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-tabs": { "version": "1.1.13", "resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.13.tgz", diff --git a/frontend/package.json b/frontend/package.json index 691aaff..57ff0cf 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -14,6 +14,7 @@ "@radix-ui/react-label": "^2.1.8", "@radix-ui/react-select": "^2.2.6", "@radix-ui/react-slot": "^1.2.4", + "@radix-ui/react-switch": "^1.2.6", "@radix-ui/react-tabs": "^1.1.13", "@radix-ui/react-tooltip": "^1.2.8", "class-variance-authority": "^0.7.1", diff --git a/frontend/src/app/journal/[id]/page.tsx b/frontend/src/app/journal/[id]/page.tsx new file mode 100644 index 0000000..c7176d1 --- /dev/null +++ b/frontend/src/app/journal/[id]/page.tsx @@ -0,0 +1,354 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import { useRouter, useParams } from 'next/navigation'; +import Link from 'next/link'; +import { DashboardLayout } from '@/components/layout/dashboard-layout'; +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; +import { Badge } from '@/components/ui/badge'; +import { Skeleton } from '@/components/ui/skeleton'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Label } from '@/components/ui/label'; +import { api } from '@/lib/api'; +import { toast } from 'sonner'; +import { ArrowLeft, Save, Edit2 } from 'lucide-react'; + +interface TradeJournal { + id: number; + user_id: number; + stock_code: string; + stock_name: string | null; + trade_type: string; + entry_price: number | null; + target_price: number | null; + stop_loss_price: number | null; + exit_price: number | null; + entry_date: string; + exit_date: string | null; + quantity: number | null; + profit_loss: number | null; + profit_loss_pct: number | null; + entry_reason: string | null; + exit_reason: string | null; + scenario: string | null; + lessons_learned: string | null; + emotional_state: string | null; + strategy_id: number | null; + status: string; + created_at: string; + updated_at: string; +} + +const formatPrice = (value: number | null | undefined) => { + if (value === null || value === undefined) return '-'; + return new Intl.NumberFormat('ko-KR').format(value); +}; + +const formatPct = (value: number | null | undefined) => { + if (value === null || value === undefined) return '-'; + return `${value >= 0 ? '+' : ''}${value.toFixed(2)}%`; +}; + +export default function JournalDetailPage() { + const router = useRouter(); + const params = useParams(); + const journalId = params.id as string; + + const [loading, setLoading] = useState(true); + const [journal, setJournal] = useState(null); + const [editing, setEditing] = useState(false); + const [submitting, setSubmitting] = useState(false); + + // Edit form + const [editForm, setEditForm] = useState({ + exit_price: '', + exit_date: '', + exit_reason: '', + lessons_learned: '', + emotional_state: '', + scenario: '', + target_price: '', + stop_loss_price: '', + }); + + useEffect(() => { + const init = async () => { + try { + await api.getCurrentUser(); + await fetchJournal(); + } catch { + router.push('/login'); + } finally { + setLoading(false); + } + }; + init(); + }, [router, journalId]); + + const fetchJournal = async () => { + try { + const data = await api.get(`/api/journal/${journalId}`); + setJournal(data); + setEditForm({ + exit_price: data.exit_price?.toString() || '', + exit_date: data.exit_date || '', + exit_reason: data.exit_reason || '', + lessons_learned: data.lessons_learned || '', + emotional_state: data.emotional_state || '', + scenario: data.scenario || '', + target_price: data.target_price?.toString() || '', + stop_loss_price: data.stop_loss_price?.toString() || '', + }); + } catch { + toast.error('저널을 불러오는데 실패했습니다.'); + router.push('/journal'); + } + }; + + const handleSave = async () => { + setSubmitting(true); + try { + const payload: Record = {}; + if (editForm.exit_price) payload.exit_price = parseFloat(editForm.exit_price); + if (editForm.exit_date) payload.exit_date = editForm.exit_date; + if (editForm.exit_reason) payload.exit_reason = editForm.exit_reason; + if (editForm.lessons_learned) payload.lessons_learned = editForm.lessons_learned; + if (editForm.emotional_state) payload.emotional_state = editForm.emotional_state; + if (editForm.scenario) payload.scenario = editForm.scenario; + if (editForm.target_price) payload.target_price = parseFloat(editForm.target_price); + if (editForm.stop_loss_price) payload.stop_loss_price = parseFloat(editForm.stop_loss_price); + + const updated = await api.put(`/api/journal/${journalId}`, payload); + setJournal(updated); + setEditing(false); + toast.success('저장되었습니다.'); + } catch (err) { + toast.error(err instanceof Error ? err.message : '저장에 실패했습니다.'); + } finally { + setSubmitting(false); + } + }; + + if (loading) { + return ( + +
+ + +
+
+ ); + } + + if (!journal) return null; + + return ( + +
+ + + 저널 목록 + +
+
+

+ {journal.stock_name || journal.stock_code} + {journal.stock_code} +

+
+ + {journal.trade_type === 'buy' ? '매수' : '매도'} + + + {journal.status === 'open' ? '진행중' : '완료'} + +
+
+ {!editing && ( + + )} +
+
+ +
+ {/* Trade Info */} + + + 거래 정보 + + +
+
+
진입일
+
{journal.entry_date}
+
+
+
수량
+
{journal.quantity ? `${journal.quantity.toLocaleString()}주` : '-'}
+
+
+
진입가
+
{formatPrice(journal.entry_price)}원
+
+
+
목표가
+
{formatPrice(journal.target_price)}원
+
+
+
손절가
+
{formatPrice(journal.stop_loss_price)}원
+
+ {journal.exit_price && ( + <> +
+
청산일
+
{journal.exit_date || '-'}
+
+
+
청산가
+
{formatPrice(journal.exit_price)}원
+
+
+
손익
+
= 0 ? 'text-green-600' : 'text-red-600'}`}> + {formatPrice(journal.profit_loss)}원 ({formatPct(journal.profit_loss_pct)}) +
+
+ + )} +
+
+
+ + {/* Analysis */} + + + 분석 + + +
+

진입 근거

+

{journal.entry_reason || '(미작성)'}

+
+
+

시나리오

+

{journal.scenario || '(미작성)'}

+
+
+

심리 상태

+

{journal.emotional_state || '(미작성)'}

+
+ {journal.exit_reason && ( +
+

청산 사유

+

{journal.exit_reason}

+
+ )} + {journal.lessons_learned && ( +
+

교훈

+

{journal.lessons_learned}

+
+ )} +
+
+
+ + {/* Edit Form */} + {editing && ( + + + 거래 수정 / 청산 기록 + + +
+
+ + setEditForm((p) => ({ ...p, exit_price: e.target.value }))} + /> +
+
+ + setEditForm((p) => ({ ...p, exit_date: e.target.value }))} + /> +
+
+ + setEditForm((p) => ({ ...p, target_price: e.target.value }))} + /> +
+
+ +
+ +