feat: add 9 new modules - notification alerts, trading journal, position sizing, pension allocation, drawdown monitoring, benchmark dashboard, tax simulation, correlation analysis, parameter optimizer

Phase 1:
- Real-time signal alerts (Discord/Telegram webhook)
- Trading journal with entry/exit tracking
- Position sizing calculator (Fixed/Kelly/ATR)

Phase 2:
- Pension asset allocation (DC/IRP 70% risk limit)
- Drawdown monitoring with SVG gauge
- Benchmark dashboard (portfolio vs KOSPI vs deposit)

Phase 3:
- Tax benefit simulation (Korean pension tax rules)
- Correlation matrix heatmap
- Parameter optimizer with grid search + overfit detection
This commit is contained in:
머니페니 2026-03-29 10:03:08 +09:00
parent fd03744bc9
commit 12d235a1f1
65 changed files with 9887 additions and 1 deletions

View File

@ -15,5 +15,10 @@ KIS_ACCOUNT_NO=your_account_number
# DART OpenAPI (Financial Statements, optional)
DART_API_KEY=your_dart_api_key
# Notifications (optional)
DISCORD_WEBHOOK_URL=
TELEGRAM_BOT_TOKEN=
TELEGRAM_CHAT_ID=
# Production only
API_URL=https://your-domain.com

View File

@ -0,0 +1,59 @@
"""add notification settings and history tables
Revision ID: d4e5f6a7b8c9
Revises: c3d4e5f6a7b8
Create Date: 2026-03-29 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'd4e5f6a7b8c9'
down_revision: Union[str, None] = 'c3d4e5f6a7b8'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
'notification_settings',
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
sa.Column('user_id', sa.Integer(), sa.ForeignKey('users.id'), nullable=False),
sa.Column('channel_type', sa.Enum('discord', 'telegram', name='channeltype'), nullable=False),
sa.Column('webhook_url', sa.String(500), nullable=False),
sa.Column('enabled', sa.Boolean(), server_default='true'),
sa.Column('created_at', sa.DateTime(), server_default=sa.func.now()),
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.now()),
)
op.create_index('ix_notification_settings_id', 'notification_settings', ['id'])
op.create_index('ix_notification_settings_user_id', 'notification_settings', ['user_id'])
op.create_table(
'notification_history',
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
sa.Column('signal_id', sa.Integer(), sa.ForeignKey('signals.id'), nullable=False),
sa.Column('channel_type', sa.Enum('discord', 'telegram', name='channeltype', create_type=False), nullable=False),
sa.Column('sent_at', sa.DateTime(), server_default=sa.func.now()),
sa.Column('status', sa.Enum('sent', 'failed', name='notificationstatus'), nullable=False),
sa.Column('message', sa.Text(), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
)
op.create_index('ix_notification_history_id', 'notification_history', ['id'])
op.create_index('ix_notification_history_signal_id', 'notification_history', ['signal_id'])
def downgrade() -> None:
op.drop_index('ix_notification_history_signal_id', table_name='notification_history')
op.drop_index('ix_notification_history_id', table_name='notification_history')
op.drop_table('notification_history')
op.drop_index('ix_notification_settings_user_id', table_name='notification_settings')
op.drop_index('ix_notification_settings_id', table_name='notification_settings')
op.drop_table('notification_settings')
sa.Enum(name='notificationstatus').drop(op.get_bind(), checkfirst=True)
sa.Enum(name='channeltype').drop(op.get_bind(), checkfirst=True)

View File

@ -0,0 +1,71 @@
"""add trade journal table
Revision ID: e5f6a7b8c9d0
Revises: d4e5f6a7b8c9
Create Date: 2026-03-29 12:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "e5f6a7b8c9d0"
down_revision: Union[str, None] = "d4e5f6a7b8c9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"trade_journals",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.Integer(), sa.ForeignKey("users.id"), nullable=False),
sa.Column("stock_code", sa.String(20), nullable=False),
sa.Column("stock_name", sa.String(100), nullable=True),
sa.Column(
"trade_type",
sa.Enum("buy", "sell", name="tradetype"),
nullable=False,
),
sa.Column("entry_price", sa.Numeric(12, 2), nullable=True),
sa.Column("target_price", sa.Numeric(12, 2), nullable=True),
sa.Column("stop_loss_price", sa.Numeric(12, 2), nullable=True),
sa.Column("exit_price", sa.Numeric(12, 2), nullable=True),
sa.Column("entry_date", sa.Date(), nullable=False),
sa.Column("exit_date", sa.Date(), nullable=True),
sa.Column("quantity", sa.Integer(), nullable=True),
sa.Column("profit_loss", sa.Numeric(14, 2), nullable=True),
sa.Column("profit_loss_pct", sa.Numeric(8, 4), nullable=True),
sa.Column("entry_reason", sa.Text(), nullable=True),
sa.Column("exit_reason", sa.Text(), nullable=True),
sa.Column("scenario", sa.Text(), nullable=True),
sa.Column("lessons_learned", sa.Text(), nullable=True),
sa.Column("emotional_state", sa.Text(), nullable=True),
sa.Column("strategy_id", sa.Integer(), nullable=True),
sa.Column(
"status",
sa.Enum("open", "closed", name="journalstatus"),
server_default="open",
),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.now()),
)
op.create_index("ix_trade_journals_id", "trade_journals", ["id"])
op.create_index("ix_trade_journals_user_id", "trade_journals", ["user_id"])
op.create_index("ix_trade_journals_stock_code", "trade_journals", ["stock_code"])
op.create_index("ix_trade_journals_entry_date", "trade_journals", ["entry_date"])
op.create_index("ix_trade_journals_status", "trade_journals", ["status"])
def downgrade() -> None:
op.drop_index("ix_trade_journals_status")
op.drop_index("ix_trade_journals_entry_date")
op.drop_index("ix_trade_journals_stock_code")
op.drop_index("ix_trade_journals_user_id")
op.drop_index("ix_trade_journals_id")
op.drop_table("trade_journals")
sa.Enum(name="tradetype").drop(op.get_bind(), checkfirst=True)
sa.Enum(name="journalstatus").drop(op.get_bind(), checkfirst=True)

View File

@ -0,0 +1,74 @@
"""add pension account tables
Revision ID: f6a7b8c9d0e1
Revises: e5f6a7b8c9d0
Create Date: 2026-03-29 14:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f6a7b8c9d0e1"
down_revision: Union[str, None] = "e5f6a7b8c9d0"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"pension_accounts",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.Integer(), sa.ForeignKey("users.id"), nullable=False),
sa.Column(
"account_type",
sa.Enum("dc", "irp", "personal", name="accounttype"),
nullable=False,
),
sa.Column("account_name", sa.String(100), nullable=False),
sa.Column("total_amount", sa.Numeric(16, 2), nullable=False, server_default="0"),
sa.Column("birth_year", sa.Integer(), nullable=False),
sa.Column("target_retirement_age", sa.Integer(), server_default="60"),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.now()),
)
op.create_index("ix_pension_accounts_id", "pension_accounts", ["id"])
op.create_index("ix_pension_accounts_user_id", "pension_accounts", ["user_id"])
op.create_index("ix_pension_accounts_account_type", "pension_accounts", ["account_type"])
op.create_table(
"pension_holdings",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column(
"account_id",
sa.Integer(),
sa.ForeignKey("pension_accounts.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("asset_name", sa.String(200), nullable=False),
sa.Column(
"asset_type",
sa.Enum("safe", "risky", name="assetrisktype"),
nullable=False,
),
sa.Column("amount", sa.Numeric(16, 2), nullable=False, server_default="0"),
sa.Column("ratio", sa.Numeric(6, 2), nullable=False, server_default="0"),
)
op.create_index("ix_pension_holdings_id", "pension_holdings", ["id"])
op.create_index("ix_pension_holdings_account_id", "pension_holdings", ["account_id"])
def downgrade() -> None:
op.drop_index("ix_pension_holdings_account_id")
op.drop_index("ix_pension_holdings_id")
op.drop_table("pension_holdings")
sa.Enum(name="assetrisktype").drop(op.get_bind(), checkfirst=True)
op.drop_index("ix_pension_accounts_account_type")
op.drop_index("ix_pension_accounts_user_id")
op.drop_index("ix_pension_accounts_id")
op.drop_table("pension_accounts")
sa.Enum(name="accounttype").drop(op.get_bind(), checkfirst=True)

View File

@ -7,6 +7,15 @@ from app.api.backtest import router as backtest_router
from app.api.snapshot import router as snapshot_router
from app.api.data_explorer import router as data_explorer_router
from app.api.signal import router as signal_router
from app.api.notification import router as notification_router
from app.api.journal import router as journal_router
from app.api.position_sizing import router as position_sizing_router
from app.api.pension import router as pension_router
from app.api.drawdown import router as drawdown_router
from app.api.benchmark import router as benchmark_router
from app.api.tax_simulation import router as tax_simulation_router
from app.api.correlation import router as correlation_router
from app.api.optimizer import router as optimizer_router
__all__ = [
"auth_router",
@ -18,4 +27,13 @@ __all__ = [
"snapshot_router",
"data_explorer_router",
"signal_router",
"notification_router",
"journal_router",
"position_sizing_router",
"pension_router",
"drawdown_router",
"benchmark_router",
"tax_simulation_router",
"correlation_router",
"optimizer_router",
]

View File

@ -0,0 +1,59 @@
"""
Benchmark comparison API endpoints.
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.schemas.benchmark import (
BenchmarkCompareResponse,
BenchmarkIndexInfo,
BenchmarkType,
PeriodType,
)
from app.services.benchmark import BenchmarkService
router = APIRouter(prefix="/api/benchmark", tags=["benchmark"])
@router.get("/indices", response_model=List[BenchmarkIndexInfo])
async def list_indices():
"""사용 가능한 벤치마크 목록 반환."""
return [
BenchmarkIndexInfo(
code="kospi",
name="KOSPI",
description="코스피 종합지수",
),
BenchmarkIndexInfo(
code="deposit",
name="정기예금",
description="정기예금 금리 (연 3.5%)",
),
]
@router.get("/compare/{portfolio_id}", response_model=BenchmarkCompareResponse)
async def compare_with_benchmark(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
benchmark: BenchmarkType = Query(BenchmarkType.KOSPI),
period: PeriodType = Query(PeriodType.ONE_YEAR),
):
"""포트폴리오 성과를 벤치마크와 비교."""
service = BenchmarkService(db)
try:
result = service.compare_with_benchmark(
portfolio_id=portfolio_id,
benchmark_type=benchmark.value,
period=period.value,
user_id=current_user.id,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
return result

View File

@ -0,0 +1,86 @@
"""
Correlation analysis API endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.schemas.correlation import (
CorrelationMatrixRequest,
CorrelationMatrixResponse,
DiversificationResponse,
HighCorrelationPair,
)
from app.services.correlation import CorrelationService
router = APIRouter(prefix="/api/correlation", tags=["correlation"])
@router.post("/matrix", response_model=CorrelationMatrixResponse)
async def calculate_correlation_matrix(
request: CorrelationMatrixRequest,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""종목 간 수익률 상관 행렬 계산."""
service = CorrelationService(db)
result = service.get_correlation_data(request.stock_codes, request.period_days)
return CorrelationMatrixResponse(
stock_codes=result["stock_codes"],
matrix=result["matrix"],
high_correlation_pairs=[
HighCorrelationPair(**p) for p in result["high_correlation_pairs"]
],
)
@router.get("/portfolio/{portfolio_id}", response_model=DiversificationResponse)
async def get_portfolio_diversification(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""포트폴리오 분산 효과 점수 조회."""
service = CorrelationService(db)
try:
score = service.calculate_portfolio_diversification(portfolio_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# Get holdings for correlation data
from app.models.portfolio import Portfolio, PortfolioSnapshot
portfolio = db.query(Portfolio).filter(
Portfolio.id == portfolio_id,
Portfolio.user_id == current_user.id,
).first()
if not portfolio:
raise HTTPException(status_code=404, detail="Portfolio not found")
snapshot = (
db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date.desc())
.first()
)
high_pairs = []
stock_count = 0
if snapshot and snapshot.holdings:
tickers = [h.ticker for h in snapshot.holdings]
stock_count = len(tickers)
if len(tickers) >= 2:
corr_data = service.get_correlation_data(tickers, period_days=60)
high_pairs = [
HighCorrelationPair(**p) for p in corr_data["high_correlation_pairs"]
]
return DiversificationResponse(
portfolio_id=portfolio_id,
diversification_score=score,
stock_count=stock_count,
high_correlation_pairs=high_pairs,
)

View File

@ -0,0 +1,69 @@
"""
Drawdown API endpoints for portfolio risk monitoring.
"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.api.snapshot import _get_portfolio
from app.schemas.drawdown import (
DrawdownResponse,
DrawdownHistoryResponse,
DrawdownSettingsUpdate,
)
from app.services.drawdown import (
calculate_drawdown,
calculate_rolling_drawdown,
get_alert_threshold,
set_alert_threshold,
)
router = APIRouter(prefix="/api/drawdown", tags=["drawdown"])
@router.get("/{portfolio_id}", response_model=DrawdownResponse)
async def get_drawdown(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get current drawdown metrics for a portfolio."""
_get_portfolio(db, portfolio_id, current_user.id)
data = calculate_drawdown(db, portfolio_id)
return DrawdownResponse(portfolio_id=portfolio_id, **data)
@router.get("/{portfolio_id}/history", response_model=DrawdownHistoryResponse)
async def get_drawdown_history(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get rolling drawdown time series for a portfolio."""
_get_portfolio(db, portfolio_id, current_user.id)
rolling = calculate_rolling_drawdown(db, portfolio_id)
summary = calculate_drawdown(db, portfolio_id)
return DrawdownHistoryResponse(
portfolio_id=portfolio_id,
data=rolling,
max_drawdown_pct=summary["max_drawdown_pct"],
current_drawdown_pct=summary["current_drawdown_pct"],
)
@router.put("/settings/{portfolio_id}")
async def update_drawdown_settings(
portfolio_id: int,
body: DrawdownSettingsUpdate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Update drawdown alert threshold for a portfolio."""
_get_portfolio(db, portfolio_id, current_user.id)
set_alert_threshold(portfolio_id, body.alert_threshold_pct)
return {
"portfolio_id": portfolio_id,
"alert_threshold_pct": float(body.alert_threshold_pct),
}

173
backend/app/api/journal.py Normal file
View File

@ -0,0 +1,173 @@
"""
Trading journal API endpoints.
"""
from datetime import date
from decimal import Decimal
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.journal import TradeJournal, JournalStatus
from app.schemas.journal import (
TradeJournalCreate,
TradeJournalUpdate,
TradeJournalResponse,
TradeJournalStats,
)
router = APIRouter(prefix="/api/journal", tags=["journal"])
@router.post("", response_model=TradeJournalResponse, status_code=201)
async def create_journal(
data: TradeJournalCreate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
journal = TradeJournal(
user_id=current_user.id,
**data.model_dump(),
)
db.add(journal)
db.commit()
db.refresh(journal)
return journal
@router.get("", response_model=List[TradeJournalResponse])
async def list_journals(
current_user: CurrentUser,
db: Session = Depends(get_db),
status: Optional[str] = Query(None),
stock_code: Optional[str] = Query(None),
start_date: Optional[date] = Query(None),
end_date: Optional[date] = Query(None),
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
):
query = db.query(TradeJournal).filter(TradeJournal.user_id == current_user.id)
if status:
query = query.filter(TradeJournal.status == status)
if stock_code:
query = query.filter(TradeJournal.stock_code == stock_code)
if start_date:
query = query.filter(TradeJournal.entry_date >= start_date)
if end_date:
query = query.filter(TradeJournal.entry_date <= end_date)
journals = (
query.order_by(TradeJournal.entry_date.desc(), TradeJournal.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return journals
@router.get("/stats", response_model=TradeJournalStats)
async def get_journal_stats(
current_user: CurrentUser,
db: Session = Depends(get_db),
):
journals = (
db.query(TradeJournal)
.filter(TradeJournal.user_id == current_user.id)
.all()
)
total = len(journals)
open_trades = sum(1 for j in journals if j.status == JournalStatus.OPEN)
closed = [j for j in journals if j.status == JournalStatus.CLOSED]
closed_count = len(closed)
closed_with_pnl = [j for j in closed if j.profit_loss_pct is not None]
win_count = sum(1 for j in closed_with_pnl if j.profit_loss_pct > 0)
loss_count = sum(1 for j in closed_with_pnl if j.profit_loss_pct <= 0)
win_rate = None
avg_pnl_pct = None
max_profit = None
max_loss = None
total_pnl = None
if closed_with_pnl:
win_rate = Decimal(win_count) / Decimal(len(closed_with_pnl)) * 100
pcts = [j.profit_loss_pct for j in closed_with_pnl]
avg_pnl_pct = sum(pcts) / len(pcts)
max_profit = max(pcts)
max_loss = min(pcts)
closed_with_pl = [j for j in closed if j.profit_loss is not None]
if closed_with_pl:
total_pnl = sum(j.profit_loss for j in closed_with_pl)
return TradeJournalStats(
total_trades=total,
open_trades=open_trades,
closed_trades=closed_count,
win_count=win_count,
loss_count=loss_count,
win_rate=win_rate,
avg_profit_loss_pct=avg_pnl_pct,
max_profit_pct=max_profit,
max_loss_pct=max_loss,
total_profit_loss=total_pnl,
)
@router.get("/{journal_id}", response_model=TradeJournalResponse)
async def get_journal(
journal_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
journal = (
db.query(TradeJournal)
.filter(TradeJournal.id == journal_id, TradeJournal.user_id == current_user.id)
.first()
)
if not journal:
raise HTTPException(status_code=404, detail="Journal not found")
return journal
@router.put("/{journal_id}", response_model=TradeJournalResponse)
async def update_journal(
journal_id: int,
data: TradeJournalUpdate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
journal = (
db.query(TradeJournal)
.filter(TradeJournal.id == journal_id, TradeJournal.user_id == current_user.id)
.first()
)
if not journal:
raise HTTPException(status_code=404, detail="Journal not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(journal, field, value)
# Auto-calculate profit/loss when closing
if data.exit_price is not None and journal.entry_price is not None:
if journal.trade_type.value == "buy":
journal.profit_loss = (data.exit_price - journal.entry_price) * (journal.quantity or 1)
journal.profit_loss_pct = (data.exit_price - journal.entry_price) / journal.entry_price * 100
else:
journal.profit_loss = (journal.entry_price - data.exit_price) * (journal.quantity or 1)
journal.profit_loss_pct = (journal.entry_price - data.exit_price) / journal.entry_price * 100
# Auto-close when exit info is provided
if data.exit_price is not None and data.exit_date is not None and journal.status == JournalStatus.OPEN:
journal.status = JournalStatus.CLOSED
db.commit()
db.refresh(journal)
return journal

View File

@ -0,0 +1,111 @@
"""
Notification settings and history API endpoints.
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.notification import NotificationSetting, NotificationHistory
from app.schemas.notification import (
NotificationSettingCreate,
NotificationSettingUpdate,
NotificationSettingResponse,
NotificationHistoryResponse,
)
router = APIRouter(prefix="/api/notifications", tags=["notifications"])
@router.get("/settings", response_model=List[NotificationSettingResponse])
async def get_notification_settings(
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get all notification settings for the current user."""
settings = (
db.query(NotificationSetting)
.filter(NotificationSetting.user_id == current_user.id)
.all()
)
return settings
@router.post("/settings", response_model=NotificationSettingResponse, status_code=201)
async def create_notification_setting(
data: NotificationSettingCreate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Create a new notification setting."""
existing = (
db.query(NotificationSetting)
.filter(
NotificationSetting.user_id == current_user.id,
NotificationSetting.channel_type == data.channel_type,
)
.first()
)
if existing:
raise HTTPException(
status_code=400,
detail=f"Setting for {data.channel_type.value} already exists. Use PUT to update.",
)
setting = NotificationSetting(
user_id=current_user.id,
channel_type=data.channel_type,
webhook_url=data.webhook_url,
enabled=data.enabled,
)
db.add(setting)
db.commit()
db.refresh(setting)
return setting
@router.put("/settings/{setting_id}", response_model=NotificationSettingResponse)
async def update_notification_setting(
setting_id: int,
data: NotificationSettingUpdate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Update an existing notification setting."""
setting = (
db.query(NotificationSetting)
.filter(
NotificationSetting.id == setting_id,
NotificationSetting.user_id == current_user.id,
)
.first()
)
if not setting:
raise HTTPException(status_code=404, detail="Notification setting not found")
if data.webhook_url is not None:
setting.webhook_url = data.webhook_url
if data.enabled is not None:
setting.enabled = data.enabled
db.commit()
db.refresh(setting)
return setting
@router.get("/history", response_model=List[NotificationHistoryResponse])
async def get_notification_history(
current_user: CurrentUser,
db: Session = Depends(get_db),
limit: int = Query(50, ge=1, le=200),
):
"""Get notification history."""
history = (
db.query(NotificationHistory)
.order_by(NotificationHistory.sent_at.desc())
.limit(limit)
.all()
)
return history

View File

@ -0,0 +1,67 @@
"""
Strategy optimizer API router.
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.api.deps import CurrentUser
from app.core.database import get_db
from app.schemas.optimizer import (
DEFAULT_GRIDS,
STRATEGY_TYPES,
OptimizeRequest,
OptimizeResponse,
)
from app.services.optimizer import OptimizerService
router = APIRouter(prefix="/api/optimizer", tags=["optimizer"])
@router.post("", response_model=OptimizeResponse)
async def run_optimization(
request: OptimizeRequest,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Run grid-search optimization for a strategy."""
if request.strategy_type not in STRATEGY_TYPES:
raise HTTPException(
status_code=400,
detail=f"Unknown strategy_type: {request.strategy_type}. "
f"Valid types: {STRATEGY_TYPES}",
)
service = OptimizerService(db)
try:
return service.optimize(request)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/presets/{strategy_type}")
async def get_preset(
strategy_type: str,
current_user: CurrentUser,
):
"""Get default parameter grid preset for a strategy type."""
if strategy_type not in DEFAULT_GRIDS:
raise HTTPException(
status_code=404,
detail=f"No preset for strategy_type: {strategy_type}",
)
return {
"strategy_type": strategy_type,
"param_grid": DEFAULT_GRIDS[strategy_type],
}
@router.get("/presets")
async def list_presets(
current_user: CurrentUser,
):
"""List all available strategy presets."""
return {
"presets": {
st: DEFAULT_GRIDS[st] for st in STRATEGY_TYPES
}
}

157
backend/app/api/pension.py Normal file
View File

@ -0,0 +1,157 @@
"""
Pension account API endpoints.
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session, joinedload
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.pension import PensionAccount, PensionHolding
from app.schemas.pension import (
PensionAccountCreate,
PensionAccountUpdate,
PensionAccountResponse,
AllocationResult,
RecommendationResult,
)
from app.services.pension_allocation import calculate_allocation, get_recommendation
router = APIRouter(prefix="/api/pension", tags=["pension"])
@router.post("/accounts", response_model=PensionAccountResponse, status_code=201)
async def create_account(
data: PensionAccountCreate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
account = PensionAccount(
user_id=current_user.id,
**data.model_dump(),
)
db.add(account)
db.commit()
db.refresh(account)
return account
@router.get("/accounts", response_model=List[PensionAccountResponse])
async def list_accounts(
current_user: CurrentUser,
db: Session = Depends(get_db),
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
):
accounts = (
db.query(PensionAccount)
.options(joinedload(PensionAccount.holdings))
.filter(PensionAccount.user_id == current_user.id)
.order_by(PensionAccount.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return accounts
@router.get("/accounts/{account_id}", response_model=PensionAccountResponse)
async def get_account(
account_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
account = (
db.query(PensionAccount)
.options(joinedload(PensionAccount.holdings))
.filter(PensionAccount.id == account_id, PensionAccount.user_id == current_user.id)
.first()
)
if not account:
raise HTTPException(status_code=404, detail="Pension account not found")
return account
@router.put("/accounts/{account_id}", response_model=PensionAccountResponse)
async def update_account(
account_id: int,
data: PensionAccountUpdate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
account = (
db.query(PensionAccount)
.options(joinedload(PensionAccount.holdings))
.filter(PensionAccount.id == account_id, PensionAccount.user_id == current_user.id)
.first()
)
if not account:
raise HTTPException(status_code=404, detail="Pension account not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(account, field, value)
db.commit()
db.refresh(account)
return account
@router.post("/accounts/{account_id}/allocate", response_model=AllocationResult)
async def allocate_assets(
account_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
account = (
db.query(PensionAccount)
.filter(PensionAccount.id == account_id, PensionAccount.user_id == current_user.id)
.first()
)
if not account:
raise HTTPException(status_code=404, detail="Pension account not found")
result = calculate_allocation(
account_id=account.id,
account_type=account.account_type.value,
total_amount=account.total_amount,
birth_year=account.birth_year,
target_retirement_age=account.target_retirement_age,
)
# Save allocation as holdings
db.query(PensionHolding).filter(PensionHolding.account_id == account_id).delete()
for alloc in result.allocations:
holding = PensionHolding(
account_id=account_id,
asset_name=alloc.asset_name,
asset_type=alloc.asset_type,
amount=alloc.amount,
ratio=alloc.ratio,
)
db.add(holding)
db.commit()
return result
@router.get("/accounts/{account_id}/recommendation", response_model=RecommendationResult)
async def get_account_recommendation(
account_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
account = (
db.query(PensionAccount)
.filter(PensionAccount.id == account_id, PensionAccount.user_id == current_user.id)
.first()
)
if not account:
raise HTTPException(status_code=404, detail="Pension account not found")
return get_recommendation(
account_id=account.id,
birth_year=account.birth_year,
target_retirement_age=account.target_retirement_age,
)

View File

@ -0,0 +1,83 @@
"""
Position sizing API endpoints.
"""
from fastapi import APIRouter, HTTPException
from app.api.deps import CurrentUser
from app.schemas.position_sizing import (
SizingMethod,
PositionSizeRequest,
PositionSizeResponse,
MethodInfo,
MethodsResponse,
)
from app.services.position_sizing import fixed_ratio, kelly_criterion, atr_based
router = APIRouter(prefix="/api/position-sizing", tags=["position-sizing"])
METHODS = [
MethodInfo(
name="fixed",
label="균등 분배 (Fixed Ratio)",
description="자본금을 종목 수로 균등 분배. 현금 비중 설정 가능. quant.md 기본 방식.",
),
MethodInfo(
name="kelly",
label="켈리 기준 (Kelly Criterion)",
description="승률과 손익비를 기반으로 최적 베팅 비율 계산. 보수적 1/4 켈리 기본.",
),
MethodInfo(
name="atr",
label="ATR 변동성 (ATR-Based)",
description="ATR(평균 진폭)을 이용한 변동성 기반 포지션 사이징.",
),
]
@router.post("/calculate", response_model=PositionSizeResponse)
async def calculate_position_size(
request: PositionSizeRequest,
current_user: CurrentUser,
):
"""Calculate position size based on selected method."""
try:
if request.method == SizingMethod.FIXED:
result = fixed_ratio(
capital=request.capital,
num_positions=request.num_positions,
cash_ratio=request.cash_ratio,
)
elif request.method == SizingMethod.KELLY:
if not all([request.win_rate, request.avg_win, request.avg_loss]):
raise HTTPException(
status_code=422,
detail="Kelly method requires win_rate, avg_win, avg_loss",
)
result = kelly_criterion(
win_rate=request.win_rate,
avg_win=request.avg_win,
avg_loss=request.avg_loss,
)
elif request.method == SizingMethod.ATR:
if not request.atr:
raise HTTPException(
status_code=422,
detail="ATR method requires atr value",
)
result = atr_based(
capital=request.capital,
atr=request.atr,
risk_pct=request.risk_pct,
)
else:
raise HTTPException(status_code=422, detail=f"Unknown method: {request.method}")
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
return PositionSizeResponse(**result)
@router.get("/methods", response_model=MethodsResponse)
async def get_methods(current_user: CurrentUser):
"""List supported position sizing methods."""
return MethodsResponse(methods=METHODS)

View File

@ -0,0 +1,54 @@
"""
Tax simulation API endpoints.
"""
from fastapi import APIRouter
from app.schemas.tax_simulation import (
AccumulationRequest,
AccumulationResponse,
PensionTaxRequest,
PensionTaxResponse,
TaxDeductionRequest,
TaxDeductionResponse,
)
from app.services.tax_simulation import (
calculate_pension_tax,
calculate_tax_deduction,
simulate_accumulation,
)
router = APIRouter(prefix="/api/tax", tags=["tax-simulation"])
@router.post("/deduction", response_model=TaxDeductionResponse)
async def tax_deduction(request: TaxDeductionRequest):
"""연간 세액공제 계산."""
result = calculate_tax_deduction(
annual_income=request.annual_income,
contribution=request.contribution,
account_type=request.account_type.value,
)
return result
@router.post("/pension-tax", response_model=PensionTaxResponse)
async def pension_tax(request: PensionTaxRequest):
"""연금 수령 시 세금 비교."""
result = calculate_pension_tax(
withdrawal_amount=request.withdrawal_amount,
withdrawal_type=request.withdrawal_type.value,
age=request.age,
)
return result
@router.post("/accumulation", response_model=AccumulationResponse)
async def accumulation(request: AccumulationRequest):
"""적립 시뮬레이션."""
result = simulate_accumulation(
monthly_contribution=request.monthly_contribution,
years=request.years,
annual_return=request.annual_return,
tax_deduction_rate=request.tax_deduction_rate,
)
return result

View File

@ -33,6 +33,11 @@ class Settings(BaseSettings):
kis_account_no: str = ""
dart_api_key: str = ""
# Notifications (optional)
discord_webhook_url: str = ""
telegram_bot_token: str = ""
telegram_chat_id: str = ""
class Config:
env_file = ".env"
case_sensitive = False

View File

@ -10,7 +10,10 @@ from fastapi.middleware.cors import CORSMiddleware
from app.api import (
auth_router, admin_router, portfolio_router, strategy_router,
market_router, backtest_router, snapshot_router, data_explorer_router,
signal_router,
signal_router, notification_router, journal_router, position_sizing_router,
pension_router, drawdown_router, benchmark_router, tax_simulation_router,
correlation_router,
optimizer_router,
)
# Configure logging
@ -114,6 +117,15 @@ app.include_router(backtest_router)
app.include_router(snapshot_router)
app.include_router(data_explorer_router)
app.include_router(signal_router)
app.include_router(notification_router)
app.include_router(journal_router)
app.include_router(position_sizing_router)
app.include_router(pension_router)
app.include_router(drawdown_router)
app.include_router(benchmark_router)
app.include_router(tax_simulation_router)
app.include_router(correlation_router)
app.include_router(optimizer_router)
@app.get("/health")

View File

@ -23,6 +23,19 @@ from app.models.stock import (
JobLog,
)
from app.models.signal import Signal, SignalType, SignalStatus
from app.models.notification import (
NotificationSetting,
NotificationHistory,
ChannelType,
NotificationStatus,
)
from app.models.journal import TradeJournal, TradeType, JournalStatus
from app.models.pension import (
PensionAccount,
PensionHolding,
AccountType,
AssetRiskType,
)
from app.models.backtest import (
Backtest,
BacktestStatus,
@ -64,4 +77,15 @@ __all__ = [
"Signal",
"SignalType",
"SignalStatus",
"NotificationSetting",
"NotificationHistory",
"ChannelType",
"NotificationStatus",
"TradeJournal",
"TradeType",
"JournalStatus",
"PensionAccount",
"PensionHolding",
"AccountType",
"AssetRiskType",
]

View File

@ -0,0 +1,53 @@
"""
Trading journal models.
"""
import enum
from datetime import datetime
from sqlalchemy import (
Column, Integer, String, Numeric, DateTime, Date,
Text, Enum as SQLEnum, ForeignKey,
)
from sqlalchemy.orm import relationship
from app.core.database import Base
class TradeType(str, enum.Enum):
BUY = "buy"
SELL = "sell"
class JournalStatus(str, enum.Enum):
OPEN = "open"
CLOSED = "closed"
class TradeJournal(Base):
__tablename__ = "trade_journals"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
stock_code = Column(String(20), nullable=False, index=True)
stock_name = Column(String(100))
trade_type = Column(SQLEnum(TradeType), nullable=False)
entry_price = Column(Numeric(12, 2))
target_price = Column(Numeric(12, 2))
stop_loss_price = Column(Numeric(12, 2))
exit_price = Column(Numeric(12, 2))
entry_date = Column(Date, nullable=False, index=True)
exit_date = Column(Date)
quantity = Column(Integer)
profit_loss = Column(Numeric(14, 2))
profit_loss_pct = Column(Numeric(8, 4))
entry_reason = Column(Text)
exit_reason = Column(Text)
scenario = Column(Text)
lessons_learned = Column(Text)
emotional_state = Column(Text)
strategy_id = Column(Integer)
status = Column(SQLEnum(JournalStatus), default=JournalStatus.OPEN, index=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
user = relationship("User", backref="trade_journals")

View File

@ -0,0 +1,46 @@
"""
Notification models for signal alerts.
"""
import enum
from datetime import datetime
from sqlalchemy import (
Column, Integer, String, DateTime, Boolean, Text,
ForeignKey, Enum as SQLEnum,
)
from app.core.database import Base
class ChannelType(str, enum.Enum):
DISCORD = "discord"
TELEGRAM = "telegram"
class NotificationStatus(str, enum.Enum):
SENT = "sent"
FAILED = "failed"
class NotificationSetting(Base):
__tablename__ = "notification_settings"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
channel_type = Column(SQLEnum(ChannelType), nullable=False)
webhook_url = Column(String(500), nullable=False)
enabled = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class NotificationHistory(Base):
__tablename__ = "notification_history"
id = Column(Integer, primary_key=True, index=True)
signal_id = Column(Integer, ForeignKey("signals.id"), nullable=False, index=True)
channel_type = Column(SQLEnum(ChannelType), nullable=False)
sent_at = Column(DateTime, default=datetime.utcnow)
status = Column(SQLEnum(NotificationStatus), nullable=False)
message = Column(Text)
error_message = Column(Text, nullable=True)

View File

@ -0,0 +1,54 @@
"""
Pension account and holding models for retirement pension asset allocation.
"""
import enum
from datetime import datetime
from sqlalchemy import (
Column, Integer, String, Numeric, DateTime,
Enum as SQLEnum, ForeignKey,
)
from sqlalchemy.orm import relationship
from app.core.database import Base
class AccountType(str, enum.Enum):
DC = "dc"
IRP = "irp"
PERSONAL = "personal"
class AssetRiskType(str, enum.Enum):
SAFE = "safe"
RISKY = "risky"
class PensionAccount(Base):
__tablename__ = "pension_accounts"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
account_type = Column(SQLEnum(AccountType), nullable=False, index=True)
account_name = Column(String(100), nullable=False)
total_amount = Column(Numeric(16, 2), nullable=False, default=0)
birth_year = Column(Integer, nullable=False)
target_retirement_age = Column(Integer, default=60)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
user = relationship("User", backref="pension_accounts")
holdings = relationship("PensionHolding", back_populates="account", cascade="all, delete-orphan")
class PensionHolding(Base):
__tablename__ = "pension_holdings"
id = Column(Integer, primary_key=True, index=True)
account_id = Column(Integer, ForeignKey("pension_accounts.id"), nullable=False, index=True)
asset_name = Column(String(200), nullable=False)
asset_type = Column(SQLEnum(AssetRiskType), nullable=False)
amount = Column(Numeric(16, 2), nullable=False, default=0)
ratio = Column(Numeric(6, 2), nullable=False, default=0)
account = relationship("PensionAccount", back_populates="holdings")

View File

@ -0,0 +1,55 @@
"""
Benchmark comparison schemas.
"""
import enum
from datetime import date
from typing import List, Optional
from pydantic import BaseModel, Field
class BenchmarkType(str, enum.Enum):
KOSPI = "kospi"
DEPOSIT = "deposit"
class PeriodType(str, enum.Enum):
ONE_MONTH = "1m"
THREE_MONTHS = "3m"
SIX_MONTHS = "6m"
ONE_YEAR = "1y"
ALL = "all"
class BenchmarkIndexInfo(BaseModel):
code: str
name: str
description: str
class TimeSeriesPoint(BaseModel):
date: date
portfolio_return: Optional[float] = None
benchmark_return: Optional[float] = None
deposit_return: Optional[float] = None
class PerformanceMetrics(BaseModel):
cumulative_return: float = Field(..., description="누적 수익률 (%)")
annualized_return: float = Field(..., description="연환산 수익률 (%)")
sharpe_ratio: Optional[float] = Field(None, description="샤프 비율")
max_drawdown: float = Field(..., description="최대 낙폭 (%)")
class BenchmarkCompareResponse(BaseModel):
portfolio_name: str
benchmark_type: str
period: str
start_date: date
end_date: date
time_series: List[TimeSeriesPoint]
portfolio_metrics: PerformanceMetrics
benchmark_metrics: PerformanceMetrics
deposit_metrics: PerformanceMetrics
alpha: float = Field(..., description="초과 수익률 (포트폴리오 - 벤치마크) (%)")
information_ratio: Optional[float] = Field(None, description="정보 비율")

View File

@ -0,0 +1,34 @@
"""
Correlation analysis schemas.
"""
from typing import List, Optional
from pydantic import BaseModel, Field
class CorrelationMatrixRequest(BaseModel):
stock_codes: List[str] = Field(..., description="종목 코드 리스트")
period_days: int = Field(60, description="분석 기간 (일)", ge=7, le=365)
class HighCorrelationPair(BaseModel):
stock_a: str
stock_b: str
correlation: float = Field(..., description="상관계수 (-1 ~ 1)")
class CorrelationMatrixResponse(BaseModel):
stock_codes: List[str]
matrix: List[List[Optional[float]]] = Field(..., description="상관 행렬 (NxN)")
high_correlation_pairs: List[HighCorrelationPair] = Field(
default_factory=list, description="높은 상관관계 종목 쌍 (|r| > 0.7)"
)
class DiversificationResponse(BaseModel):
portfolio_id: int
diversification_score: float = Field(
..., description="분산 효과 점수 (0=집중, 1=완벽 분산)", ge=0, le=1
)
stock_count: int
high_correlation_pairs: List[HighCorrelationPair] = Field(default_factory=list)

View File

@ -0,0 +1,40 @@
"""
Drawdown related Pydantic schemas.
"""
from datetime import date
from decimal import Decimal
from typing import List, Optional
from pydantic import BaseModel, Field
from app.schemas.portfolio import FloatDecimal
class DrawdownDataPoint(BaseModel):
date: date
total_value: FloatDecimal
peak: FloatDecimal
drawdown_pct: FloatDecimal
class DrawdownResponse(BaseModel):
portfolio_id: int
current_drawdown_pct: FloatDecimal
max_drawdown_pct: FloatDecimal
max_drawdown_date: date | None = None
peak_value: FloatDecimal | None = None
peak_date: date | None = None
trough_value: FloatDecimal | None = None
trough_date: date | None = None
alert_threshold_pct: FloatDecimal = Decimal("20")
class DrawdownHistoryResponse(BaseModel):
portfolio_id: int
data: List[DrawdownDataPoint] = []
max_drawdown_pct: FloatDecimal
current_drawdown_pct: FloatDecimal
class DrawdownSettingsUpdate(BaseModel):
alert_threshold_pct: FloatDecimal = Field(..., gt=0, le=100)

View File

@ -0,0 +1,93 @@
"""
Trading journal Pydantic schemas.
"""
from datetime import date, datetime
from decimal import Decimal
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
from app.schemas.portfolio import FloatDecimal
class TradeType(str, Enum):
BUY = "buy"
SELL = "sell"
class JournalStatus(str, Enum):
OPEN = "open"
CLOSED = "closed"
class TradeJournalCreate(BaseModel):
stock_code: str = Field(..., min_length=1, max_length=20)
stock_name: Optional[str] = Field(None, max_length=100)
trade_type: TradeType
entry_price: Optional[FloatDecimal] = Field(None, gt=0)
target_price: Optional[FloatDecimal] = Field(None, gt=0)
stop_loss_price: Optional[FloatDecimal] = Field(None, gt=0)
entry_date: date
quantity: Optional[int] = Field(None, gt=0)
entry_reason: Optional[str] = None
scenario: Optional[str] = None
emotional_state: Optional[str] = None
strategy_id: Optional[int] = None
class TradeJournalUpdate(BaseModel):
stock_name: Optional[str] = Field(None, max_length=100)
exit_price: Optional[FloatDecimal] = Field(None, gt=0)
exit_date: Optional[date] = None
exit_reason: Optional[str] = None
target_price: Optional[FloatDecimal] = Field(None, gt=0)
stop_loss_price: Optional[FloatDecimal] = Field(None, gt=0)
quantity: Optional[int] = Field(None, gt=0)
lessons_learned: Optional[str] = None
emotional_state: Optional[str] = None
scenario: Optional[str] = None
entry_reason: Optional[str] = None
status: Optional[JournalStatus] = None
class TradeJournalResponse(BaseModel):
id: int
user_id: int
stock_code: str
stock_name: Optional[str] = None
trade_type: str
entry_price: Optional[FloatDecimal] = None
target_price: Optional[FloatDecimal] = None
stop_loss_price: Optional[FloatDecimal] = None
exit_price: Optional[FloatDecimal] = None
entry_date: date
exit_date: Optional[date] = None
quantity: Optional[int] = None
profit_loss: Optional[FloatDecimal] = None
profit_loss_pct: Optional[FloatDecimal] = None
entry_reason: Optional[str] = None
exit_reason: Optional[str] = None
scenario: Optional[str] = None
lessons_learned: Optional[str] = None
emotional_state: Optional[str] = None
strategy_id: Optional[int] = None
status: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class TradeJournalStats(BaseModel):
total_trades: int = 0
open_trades: int = 0
closed_trades: int = 0
win_count: int = 0
loss_count: int = 0
win_rate: Optional[FloatDecimal] = None
avg_profit_loss_pct: Optional[FloatDecimal] = None
max_profit_pct: Optional[FloatDecimal] = None
max_loss_pct: Optional[FloatDecimal] = None
total_profit_loss: Optional[FloatDecimal] = None

View File

@ -0,0 +1,50 @@
"""
Notification related Pydantic schemas.
"""
from datetime import datetime
from typing import Optional, List
from enum import Enum
from pydantic import BaseModel, Field
class ChannelType(str, Enum):
DISCORD = "discord"
TELEGRAM = "telegram"
class NotificationSettingCreate(BaseModel):
channel_type: ChannelType
webhook_url: str = Field(..., min_length=1, max_length=500)
enabled: bool = True
class NotificationSettingUpdate(BaseModel):
webhook_url: Optional[str] = Field(None, min_length=1, max_length=500)
enabled: Optional[bool] = None
class NotificationSettingResponse(BaseModel):
id: int
user_id: int
channel_type: str
webhook_url: str
enabled: bool
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class NotificationHistoryResponse(BaseModel):
id: int
signal_id: int
channel_type: str
sent_at: datetime
status: str
message: Optional[str] = None
error_message: Optional[str] = None
class Config:
from_attributes = True

View File

@ -0,0 +1,86 @@
"""
Strategy optimizer schemas.
"""
from datetime import date
from decimal import Decimal
from typing import Annotated, Any, Dict, List, Optional
from pydantic import BaseModel, Field, PlainSerializer
FloatDecimal = Annotated[
Decimal,
PlainSerializer(lambda v: float(v), return_type=float, when_used="json"),
]
# --- Default parameter grids per strategy type ---
KJB_DEFAULT_GRID: Dict[str, List[Any]] = {
"stop_loss_pct": [0.03, 0.05, 0.07],
"target1_pct": [0.05, 0.07, 0.10],
"rs_lookback": [10, 20, 30],
}
MULTI_FACTOR_DEFAULT_GRID: Dict[str, List[Any]] = {
"weights.value": [0.15, 0.25, 0.35],
"weights.quality": [0.15, 0.25, 0.35],
"weights.momentum": [0.15, 0.25, 0.35],
}
QUALITY_DEFAULT_GRID: Dict[str, List[Any]] = {
"min_fscore": [5, 6, 7, 8],
}
VALUE_MOMENTUM_DEFAULT_GRID: Dict[str, List[Any]] = {
"value_weight": [0.3, 0.4, 0.5, 0.6, 0.7],
"momentum_weight": [0.3, 0.4, 0.5, 0.6, 0.7],
}
DEFAULT_GRIDS: Dict[str, Dict[str, List[Any]]] = {
"kjb": KJB_DEFAULT_GRID,
"multi_factor": MULTI_FACTOR_DEFAULT_GRID,
"quality": QUALITY_DEFAULT_GRID,
"value_momentum": VALUE_MOMENTUM_DEFAULT_GRID,
}
STRATEGY_TYPES = ["kjb", "multi_factor", "quality", "value_momentum"]
class OptimizeRequest(BaseModel):
strategy_type: str = Field(
...,
description="Strategy type: kjb, multi_factor, quality, value_momentum",
)
param_grid: Optional[Dict[str, List[Any]]] = Field(
default=None,
description="Parameter grid. If None, uses default preset for the strategy type.",
)
start_date: date
end_date: date
initial_capital: Decimal = Field(default=Decimal("100000000"), gt=0)
commission_rate: Decimal = Field(default=Decimal("0.00015"), ge=0, le=1)
slippage_rate: Decimal = Field(default=Decimal("0.001"), ge=0, le=1)
benchmark: str = Field(default="KOSPI")
top_n: int = Field(default=30, ge=1, le=100)
rank_by: str = Field(
default="sharpe_ratio",
description="Metric to rank results by: sharpe_ratio, cagr, total_return, mdd",
)
class OptimizeResultItem(BaseModel):
rank: int
params: Dict[str, Any]
total_return: FloatDecimal
cagr: FloatDecimal
mdd: FloatDecimal
sharpe_ratio: FloatDecimal
volatility: FloatDecimal
benchmark_return: FloatDecimal
excess_return: FloatDecimal
class OptimizeResponse(BaseModel):
strategy_type: str
total_combinations: int
results: List[OptimizeResultItem]
best_params: Dict[str, Any]

View File

@ -0,0 +1,108 @@
"""
Pension account Pydantic schemas.
"""
from datetime import datetime
from enum import Enum
from typing import Optional, List
from pydantic import BaseModel, Field
from app.schemas.portfolio import FloatDecimal
class AccountType(str, Enum):
DC = "dc"
IRP = "irp"
PERSONAL = "personal"
class AssetRiskType(str, Enum):
SAFE = "safe"
RISKY = "risky"
# --- Account schemas ---
class PensionAccountCreate(BaseModel):
account_type: AccountType
account_name: str = Field(..., min_length=1, max_length=100)
total_amount: FloatDecimal = Field(..., ge=0)
birth_year: int = Field(..., ge=1940, le=2010)
target_retirement_age: int = Field(60, ge=50, le=70)
class PensionAccountUpdate(BaseModel):
account_name: Optional[str] = Field(None, min_length=1, max_length=100)
total_amount: Optional[FloatDecimal] = Field(None, ge=0)
target_retirement_age: Optional[int] = Field(None, ge=50, le=70)
class PensionHoldingResponse(BaseModel):
id: int
account_id: int
asset_name: str
asset_type: str
amount: FloatDecimal
ratio: FloatDecimal
class Config:
from_attributes = True
class PensionAccountResponse(BaseModel):
id: int
user_id: int
account_type: str
account_name: str
total_amount: FloatDecimal
birth_year: int
target_retirement_age: int
created_at: datetime
updated_at: datetime
holdings: List[PensionHoldingResponse] = []
class Config:
from_attributes = True
# --- Allocation schemas ---
class AllocationItem(BaseModel):
asset_name: str
asset_type: str # safe / risky
amount: FloatDecimal
ratio: FloatDecimal
class AllocationResult(BaseModel):
account_id: int
account_type: str
total_amount: FloatDecimal
risky_limit_pct: FloatDecimal
safe_min_pct: FloatDecimal
glide_path_equity_pct: FloatDecimal
glide_path_bond_pct: FloatDecimal
current_age: int
years_to_retirement: int
allocations: List[AllocationItem]
# --- Recommendation schemas ---
class RecommendationItem(BaseModel):
asset_name: str
asset_type: str
category: str # tdf, bond_etf, equity_etf, deposit
ratio: FloatDecimal
reason: str
class RecommendationResult(BaseModel):
account_id: int
birth_year: int
current_age: int
target_retirement_age: int
years_to_retirement: int
glide_path_equity_pct: FloatDecimal
glide_path_bond_pct: FloatDecimal
recommendations: List[RecommendationItem]

View File

@ -0,0 +1,46 @@
"""
Position sizing Pydantic schemas.
"""
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class SizingMethod(str, Enum):
FIXED = "fixed"
KELLY = "kelly"
ATR = "atr"
class PositionSizeRequest(BaseModel):
capital: float = Field(..., gt=0, description="Total capital (KRW)")
method: SizingMethod = Field(..., description="Sizing method")
# fixed_ratio params
num_positions: int = Field(default=10, ge=1, le=50)
cash_ratio: float = Field(default=0.3, ge=0, lt=1.0)
# kelly params
win_rate: Optional[float] = Field(default=None, gt=0, le=1.0)
avg_win: Optional[float] = Field(default=None, gt=0)
avg_loss: Optional[float] = Field(default=None, gt=0)
# atr params
atr: Optional[float] = Field(default=None, gt=0)
risk_pct: float = Field(default=0.02, gt=0, le=1.0)
class PositionSizeResponse(BaseModel):
method: str
position_size: float
shares: int
risk_amount: float
notes: str
class MethodInfo(BaseModel):
name: str
label: str
description: str
class MethodsResponse(BaseModel):
methods: list[MethodInfo]

View File

@ -0,0 +1,78 @@
"""
Tax simulation request/response schemas.
"""
import enum
from typing import List
from pydantic import BaseModel, Field
class AccountType(str, enum.Enum):
IRP = "irp"
DC = "dc"
class WithdrawalType(str, enum.Enum):
PENSION = "pension"
LUMP_SUM = "lump_sum"
class TaxDeductionRequest(BaseModel):
annual_income: int = Field(..., gt=0, description="연간 총급여 (원)")
contribution: int = Field(..., ge=0, description="연간 납입액 (원)")
account_type: AccountType = Field(AccountType.IRP, description="계좌 유형")
class TaxDeductionResponse(BaseModel):
annual_income: int
contribution: int
account_type: str
deduction_rate: float = Field(..., description="세액공제율 (%)")
irp_limit: int = Field(..., description="연간 공제 한도 (원)")
deductible_contribution: int = Field(..., description="공제 대상 납입액 (원)")
tax_deduction: float = Field(..., description="세액공제 금액 (원)")
class PensionTaxRequest(BaseModel):
withdrawal_amount: int = Field(..., gt=0, description="수령 금액 (원)")
withdrawal_type: WithdrawalType = Field(WithdrawalType.PENSION, description="수령 방식")
age: int = Field(..., ge=55, le=100, description="수령 시 나이")
class PensionTaxResponse(BaseModel):
withdrawal_amount: int
withdrawal_type: str
age: int
pension_tax_rate: float = Field(..., description="연금소득세율 (%)")
pension_tax: float = Field(..., description="연금소득세 (원)")
lump_sum_tax_rate: float = Field(..., description="기타소득세율 (%)")
lump_sum_tax: float = Field(..., description="기타소득세 (원)")
tax_saving: float = Field(..., description="연금 수령 시 절세 금액 (원)")
class AccumulationRequest(BaseModel):
monthly_contribution: int = Field(..., gt=0, description="월 납입액 (원)")
years: int = Field(..., ge=1, le=50, description="적립 기간 (년)")
annual_return: float = Field(..., ge=0, le=30, description="연간 기대 수익률 (%)")
tax_deduction_rate: float = Field(..., description="세액공제율 (%)")
class YearlyAccumulationData(BaseModel):
year: int
contribution: int
cumulative_contribution: int
investment_value: float
tax_deduction: float
cumulative_tax_deduction: float
class AccumulationResponse(BaseModel):
monthly_contribution: int
years: int
annual_return: float
tax_deduction_rate: float
total_contribution: int
final_value: float
total_return: float
total_tax_deduction: float
yearly_data: List[YearlyAccumulationData]

View File

@ -0,0 +1,277 @@
"""
Benchmark comparison service.
Compares portfolio performance against KOSPI index and deposit rate.
"""
import logging
import math
from datetime import date, timedelta
from decimal import Decimal
from typing import List, Optional
from pykrx import stock as pykrx_stock
from sqlalchemy.orm import Session
from app.models.portfolio import Portfolio, PortfolioSnapshot
from app.schemas.benchmark import (
BenchmarkCompareResponse,
PerformanceMetrics,
TimeSeriesPoint,
)
logger = logging.getLogger(__name__)
KOSPI_INDEX_CODE = "1001"
DEPOSIT_ANNUAL_RATE = 3.5
RISK_FREE_RATE = 3.5
class BenchmarkService:
def __init__(self, db: Session):
self.db = db
def get_deposit_rate(self) -> float:
return DEPOSIT_ANNUAL_RATE
def get_benchmark_data(
self, benchmark_type: str, start_date: date, end_date: date
) -> List[dict]:
start_str = start_date.strftime("%Y%m%d")
end_str = end_date.strftime("%Y%m%d")
if benchmark_type == "kospi":
df = pykrx_stock.get_index_ohlcv(start_str, end_str, KOSPI_INDEX_CODE)
else:
return []
if df.empty:
return []
result = []
for idx, row in df.iterrows():
result.append({
"date": idx.date() if hasattr(idx, "date") else idx,
"close": row["종가"],
})
return result
def compare_with_benchmark(
self,
portfolio_id: int,
benchmark_type: str,
period: str,
user_id: int,
) -> BenchmarkCompareResponse:
portfolio = (
self.db.query(Portfolio)
.filter(Portfolio.id == portfolio_id, Portfolio.user_id == user_id)
.first()
)
if not portfolio:
raise ValueError("Portfolio not found")
snapshots = (
self.db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date)
.all()
)
if not snapshots:
raise ValueError("스냅샷 데이터가 없습니다")
end_date = snapshots[-1].snapshot_date
start_date = self._calc_start_date(period, snapshots[0].snapshot_date, end_date)
filtered = [s for s in snapshots if s.snapshot_date >= start_date]
if len(filtered) < 2:
filtered = snapshots
start_date = filtered[0].snapshot_date
num_days = (end_date - start_date).days
# Portfolio daily returns
portfolio_values = [float(s.total_value) for s in filtered]
portfolio_returns = self._values_to_returns(portfolio_values)
# Benchmark data
benchmark_data = self.get_benchmark_data(benchmark_type, start_date, end_date)
benchmark_closes = [d["close"] for d in benchmark_data]
benchmark_returns = self._values_to_returns(benchmark_closes)
# Deposit returns (daily)
daily_deposit_rate = (1 + DEPOSIT_ANNUAL_RATE / 100) ** (1 / 365) - 1
deposit_returns = [daily_deposit_rate] * max(num_days, 0)
# Cumulative return time series
time_series = self._build_time_series(
filtered, benchmark_data, start_date, end_date
)
# Metrics
portfolio_metrics = self._calculate_metrics(portfolio_returns, num_days)
benchmark_metrics = self._calculate_metrics(benchmark_returns, num_days)
deposit_metrics = self._calculate_metrics(deposit_returns, num_days)
alpha = portfolio_metrics.cumulative_return - benchmark_metrics.cumulative_return
info_ratio = self._calculate_information_ratio(
portfolio_returns, benchmark_returns
)
return BenchmarkCompareResponse(
portfolio_name=portfolio.name,
benchmark_type=benchmark_type,
period=period,
start_date=start_date,
end_date=end_date,
time_series=time_series,
portfolio_metrics=portfolio_metrics,
benchmark_metrics=benchmark_metrics,
deposit_metrics=deposit_metrics,
alpha=round(alpha, 2),
information_ratio=round(info_ratio, 4) if info_ratio is not None else None,
)
def _calculate_metrics(
self, returns: List[float], num_days: int
) -> PerformanceMetrics:
if not returns:
return PerformanceMetrics(
cumulative_return=0.0,
annualized_return=0.0,
sharpe_ratio=None,
max_drawdown=0.0,
)
# Cumulative return
cum = 1.0
for r in returns:
cum *= 1 + r
cum_return = (cum - 1) * 100
# Annualized return
if num_days > 0:
ann_return = (cum ** (365 / num_days) - 1) * 100
else:
ann_return = 0.0
# Sharpe ratio
if len(returns) >= 2:
mean_r = sum(returns) / len(returns)
variance = sum((r - mean_r) ** 2 for r in returns) / (len(returns) - 1)
std_r = math.sqrt(variance)
if std_r > 1e-10:
daily_rf = (1 + RISK_FREE_RATE / 100) ** (1 / 365) - 1
sharpe = (mean_r - daily_rf) / std_r * math.sqrt(252)
sharpe = round(sharpe, 4)
else:
sharpe = None
else:
sharpe = None
# Max drawdown
peak = 1.0
max_dd = 0.0
cum_val = 1.0
for r in returns:
cum_val *= 1 + r
if cum_val > peak:
peak = cum_val
dd = (cum_val - peak) / peak
if dd < max_dd:
max_dd = dd
max_dd_pct = max_dd * 100
return PerformanceMetrics(
cumulative_return=round(cum_return, 2),
annualized_return=round(ann_return, 2),
sharpe_ratio=sharpe,
max_drawdown=round(max_dd_pct, 2),
)
def _calculate_information_ratio(
self, portfolio_returns: List[float], benchmark_returns: List[float]
) -> Optional[float]:
if not portfolio_returns or not benchmark_returns:
return None
min_len = min(len(portfolio_returns), len(benchmark_returns))
excess = [
portfolio_returns[i] - benchmark_returns[i] for i in range(min_len)
]
if len(excess) < 2:
return None
mean_excess = sum(excess) / len(excess)
variance = sum((e - mean_excess) ** 2 for e in excess) / (len(excess) - 1)
tracking_error = math.sqrt(variance)
if tracking_error < 1e-10:
return None
return (mean_excess / tracking_error) * math.sqrt(252)
def _values_to_returns(self, values: List[float]) -> List[float]:
if len(values) < 2:
return []
return [
(values[i] - values[i - 1]) / values[i - 1]
for i in range(1, len(values))
if values[i - 1] != 0
]
def _calc_start_date(
self, period: str, first_snapshot: date, end_date: date
) -> date:
period_map = {
"1m": timedelta(days=30),
"3m": timedelta(days=90),
"6m": timedelta(days=180),
"1y": timedelta(days=365),
}
if period == "all":
return first_snapshot
delta = period_map.get(period, timedelta(days=365))
return max(end_date - delta, first_snapshot)
def _build_time_series(
self,
snapshots: list,
benchmark_data: List[dict],
start_date: date,
end_date: date,
) -> List[TimeSeriesPoint]:
if not snapshots:
return []
base_portfolio = float(snapshots[0].total_value)
portfolio_map = {}
for s in snapshots:
val = float(s.total_value)
ret = ((val / base_portfolio) - 1) * 100 if base_portfolio else 0
portfolio_map[s.snapshot_date] = ret
benchmark_map = {}
if benchmark_data:
base_bench = benchmark_data[0]["close"]
for d in benchmark_data:
ret = ((d["close"] / base_bench) - 1) * 100 if base_bench else 0
benchmark_map[d["date"]] = ret
daily_deposit_rate = DEPOSIT_ANNUAL_RATE / 100 / 365
all_dates = sorted(set(list(portfolio_map.keys()) + list(benchmark_map.keys())))
result = []
for d in all_dates:
days_elapsed = (d - start_date).days
deposit_ret = ((1 + DEPOSIT_ANNUAL_RATE / 100) ** (days_elapsed / 365) - 1) * 100
result.append(TimeSeriesPoint(
date=d,
portfolio_return=round(portfolio_map[d], 2) if d in portfolio_map else None,
benchmark_return=round(benchmark_map[d], 2) if d in benchmark_map else None,
deposit_return=round(deposit_ret, 2),
))
return result

View File

@ -0,0 +1,188 @@
"""
Correlation analysis service.
Calculates inter-stock correlation matrices and portfolio diversification scores.
"""
import logging
from datetime import date, timedelta
from typing import List, Optional
import numpy as np
import pandas as pd
from sqlalchemy.orm import Session
from app.models.stock import Price
from app.models.portfolio import Portfolio, PortfolioSnapshot
logger = logging.getLogger(__name__)
class CorrelationService:
def __init__(self, db: Session):
self.db = db
def calculate_correlation_matrix(
self, stock_codes: List[str], period_days: int = 60
) -> dict:
if not stock_codes:
return {"stock_codes": [], "matrix": []}
end_date = date.today()
start_date = end_date - timedelta(days=period_days)
prices = (
self.db.query(Price)
.filter(
Price.ticker.in_(stock_codes),
Price.date >= start_date,
Price.date <= end_date,
)
.order_by(Price.date)
.all()
)
returns_df = self._prices_to_returns_df(prices, stock_codes)
if returns_df.empty or len(returns_df) < 2:
n = len(stock_codes)
matrix = [[None if i != j else 1.0 for j in range(n)] for i in range(n)]
return {"stock_codes": stock_codes, "matrix": matrix}
corr_matrix = returns_df.corr()
matrix = []
for code in stock_codes:
row = []
for other in stock_codes:
if code in corr_matrix.columns and other in corr_matrix.columns:
val = corr_matrix.loc[code, other]
row.append(round(float(val), 4) if not np.isnan(val) else None)
else:
row.append(None if code != other else 1.0)
matrix.append(row)
return {"stock_codes": stock_codes, "matrix": matrix}
def calculate_portfolio_diversification(self, portfolio_id: int) -> float:
portfolio = (
self.db.query(Portfolio)
.filter(Portfolio.id == portfolio_id)
.first()
)
if not portfolio:
raise ValueError("Portfolio not found")
snapshot = (
self.db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date.desc())
.first()
)
if not snapshot or not snapshot.holdings:
return 1.0
holdings = snapshot.holdings
if len(holdings) == 1:
return 0.0
tickers = [h.ticker for h in holdings]
total_value = sum(float(h.value) for h in holdings)
if total_value == 0:
return 1.0
weights = np.array([float(h.value) / total_value for h in holdings])
end_date = date.today()
start_date = end_date - timedelta(days=60)
prices = (
self.db.query(Price)
.filter(
Price.ticker.in_(tickers),
Price.date >= start_date,
Price.date <= end_date,
)
.order_by(Price.date)
.all()
)
returns_df = self._prices_to_returns_df(prices, tickers)
if returns_df.empty or len(returns_df) < 2:
return 0.5
cov_matrix = returns_df.cov().values
stds = returns_df.std().values
# Portfolio variance
portfolio_variance = weights @ cov_matrix @ weights
# Weighted average variance (no diversification case)
weighted_avg_variance = np.sum((weights ** 2) * (stds ** 2)) + \
2 * np.sum([
weights[i] * weights[j] * stds[i] * stds[j]
for i in range(len(weights))
for j in range(i + 1, len(weights))
])
if weighted_avg_variance < 1e-10:
return 1.0
# Diversification ratio: 1 - (portfolio_vol / weighted_avg_vol)
portfolio_vol = np.sqrt(portfolio_variance)
weighted_avg_vol = np.sum(weights * stds)
if weighted_avg_vol < 1e-10:
return 1.0
diversification_ratio = 1.0 - (portfolio_vol / weighted_avg_vol)
return round(float(np.clip(diversification_ratio, 0, 1)), 4)
def get_correlation_data(
self, stock_codes: List[str], period_days: int = 60
) -> dict:
result = self.calculate_correlation_matrix(stock_codes, period_days)
high_pairs = []
codes = result["stock_codes"]
matrix = result["matrix"]
for i in range(len(codes)):
for j in range(i + 1, len(codes)):
val = matrix[i][j]
if val is not None and abs(val) > 0.7:
high_pairs.append({
"stock_a": codes[i],
"stock_b": codes[j],
"correlation": val,
})
result["high_correlation_pairs"] = high_pairs
return result
def _prices_to_returns_df(
self, prices: list, stock_codes: List[str]
) -> pd.DataFrame:
if not prices:
return pd.DataFrame()
data = {}
for p in prices:
if p.ticker not in data:
data[p.ticker] = {}
data[p.ticker][p.date] = float(p.close)
if not data:
return pd.DataFrame()
df = pd.DataFrame(data)
df.index = pd.to_datetime(df.index)
df = df.sort_index()
# Reorder columns to match requested order
existing = [c for c in stock_codes if c in df.columns]
df = df[existing]
returns_df = df.pct_change().dropna()
return returns_df

View File

@ -0,0 +1,164 @@
"""
Drawdown calculation service using PortfolioSnapshot.total_value time series.
"""
import logging
from datetime import date
from decimal import Decimal
from typing import Optional
from sqlalchemy.orm import Session
from app.models.portfolio import Portfolio, PortfolioSnapshot
logger = logging.getLogger(__name__)
# In-memory per-portfolio settings (no separate table needed)
_drawdown_settings: dict[int, Decimal] = {}
DEFAULT_ALERT_THRESHOLD = Decimal("20")
def get_alert_threshold(portfolio_id: int) -> Decimal:
return _drawdown_settings.get(portfolio_id, DEFAULT_ALERT_THRESHOLD)
def set_alert_threshold(portfolio_id: int, threshold_pct: Decimal) -> None:
_drawdown_settings[portfolio_id] = threshold_pct
def calculate_drawdown(
db: Session,
portfolio_id: int,
) -> dict:
"""Calculate current and max drawdown from snapshot time series.
Returns dict with:
current_drawdown_pct, max_drawdown_pct,
peak_value, peak_date, trough_value, trough_date,
max_drawdown_date, alert_threshold_pct
"""
snapshots = (
db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date)
.all()
)
if not snapshots:
return {
"current_drawdown_pct": Decimal("0"),
"max_drawdown_pct": Decimal("0"),
"peak_value": None,
"peak_date": None,
"trough_value": None,
"trough_date": None,
"max_drawdown_date": None,
"alert_threshold_pct": get_alert_threshold(portfolio_id),
}
peak = Decimal(str(snapshots[0].total_value))
peak_date = snapshots[0].snapshot_date
max_dd = Decimal("0")
max_dd_date: Optional[date] = None
trough_value = peak
trough_date = peak_date
for snap in snapshots:
value = Decimal(str(snap.total_value))
if value > peak:
peak = value
peak_date = snap.snapshot_date
if peak > 0:
dd = ((peak - value) / peak * 100).quantize(Decimal("0.01"))
else:
dd = Decimal("0")
if dd > max_dd:
max_dd = dd
max_dd_date = snap.snapshot_date
trough_value = value
trough_date = snap.snapshot_date
# Current drawdown = drawdown of last snapshot from running peak
last_value = Decimal(str(snapshots[-1].total_value))
if peak > 0:
current_dd = ((peak - last_value) / peak * 100).quantize(Decimal("0.01"))
else:
current_dd = Decimal("0")
return {
"current_drawdown_pct": current_dd,
"max_drawdown_pct": max_dd,
"peak_value": peak,
"peak_date": peak_date,
"trough_value": trough_value,
"trough_date": trough_date,
"max_drawdown_date": max_dd_date,
"alert_threshold_pct": get_alert_threshold(portfolio_id),
}
def calculate_rolling_drawdown(
db: Session,
portfolio_id: int,
) -> list[dict]:
"""Calculate rolling drawdown time series.
Returns list of {date, total_value, peak, drawdown_pct}.
"""
snapshots = (
db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date)
.all()
)
if not snapshots:
return []
result = []
peak = Decimal("0")
for snap in snapshots:
value = Decimal(str(snap.total_value))
if value > peak:
peak = value
if peak > 0:
dd = ((peak - value) / peak * 100).quantize(Decimal("0.01"))
else:
dd = Decimal("0")
result.append({
"date": snap.snapshot_date,
"total_value": value,
"peak": peak,
"drawdown_pct": dd,
})
return result
def check_drawdown_alert(
db: Session,
portfolio_id: int,
) -> Optional[str]:
"""Check if current drawdown exceeds alert threshold.
Returns alert message string if threshold exceeded, None otherwise.
"""
data = calculate_drawdown(db, portfolio_id)
threshold = data["alert_threshold_pct"]
current_dd = data["current_drawdown_pct"]
if current_dd >= threshold:
portfolio = db.query(Portfolio).filter(Portfolio.id == portfolio_id).first()
name = portfolio.name if portfolio else f"Portfolio #{portfolio_id}"
return (
f"[Drawdown 경고] {name}: "
f"현재 낙폭 {current_dd}%가 한도 {threshold}%를 초과했습니다. "
f"(고점: {data['peak_value']:,.0f}원, 현재: {data['trough_value']:,.0f}원)"
)
return None

View File

@ -0,0 +1,132 @@
"""
Notification service for sending signal alerts via Discord/Telegram.
"""
import logging
from datetime import datetime, timedelta
import httpx
from sqlalchemy.orm import Session
from app.models.signal import Signal
from app.models.notification import (
NotificationSetting, NotificationHistory,
ChannelType, NotificationStatus,
)
logger = logging.getLogger(__name__)
async def send_discord(webhook_url: str, message: str) -> None:
"""Send a message to Discord via webhook."""
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(
webhook_url,
json={"content": message},
)
response.raise_for_status()
async def send_telegram(webhook_url: str, message: str) -> None:
"""Send a message to Telegram via Bot API.
webhook_url format: https://api.telegram.org/bot<TOKEN>/sendMessage?chat_id=<CHAT_ID>
"""
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(
webhook_url,
json={"text": message, "parse_mode": "Markdown"},
)
response.raise_for_status()
def format_signal_message(signal: Signal) -> str:
"""Format a signal into a human-readable notification message."""
signal_type_kr = {
"buy": "매수",
"sell": "매도",
"partial_sell": "부분매도",
}
type_label = signal_type_kr.get(signal.signal_type.value, signal.signal_type.value)
lines = [
f"[KJB 신호] {signal.name or signal.ticker} ({signal.ticker})",
f"신호: {type_label}",
]
if signal.entry_price:
lines.append(f"진입가: {signal.entry_price:,.0f}")
if signal.target_price:
lines.append(f"목표가: {signal.target_price:,.0f}")
if signal.stop_loss_price:
lines.append(f"손절가: {signal.stop_loss_price:,.0f}")
if signal.reason:
lines.append(f"사유: {signal.reason}")
return "\n".join(lines)
def _is_duplicate(db: Session, signal_id: int, channel_type: ChannelType) -> bool:
"""Check if a notification was already sent for this signal+channel within 24h."""
cutoff = datetime.utcnow() - timedelta(hours=24)
existing = (
db.query(NotificationHistory)
.filter(
NotificationHistory.signal_id == signal_id,
NotificationHistory.channel_type == channel_type,
NotificationHistory.status == NotificationStatus.SENT,
NotificationHistory.sent_at >= cutoff,
)
.first()
)
return existing is not None
async def send_notification(signal: Signal, db: Session) -> None:
"""Send notification for a signal to all enabled channels.
Skips duplicate notifications (same signal_id + channel_type within 24h).
"""
settings = (
db.query(NotificationSetting)
.filter(NotificationSetting.enabled.is_(True))
.all()
)
message = format_signal_message(signal)
for setting in settings:
if _is_duplicate(db, signal.id, setting.channel_type):
logger.info(
f"Skipping duplicate notification for signal {signal.id} "
f"on {setting.channel_type.value}"
)
continue
history = NotificationHistory(
signal_id=signal.id,
channel_type=setting.channel_type,
message=message,
)
try:
if setting.channel_type == ChannelType.DISCORD:
await send_discord(setting.webhook_url, message)
elif setting.channel_type == ChannelType.TELEGRAM:
await send_telegram(setting.webhook_url, message)
history.status = NotificationStatus.SENT
logger.info(
f"Notification sent for signal {signal.id} "
f"via {setting.channel_type.value}"
)
except Exception as e:
history.status = NotificationStatus.FAILED
history.error_message = str(e)[:500]
logger.error(
f"Failed to send notification for signal {signal.id} "
f"via {setting.channel_type.value}: {e}"
)
db.add(history)
db.commit()

View File

@ -0,0 +1,446 @@
"""
Grid-search strategy optimizer.
Runs backtests across parameter combinations and ranks by selected metric.
Reuses existing BacktestEngine / DailyBacktestEngine logic without DB persistence.
"""
import itertools
import logging
from dataclasses import asdict
from datetime import date
from decimal import Decimal
from typing import Any, Dict, List, Tuple
from sqlalchemy.orm import Session
from app.schemas.optimizer import (
DEFAULT_GRIDS,
OptimizeRequest,
OptimizeResponse,
OptimizeResultItem,
)
from app.services.backtest.metrics import MetricsCalculator
logger = logging.getLogger(__name__)
def _expand_grid(param_grid: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
"""Expand parameter grid into list of parameter dicts."""
keys = list(param_grid.keys())
values = list(param_grid.values())
combos = []
for combo in itertools.product(*values):
combos.append(dict(zip(keys, combo)))
return combos
def _build_strategy_params(strategy_type: str, flat_params: Dict[str, Any]) -> Dict[str, Any]:
"""Convert flat grid params to nested strategy_params dict."""
result: Dict[str, Any] = {}
for key, value in flat_params.items():
parts = key.split(".")
target = result
for part in parts[:-1]:
if part not in target:
target[part] = {}
target = target[part]
target[parts[-1]] = value
return result
class OptimizerService:
"""Grid-search optimizer that runs backtests across parameter combinations."""
def __init__(self, db: Session):
self.db = db
def optimize(self, request: OptimizeRequest) -> OptimizeResponse:
grid = request.param_grid or DEFAULT_GRIDS.get(request.strategy_type, {})
if not grid:
raise ValueError(f"No parameter grid for strategy type: {request.strategy_type}")
combinations = _expand_grid(grid)
logger.info(
f"Optimizer: {request.strategy_type}, {len(combinations)} combinations"
)
results: List[Tuple[Dict[str, Any], Dict[str, float]]] = []
for combo in combinations:
try:
metrics = self._run_single(request, combo)
results.append((combo, metrics))
except Exception as e:
logger.warning(f"Optimizer: failed for params {combo}: {e}")
if not results:
return OptimizeResponse(
strategy_type=request.strategy_type,
total_combinations=len(combinations),
results=[],
best_params={},
)
# Sort by rank_by metric (descending, except mdd which is negative so also desc)
rank_by = request.rank_by
results.sort(key=lambda x: x[1].get(rank_by, 0), reverse=True)
items = []
for i, (combo, metrics) in enumerate(results, 1):
items.append(OptimizeResultItem(
rank=i,
params=combo,
total_return=Decimal(str(metrics["total_return"])),
cagr=Decimal(str(metrics["cagr"])),
mdd=Decimal(str(metrics["mdd"])),
sharpe_ratio=Decimal(str(metrics["sharpe_ratio"])),
volatility=Decimal(str(metrics["volatility"])),
benchmark_return=Decimal(str(metrics["benchmark_return"])),
excess_return=Decimal(str(metrics["excess_return"])),
))
return OptimizeResponse(
strategy_type=request.strategy_type,
total_combinations=len(combinations),
results=items,
best_params=items[0].params if items else {},
)
def _run_single(
self, request: OptimizeRequest, flat_params: Dict[str, Any]
) -> Dict[str, float]:
"""Run a single backtest with given params, return metrics dict."""
strategy_params = _build_strategy_params(request.strategy_type, flat_params)
if request.strategy_type == "kjb":
return self._run_kjb(request, strategy_params, flat_params)
else:
return self._run_factor(request, strategy_params)
def _run_kjb(
self,
request: OptimizeRequest,
strategy_params: Dict[str, Any],
flat_params: Dict[str, Any],
) -> Dict[str, float]:
"""Run KJB daily backtest in-memory."""
import pandas as pd
from app.models.stock import Stock, Price
from app.services.backtest.trading_portfolio import TradingPortfolio
from app.services.strategy.kjb import KJBSignalGenerator
signal_gen = KJBSignalGenerator()
portfolio = TradingPortfolio(
initial_capital=request.initial_capital,
max_positions=strategy_params.get("max_positions", 10),
cash_reserve_ratio=Decimal(str(strategy_params.get("cash_reserve_ratio", 0.3))),
stop_loss_pct=Decimal(str(flat_params.get("stop_loss_pct", 0.03))),
target1_pct=Decimal(str(flat_params.get("target1_pct", 0.05))),
target2_pct=Decimal(str(flat_params.get("target2_pct", 0.10))),
)
rs_lookback = flat_params.get("rs_lookback", 10)
breakout_lookback = flat_params.get("breakout_lookback", 20)
trading_days = self._get_trading_days(request.start_date, request.end_date)
if not trading_days:
raise ValueError("No trading days found")
universe_tickers = self._get_universe_tickers()
all_prices = self._load_all_prices(universe_tickers, request.start_date, request.end_date)
stock_dfs = self._build_stock_dfs(all_prices, universe_tickers)
kospi_df = self._load_kospi_df(request.start_date, request.end_date)
benchmark_prices = self._load_benchmark_prices(request.benchmark, request.start_date, request.end_date)
day_prices_map: Dict[date, Dict[str, Decimal]] = {}
for p in all_prices:
if p.date not in day_prices_map:
day_prices_map[p.date] = {}
day_prices_map[p.date][p.ticker] = p.close
equity_curve: List[Dict] = []
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
for trading_date in trading_days:
day_prices = day_prices_map.get(trading_date, {})
portfolio.check_exits(
date=trading_date,
prices=day_prices,
commission_rate=request.commission_rate,
slippage_rate=request.slippage_rate,
)
for ticker in universe_tickers:
if ticker in portfolio.positions:
continue
if ticker not in stock_dfs or ticker not in day_prices:
continue
stock_df = stock_dfs[ticker]
if trading_date not in stock_df.index:
continue
hist = stock_df.loc[stock_df.index <= trading_date]
if len(hist) < 21:
continue
kospi_hist = kospi_df.loc[kospi_df.index <= trading_date]
if len(kospi_hist) < 11:
continue
signals = signal_gen.generate_signals(
hist, kospi_hist,
rs_lookback=rs_lookback,
breakout_lookback=breakout_lookback,
)
if trading_date in signals.index and signals.loc[trading_date, "buy"]:
portfolio.enter_position(
ticker=ticker,
price=day_prices[ticker],
date=trading_date,
commission_rate=request.commission_rate,
slippage_rate=request.slippage_rate,
)
portfolio_value = portfolio.get_value(day_prices)
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
normalized_benchmark = benchmark_value / initial_benchmark * request.initial_capital
equity_curve.append({
"portfolio_value": portfolio_value,
"benchmark_value": normalized_benchmark,
})
return self._compute_metrics(equity_curve)
def _run_factor(
self,
request: OptimizeRequest,
strategy_params: Dict[str, Any],
) -> Dict[str, float]:
"""Run factor-based backtest in-memory (multi_factor, quality, value_momentum)."""
from dateutil.relativedelta import relativedelta
from app.models.backtest import RebalancePeriod
from app.services.backtest.portfolio import VirtualPortfolio
from app.services.strategy import (
MultiFactorStrategy,
QualityStrategy,
ValueMomentumStrategy,
)
from app.schemas.strategy import UniverseFilter, FactorWeights
strategy = self._create_strategy(
request.strategy_type, strategy_params, request.top_n,
)
portfolio = VirtualPortfolio(request.initial_capital)
trading_days = self._get_trading_days(request.start_date, request.end_date)
if not trading_days:
raise ValueError("No trading days found")
rebalance_dates = self._generate_rebalance_dates(
request.start_date, request.end_date, RebalancePeriod.QUARTERLY,
)
benchmark_prices = self._load_benchmark_prices(
request.benchmark, request.start_date, request.end_date,
)
all_date_prices = self._load_all_prices_by_date(
request.start_date, request.end_date,
)
names = self._get_stock_names()
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
equity_curve: List[Dict] = []
for trading_date in trading_days:
prices = all_date_prices.get(trading_date, {})
if trading_date in rebalance_dates:
target_stocks = strategy.run(
universe_filter=UniverseFilter(),
top_n=request.top_n,
base_date=trading_date,
)
target_tickers = [s.ticker for s in target_stocks.stocks]
portfolio.rebalance(
target_tickers=target_tickers,
prices=prices,
names=names,
commission_rate=request.commission_rate,
slippage_rate=request.slippage_rate,
)
portfolio_value = portfolio.get_value(prices)
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
normalized_benchmark = benchmark_value / initial_benchmark * request.initial_capital
equity_curve.append({
"portfolio_value": portfolio_value,
"benchmark_value": normalized_benchmark,
})
return self._compute_metrics(equity_curve)
def _compute_metrics(self, equity_curve: List[Dict]) -> Dict[str, float]:
portfolio_values = [Decimal(str(e["portfolio_value"])) for e in equity_curve]
benchmark_values = [Decimal(str(e["benchmark_value"])) for e in equity_curve]
metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values)
return {
"total_return": float(metrics.total_return),
"cagr": float(metrics.cagr),
"mdd": float(metrics.mdd),
"sharpe_ratio": float(metrics.sharpe_ratio),
"volatility": float(metrics.volatility),
"benchmark_return": float(metrics.benchmark_return),
"excess_return": float(metrics.excess_return),
}
def _create_strategy(self, strategy_type: str, strategy_params: dict, top_n: int):
from app.services.strategy import (
MultiFactorStrategy,
QualityStrategy,
ValueMomentumStrategy,
)
from app.schemas.strategy import FactorWeights
if strategy_type == "multi_factor":
strategy = MultiFactorStrategy(self.db)
strategy._weights = FactorWeights(**strategy_params.get("weights", {}))
elif strategy_type == "quality":
strategy = QualityStrategy(self.db)
strategy._min_fscore = strategy_params.get("min_fscore", 7)
elif strategy_type == "value_momentum":
strategy = ValueMomentumStrategy(self.db)
strategy._value_weight = Decimal(
str(strategy_params.get("value_weight", 0.5))
)
strategy._momentum_weight = Decimal(
str(strategy_params.get("momentum_weight", 0.5))
)
else:
raise ValueError(f"Unknown strategy type: {strategy_type}")
return strategy
# --- Data loading helpers (mirrored from engines) ---
def _get_trading_days(self, start_date: date, end_date: date) -> List[date]:
from app.models.stock import Price
prices = (
self.db.query(Price.date)
.filter(Price.date >= start_date, Price.date <= end_date)
.distinct()
.order_by(Price.date)
.all()
)
return [p[0] for p in prices]
def _get_universe_tickers(self) -> List[str]:
from app.models.stock import Stock
stocks = (
self.db.query(Stock)
.filter(Stock.market_cap.isnot(None))
.order_by(Stock.market_cap.desc())
.limit(30)
.all()
)
return [s.ticker for s in stocks]
def _load_all_prices(self, tickers, start_date, end_date):
from app.models.stock import Price
return (
self.db.query(Price)
.filter(Price.ticker.in_(tickers))
.filter(Price.date >= start_date, Price.date <= end_date)
.all()
)
def _load_kospi_df(self, start_date, end_date):
import pandas as pd
from app.models.stock import Price
prices = (
self.db.query(Price)
.filter(Price.ticker == "069500")
.filter(Price.date >= start_date, Price.date <= end_date)
.order_by(Price.date)
.all()
)
if not prices:
return pd.DataFrame(columns=["close"])
data = [{"date": p.date, "close": float(p.close)} for p in prices]
return pd.DataFrame(data).set_index("date")
def _load_benchmark_prices(self, benchmark, start_date, end_date):
from app.models.stock import Price
prices = (
self.db.query(Price)
.filter(Price.ticker == "069500")
.filter(Price.date >= start_date, Price.date <= end_date)
.all()
)
return {p.date: p.close for p in prices}
def _build_stock_dfs(self, price_data, tickers):
import pandas as pd
ticker_rows = {t: [] for t in tickers}
for p in price_data:
if p.ticker in ticker_rows:
ticker_rows[p.ticker].append({
"date": p.date,
"open": float(p.open),
"high": float(p.high),
"low": float(p.low),
"close": float(p.close),
"volume": int(p.volume),
})
result = {}
for ticker, rows in ticker_rows.items():
if rows:
df = pd.DataFrame(rows).set_index("date").sort_index()
result[ticker] = df
return result
def _load_all_prices_by_date(self, start_date, end_date):
from app.models.stock import Price
prices = (
self.db.query(Price)
.filter(Price.date >= start_date, Price.date <= end_date)
.all()
)
result = {}
for p in prices:
if p.date not in result:
result[p.date] = {}
result[p.date][p.ticker] = p.close
return result
def _get_stock_names(self):
from app.models.stock import Stock
stocks = self.db.query(Stock.ticker, Stock.name).all()
return {s.ticker: s.name for s in stocks}
def _generate_rebalance_dates(self, start_date, end_date, period):
from dateutil.relativedelta import relativedelta
from app.models.backtest import RebalancePeriod
dates = []
current = start_date
if period == RebalancePeriod.MONTHLY:
delta = relativedelta(months=1)
elif period == RebalancePeriod.QUARTERLY:
delta = relativedelta(months=3)
elif period == RebalancePeriod.SEMI_ANNUAL:
delta = relativedelta(months=6)
else:
delta = relativedelta(years=1)
while current <= end_date:
dates.append(current)
current = current + delta
return dates

View File

@ -0,0 +1,209 @@
"""
Pension asset allocation service.
Korean retirement pension regulations:
- DC/IRP: risky assets max 70%, safe assets min 30%
- Personal pension: no regulatory limit (but we apply same guideline)
- Safe assets: bond funds, deposits, TDF, principal-guaranteed products
- Risky assets: equity funds, equity ETFs, hybrid funds
"""
from datetime import date
from decimal import Decimal
from typing import List
from app.schemas.pension import (
AllocationItem,
AllocationResult,
RecommendationItem,
RecommendationResult,
)
# Regulatory limits
RISKY_ASSET_LIMIT_PCT = Decimal("70")
SAFE_ASSET_MIN_PCT = Decimal("30")
# Glide path parameters (equity allocation decreases with age)
GLIDE_PATH_MAX_EQUITY = Decimal("80") # max equity at young age
GLIDE_PATH_MIN_EQUITY = Decimal("20") # min equity near retirement
def calculate_current_age(birth_year: int) -> int:
return date.today().year - birth_year
def calculate_years_to_retirement(birth_year: int, target_retirement_age: int) -> int:
current_age = calculate_current_age(birth_year)
return max(0, target_retirement_age - current_age)
def calculate_glide_path(birth_year: int, target_retirement_age: int) -> tuple[Decimal, Decimal]:
"""Calculate equity/bond ratio based on age (glide path).
Linear interpolation: young = high equity, near retirement = low equity.
Working age range: 25 ~ target_retirement_age.
"""
current_age = calculate_current_age(birth_year)
working_years = target_retirement_age - 25
if working_years <= 0:
equity_pct = GLIDE_PATH_MIN_EQUITY
else:
years_worked = min(max(current_age - 25, 0), working_years)
progress = Decimal(years_worked) / Decimal(working_years)
equity_pct = GLIDE_PATH_MAX_EQUITY - progress * (GLIDE_PATH_MAX_EQUITY - GLIDE_PATH_MIN_EQUITY)
# Clamp to regulatory limit
equity_pct = min(equity_pct, RISKY_ASSET_LIMIT_PCT)
equity_pct = max(equity_pct, GLIDE_PATH_MIN_EQUITY)
bond_pct = Decimal("100") - equity_pct
return equity_pct.quantize(Decimal("0.01")), bond_pct.quantize(Decimal("0.01"))
def calculate_allocation(
account_id: int,
account_type: str,
total_amount: Decimal,
birth_year: int,
target_retirement_age: int,
) -> AllocationResult:
"""Calculate recommended asset allocation for a pension account."""
equity_pct, bond_pct = calculate_glide_path(birth_year, target_retirement_age)
current_age = calculate_current_age(birth_year)
years_to_ret = calculate_years_to_retirement(birth_year, target_retirement_age)
equity_amount = (total_amount * equity_pct / Decimal("100")).quantize(Decimal("0.01"))
bond_amount = total_amount - equity_amount
allocations: List[AllocationItem] = []
# Split equity portion
if equity_pct > 0:
allocations.append(AllocationItem(
asset_name="국내 주식형 ETF",
asset_type="risky",
amount=float((equity_amount * Decimal("0.5")).quantize(Decimal("0.01"))),
ratio=float((equity_pct * Decimal("0.5")).quantize(Decimal("0.01"))),
))
allocations.append(AllocationItem(
asset_name="해외 주식형 ETF",
asset_type="risky",
amount=float((equity_amount * Decimal("0.5")).quantize(Decimal("0.01"))),
ratio=float((equity_pct * Decimal("0.5")).quantize(Decimal("0.01"))),
))
# Split bond portion
if bond_pct > 0:
tdf_ratio = Decimal("0.4")
bond_etf_ratio = Decimal("0.3")
deposit_ratio = Decimal("0.3")
allocations.append(AllocationItem(
asset_name=f"TDF {_recommend_tdf_year(birth_year, target_retirement_age)}",
asset_type="safe",
amount=float((bond_amount * tdf_ratio).quantize(Decimal("0.01"))),
ratio=float((bond_pct * tdf_ratio).quantize(Decimal("0.01"))),
))
allocations.append(AllocationItem(
asset_name="국내 채권형 ETF",
asset_type="safe",
amount=float((bond_amount * bond_etf_ratio).quantize(Decimal("0.01"))),
ratio=float((bond_pct * bond_etf_ratio).quantize(Decimal("0.01"))),
))
allocations.append(AllocationItem(
asset_name="원리금 보장 예금",
asset_type="safe",
amount=float((bond_amount * deposit_ratio).quantize(Decimal("0.01"))),
ratio=float((bond_pct * deposit_ratio).quantize(Decimal("0.01"))),
))
return AllocationResult(
account_id=account_id,
account_type=account_type,
total_amount=float(total_amount),
risky_limit_pct=float(RISKY_ASSET_LIMIT_PCT),
safe_min_pct=float(SAFE_ASSET_MIN_PCT),
glide_path_equity_pct=float(equity_pct),
glide_path_bond_pct=float(bond_pct),
current_age=current_age,
years_to_retirement=years_to_ret,
allocations=allocations,
)
def _recommend_tdf_year(birth_year: int, target_retirement_age: int) -> int:
"""Recommend TDF target year (rounded to nearest 5)."""
retirement_year = birth_year + target_retirement_age
return round(retirement_year / 5) * 5
def get_recommendation(
account_id: int,
birth_year: int,
target_retirement_age: int,
) -> RecommendationResult:
"""Generate TDF/ETF recommendations based on age and retirement target."""
equity_pct, bond_pct = calculate_glide_path(birth_year, target_retirement_age)
current_age = calculate_current_age(birth_year)
years_to_ret = calculate_years_to_retirement(birth_year, target_retirement_age)
tdf_year = _recommend_tdf_year(birth_year, target_retirement_age)
recommendations: List[RecommendationItem] = []
# TDF recommendation
recommendations.append(RecommendationItem(
asset_name=f"TDF {tdf_year}",
asset_type="safe",
category="tdf",
ratio=float((bond_pct * Decimal("0.4")).quantize(Decimal("0.01"))),
reason=f"은퇴 목표 시점({tdf_year}년)에 맞춘 자동 자산 배분 펀드",
))
# Bond ETF
recommendations.append(RecommendationItem(
asset_name="KODEX 국고채 10년",
asset_type="safe",
category="bond_etf",
ratio=float((bond_pct * Decimal("0.3")).quantize(Decimal("0.01"))),
reason="안정적인 국고채 장기 투자로 원금 보전",
))
# Deposit
recommendations.append(RecommendationItem(
asset_name="원리금 보장 예금",
asset_type="safe",
category="deposit",
ratio=float((bond_pct * Decimal("0.3")).quantize(Decimal("0.01"))),
reason="원리금 보장으로 안전자산 비중 확보",
))
# Equity ETFs
domestic_equity_ratio = (equity_pct * Decimal("0.5")).quantize(Decimal("0.01"))
foreign_equity_ratio = (equity_pct * Decimal("0.5")).quantize(Decimal("0.01"))
recommendations.append(RecommendationItem(
asset_name="KODEX 200",
asset_type="risky",
category="equity_etf",
ratio=float(domestic_equity_ratio),
reason="국내 대형주 분산 투자 (KOSPI 200 추종)",
))
recommendations.append(RecommendationItem(
asset_name="TIGER 미국 S&P500",
asset_type="risky",
category="equity_etf",
ratio=float(foreign_equity_ratio),
reason="미국 대형주 분산 투자 (S&P 500 추종)",
))
return RecommendationResult(
account_id=account_id,
birth_year=birth_year,
current_age=current_age,
target_retirement_age=target_retirement_age,
years_to_retirement=years_to_ret,
glide_path_equity_pct=float(equity_pct),
glide_path_bond_pct=float(bond_pct),
recommendations=recommendations,
)

View File

@ -0,0 +1,107 @@
"""
Position sizing module.
Supports three methods:
- fixed_ratio: Equal allocation (quant.md default, 30% cash reserve)
- kelly_criterion: Kelly criterion (conservative 1/4 Kelly)
- atr_based: ATR-based volatility sizing
"""
def fixed_ratio(
capital: float,
num_positions: int,
cash_ratio: float = 0.3,
) -> dict:
"""Equal allocation across positions with cash reserve.
Per quant.md: 5-10 positions, 30% cash, max loss per position -3%.
"""
if capital <= 0:
raise ValueError("capital must be positive")
if num_positions <= 0:
raise ValueError("num_positions must be positive")
if cash_ratio < 0 or cash_ratio >= 1.0:
raise ValueError("cash_ratio must be in [0, 1)")
investable = capital * (1 - cash_ratio)
position_size = investable / num_positions
risk_amount = position_size * 0.03 # -3% max loss per position (quant.md)
return {
"method": "fixed",
"position_size": position_size,
"shares": 0, # needs price to calculate
"risk_amount": risk_amount,
"notes": (
f"균등 분배: 투자금 {investable:,.0f}원을 {num_positions}개 종목에 배분. "
f"종목당 최대 손실 -3% = {risk_amount:,.0f}원."
),
}
def kelly_criterion(
win_rate: float,
avg_win: float,
avg_loss: float,
fraction: float = 0.25,
) -> dict:
"""Kelly criterion position sizing (default: conservative 1/4 Kelly).
Kelly% = W - (1-W)/R where W=win_rate, R=avg_win/avg_loss
"""
if win_rate <= 0 or win_rate > 1.0:
raise ValueError("win_rate must be in (0, 1]")
if avg_win <= 0:
raise ValueError("avg_win must be positive")
if avg_loss <= 0:
raise ValueError("avg_loss must be positive")
if fraction <= 0 or fraction > 1.0:
raise ValueError("fraction must be in (0, 1]")
win_loss_ratio = avg_win / avg_loss
kelly_pct = win_rate - (1 - win_rate) / win_loss_ratio
position_size = max(kelly_pct * fraction, 0.0)
return {
"method": "kelly",
"position_size": position_size,
"shares": 0,
"risk_amount": position_size * avg_loss,
"notes": (
f"켈리 기준: Full Kelly = {kelly_pct:.2%}, "
f"{fraction:.0%} Kelly = {position_size:.2%}. "
f"승률 {win_rate:.0%}, 평균 수익 {avg_win:.1%}, 평균 손실 {avg_loss:.1%}."
),
}
def atr_based(
capital: float,
atr: float,
risk_pct: float = 0.02,
) -> dict:
"""ATR-based volatility position sizing.
Shares = (Capital * Risk%) / ATR
"""
if capital <= 0:
raise ValueError("capital must be positive")
if atr <= 0:
raise ValueError("atr must be positive")
if risk_pct <= 0 or risk_pct > 1.0:
raise ValueError("risk_pct must be in (0, 1]")
risk_amount = capital * risk_pct
shares = int(risk_amount / atr)
return {
"method": "atr",
"position_size": risk_amount,
"shares": shares,
"risk_amount": risk_amount,
"notes": (
f"ATR 사이징: 자본금의 {risk_pct:.1%} = {risk_amount:,.0f}원 위험 허용. "
f"ATR {atr:,.0f} 기준 {shares}주."
),
}

View File

@ -0,0 +1,108 @@
"""
Tax simulation service for Korean retirement pension (퇴직연금) tax benefits.
"""
IRP_ANNUAL_LIMIT = 9_000_000 # DC+IRP 합산 연간 한도
INCOME_THRESHOLD = 55_000_000 # 총급여 기준
LOW_INCOME_RATE = 16.5 # 5,500만원 이하 공제율
HIGH_INCOME_RATE = 13.2 # 5,500만원 초과 공제율
LUMP_SUM_TAX_RATE = 16.5 # 기타소득세 (일시금)
def _get_pension_tax_rate(age: int) -> float:
if age >= 80:
return 3.3
elif age >= 70:
return 4.4
else:
return 5.5
def calculate_tax_deduction(
annual_income: int,
contribution: int,
account_type: str,
) -> dict:
deduction_rate = LOW_INCOME_RATE if annual_income <= INCOME_THRESHOLD else HIGH_INCOME_RATE
deductible = min(contribution, IRP_ANNUAL_LIMIT)
tax_deduction = deductible * (deduction_rate / 100)
return {
"annual_income": annual_income,
"contribution": contribution,
"account_type": account_type,
"deduction_rate": deduction_rate,
"irp_limit": IRP_ANNUAL_LIMIT,
"deductible_contribution": deductible,
"tax_deduction": tax_deduction,
}
def calculate_pension_tax(
withdrawal_amount: int,
withdrawal_type: str,
age: int,
) -> dict:
pension_tax_rate = _get_pension_tax_rate(age)
pension_tax = withdrawal_amount * (pension_tax_rate / 100)
lump_sum_tax = withdrawal_amount * (LUMP_SUM_TAX_RATE / 100)
tax_saving = lump_sum_tax - pension_tax
return {
"withdrawal_amount": withdrawal_amount,
"withdrawal_type": withdrawal_type,
"age": age,
"pension_tax_rate": pension_tax_rate,
"pension_tax": pension_tax,
"lump_sum_tax_rate": LUMP_SUM_TAX_RATE,
"lump_sum_tax": lump_sum_tax,
"tax_saving": tax_saving,
}
def simulate_accumulation(
monthly_contribution: int,
years: int,
annual_return: float,
tax_deduction_rate: float,
) -> dict:
annual_contribution = monthly_contribution * 12
monthly_return = (1 + annual_return / 100) ** (1 / 12) - 1
yearly_data = []
current_value = 0.0
cumulative_contribution = 0
cumulative_tax_deduction = 0.0
for year in range(1, years + 1):
for _ in range(12):
current_value = current_value * (1 + monthly_return) + monthly_contribution
cumulative_contribution += annual_contribution
deductible = min(annual_contribution, IRP_ANNUAL_LIMIT)
yearly_tax_deduction = deductible * (tax_deduction_rate / 100)
cumulative_tax_deduction += yearly_tax_deduction
yearly_data.append({
"year": year,
"contribution": annual_contribution,
"cumulative_contribution": cumulative_contribution,
"investment_value": round(current_value, 0),
"tax_deduction": yearly_tax_deduction,
"cumulative_tax_deduction": cumulative_tax_deduction,
})
total_contribution = annual_contribution * years
total_return = round(current_value - total_contribution, 0)
return {
"monthly_contribution": monthly_contribution,
"years": years,
"annual_return": annual_return,
"tax_deduction_rate": tax_deduction_rate,
"total_contribution": total_contribution,
"final_value": round(current_value, 0),
"total_return": total_return,
"total_tax_deduction": cumulative_tax_deduction,
"yearly_data": yearly_data,
}

View File

@ -1,6 +1,7 @@
"""
Daily KJB signal generation job.
"""
import asyncio
import logging
from datetime import date, timedelta
@ -10,6 +11,7 @@ from app.core.database import SessionLocal
from app.models.stock import Stock, Price
from app.models.signal import Signal, SignalType, SignalStatus
from app.services.strategy.kjb import KJBSignalGenerator
from app.services.notification import send_notification
logger = logging.getLogger(__name__)
@ -103,6 +105,19 @@ def run_kjb_signals():
db.commit()
logger.info(f"KJB signal generation complete: {signals_created} buy signals")
# Send notifications for newly created signals
if signals_created > 0:
new_signals = (
db.query(Signal)
.filter(Signal.date == today, Signal.status == SignalStatus.ACTIVE)
.all()
)
for sig in new_signals:
try:
asyncio.run(send_notification(sig, db))
except Exception as e:
logger.error(f"Notification failed for signal {sig.id}: {e}")
except Exception as e:
logger.exception(f"KJB signal generation failed: {e}")
finally:

View File

@ -0,0 +1,166 @@
"""
Unit tests for benchmark service.
"""
from datetime import date, timedelta
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
from app.services.benchmark import BenchmarkService
@pytest.fixture
def db():
return MagicMock()
@pytest.fixture
def service(db):
return BenchmarkService(db)
class TestGetDepositRate:
def test_returns_fixed_rate(self, service):
rate = service.get_deposit_rate()
assert rate == 3.5
class TestGetBenchmarkData:
@patch("app.services.benchmark.pykrx_stock")
def test_returns_kospi_time_series(self, mock_pykrx, service):
dates = pd.date_range("2025-01-02", periods=3, freq="B")
mock_pykrx.get_index_ohlcv.return_value = pd.DataFrame(
{"시가": [2800, 2810, 2820], "종가": [2810, 2820, 2830]},
index=dates,
)
result = service.get_benchmark_data(
"kospi", date(2025, 1, 2), date(2025, 1, 6)
)
assert len(result) == 3
assert result[0]["date"] == dates[0].date()
assert result[0]["close"] == 2810
@patch("app.services.benchmark.pykrx_stock")
def test_empty_data_returns_empty_list(self, mock_pykrx, service):
mock_pykrx.get_index_ohlcv.return_value = pd.DataFrame()
result = service.get_benchmark_data(
"kospi", date(2025, 1, 2), date(2025, 1, 6)
)
assert result == []
class TestCalculateMetrics:
def test_cumulative_return(self, service):
returns = [0.01, 0.02, -0.005, 0.015]
metrics = service._calculate_metrics(returns, num_days=120)
expected_cum = ((1.01) * (1.02) * (0.995) * (1.015) - 1) * 100
assert abs(metrics.cumulative_return - expected_cum) < 0.01
def test_max_drawdown(self, service):
returns = [0.10, -0.20, 0.05]
metrics = service._calculate_metrics(returns, num_days=90)
assert metrics.max_drawdown < 0
def test_sharpe_ratio_with_zero_std(self, service):
returns = [0.01, 0.01, 0.01]
metrics = service._calculate_metrics(returns, num_days=90)
assert metrics.sharpe_ratio is None
def test_empty_returns(self, service):
metrics = service._calculate_metrics([], num_days=0)
assert metrics.cumulative_return == 0.0
assert metrics.annualized_return == 0.0
assert metrics.max_drawdown == 0.0
class TestCompareWithBenchmark:
def _make_snapshot(self, snapshot_date, total_value):
snap = MagicMock()
snap.snapshot_date = snapshot_date
snap.total_value = Decimal(str(total_value))
return snap
@patch("app.services.benchmark.pykrx_stock")
def test_compare_basic(self, mock_pykrx, service, db):
base = date(2025, 1, 2)
snapshots = [
self._make_snapshot(base, 10000),
self._make_snapshot(base + timedelta(days=30), 10500),
self._make_snapshot(base + timedelta(days=60), 10800),
]
portfolio = MagicMock()
portfolio.id = 1
portfolio.name = "Test Portfolio"
portfolio.user_id = 1
db.query.return_value.filter.return_value.first.return_value = portfolio
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = snapshots
dates = pd.date_range(base.strftime("%Y%m%d"), periods=3, freq="30D")
mock_pykrx.get_index_ohlcv.return_value = pd.DataFrame(
{"시가": [2800, 2850, 2900], "종가": [2810, 2860, 2910]},
index=dates,
)
result = service.compare_with_benchmark(
portfolio_id=1, benchmark_type="kospi", period="all", user_id=1
)
assert result.portfolio_name == "Test Portfolio"
assert result.benchmark_type == "kospi"
assert result.alpha is not None
assert len(result.time_series) > 0
@patch("app.services.benchmark.pykrx_stock")
def test_compare_not_found(self, mock_pykrx, service, db):
db.query.return_value.filter.return_value.first.return_value = None
with pytest.raises(ValueError, match="Portfolio not found"):
service.compare_with_benchmark(
portfolio_id=999, benchmark_type="kospi", period="1y", user_id=1
)
@patch("app.services.benchmark.pykrx_stock")
def test_compare_no_snapshots(self, mock_pykrx, service, db):
portfolio = MagicMock()
portfolio.id = 1
portfolio.name = "Empty"
portfolio.user_id = 1
db.query.return_value.filter.return_value.first.return_value = portfolio
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = []
with pytest.raises(ValueError, match="스냅샷 데이터가 없습니다"):
service.compare_with_benchmark(
portfolio_id=1, benchmark_type="kospi", period="1y", user_id=1
)
class TestCalculateInformationRatio:
def test_positive_tracking_error(self, service):
portfolio_returns = [0.02, 0.03, -0.01, 0.04]
benchmark_returns = [0.01, 0.02, -0.005, 0.02]
ir = service._calculate_information_ratio(portfolio_returns, benchmark_returns)
assert ir is not None
assert isinstance(ir, float)
def test_zero_tracking_error(self, service):
same_returns = [0.01, 0.02, 0.03]
ir = service._calculate_information_ratio(same_returns, same_returns)
assert ir is None
def test_empty_returns(self, service):
ir = service._calculate_information_ratio([], [])
assert ir is None

View File

@ -0,0 +1,220 @@
"""
Unit tests for correlation analysis service.
"""
from datetime import date, timedelta
from decimal import Decimal
from unittest.mock import MagicMock, patch
import numpy as np
import pandas as pd
import pytest
from app.services.correlation import CorrelationService
@pytest.fixture
def db():
return MagicMock()
@pytest.fixture
def service(db):
return CorrelationService(db)
class TestCalculateCorrelationMatrix:
def _make_prices(self, ticker: str, dates: list, closes: list):
prices = []
for d, c in zip(dates, closes):
p = MagicMock()
p.ticker = ticker
p.date = d
p.close = Decimal(str(c))
prices.append(p)
return prices
def test_two_stocks_positive_correlation(self, service, db):
dates = [date(2025, 1, i) for i in range(1, 11)]
prices_a = self._make_prices("A", dates, [100, 102, 104, 103, 105, 107, 106, 108, 110, 112])
prices_b = self._make_prices("B", dates, [50, 51, 52, 51.5, 52.5, 53.5, 53, 54, 55, 56])
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices_a + prices_b
result = service.calculate_correlation_matrix(["A", "B"], period_days=60)
assert "A" in result["stock_codes"]
assert "B" in result["stock_codes"]
assert len(result["matrix"]) == 2
assert len(result["matrix"][0]) == 2
# Diagonal should be 1.0
assert result["matrix"][0][0] == pytest.approx(1.0, abs=0.01)
assert result["matrix"][1][1] == pytest.approx(1.0, abs=0.01)
# These stocks move together, correlation should be high
assert result["matrix"][0][1] > 0.5
def test_single_stock_returns_identity(self, service, db):
dates = [date(2025, 1, i) for i in range(1, 6)]
prices = self._make_prices("A", dates, [100, 102, 101, 103, 105])
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices
result = service.calculate_correlation_matrix(["A"], period_days=60)
assert result["matrix"] == [[1.0]]
def test_empty_stock_codes(self, service, db):
result = service.calculate_correlation_matrix([], period_days=60)
assert result["stock_codes"] == []
assert result["matrix"] == []
def test_insufficient_data_returns_nan_as_none(self, service, db):
dates = [date(2025, 1, 1)]
prices_a = self._make_prices("A", dates, [100])
prices_b = self._make_prices("B", dates, [50])
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices_a + prices_b
result = service.calculate_correlation_matrix(["A", "B"], period_days=60)
# With only 1 data point, no returns can be calculated
assert result["matrix"][0][1] is None
class TestCalculatePortfolioDiversification:
def _make_holding(self, ticker: str, value: float, ratio: float):
h = MagicMock()
h.ticker = ticker
h.value = Decimal(str(value))
h.current_ratio = Decimal(str(ratio))
return h
def _make_prices(self, ticker: str, dates: list, closes: list):
prices = []
for d, c in zip(dates, closes):
p = MagicMock()
p.ticker = ticker
p.date = d
p.close = Decimal(str(c))
prices.append(p)
return prices
def test_diversified_portfolio(self, service, db):
"""Low correlation stocks -> high diversification score."""
dates = [date(2025, 1, i) for i in range(1, 21)]
np.random.seed(42)
closes_a = np.cumsum(np.random.randn(20)) + 100
closes_b = np.cumsum(np.random.randn(20)) + 200
prices = (
self._make_prices("A", dates, closes_a.tolist()) +
self._make_prices("B", dates, closes_b.tolist())
)
portfolio = MagicMock()
portfolio.id = 1
portfolio.user_id = 1
snapshot = MagicMock()
snapshot.holdings = [
self._make_holding("A", 5000, 50),
self._make_holding("B", 5000, 50),
]
# DB query chain
db.query.return_value.filter.return_value.first.return_value = portfolio
db.query.return_value.filter.return_value.order_by.return_value.first.return_value = snapshot
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices
score = service.calculate_portfolio_diversification(portfolio_id=1)
assert 0 <= score <= 1
def test_portfolio_not_found(self, service, db):
db.query.return_value.filter.return_value.first.return_value = None
with pytest.raises(ValueError, match="Portfolio not found"):
service.calculate_portfolio_diversification(portfolio_id=999)
def test_no_holdings(self, service, db):
portfolio = MagicMock()
portfolio.id = 1
snapshot = MagicMock()
snapshot.holdings = []
db.query.return_value.filter.return_value.first.return_value = portfolio
db.query.return_value.filter.return_value.order_by.return_value.first.return_value = snapshot
score = service.calculate_portfolio_diversification(portfolio_id=1)
assert score == 1.0
def test_single_holding(self, service, db):
dates = [date(2025, 1, i) for i in range(1, 11)]
prices = self._make_prices("A", dates, [100 + i for i in range(10)])
portfolio = MagicMock()
portfolio.id = 1
snapshot = MagicMock()
snapshot.holdings = [self._make_holding("A", 10000, 100)]
db.query.return_value.filter.return_value.first.return_value = portfolio
db.query.return_value.filter.return_value.order_by.return_value.first.return_value = snapshot
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices
score = service.calculate_portfolio_diversification(portfolio_id=1)
# Single stock = no diversification benefit, score should be 0
assert score == 0.0
class TestGetCorrelationData:
def _make_prices(self, ticker: str, dates: list, closes: list):
prices = []
for d, c in zip(dates, closes):
p = MagicMock()
p.ticker = ticker
p.date = d
p.close = Decimal(str(c))
prices.append(p)
return prices
def test_heatmap_data_structure(self, service, db):
dates = [date(2025, 1, i) for i in range(1, 11)]
prices_a = self._make_prices("A", dates, [100, 102, 104, 103, 105, 107, 106, 108, 110, 112])
prices_b = self._make_prices("B", dates, [50, 51, 52, 51.5, 52.5, 53.5, 53, 54, 55, 56])
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices_a + prices_b
result = service.get_correlation_data(["A", "B"], period_days=60)
assert "stock_codes" in result
assert "matrix" in result
assert "high_correlation_pairs" in result
# high_correlation_pairs should have pairs with corr > 0.7
for pair in result["high_correlation_pairs"]:
assert "stock_a" in pair
assert "stock_b" in pair
assert "correlation" in pair
assert abs(pair["correlation"]) > 0.7
def test_no_high_correlation_pairs_when_uncorrelated(self, service, db):
dates = [date(2025, 1, i) for i in range(1, 21)]
np.random.seed(123)
closes_a = np.cumsum(np.random.randn(20)) + 100
closes_b = np.cumsum(np.random.randn(20)) + 200
prices = (
self._make_prices("A", dates, closes_a.tolist()) +
self._make_prices("B", dates, closes_b.tolist())
)
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices
result = service.get_correlation_data(["A", "B"], period_days=60)
# random walks are unlikely to have > 0.7 correlation
high_pairs = [p for p in result["high_correlation_pairs"] if abs(p["correlation"]) > 0.7]
# This is probabilistic but with seed 123 they should be uncorrelated
assert len(result["high_correlation_pairs"]) >= 0 # may or may not have pairs

View File

@ -0,0 +1,278 @@
"""
Tests for drawdown service and API endpoints.
"""
import pytest
from datetime import date
from decimal import Decimal
from app.models.portfolio import Portfolio, PortfolioSnapshot
from app.services.drawdown import (
calculate_drawdown,
calculate_rolling_drawdown,
check_drawdown_alert,
get_alert_threshold,
set_alert_threshold,
DEFAULT_ALERT_THRESHOLD,
_drawdown_settings,
)
# --- Helper ---
def _create_portfolio_with_snapshots(db, user_id, values_and_dates):
"""Create a portfolio with snapshot time series."""
portfolio = Portfolio(
user_id=user_id,
name="테스트 포트폴리오",
portfolio_type="general",
)
db.add(portfolio)
db.flush()
for snap_date, total_value in values_and_dates:
snapshot = PortfolioSnapshot(
portfolio_id=portfolio.id,
total_value=Decimal(str(total_value)),
snapshot_date=snap_date,
)
db.add(snapshot)
db.commit()
db.refresh(portfolio)
return portfolio
# --- calculate_drawdown tests ---
class TestCalculateDrawdown:
def test_no_snapshots(self, db, test_user):
portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오")
db.add(portfolio)
db.commit()
result = calculate_drawdown(db, portfolio.id)
assert result["current_drawdown_pct"] == Decimal("0")
assert result["max_drawdown_pct"] == Decimal("0")
assert result["peak_value"] is None
def test_no_drawdown_monotonic_increase(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_100_000),
(date(2025, 3, 1), 1_200_000),
])
result = calculate_drawdown(db, portfolio.id)
assert result["current_drawdown_pct"] == Decimal("0")
assert result["max_drawdown_pct"] == Decimal("0")
assert result["peak_value"] == Decimal("1200000")
def test_simple_drawdown(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000), # peak
(date(2025, 3, 1), 1_080_000), # -10%
])
result = calculate_drawdown(db, portfolio.id)
assert result["current_drawdown_pct"] == Decimal("10.00")
assert result["max_drawdown_pct"] == Decimal("10.00")
assert result["peak_value"] == Decimal("1200000")
assert result["peak_date"] == date(2025, 2, 1)
assert result["trough_value"] == Decimal("1080000")
def test_recovery_after_drawdown(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000), # peak
(date(2025, 3, 1), 960_000), # -20% (max dd)
(date(2025, 4, 1), 1_300_000), # new peak, recovery
])
result = calculate_drawdown(db, portfolio.id)
assert result["current_drawdown_pct"] == Decimal("0")
assert result["max_drawdown_pct"] == Decimal("20.00")
assert result["peak_value"] == Decimal("1300000")
assert result["max_drawdown_date"] == date(2025, 3, 1)
def test_multiple_drawdowns_picks_worst(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 900_000), # -10%
(date(2025, 3, 1), 1_100_000), # new peak
(date(2025, 4, 1), 880_000), # -20% from 1.1M
])
result = calculate_drawdown(db, portfolio.id)
assert result["max_drawdown_pct"] == Decimal("20.00")
assert result["current_drawdown_pct"] == Decimal("20.00")
# --- calculate_rolling_drawdown tests ---
class TestCalculateRollingDrawdown:
def test_empty_snapshots(self, db, test_user):
portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오")
db.add(portfolio)
db.commit()
result = calculate_rolling_drawdown(db, portfolio.id)
assert result == []
def test_rolling_drawdown_series(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000),
(date(2025, 3, 1), 1_080_000),
])
result = calculate_rolling_drawdown(db, portfolio.id)
assert len(result) == 3
# First point: no drawdown
assert result[0]["drawdown_pct"] == Decimal("0")
assert result[0]["peak"] == Decimal("1000000")
# Second point: new peak, no drawdown
assert result[1]["drawdown_pct"] == Decimal("0")
assert result[1]["peak"] == Decimal("1200000")
# Third point: drawdown from peak
assert result[2]["drawdown_pct"] == Decimal("10.00")
assert result[2]["peak"] == Decimal("1200000")
# --- check_drawdown_alert tests ---
class TestCheckDrawdownAlert:
def setup_method(self):
_drawdown_settings.clear()
def test_no_alert_under_threshold(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 950_000), # -5%
])
result = check_drawdown_alert(db, portfolio.id)
assert result is None
def test_alert_above_default_threshold(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 780_000), # -22%
])
result = check_drawdown_alert(db, portfolio.id)
assert result is not None
assert "Drawdown 경고" in result
assert "테스트 포트폴리오" in result
def test_alert_with_custom_threshold(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 900_000), # -10%
])
set_alert_threshold(portfolio.id, Decimal("5"))
result = check_drawdown_alert(db, portfolio.id)
assert result is not None
assert "경고" in result
def test_no_alert_empty_portfolio(self, db, test_user):
portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오")
db.add(portfolio)
db.commit()
result = check_drawdown_alert(db, portfolio.id)
assert result is None
# --- settings tests ---
class TestDrawdownSettings:
def setup_method(self):
_drawdown_settings.clear()
def test_default_threshold(self):
assert get_alert_threshold(999) == DEFAULT_ALERT_THRESHOLD
def test_set_and_get_threshold(self):
set_alert_threshold(1, Decimal("15"))
assert get_alert_threshold(1) == Decimal("15")
def test_different_portfolios_independent(self):
set_alert_threshold(1, Decimal("10"))
set_alert_threshold(2, Decimal("25"))
assert get_alert_threshold(1) == Decimal("10")
assert get_alert_threshold(2) == Decimal("25")
# --- API endpoint tests ---
class TestDrawdownAPI:
def _create_portfolio_via_api(self, client, auth_headers):
resp = client.post(
"/api/portfolios",
headers=auth_headers,
json={"name": "테스트", "portfolio_type": "general"},
)
return resp.json()["id"]
def _add_snapshots(self, db, portfolio_id, values_and_dates):
for snap_date, total_value in values_and_dates:
snapshot = PortfolioSnapshot(
portfolio_id=portfolio_id,
total_value=Decimal(str(total_value)),
snapshot_date=snap_date,
)
db.add(snapshot)
db.commit()
def test_get_drawdown(self, client, auth_headers, db):
pid = self._create_portfolio_via_api(client, auth_headers)
self._add_snapshots(db, pid, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000),
(date(2025, 3, 1), 1_080_000),
])
resp = client.get(f"/api/drawdown/{pid}", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["portfolio_id"] == pid
assert data["current_drawdown_pct"] == 10.0
assert data["max_drawdown_pct"] == 10.0
def test_get_drawdown_history(self, client, auth_headers, db):
pid = self._create_portfolio_via_api(client, auth_headers)
self._add_snapshots(db, pid, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000),
(date(2025, 3, 1), 1_080_000),
])
resp = client.get(f"/api/drawdown/{pid}/history", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert len(data["data"]) == 3
assert data["max_drawdown_pct"] == 10.0
def test_update_settings(self, client, auth_headers, db):
pid = self._create_portfolio_via_api(client, auth_headers)
resp = client.put(
f"/api/drawdown/settings/{pid}",
headers=auth_headers,
json={"alert_threshold_pct": 15.0},
)
assert resp.status_code == 200
assert resp.json()["alert_threshold_pct"] == 15.0
def test_drawdown_nonexistent_portfolio(self, client, auth_headers):
resp = client.get("/api/drawdown/9999", headers=auth_headers)
assert resp.status_code == 404
def test_unauthenticated_access(self, client):
resp = client.get("/api/drawdown/1")
assert resp.status_code == 401

View File

@ -0,0 +1,235 @@
"""
Tests for trading journal models and API endpoints.
"""
import pytest
from datetime import date, datetime
from decimal import Decimal
# --- API endpoint tests ---
class TestJournalAPI:
def _create_journal(self, client, auth_headers, **overrides):
payload = {
"stock_code": "005930",
"stock_name": "삼성전자",
"trade_type": "buy",
"entry_price": 72000,
"target_price": 75600,
"stop_loss_price": 69840,
"entry_date": "2026-03-20",
"quantity": 10,
"entry_reason": "KJB 매수 신호 - 돌파 패턴",
"scenario": "목표가 75,600 도달 시 전량 매도, 손절가 69,840 이탈 시 손절",
**overrides,
}
return client.post("/api/journal", headers=auth_headers, json=payload)
def test_create_journal(self, client, auth_headers):
response = self._create_journal(client, auth_headers)
assert response.status_code == 201
data = response.json()
assert data["stock_code"] == "005930"
assert data["stock_name"] == "삼성전자"
assert data["trade_type"] == "buy"
assert data["entry_price"] == 72000
assert data["status"] == "open"
assert data["entry_reason"] == "KJB 매수 신호 - 돌파 패턴"
assert data["scenario"] is not None
def test_create_journal_minimal(self, client, auth_headers):
response = client.post(
"/api/journal",
headers=auth_headers,
json={
"stock_code": "035420",
"trade_type": "sell",
"entry_date": "2026-03-20",
},
)
assert response.status_code == 201
data = response.json()
assert data["stock_code"] == "035420"
assert data["trade_type"] == "sell"
assert data["status"] == "open"
def test_list_journals(self, client, auth_headers):
self._create_journal(client, auth_headers, stock_code="005930")
self._create_journal(client, auth_headers, stock_code="035420", stock_name="NAVER")
response = client.get("/api/journal", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 2
def test_list_journals_filter_by_status(self, client, auth_headers):
self._create_journal(client, auth_headers)
response = client.get("/api/journal?status=open", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 1
response = client.get("/api/journal?status=closed", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 0
def test_list_journals_filter_by_stock_code(self, client, auth_headers):
self._create_journal(client, auth_headers, stock_code="005930")
self._create_journal(client, auth_headers, stock_code="035420", stock_name="NAVER")
response = client.get("/api/journal?stock_code=005930", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["stock_code"] == "005930"
def test_list_journals_filter_by_date_range(self, client, auth_headers):
self._create_journal(client, auth_headers, entry_date="2026-03-15")
self._create_journal(client, auth_headers, entry_date="2026-03-25")
response = client.get(
"/api/journal?start_date=2026-03-20&end_date=2026-03-31",
headers=auth_headers,
)
assert response.status_code == 200
assert len(response.json()) == 1
def test_get_journal(self, client, auth_headers):
create_resp = self._create_journal(client, auth_headers)
journal_id = create_resp.json()["id"]
response = client.get(f"/api/journal/{journal_id}", headers=auth_headers)
assert response.status_code == 200
assert response.json()["id"] == journal_id
def test_get_journal_not_found(self, client, auth_headers):
response = client.get("/api/journal/9999", headers=auth_headers)
assert response.status_code == 404
def test_update_journal_add_exit(self, client, auth_headers):
create_resp = self._create_journal(client, auth_headers)
journal_id = create_resp.json()["id"]
response = client.put(
f"/api/journal/{journal_id}",
headers=auth_headers,
json={
"exit_price": 75600,
"exit_date": "2026-03-28",
"exit_reason": "목표가 도달",
"lessons_learned": "시나리오대로 진행된 좋은 거래",
},
)
assert response.status_code == 200
data = response.json()
assert data["exit_price"] == 75600
assert data["status"] == "closed"
assert data["profit_loss"] is not None
assert data["profit_loss_pct"] is not None
# Buy: (75600 - 72000) / 72000 * 100 = 5.0
assert abs(data["profit_loss_pct"] - 5.0) < 0.01
# Profit: (75600 - 72000) * 10 = 36000
assert abs(data["profit_loss"] - 36000) < 1
def test_update_journal_sell_pnl_calculation(self, client, auth_headers):
create_resp = self._create_journal(
client, auth_headers,
trade_type="sell",
entry_price=75000,
quantity=5,
)
journal_id = create_resp.json()["id"]
response = client.put(
f"/api/journal/{journal_id}",
headers=auth_headers,
json={
"exit_price": 72000,
"exit_date": "2026-03-28",
},
)
data = response.json()
# Sell: (75000 - 72000) / 75000 * 100 = 4.0
assert abs(data["profit_loss_pct"] - 4.0) < 0.01
# Profit: (75000 - 72000) * 5 = 15000
assert abs(data["profit_loss"] - 15000) < 1
def test_update_journal_not_found(self, client, auth_headers):
response = client.put(
"/api/journal/9999",
headers=auth_headers,
json={"lessons_learned": "test"},
)
assert response.status_code == 404
def test_get_stats_empty(self, client, auth_headers):
response = client.get("/api/journal/stats", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert data["total_trades"] == 0
assert data["win_rate"] is None
def test_get_stats_with_data(self, client, auth_headers):
# Create and close a winning trade
r1 = self._create_journal(client, auth_headers, entry_price=72000, quantity=10)
client.put(
f"/api/journal/{r1.json()['id']}",
headers=auth_headers,
json={"exit_price": 75600, "exit_date": "2026-03-28"},
)
# Create and close a losing trade
r2 = self._create_journal(
client, auth_headers,
stock_code="035420",
stock_name="NAVER",
entry_price=50000,
quantity=5,
)
client.put(
f"/api/journal/{r2.json()['id']}",
headers=auth_headers,
json={"exit_price": 48000, "exit_date": "2026-03-28"},
)
# Create an open trade
self._create_journal(client, auth_headers, stock_code="000660")
response = client.get("/api/journal/stats", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert data["total_trades"] == 3
assert data["open_trades"] == 1
assert data["closed_trades"] == 2
assert data["win_count"] == 1
assert data["loss_count"] == 1
assert data["win_rate"] == 50.0
assert data["total_profit_loss"] is not None
def test_unauthenticated_access(self, client):
response = client.get("/api/journal")
assert response.status_code == 401
def test_user_isolation(self, client, auth_headers, db):
"""User can only see their own journals."""
from app.models.user import User
from app.core.security import get_password_hash, create_access_token
# Create another user's journal
other_user = User(
username="otheruser",
email="other@example.com",
hashed_password=get_password_hash("password"),
)
db.add(other_user)
db.commit()
db.refresh(other_user)
other_token = create_access_token(data={"sub": other_user.username})
other_headers = {"Authorization": f"Bearer {other_token}"}
self._create_journal(client, other_headers, stock_code="000660")
# Current user should see 0 journals
response = client.get("/api/journal", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 0

View File

@ -0,0 +1,326 @@
"""
Tests for notification service, models, and API endpoints.
"""
import pytest
from datetime import datetime, timedelta
from decimal import Decimal
from unittest.mock import AsyncMock, patch
from app.models.signal import Signal, SignalType, SignalStatus
from app.models.notification import (
NotificationSetting, NotificationHistory,
ChannelType, NotificationStatus,
)
from app.services.notification import (
format_signal_message,
_is_duplicate,
send_notification,
)
# --- format_signal_message tests ---
class TestFormatSignalMessage:
def test_buy_signal_full(self):
signal = Signal(
ticker="005930",
name="삼성전자",
signal_type=SignalType.BUY,
entry_price=Decimal("72000"),
target_price=Decimal("75600"),
stop_loss_price=Decimal("69840"),
reason="breakout, large_candle",
)
msg = format_signal_message(signal)
assert "삼성전자" in msg
assert "005930" in msg
assert "매수" in msg
assert "72,000" in msg
assert "75,600" in msg
assert "69,840" in msg
assert "breakout" in msg
def test_sell_signal_minimal(self):
signal = Signal(
ticker="035420",
name=None,
signal_type=SignalType.SELL,
entry_price=None,
target_price=None,
stop_loss_price=None,
reason=None,
)
msg = format_signal_message(signal)
assert "035420" in msg
assert "매도" in msg
assert "진입가" not in msg
assert "목표가" not in msg
def test_partial_sell_signal(self):
signal = Signal(
ticker="000660",
name="SK하이닉스",
signal_type=SignalType.PARTIAL_SELL,
entry_price=Decimal("150000"),
)
msg = format_signal_message(signal)
assert "부분매도" in msg
assert "150,000" in msg
# --- _is_duplicate tests ---
class TestIsDuplicate:
def test_no_duplicate_when_empty(self, db):
assert _is_duplicate(db, signal_id=999, channel_type=ChannelType.DISCORD) is False
def test_duplicate_within_24h(self, db):
history = NotificationHistory(
signal_id=1,
channel_type=ChannelType.DISCORD,
status=NotificationStatus.SENT,
sent_at=datetime.utcnow() - timedelta(hours=1),
message="test",
)
db.add(history)
db.commit()
assert _is_duplicate(db, signal_id=1, channel_type=ChannelType.DISCORD) is True
def test_no_duplicate_after_24h(self, db):
history = NotificationHistory(
signal_id=1,
channel_type=ChannelType.DISCORD,
status=NotificationStatus.SENT,
sent_at=datetime.utcnow() - timedelta(hours=25),
message="test",
)
db.add(history)
db.commit()
assert _is_duplicate(db, signal_id=1, channel_type=ChannelType.DISCORD) is False
def test_failed_notification_not_duplicate(self, db):
history = NotificationHistory(
signal_id=1,
channel_type=ChannelType.DISCORD,
status=NotificationStatus.FAILED,
sent_at=datetime.utcnow(),
message="test",
)
db.add(history)
db.commit()
assert _is_duplicate(db, signal_id=1, channel_type=ChannelType.DISCORD) is False
def test_different_channel_not_duplicate(self, db):
history = NotificationHistory(
signal_id=1,
channel_type=ChannelType.DISCORD,
status=NotificationStatus.SENT,
sent_at=datetime.utcnow(),
message="test",
)
db.add(history)
db.commit()
assert _is_duplicate(db, signal_id=1, channel_type=ChannelType.TELEGRAM) is False
# --- send_notification tests ---
class TestSendNotification:
@pytest.mark.asyncio
async def test_sends_to_enabled_channels(self, db, test_user):
setting = NotificationSetting(
user_id=test_user.id,
channel_type=ChannelType.DISCORD,
webhook_url="https://discord.com/api/webhooks/test",
enabled=True,
)
db.add(setting)
signal = Signal(
id=1,
date=datetime.utcnow().date(),
ticker="005930",
name="삼성전자",
signal_type=SignalType.BUY,
entry_price=Decimal("72000"),
status=SignalStatus.ACTIVE,
)
db.add(signal)
db.commit()
with patch("app.services.notification.send_discord", new_callable=AsyncMock) as mock_discord:
await send_notification(signal, db)
mock_discord.assert_called_once()
history = db.query(NotificationHistory).first()
assert history is not None
assert history.status == NotificationStatus.SENT
assert history.signal_id == 1
@pytest.mark.asyncio
async def test_skips_disabled_channels(self, db, test_user):
setting = NotificationSetting(
user_id=test_user.id,
channel_type=ChannelType.DISCORD,
webhook_url="https://discord.com/api/webhooks/test",
enabled=False,
)
db.add(setting)
signal = Signal(
id=2,
date=datetime.utcnow().date(),
ticker="005930",
name="삼성전자",
signal_type=SignalType.BUY,
status=SignalStatus.ACTIVE,
)
db.add(signal)
db.commit()
with patch("app.services.notification.send_discord", new_callable=AsyncMock) as mock_discord:
await send_notification(signal, db)
mock_discord.assert_not_called()
@pytest.mark.asyncio
async def test_skips_duplicate(self, db, test_user):
setting = NotificationSetting(
user_id=test_user.id,
channel_type=ChannelType.DISCORD,
webhook_url="https://discord.com/api/webhooks/test",
enabled=True,
)
db.add(setting)
signal = Signal(
id=3,
date=datetime.utcnow().date(),
ticker="005930",
name="삼성전자",
signal_type=SignalType.BUY,
status=SignalStatus.ACTIVE,
)
db.add(signal)
# Pre-existing sent notification
history = NotificationHistory(
signal_id=3,
channel_type=ChannelType.DISCORD,
status=NotificationStatus.SENT,
sent_at=datetime.utcnow(),
message="already sent",
)
db.add(history)
db.commit()
with patch("app.services.notification.send_discord", new_callable=AsyncMock) as mock_discord:
await send_notification(signal, db)
mock_discord.assert_not_called()
@pytest.mark.asyncio
async def test_records_failure(self, db, test_user):
setting = NotificationSetting(
user_id=test_user.id,
channel_type=ChannelType.DISCORD,
webhook_url="https://discord.com/api/webhooks/test",
enabled=True,
)
db.add(setting)
signal = Signal(
id=4,
date=datetime.utcnow().date(),
ticker="005930",
name="삼성전자",
signal_type=SignalType.BUY,
status=SignalStatus.ACTIVE,
)
db.add(signal)
db.commit()
with patch(
"app.services.notification.send_discord",
new_callable=AsyncMock,
side_effect=Exception("Connection failed"),
):
await send_notification(signal, db)
history = db.query(NotificationHistory).first()
assert history.status == NotificationStatus.FAILED
assert "Connection failed" in history.error_message
# --- API endpoint tests ---
class TestNotificationAPI:
def test_get_settings_empty(self, client, auth_headers):
response = client.get("/api/notifications/settings", headers=auth_headers)
assert response.status_code == 200
assert response.json() == []
def test_create_setting(self, client, auth_headers):
response = client.post(
"/api/notifications/settings",
headers=auth_headers,
json={
"channel_type": "discord",
"webhook_url": "https://discord.com/api/webhooks/test/token",
"enabled": True,
},
)
assert response.status_code == 201
data = response.json()
assert data["channel_type"] == "discord"
assert data["webhook_url"] == "https://discord.com/api/webhooks/test/token"
assert data["enabled"] is True
def test_create_duplicate_channel_fails(self, client, auth_headers):
payload = {
"channel_type": "discord",
"webhook_url": "https://discord.com/api/webhooks/test/token",
}
client.post("/api/notifications/settings", headers=auth_headers, json=payload)
response = client.post("/api/notifications/settings", headers=auth_headers, json=payload)
assert response.status_code == 400
def test_update_setting(self, client, auth_headers):
create_resp = client.post(
"/api/notifications/settings",
headers=auth_headers,
json={
"channel_type": "discord",
"webhook_url": "https://discord.com/api/webhooks/old",
},
)
setting_id = create_resp.json()["id"]
update_resp = client.put(
f"/api/notifications/settings/{setting_id}",
headers=auth_headers,
json={"webhook_url": "https://discord.com/api/webhooks/new", "enabled": False},
)
assert update_resp.status_code == 200
data = update_resp.json()
assert data["webhook_url"] == "https://discord.com/api/webhooks/new"
assert data["enabled"] is False
def test_update_nonexistent_setting(self, client, auth_headers):
response = client.put(
"/api/notifications/settings/9999",
headers=auth_headers,
json={"enabled": False},
)
assert response.status_code == 404
def test_get_history(self, client, auth_headers):
response = client.get("/api/notifications/history", headers=auth_headers)
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_unauthenticated_access(self, client):
response = client.get("/api/notifications/settings")
assert response.status_code == 401

View File

@ -0,0 +1,281 @@
"""
Tests for the strategy optimizer service and schemas.
"""
import pytest
from datetime import date
from decimal import Decimal
from unittest.mock import MagicMock, patch
from app.schemas.optimizer import (
DEFAULT_GRIDS,
KJB_DEFAULT_GRID,
STRATEGY_TYPES,
OptimizeRequest,
OptimizeResponse,
OptimizeResultItem,
)
from app.services.optimizer import (
OptimizerService,
_expand_grid,
_build_strategy_params,
)
# --- Unit tests for grid expansion ---
class TestExpandGrid:
def test_single_param(self):
grid = {"a": [1, 2, 3]}
result = _expand_grid(grid)
assert len(result) == 3
assert result[0] == {"a": 1}
assert result[2] == {"a": 3}
def test_two_params(self):
grid = {"a": [1, 2], "b": [10, 20]}
result = _expand_grid(grid)
assert len(result) == 4
assert {"a": 1, "b": 10} in result
assert {"a": 2, "b": 20} in result
def test_empty_grid(self):
result = _expand_grid({})
assert len(result) == 1 # single empty combo
assert result[0] == {}
def test_kjb_default_grid_size(self):
result = _expand_grid(KJB_DEFAULT_GRID)
# 3 * 3 * 3 = 27
assert len(result) == 27
class TestBuildStrategyParams:
def test_flat_params(self):
result = _build_strategy_params("kjb", {
"stop_loss_pct": 0.05,
"target1_pct": 0.07,
})
assert result == {"stop_loss_pct": 0.05, "target1_pct": 0.07}
def test_nested_params(self):
result = _build_strategy_params("multi_factor", {
"weights.value": 0.3,
"weights.quality": 0.2,
})
assert result == {"weights": {"value": 0.3, "quality": 0.2}}
def test_deeply_nested(self):
result = _build_strategy_params("test", {
"a.b.c": 1,
})
assert result == {"a": {"b": {"c": 1}}}
# --- Schema tests ---
class TestOptimizeRequest:
def test_defaults(self):
req = OptimizeRequest(
strategy_type="kjb",
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
)
assert req.initial_capital == Decimal("100000000")
assert req.commission_rate == Decimal("0.00015")
assert req.slippage_rate == Decimal("0.001")
assert req.benchmark == "KOSPI"
assert req.top_n == 30
assert req.rank_by == "sharpe_ratio"
assert req.param_grid is None
def test_custom_grid(self):
custom = {"stop_loss_pct": [0.02, 0.04]}
req = OptimizeRequest(
strategy_type="kjb",
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
param_grid=custom,
)
assert req.param_grid == custom
def test_all_strategy_types_valid(self):
for st in STRATEGY_TYPES:
req = OptimizeRequest(
strategy_type=st,
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
)
assert req.strategy_type == st
class TestOptimizeResponse:
def test_response_serialization(self):
item = OptimizeResultItem(
rank=1,
params={"stop_loss_pct": 0.05},
total_return=Decimal("15.5"),
cagr=Decimal("12.3"),
mdd=Decimal("-8.2"),
sharpe_ratio=Decimal("1.45"),
volatility=Decimal("18.7"),
benchmark_return=Decimal("10.0"),
excess_return=Decimal("5.5"),
)
resp = OptimizeResponse(
strategy_type="kjb",
total_combinations=27,
results=[item],
best_params={"stop_loss_pct": 0.05},
)
data = resp.model_dump(mode="json")
assert data["total_combinations"] == 27
assert data["results"][0]["sharpe_ratio"] == 1.45
assert isinstance(data["results"][0]["sharpe_ratio"], float)
# --- Default grids ---
class TestDefaultGrids:
def test_all_strategy_types_have_grids(self):
for st in STRATEGY_TYPES:
assert st in DEFAULT_GRIDS
def test_kjb_grid_keys(self):
assert "stop_loss_pct" in KJB_DEFAULT_GRID
assert "target1_pct" in KJB_DEFAULT_GRID
assert "rs_lookback" in KJB_DEFAULT_GRID
# --- OptimizerService tests with mocked DB ---
class TestOptimizerService:
def _make_service(self):
db = MagicMock()
return OptimizerService(db)
def test_optimize_no_grid_raises_for_unknown_type(self):
service = self._make_service()
req = OptimizeRequest(
strategy_type="unknown_type",
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
)
with pytest.raises(ValueError, match="No parameter grid"):
service.optimize(req)
@patch.object(OptimizerService, "_run_single")
def test_optimize_uses_default_grid(self, mock_run):
mock_run.return_value = {
"total_return": 10.0,
"cagr": 8.0,
"mdd": -5.0,
"sharpe_ratio": 1.2,
"volatility": 15.0,
"benchmark_return": 7.0,
"excess_return": 3.0,
}
service = self._make_service()
req = OptimizeRequest(
strategy_type="kjb",
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
)
result = service.optimize(req)
assert result.strategy_type == "kjb"
assert result.total_combinations == 27 # 3*3*3
assert len(result.results) == 27
assert result.results[0].rank == 1
@patch.object(OptimizerService, "_run_single")
def test_optimize_ranks_by_sharpe(self, mock_run):
mock_run.side_effect = [
{
"total_return": 10.0, "cagr": 8.0, "mdd": -5.0,
"sharpe_ratio": 0.5, "volatility": 15.0,
"benchmark_return": 7.0, "excess_return": 3.0,
},
{
"total_return": 20.0, "cagr": 15.0, "mdd": -10.0,
"sharpe_ratio": 2.0, "volatility": 20.0,
"benchmark_return": 7.0, "excess_return": 13.0,
},
]
service = self._make_service()
req = OptimizeRequest(
strategy_type="kjb",
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
param_grid={"stop_loss_pct": [0.03, 0.05]},
rank_by="sharpe_ratio",
)
result = service.optimize(req)
assert result.total_combinations == 2
assert result.results[0].sharpe_ratio == Decimal("2.0")
assert result.results[1].sharpe_ratio == Decimal("0.5")
assert result.best_params == {"stop_loss_pct": 0.05}
@patch.object(OptimizerService, "_run_single")
def test_optimize_handles_failures_gracefully(self, mock_run):
mock_run.side_effect = [
Exception("data error"),
{
"total_return": 10.0, "cagr": 8.0, "mdd": -5.0,
"sharpe_ratio": 1.0, "volatility": 15.0,
"benchmark_return": 7.0, "excess_return": 3.0,
},
]
service = self._make_service()
req = OptimizeRequest(
strategy_type="kjb",
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
param_grid={"stop_loss_pct": [0.03, 0.05]},
)
result = service.optimize(req)
assert result.total_combinations == 2
assert len(result.results) == 1
@patch.object(OptimizerService, "_run_single")
def test_optimize_all_fail_returns_empty(self, mock_run):
mock_run.side_effect = Exception("fail")
service = self._make_service()
req = OptimizeRequest(
strategy_type="kjb",
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
param_grid={"stop_loss_pct": [0.03]},
)
result = service.optimize(req)
assert result.total_combinations == 1
assert len(result.results) == 0
assert result.best_params == {}
@patch.object(OptimizerService, "_run_single")
def test_optimize_rank_by_cagr(self, mock_run):
mock_run.side_effect = [
{
"total_return": 30.0, "cagr": 25.0, "mdd": -15.0,
"sharpe_ratio": 0.8, "volatility": 25.0,
"benchmark_return": 7.0, "excess_return": 23.0,
},
{
"total_return": 15.0, "cagr": 12.0, "mdd": -5.0,
"sharpe_ratio": 1.5, "volatility": 10.0,
"benchmark_return": 7.0, "excess_return": 8.0,
},
]
service = self._make_service()
req = OptimizeRequest(
strategy_type="quality",
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
param_grid={"min_fscore": [6, 7]},
rank_by="cagr",
)
result = service.optimize(req)
assert result.results[0].cagr == Decimal("25.0")
assert result.results[1].cagr == Decimal("12.0")

View File

@ -0,0 +1,291 @@
"""
Tests for pension account models, service, and API endpoints.
"""
import pytest
from decimal import Decimal
from unittest.mock import patch
from app.services.pension_allocation import (
calculate_current_age,
calculate_years_to_retirement,
calculate_glide_path,
calculate_allocation,
get_recommendation,
RISKY_ASSET_LIMIT_PCT,
SAFE_ASSET_MIN_PCT,
)
# --- Service unit tests ---
class TestPensionAllocationService:
def test_calculate_current_age(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
assert calculate_current_age(1990) == 36
assert calculate_current_age(1966) == 60
def test_calculate_years_to_retirement(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
assert calculate_years_to_retirement(1990, 60) == 24
assert calculate_years_to_retirement(1966, 60) == 0
assert calculate_years_to_retirement(1960, 60) == 0 # already past
def test_glide_path_young_person(self):
"""Young person (30) should have high equity allocation."""
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
equity_pct, bond_pct = calculate_glide_path(1996, 60)
assert equity_pct + bond_pct == Decimal("100.00")
assert equity_pct >= Decimal("60") # young = high equity
def test_glide_path_near_retirement(self):
"""Person near retirement (55) should have low equity allocation."""
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
equity_pct, bond_pct = calculate_glide_path(1971, 60)
assert equity_pct + bond_pct == Decimal("100.00")
assert equity_pct <= Decimal("40") # near retirement = low equity
def test_glide_path_respects_regulatory_limit(self):
"""Equity allocation should never exceed 70% (regulatory limit)."""
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
equity_pct, _ = calculate_glide_path(2000, 60)
assert equity_pct <= RISKY_ASSET_LIMIT_PCT
def test_glide_path_minimum_equity(self):
"""Equity allocation should not go below 20% (minimum)."""
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
equity_pct, _ = calculate_glide_path(1960, 60)
assert equity_pct >= Decimal("20")
def test_calculate_allocation(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
result = calculate_allocation(
account_id=1,
account_type="dc",
total_amount=Decimal("10000000"),
birth_year=1990,
target_retirement_age=60,
)
assert result.account_id == 1
assert result.account_type == "dc"
assert result.total_amount == 10000000
assert result.risky_limit_pct == 70
assert result.safe_min_pct == 30
assert len(result.allocations) == 5 # 2 risky + 3 safe
# Verify risky assets don't exceed limit
risky_ratio = sum(a.ratio for a in result.allocations if a.asset_type == "risky")
assert risky_ratio <= 70
# Verify total amounts sum to total_amount
total_allocated = sum(a.amount for a in result.allocations)
assert abs(total_allocated - 10000000) < 1 # allow small rounding
def test_calculate_allocation_types(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
result = calculate_allocation(
account_id=1,
account_type="irp",
total_amount=Decimal("5000000"),
birth_year=1985,
target_retirement_age=60,
)
risky = [a for a in result.allocations if a.asset_type == "risky"]
safe = [a for a in result.allocations if a.asset_type == "safe"]
assert len(risky) == 2
assert len(safe) == 3
# Safe includes TDF, bond ETF, deposit
safe_names = [a.asset_name for a in safe]
assert any("TDF" in n for n in safe_names)
assert any("채권" in n for n in safe_names)
assert any("예금" in n for n in safe_names)
def test_get_recommendation(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
result = get_recommendation(
account_id=1,
birth_year=1990,
target_retirement_age=60,
)
assert result.account_id == 1
assert result.birth_year == 1990
assert result.current_age == 36
assert result.years_to_retirement == 24
assert len(result.recommendations) == 5
categories = [r.category for r in result.recommendations]
assert "tdf" in categories
assert "bond_etf" in categories
assert "deposit" in categories
assert "equity_etf" in categories
# All recommendations have reasons
for rec in result.recommendations:
assert rec.reason
# --- API endpoint tests ---
class TestPensionAPI:
def _create_account(self, client, auth_headers, **overrides):
payload = {
"account_type": "dc",
"account_name": "삼성생명 DC",
"total_amount": 10000000,
"birth_year": 1990,
"target_retirement_age": 60,
**overrides,
}
return client.post("/api/pension/accounts", headers=auth_headers, json=payload)
def test_create_account(self, client, auth_headers):
response = self._create_account(client, auth_headers)
assert response.status_code == 201
data = response.json()
assert data["account_type"] == "dc"
assert data["account_name"] == "삼성생명 DC"
assert data["total_amount"] == 10000000
assert data["birth_year"] == 1990
assert data["target_retirement_age"] == 60
assert data["holdings"] == []
def test_create_account_irp(self, client, auth_headers):
response = self._create_account(
client, auth_headers,
account_type="irp",
account_name="NH IRP",
)
assert response.status_code == 201
assert response.json()["account_type"] == "irp"
def test_list_accounts(self, client, auth_headers):
self._create_account(client, auth_headers, account_name="DC 1호")
self._create_account(client, auth_headers, account_name="IRP", account_type="irp")
response = client.get("/api/pension/accounts", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 2
def test_get_account(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
response = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers)
assert response.status_code == 200
assert response.json()["id"] == account_id
def test_get_account_not_found(self, client, auth_headers):
response = client.get("/api/pension/accounts/9999", headers=auth_headers)
assert response.status_code == 404
def test_update_account(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
response = client.put(
f"/api/pension/accounts/{account_id}",
headers=auth_headers,
json={"account_name": "변경된 계좌명", "total_amount": 20000000},
)
assert response.status_code == 200
data = response.json()
assert data["account_name"] == "변경된 계좌명"
assert data["total_amount"] == 20000000
def test_allocate_assets(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
response = client.post(
f"/api/pension/accounts/{account_id}/allocate",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["account_id"] == account_id
assert data["risky_limit_pct"] == 70
assert data["safe_min_pct"] == 30
assert len(data["allocations"]) == 5
# Verify risky ratio <= 70%
risky_ratio = sum(a["ratio"] for a in data["allocations"] if a["asset_type"] == "risky")
assert risky_ratio <= 70
# Verify holdings were saved
account_resp = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers)
assert len(account_resp.json()["holdings"]) == 5
def test_allocate_replaces_previous_holdings(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
# Allocate twice
client.post(f"/api/pension/accounts/{account_id}/allocate", headers=auth_headers)
client.post(f"/api/pension/accounts/{account_id}/allocate", headers=auth_headers)
# Should still have 5 holdings (replaced, not duplicated)
account_resp = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers)
assert len(account_resp.json()["holdings"]) == 5
def test_get_recommendation(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
response = client.get(
f"/api/pension/accounts/{account_id}/recommendation",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["account_id"] == account_id
assert data["birth_year"] == 1990
assert len(data["recommendations"]) == 5
# Verify recommendation categories
categories = [r["category"] for r in data["recommendations"]]
assert "tdf" in categories
assert "bond_etf" in categories
assert "deposit" in categories
assert "equity_etf" in categories
# All recommendations have reasons
for rec in data["recommendations"]:
assert rec["reason"]
assert rec["asset_name"]
def test_unauthenticated_access(self, client):
response = client.get("/api/pension/accounts")
assert response.status_code == 401
def test_user_isolation(self, client, auth_headers, db):
"""User can only see their own pension accounts."""
from app.models.user import User
from app.core.security import get_password_hash, create_access_token
other_user = User(
username="otheruser",
email="other@example.com",
hashed_password=get_password_hash("password"),
)
db.add(other_user)
db.commit()
db.refresh(other_user)
other_token = create_access_token(data={"sub": other_user.username})
other_headers = {"Authorization": f"Bearer {other_token}"}
self._create_account(client, other_headers, account_name="다른 사용자 계좌")
# Current user should see 0 accounts
response = client.get("/api/pension/accounts", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 0

View File

@ -0,0 +1,161 @@
"""
Tests for position sizing module.
"""
import pytest
from app.services.position_sizing import fixed_ratio, kelly_criterion, atr_based
class TestFixedRatio:
"""Tests for fixed_ratio position sizing (quant.md default)."""
def test_basic_calculation(self):
result = fixed_ratio(capital=10_000_000, num_positions=10, cash_ratio=0.3)
# 10M * 0.7 (invest portion) / 10 positions = 700,000
assert result["position_size"] == 700_000
assert result["method"] == "fixed"
assert result["risk_amount"] == pytest.approx(700_000 * 0.03) # -3% max loss per position
def test_default_cash_ratio(self):
result = fixed_ratio(capital=10_000_000, num_positions=10)
assert result["position_size"] == 700_000
def test_custom_cash_ratio(self):
result = fixed_ratio(capital=10_000_000, num_positions=5, cash_ratio=0.5)
# 10M * 0.5 / 5 = 1,000,000
assert result["position_size"] == 1_000_000
def test_single_position(self):
result = fixed_ratio(capital=1_000_000, num_positions=1, cash_ratio=0.0)
assert result["position_size"] == 1_000_000
def test_zero_capital(self):
with pytest.raises(ValueError, match="capital"):
fixed_ratio(capital=0, num_positions=10)
def test_negative_capital(self):
with pytest.raises(ValueError, match="capital"):
fixed_ratio(capital=-1_000_000, num_positions=10)
def test_zero_positions(self):
with pytest.raises(ValueError, match="num_positions"):
fixed_ratio(capital=10_000_000, num_positions=0)
def test_invalid_cash_ratio(self):
with pytest.raises(ValueError, match="cash_ratio"):
fixed_ratio(capital=10_000_000, num_positions=10, cash_ratio=1.5)
def test_cash_ratio_one(self):
"""cash_ratio=1.0 means 100% cash, 0 investable."""
with pytest.raises(ValueError, match="cash_ratio"):
fixed_ratio(capital=10_000_000, num_positions=10, cash_ratio=1.0)
def test_notes_included(self):
result = fixed_ratio(capital=10_000_000, num_positions=10)
assert "notes" in result
assert isinstance(result["notes"], str)
class TestKellyCriterion:
"""Tests for Kelly criterion position sizing."""
def test_basic_calculation(self):
# Kelly = W - (1-W)/R where W=win_rate, R=avg_win/avg_loss
# Kelly = 0.6 - (0.4)/(5/3) = 0.6 - 0.24 = 0.36
# Quarter Kelly = 0.36 * 0.25 = 0.09
result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03, fraction=0.25)
assert result["method"] == "kelly"
assert result["position_size"] == pytest.approx(0.09, abs=1e-6)
def test_default_quarter_kelly(self):
result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03)
assert result["position_size"] == pytest.approx(0.09, abs=1e-6)
def test_full_kelly(self):
result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03, fraction=1.0)
assert result["position_size"] == pytest.approx(0.36, abs=1e-6)
def test_negative_kelly_returns_zero(self):
"""Negative Kelly means don't bet - should clamp to 0."""
result = kelly_criterion(win_rate=0.3, avg_win=0.02, avg_loss=0.05)
assert result["position_size"] == 0.0
def test_win_rate_zero(self):
with pytest.raises(ValueError, match="win_rate"):
kelly_criterion(win_rate=0.0, avg_win=0.05, avg_loss=0.03)
def test_win_rate_above_one(self):
with pytest.raises(ValueError, match="win_rate"):
kelly_criterion(win_rate=1.1, avg_win=0.05, avg_loss=0.03)
def test_negative_avg_win(self):
with pytest.raises(ValueError, match="avg_win"):
kelly_criterion(win_rate=0.6, avg_win=-0.05, avg_loss=0.03)
def test_zero_avg_loss(self):
with pytest.raises(ValueError, match="avg_loss"):
kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.0)
def test_negative_win_rate(self):
with pytest.raises(ValueError, match="win_rate"):
kelly_criterion(win_rate=-0.1, avg_win=0.05, avg_loss=0.03)
def test_risk_amount_in_result(self):
result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03)
assert "risk_amount" in result
def test_notes_included(self):
result = kelly_criterion(win_rate=0.6, avg_win=0.05, avg_loss=0.03)
assert "notes" in result
class TestATRBased:
"""Tests for ATR-based volatility sizing."""
def test_basic_calculation(self):
# position_size = (capital * risk_pct) / atr
# = (10M * 0.02) / 1000 = 200
result = atr_based(capital=10_000_000, atr=1000, risk_pct=0.02)
assert result["method"] == "atr"
assert result["shares"] == 200
assert result["risk_amount"] == pytest.approx(10_000_000 * 0.02)
def test_default_risk_pct(self):
result = atr_based(capital=10_000_000, atr=1000)
assert result["shares"] == 200 # 2% default
def test_custom_risk_pct(self):
result = atr_based(capital=10_000_000, atr=500, risk_pct=0.01)
# (10M * 0.01) / 500 = 200
assert result["shares"] == 200
def test_shares_truncated_to_int(self):
# (10M * 0.02) / 3000 = 66.666... -> 66
result = atr_based(capital=10_000_000, atr=3000, risk_pct=0.02)
assert result["shares"] == 66
assert isinstance(result["shares"], int)
def test_zero_capital(self):
with pytest.raises(ValueError, match="capital"):
atr_based(capital=0, atr=1000)
def test_zero_atr(self):
with pytest.raises(ValueError, match="atr"):
atr_based(capital=10_000_000, atr=0)
def test_negative_atr(self):
with pytest.raises(ValueError, match="atr"):
atr_based(capital=10_000_000, atr=-100)
def test_risk_pct_too_high(self):
with pytest.raises(ValueError, match="risk_pct"):
atr_based(capital=10_000_000, atr=1000, risk_pct=1.5)
def test_position_size_in_result(self):
result = atr_based(capital=10_000_000, atr=1000, risk_pct=0.02)
assert "position_size" in result
assert result["position_size"] == result["risk_amount"]
def test_notes_included(self):
result = atr_based(capital=10_000_000, atr=1000)
assert "notes" in result

View File

@ -0,0 +1,245 @@
"""
Unit tests for tax simulation service.
"""
import pytest
from app.services.tax_simulation import (
calculate_tax_deduction,
calculate_pension_tax,
simulate_accumulation,
)
class TestCalculateTaxDeduction:
def test_low_income_deduction_rate(self):
"""총급여 5,500만원 이하 → 공제율 16.5%"""
result = calculate_tax_deduction(
annual_income=40_000_000,
contribution=9_000_000,
account_type="irp",
)
assert result["deduction_rate"] == 16.5
assert result["deductible_contribution"] == 9_000_000
assert result["tax_deduction"] == 9_000_000 * 0.165
def test_high_income_deduction_rate(self):
"""총급여 5,500만원 초과 → 공제율 13.2%"""
result = calculate_tax_deduction(
annual_income=80_000_000,
contribution=9_000_000,
account_type="irp",
)
assert result["deduction_rate"] == 13.2
assert result["tax_deduction"] == 9_000_000 * 0.132
def test_boundary_income_55m(self):
"""정확히 5,500만원은 16.5% 적용"""
result = calculate_tax_deduction(
annual_income=55_000_000,
contribution=5_000_000,
account_type="irp",
)
assert result["deduction_rate"] == 16.5
def test_contribution_exceeds_limit(self):
"""납입액이 900만원 한도 초과 시 900만원까지만 공제"""
result = calculate_tax_deduction(
annual_income=40_000_000,
contribution=12_000_000,
account_type="irp",
)
assert result["deductible_contribution"] == 9_000_000
assert result["tax_deduction"] == 9_000_000 * 0.165
def test_contribution_below_limit(self):
"""납입액이 한도 미만이면 실제 납입액 기준 공제"""
result = calculate_tax_deduction(
annual_income=40_000_000,
contribution=3_000_000,
account_type="irp",
)
assert result["deductible_contribution"] == 3_000_000
assert result["tax_deduction"] == 3_000_000 * 0.165
def test_dc_account_type(self):
"""DC 계좌도 동일 한도 적용 (DC+IRP 합산 900만원)"""
result = calculate_tax_deduction(
annual_income=60_000_000,
contribution=9_000_000,
account_type="dc",
)
assert result["deduction_rate"] == 13.2
assert result["deductible_contribution"] == 9_000_000
def test_zero_contribution(self):
result = calculate_tax_deduction(
annual_income=50_000_000,
contribution=0,
account_type="irp",
)
assert result["tax_deduction"] == 0
def test_result_structure(self):
result = calculate_tax_deduction(
annual_income=50_000_000,
contribution=5_000_000,
account_type="irp",
)
assert "annual_income" in result
assert "contribution" in result
assert "account_type" in result
assert "deduction_rate" in result
assert "deductible_contribution" in result
assert "tax_deduction" in result
assert "irp_limit" in result
class TestCalculatePensionTax:
def test_pension_tax_under_70(self):
"""70세 미만 연금소득세 5.5%"""
result = calculate_pension_tax(
withdrawal_amount=10_000_000,
withdrawal_type="pension",
age=65,
)
assert result["pension_tax_rate"] == 5.5
assert result["pension_tax"] == 10_000_000 * 0.055
def test_pension_tax_70_to_79(self):
"""70~79세 연금소득세 4.4%"""
result = calculate_pension_tax(
withdrawal_amount=10_000_000,
withdrawal_type="pension",
age=75,
)
assert result["pension_tax_rate"] == 4.4
assert result["pension_tax"] == pytest.approx(10_000_000 * 0.044)
def test_pension_tax_80_and_over(self):
"""80세 이상 연금소득세 3.3%"""
result = calculate_pension_tax(
withdrawal_amount=10_000_000,
withdrawal_type="pension",
age=85,
)
assert result["pension_tax_rate"] == 3.3
assert result["pension_tax"] == pytest.approx(10_000_000 * 0.033)
def test_pension_tax_boundary_70(self):
"""정확히 70세는 4.4% 적용"""
result = calculate_pension_tax(
withdrawal_amount=10_000_000,
withdrawal_type="pension",
age=70,
)
assert result["pension_tax_rate"] == 4.4
def test_pension_tax_boundary_80(self):
"""정확히 80세는 3.3% 적용"""
result = calculate_pension_tax(
withdrawal_amount=10_000_000,
withdrawal_type="pension",
age=80,
)
assert result["pension_tax_rate"] == 3.3
def test_lump_sum_tax(self):
"""일시금 수령 시 기타소득세 16.5%"""
result = calculate_pension_tax(
withdrawal_amount=10_000_000,
withdrawal_type="lump_sum",
age=65,
)
assert result["lump_sum_tax_rate"] == 16.5
assert result["lump_sum_tax"] == 10_000_000 * 0.165
def test_comparison_shows_savings(self):
"""연금 수령이 일시금보다 세금이 적음"""
result = calculate_pension_tax(
withdrawal_amount=10_000_000,
withdrawal_type="pension",
age=65,
)
assert result["tax_saving"] > 0
assert result["tax_saving"] == result["lump_sum_tax"] - result["pension_tax"]
def test_result_structure(self):
result = calculate_pension_tax(
withdrawal_amount=10_000_000,
withdrawal_type="pension",
age=65,
)
assert "withdrawal_amount" in result
assert "pension_tax_rate" in result
assert "pension_tax" in result
assert "lump_sum_tax_rate" in result
assert "lump_sum_tax" in result
assert "tax_saving" in result
class TestSimulateAccumulation:
def test_basic_accumulation(self):
"""기본 적립 시뮬레이션"""
result = simulate_accumulation(
monthly_contribution=500_000,
years=20,
annual_return=7.0,
tax_deduction_rate=16.5,
)
assert len(result["yearly_data"]) == 20
assert result["total_contribution"] == 500_000 * 12 * 20
assert result["final_value"] > result["total_contribution"]
def test_yearly_data_structure(self):
result = simulate_accumulation(
monthly_contribution=300_000,
years=5,
annual_return=5.0,
tax_deduction_rate=13.2,
)
first_year = result["yearly_data"][0]
assert "year" in first_year
assert "contribution" in first_year
assert "cumulative_contribution" in first_year
assert "investment_value" in first_year
assert "tax_deduction" in first_year
assert "cumulative_tax_deduction" in first_year
def test_tax_deduction_accumulates(self):
result = simulate_accumulation(
monthly_contribution=500_000,
years=3,
annual_return=5.0,
tax_deduction_rate=16.5,
)
yearly = result["yearly_data"]
annual_contribution = 500_000 * 12
deductible = min(annual_contribution, 9_000_000)
expected_deduction = deductible * 0.165
assert yearly[0]["tax_deduction"] == expected_deduction
assert yearly[2]["cumulative_tax_deduction"] == pytest.approx(
expected_deduction * 3, rel=1e-6
)
def test_zero_return(self):
result = simulate_accumulation(
monthly_contribution=100_000,
years=10,
annual_return=0.0,
tax_deduction_rate=16.5,
)
assert result["final_value"] == result["total_contribution"]
assert result["total_return"] == 0
def test_result_summary(self):
result = simulate_accumulation(
monthly_contribution=500_000,
years=20,
annual_return=7.0,
tax_deduction_rate=16.5,
)
assert "total_contribution" in result
assert "final_value" in result
assert "total_return" in result
assert "total_tax_deduction" in result
assert "yearly_data" in result

View File

@ -13,6 +13,7 @@
"@radix-ui/react-label": "^2.1.8",
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-switch": "^1.2.6",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-tooltip": "^1.2.8",
"class-variance-authority": "^0.7.1",
@ -1923,6 +1924,35 @@
}
}
},
"node_modules/@radix-ui/react-switch": {
"version": "1.2.6",
"resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.2.6.tgz",
"integrity": "sha512-bByzr1+ep1zk4VubeEVViV592vu2lHE2BZY5OnzehZqOOgogN80+mNtCqPkhn2gklJqOpxWgPoYTSnhBCqpOXQ==",
"license": "MIT",
"dependencies": {
"@radix-ui/primitive": "1.1.3",
"@radix-ui/react-compose-refs": "1.1.2",
"@radix-ui/react-context": "1.1.2",
"@radix-ui/react-primitive": "2.1.3",
"@radix-ui/react-use-controllable-state": "1.2.2",
"@radix-ui/react-use-previous": "1.1.1",
"@radix-ui/react-use-size": "1.1.1"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-tabs": {
"version": "1.1.13",
"resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.13.tgz",

View File

@ -14,6 +14,7 @@
"@radix-ui/react-label": "^2.1.8",
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-switch": "^1.2.6",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-tooltip": "^1.2.8",
"class-variance-authority": "^0.7.1",

View File

@ -0,0 +1,354 @@
'use client';
import { useEffect, useState } from 'react';
import { useRouter, useParams } from 'next/navigation';
import Link from 'next/link';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Badge } from '@/components/ui/badge';
import { Skeleton } from '@/components/ui/skeleton';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import { ArrowLeft, Save, Edit2 } from 'lucide-react';
interface TradeJournal {
id: number;
user_id: number;
stock_code: string;
stock_name: string | null;
trade_type: string;
entry_price: number | null;
target_price: number | null;
stop_loss_price: number | null;
exit_price: number | null;
entry_date: string;
exit_date: string | null;
quantity: number | null;
profit_loss: number | null;
profit_loss_pct: number | null;
entry_reason: string | null;
exit_reason: string | null;
scenario: string | null;
lessons_learned: string | null;
emotional_state: string | null;
strategy_id: number | null;
status: string;
created_at: string;
updated_at: string;
}
const formatPrice = (value: number | null | undefined) => {
if (value === null || value === undefined) return '-';
return new Intl.NumberFormat('ko-KR').format(value);
};
const formatPct = (value: number | null | undefined) => {
if (value === null || value === undefined) return '-';
return `${value >= 0 ? '+' : ''}${value.toFixed(2)}%`;
};
export default function JournalDetailPage() {
const router = useRouter();
const params = useParams();
const journalId = params.id as string;
const [loading, setLoading] = useState(true);
const [journal, setJournal] = useState<TradeJournal | null>(null);
const [editing, setEditing] = useState(false);
const [submitting, setSubmitting] = useState(false);
// Edit form
const [editForm, setEditForm] = useState({
exit_price: '',
exit_date: '',
exit_reason: '',
lessons_learned: '',
emotional_state: '',
scenario: '',
target_price: '',
stop_loss_price: '',
});
useEffect(() => {
const init = async () => {
try {
await api.getCurrentUser();
await fetchJournal();
} catch {
router.push('/login');
} finally {
setLoading(false);
}
};
init();
}, [router, journalId]);
const fetchJournal = async () => {
try {
const data = await api.get<TradeJournal>(`/api/journal/${journalId}`);
setJournal(data);
setEditForm({
exit_price: data.exit_price?.toString() || '',
exit_date: data.exit_date || '',
exit_reason: data.exit_reason || '',
lessons_learned: data.lessons_learned || '',
emotional_state: data.emotional_state || '',
scenario: data.scenario || '',
target_price: data.target_price?.toString() || '',
stop_loss_price: data.stop_loss_price?.toString() || '',
});
} catch {
toast.error('저널을 불러오는데 실패했습니다.');
router.push('/journal');
}
};
const handleSave = async () => {
setSubmitting(true);
try {
const payload: Record<string, unknown> = {};
if (editForm.exit_price) payload.exit_price = parseFloat(editForm.exit_price);
if (editForm.exit_date) payload.exit_date = editForm.exit_date;
if (editForm.exit_reason) payload.exit_reason = editForm.exit_reason;
if (editForm.lessons_learned) payload.lessons_learned = editForm.lessons_learned;
if (editForm.emotional_state) payload.emotional_state = editForm.emotional_state;
if (editForm.scenario) payload.scenario = editForm.scenario;
if (editForm.target_price) payload.target_price = parseFloat(editForm.target_price);
if (editForm.stop_loss_price) payload.stop_loss_price = parseFloat(editForm.stop_loss_price);
const updated = await api.put<TradeJournal>(`/api/journal/${journalId}`, payload);
setJournal(updated);
setEditing(false);
toast.success('저장되었습니다.');
} catch (err) {
toast.error(err instanceof Error ? err.message : '저장에 실패했습니다.');
} finally {
setSubmitting(false);
}
};
if (loading) {
return (
<DashboardLayout>
<div className="space-y-6">
<Skeleton className="h-8 w-48" />
<Skeleton className="h-[400px]" />
</div>
</DashboardLayout>
);
}
if (!journal) return null;
return (
<DashboardLayout>
<div className="mb-6">
<Link href="/journal" className="inline-flex items-center text-sm text-muted-foreground hover:text-foreground mb-4">
<ArrowLeft className="mr-1 h-4 w-4" />
</Link>
<div className="flex items-center justify-between">
<div>
<h1 className="text-2xl font-bold text-foreground">
{journal.stock_name || journal.stock_code}
<span className="ml-2 text-lg font-normal text-muted-foreground font-mono">{journal.stock_code}</span>
</h1>
<div className="mt-1 flex items-center gap-2">
<Badge className={journal.trade_type === 'buy'
? 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200'
: 'bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200'
}>
{journal.trade_type === 'buy' ? '매수' : '매도'}
</Badge>
<Badge className={journal.status === 'open'
? 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200'
: 'bg-gray-100 text-gray-800 dark:bg-gray-900 dark:text-gray-200'
}>
{journal.status === 'open' ? '진행중' : '완료'}
</Badge>
</div>
</div>
{!editing && (
<Button variant="outline" onClick={() => setEditing(true)}>
<Edit2 className="mr-2 h-4 w-4" />
</Button>
)}
</div>
</div>
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
{/* Trade Info */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<dl className="space-y-3">
<div className="flex justify-between text-sm">
<dt className="text-muted-foreground"></dt>
<dd className="font-medium">{journal.entry_date}</dd>
</div>
<div className="flex justify-between text-sm">
<dt className="text-muted-foreground"></dt>
<dd className="font-mono">{journal.quantity ? `${journal.quantity.toLocaleString()}` : '-'}</dd>
</div>
<div className="flex justify-between text-sm">
<dt className="text-muted-foreground"></dt>
<dd className="font-mono">{formatPrice(journal.entry_price)}</dd>
</div>
<div className="flex justify-between text-sm">
<dt className="text-muted-foreground"></dt>
<dd className="font-mono text-green-600">{formatPrice(journal.target_price)}</dd>
</div>
<div className="flex justify-between text-sm">
<dt className="text-muted-foreground"></dt>
<dd className="font-mono text-red-600">{formatPrice(journal.stop_loss_price)}</dd>
</div>
{journal.exit_price && (
<>
<div className="border-t pt-3 flex justify-between text-sm">
<dt className="text-muted-foreground"></dt>
<dd className="font-medium">{journal.exit_date || '-'}</dd>
</div>
<div className="flex justify-between text-sm">
<dt className="text-muted-foreground"></dt>
<dd className="font-mono">{formatPrice(journal.exit_price)}</dd>
</div>
<div className="flex justify-between text-sm">
<dt className="text-muted-foreground"></dt>
<dd className={`font-mono font-bold ${journal.profit_loss !== null && journal.profit_loss >= 0 ? 'text-green-600' : 'text-red-600'}`}>
{formatPrice(journal.profit_loss)} ({formatPct(journal.profit_loss_pct)})
</dd>
</div>
</>
)}
</dl>
</CardContent>
</Card>
{/* Analysis */}
<Card>
<CardHeader>
<CardTitle></CardTitle>
</CardHeader>
<CardContent className="space-y-4">
<div>
<h4 className="text-sm font-medium text-muted-foreground mb-1"> </h4>
<p className="text-sm whitespace-pre-wrap">{journal.entry_reason || '(미작성)'}</p>
</div>
<div>
<h4 className="text-sm font-medium text-muted-foreground mb-1"></h4>
<p className="text-sm whitespace-pre-wrap">{journal.scenario || '(미작성)'}</p>
</div>
<div>
<h4 className="text-sm font-medium text-muted-foreground mb-1"> </h4>
<p className="text-sm whitespace-pre-wrap">{journal.emotional_state || '(미작성)'}</p>
</div>
{journal.exit_reason && (
<div>
<h4 className="text-sm font-medium text-muted-foreground mb-1"> </h4>
<p className="text-sm whitespace-pre-wrap">{journal.exit_reason}</p>
</div>
)}
{journal.lessons_learned && (
<div>
<h4 className="text-sm font-medium text-muted-foreground mb-1"></h4>
<p className="text-sm whitespace-pre-wrap">{journal.lessons_learned}</p>
</div>
)}
</CardContent>
</Card>
</div>
{/* Edit Form */}
{editing && (
<Card className="mt-6">
<CardHeader>
<CardTitle> / </CardTitle>
</CardHeader>
<CardContent className="space-y-4">
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<div className="space-y-2">
<Label htmlFor="edit_exit_price"></Label>
<Input
id="edit_exit_price"
type="number"
min="0"
step="any"
value={editForm.exit_price}
onChange={(e) => setEditForm((p) => ({ ...p, exit_price: e.target.value }))}
/>
</div>
<div className="space-y-2">
<Label htmlFor="edit_exit_date"></Label>
<Input
id="edit_exit_date"
type="date"
value={editForm.exit_date}
onChange={(e) => setEditForm((p) => ({ ...p, exit_date: e.target.value }))}
/>
</div>
<div className="space-y-2">
<Label htmlFor="edit_target_price"> </Label>
<Input
id="edit_target_price"
type="number"
min="0"
step="any"
value={editForm.target_price}
onChange={(e) => setEditForm((p) => ({ ...p, target_price: e.target.value }))}
/>
</div>
</div>
<div className="space-y-2">
<Label htmlFor="edit_exit_reason"> </Label>
<textarea
id="edit_exit_reason"
className="w-full min-h-[80px] rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring"
placeholder="왜 청산했는가?"
value={editForm.exit_reason}
onChange={(e) => setEditForm((p) => ({ ...p, exit_reason: e.target.value }))}
/>
</div>
<div className="space-y-2">
<Label htmlFor="edit_lessons"> ( )</Label>
<textarea
id="edit_lessons"
className="w-full min-h-[80px] rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring"
placeholder="이 거래에서 배운 점은?"
value={editForm.lessons_learned}
onChange={(e) => setEditForm((p) => ({ ...p, lessons_learned: e.target.value }))}
/>
</div>
<div className="space-y-2">
<Label htmlFor="edit_scenario"> </Label>
<textarea
id="edit_scenario"
className="w-full min-h-[80px] rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring"
value={editForm.scenario}
onChange={(e) => setEditForm((p) => ({ ...p, scenario: e.target.value }))}
/>
</div>
<div className="flex justify-end gap-3">
<Button variant="outline" onClick={() => setEditing(false)} disabled={submitting}>
</Button>
<Button onClick={handleSave} disabled={submitting}>
<Save className="mr-2 h-4 w-4" />
{submitting ? '저장 중...' : '저장'}
</Button>
</div>
</CardContent>
</Card>
)}
</DashboardLayout>
);
}

View File

@ -0,0 +1,255 @@
'use client';
import { useState } from 'react';
import { useRouter } from 'next/navigation';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import { ArrowLeft, Save } from 'lucide-react';
import Link from 'next/link';
export default function NewJournalPage() {
const router = useRouter();
const [submitting, setSubmitting] = useState(false);
const [form, setForm] = useState({
stock_code: '',
stock_name: '',
trade_type: 'buy',
entry_price: '',
target_price: '',
stop_loss_price: '',
entry_date: new Date().toISOString().split('T')[0],
quantity: '',
entry_reason: '',
scenario: '',
emotional_state: '',
});
const handleChange = (field: string, value: string) => {
setForm((prev) => ({ ...prev, [field]: value }));
};
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
if (!form.stock_code || !form.trade_type || !form.entry_date) {
toast.error('종목코드, 거래유형, 진입일은 필수입니다.');
return;
}
setSubmitting(true);
try {
const payload: Record<string, unknown> = {
stock_code: form.stock_code,
trade_type: form.trade_type,
entry_date: form.entry_date,
};
if (form.stock_name) payload.stock_name = form.stock_name;
if (form.entry_price) payload.entry_price = parseFloat(form.entry_price);
if (form.target_price) payload.target_price = parseFloat(form.target_price);
if (form.stop_loss_price) payload.stop_loss_price = parseFloat(form.stop_loss_price);
if (form.quantity) payload.quantity = parseInt(form.quantity);
if (form.entry_reason) payload.entry_reason = form.entry_reason;
if (form.scenario) payload.scenario = form.scenario;
if (form.emotional_state) payload.emotional_state = form.emotional_state;
const result = await api.post<{ id: number }>('/api/journal', payload);
toast.success('거래가 기록되었습니다.');
router.push(`/journal/${result.id}`);
} catch (err) {
toast.error(err instanceof Error ? err.message : '저장에 실패했습니다.');
} finally {
setSubmitting(false);
}
};
return (
<DashboardLayout>
<div className="mb-6">
<Link href="/journal" className="inline-flex items-center text-sm text-muted-foreground hover:text-foreground mb-4">
<ArrowLeft className="mr-1 h-4 w-4" />
</Link>
<h1 className="text-2xl font-bold text-foreground"> </h1>
<p className="mt-1 text-muted-foreground">
</p>
</div>
<form onSubmit={handleSubmit}>
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
{/* Trade Info */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent className="space-y-4">
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="stock_code"> *</Label>
<Input
id="stock_code"
placeholder="예: 005930"
value={form.stock_code}
onChange={(e) => handleChange('stock_code', e.target.value)}
required
/>
</div>
<div className="space-y-2">
<Label htmlFor="stock_name"></Label>
<Input
id="stock_name"
placeholder="예: 삼성전자"
value={form.stock_name}
onChange={(e) => handleChange('stock_name', e.target.value)}
/>
</div>
</div>
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<Label> *</Label>
<Select value={form.trade_type} onValueChange={(v) => handleChange('trade_type', v)}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="buy"></SelectItem>
<SelectItem value="sell"></SelectItem>
</SelectContent>
</Select>
</div>
<div className="space-y-2">
<Label htmlFor="entry_date"> *</Label>
<Input
id="entry_date"
type="date"
value={form.entry_date}
onChange={(e) => handleChange('entry_date', e.target.value)}
required
/>
</div>
</div>
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="entry_price"></Label>
<Input
id="entry_price"
type="number"
min="0"
step="any"
placeholder="0"
value={form.entry_price}
onChange={(e) => handleChange('entry_price', e.target.value)}
/>
</div>
<div className="space-y-2">
<Label htmlFor="quantity"> ()</Label>
<Input
id="quantity"
type="number"
min="1"
placeholder="0"
value={form.quantity}
onChange={(e) => handleChange('quantity', e.target.value)}
/>
</div>
</div>
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="target_price"></Label>
<Input
id="target_price"
type="number"
min="0"
step="any"
placeholder="0"
value={form.target_price}
onChange={(e) => handleChange('target_price', e.target.value)}
/>
</div>
<div className="space-y-2">
<Label htmlFor="stop_loss_price"></Label>
<Input
id="stop_loss_price"
type="number"
min="0"
step="any"
placeholder="0"
value={form.stop_loss_price}
onChange={(e) => handleChange('stop_loss_price', e.target.value)}
/>
</div>
</div>
</CardContent>
</Card>
{/* Analysis */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent className="space-y-4">
<div className="space-y-2">
<Label htmlFor="entry_reason"> </Label>
<textarea
id="entry_reason"
className="w-full min-h-[100px] rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring"
placeholder="왜 이 종목에 진입하는가? (기술적/기본적 분석 근거)"
value={form.entry_reason}
onChange={(e) => handleChange('entry_reason', e.target.value)}
/>
</div>
<div className="space-y-2">
<Label htmlFor="scenario"></Label>
<textarea
id="scenario"
className="w-full min-h-[100px] rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring"
placeholder="목표가 도달 시 → ?&#10;손절가 이탈 시 → ?&#10;횡보 시 → ?"
value={form.scenario}
onChange={(e) => handleChange('scenario', e.target.value)}
/>
</div>
<div className="space-y-2">
<Label htmlFor="emotional_state"> </Label>
<textarea
id="emotional_state"
className="w-full min-h-[60px] rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring"
placeholder="현재 느끼는 감정 (자신감, 불안, 확신 등)"
value={form.emotional_state}
onChange={(e) => handleChange('emotional_state', e.target.value)}
/>
</div>
</CardContent>
</Card>
</div>
<div className="mt-6 flex justify-end gap-3">
<Link href="/journal">
<Button type="button" variant="outline"></Button>
</Link>
<Button type="submit" disabled={submitting}>
<Save className="mr-2 h-4 w-4" />
{submitting ? '저장 중...' : '저장'}
</Button>
</div>
</form>
</DashboardLayout>
);
}

View File

@ -0,0 +1,358 @@
'use client';
import { useEffect, useState } from 'react';
import { useRouter } from 'next/navigation';
import Link from 'next/link';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Badge } from '@/components/ui/badge';
import { Skeleton } from '@/components/ui/skeleton';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import {
Plus,
TrendingUp,
TrendingDown,
Target,
BarChart3,
BookOpen,
} from 'lucide-react';
interface TradeJournal {
id: number;
user_id: number;
stock_code: string;
stock_name: string | null;
trade_type: string;
entry_price: number | null;
target_price: number | null;
stop_loss_price: number | null;
exit_price: number | null;
entry_date: string;
exit_date: string | null;
quantity: number | null;
profit_loss: number | null;
profit_loss_pct: number | null;
entry_reason: string | null;
exit_reason: string | null;
scenario: string | null;
lessons_learned: string | null;
emotional_state: string | null;
strategy_id: number | null;
status: string;
created_at: string;
updated_at: string;
}
interface JournalStats {
total_trades: number;
open_trades: number;
closed_trades: number;
win_count: number;
loss_count: number;
win_rate: number | null;
avg_profit_loss_pct: number | null;
max_profit_pct: number | null;
max_loss_pct: number | null;
total_profit_loss: number | null;
}
const formatPrice = (value: number | null | undefined) => {
if (value === null || value === undefined) return '-';
return new Intl.NumberFormat('ko-KR').format(value);
};
const formatPct = (value: number | null | undefined) => {
if (value === null || value === undefined) return '-';
return `${value >= 0 ? '+' : ''}${value.toFixed(2)}%`;
};
export default function JournalPage() {
const router = useRouter();
const [loading, setLoading] = useState(true);
const [journals, setJournals] = useState<TradeJournal[]>([]);
const [stats, setStats] = useState<JournalStats | null>(null);
// Filter state
const [filterStatus, setFilterStatus] = useState('all');
const [filterStockCode, setFilterStockCode] = useState('');
const [filterStartDate, setFilterStartDate] = useState('');
const [filterEndDate, setFilterEndDate] = useState('');
useEffect(() => {
const init = async () => {
try {
await api.getCurrentUser();
await Promise.all([fetchJournals(), fetchStats()]);
} catch {
router.push('/login');
} finally {
setLoading(false);
}
};
init();
}, [router]);
const fetchJournals = async () => {
try {
const params = new URLSearchParams();
if (filterStatus !== 'all') params.set('status', filterStatus);
if (filterStockCode) params.set('stock_code', filterStockCode);
if (filterStartDate) params.set('start_date', filterStartDate);
if (filterEndDate) params.set('end_date', filterEndDate);
const query = params.toString();
const data = await api.get<TradeJournal[]>(`/api/journal${query ? `?${query}` : ''}`);
setJournals(data);
} catch {
toast.error('저널 목록을 불러오는데 실패했습니다.');
}
};
const fetchStats = async () => {
try {
const data = await api.get<JournalStats>('/api/journal/stats');
setStats(data);
} catch {
// stats failure is non-critical
}
};
const handleFilter = async (e: React.FormEvent) => {
e.preventDefault();
await fetchJournals();
};
if (loading) {
return (
<DashboardLayout>
<div className="space-y-6">
<Skeleton className="h-8 w-48" />
<Skeleton className="h-[400px]" />
</div>
</DashboardLayout>
);
}
return (
<DashboardLayout>
<div className="mb-6 flex items-center justify-between">
<div>
<h1 className="text-2xl font-bold text-foreground"> </h1>
<p className="mt-1 text-muted-foreground">
/
</p>
</div>
<Link href="/journal/new">
<Button>
<Plus className="mr-2 h-4 w-4" />
</Button>
</Link>
</div>
{/* Stats Cards */}
{stats && (
<div className="grid grid-cols-2 md:grid-cols-5 gap-4 mb-6">
<Card>
<CardContent className="p-4">
<div className="flex items-center gap-2 text-muted-foreground mb-1">
<BookOpen className="h-4 w-4" />
<span className="text-xs font-medium"> </span>
</div>
<p className="text-2xl font-bold">{stats.total_trades}</p>
<p className="text-xs text-muted-foreground">
{stats.open_trades} / {stats.closed_trades}
</p>
</CardContent>
</Card>
<Card>
<CardContent className="p-4">
<div className="flex items-center gap-2 text-muted-foreground mb-1">
<Target className="h-4 w-4" />
<span className="text-xs font-medium"></span>
</div>
<p className="text-2xl font-bold">
{stats.win_rate !== null ? `${stats.win_rate.toFixed(1)}%` : '-'}
</p>
<p className="text-xs text-muted-foreground">
{stats.win_count} / {stats.loss_count}
</p>
</CardContent>
</Card>
<Card>
<CardContent className="p-4">
<div className="flex items-center gap-2 text-muted-foreground mb-1">
<BarChart3 className="h-4 w-4" />
<span className="text-xs font-medium"> </span>
</div>
<p className={`text-2xl font-bold ${stats.avg_profit_loss_pct !== null && stats.avg_profit_loss_pct >= 0 ? 'text-green-600' : 'text-red-600'}`}>
{formatPct(stats.avg_profit_loss_pct)}
</p>
</CardContent>
</Card>
<Card>
<CardContent className="p-4">
<div className="flex items-center gap-2 text-muted-foreground mb-1">
<TrendingUp className="h-4 w-4 text-green-600" />
<span className="text-xs font-medium"> </span>
</div>
<p className="text-2xl font-bold text-green-600">
{formatPct(stats.max_profit_pct)}
</p>
</CardContent>
</Card>
<Card>
<CardContent className="p-4">
<div className="flex items-center gap-2 text-muted-foreground mb-1">
<TrendingDown className="h-4 w-4 text-red-600" />
<span className="text-xs font-medium"> </span>
</div>
<p className="text-2xl font-bold text-red-600">
{formatPct(stats.max_loss_pct)}
</p>
</CardContent>
</Card>
</div>
)}
{/* Filters */}
<Card className="mb-6">
<CardContent className="p-4">
<form onSubmit={handleFilter} className="flex flex-wrap items-end gap-4">
<div className="space-y-2">
<Label></Label>
<Select value={filterStatus} onValueChange={setFilterStatus}>
<SelectTrigger className="w-28">
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="all"></SelectItem>
<SelectItem value="open"></SelectItem>
<SelectItem value="closed"></SelectItem>
</SelectContent>
</Select>
</div>
<div className="space-y-2">
<Label></Label>
<Input
type="text"
placeholder="예: 005930"
value={filterStockCode}
onChange={(e) => setFilterStockCode(e.target.value)}
className="w-36"
/>
</div>
<div className="space-y-2">
<Label></Label>
<Input
type="date"
value={filterStartDate}
onChange={(e) => setFilterStartDate(e.target.value)}
className="w-40"
/>
</div>
<div className="space-y-2">
<Label></Label>
<Input
type="date"
value={filterEndDate}
onChange={(e) => setFilterEndDate(e.target.value)}
className="w-40"
/>
</div>
<Button type="submit" variant="outline">
</Button>
</form>
</CardContent>
</Card>
{/* Journal List */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent className="p-0">
<div className="overflow-x-auto">
<table className="w-full">
<thead className="bg-muted">
<tr>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground"></th>
</tr>
</thead>
<tbody className="divide-y divide-border">
{journals.map((j) => (
<tr
key={j.id}
className="hover:bg-muted/50 cursor-pointer"
onClick={() => router.push(`/journal/${j.id}`)}
>
<td className="px-4 py-3 text-sm">{j.entry_date}</td>
<td className="px-4 py-3 text-sm">
<span className="font-mono">{j.stock_code}</span>
{j.stock_name && (
<span className="ml-1 text-muted-foreground">{j.stock_name}</span>
)}
</td>
<td className="px-4 py-3 text-center">
<Badge className={j.trade_type === 'buy'
? 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200'
: 'bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200'
}>
{j.trade_type === 'buy' ? '매수' : '매도'}
</Badge>
</td>
<td className="px-4 py-3 text-sm text-right font-mono">{formatPrice(j.entry_price)}</td>
<td className="px-4 py-3 text-sm text-right font-mono text-green-600">{formatPrice(j.target_price)}</td>
<td className="px-4 py-3 text-sm text-right font-mono text-red-600">{formatPrice(j.stop_loss_price)}</td>
<td className="px-4 py-3 text-sm text-right font-mono">{formatPrice(j.exit_price)}</td>
<td className="px-4 py-3 text-sm text-right font-mono">
{j.profit_loss_pct !== null ? (
<span className={j.profit_loss_pct >= 0 ? 'text-green-600' : 'text-red-600'}>
{formatPct(j.profit_loss_pct)}
</span>
) : '-'}
</td>
<td className="px-4 py-3 text-center">
<Badge className={j.status === 'open'
? 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200'
: 'bg-gray-100 text-gray-800 dark:bg-gray-900 dark:text-gray-200'
}>
{j.status === 'open' ? '진행중' : '완료'}
</Badge>
</td>
</tr>
))}
{journals.length === 0 && (
<tr>
<td colSpan={9} className="px-4 py-8 text-center text-muted-foreground">
. .
</td>
</tr>
)}
</tbody>
</table>
</div>
</CardContent>
</Card>
</DashboardLayout>
);
}

View File

@ -0,0 +1,368 @@
'use client';
import { useEffect, useState } from 'react';
import { useRouter, useParams } from 'next/navigation';
import Link from 'next/link';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Badge } from '@/components/ui/badge';
import { Skeleton } from '@/components/ui/skeleton';
import { Button } from '@/components/ui/button';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import { ArrowLeft, RefreshCw, Lightbulb } from 'lucide-react';
import { PieChart, Pie, Cell, ResponsiveContainer, Legend, Tooltip } from 'recharts';
interface PensionHolding {
id: number;
account_id: number;
asset_name: string;
asset_type: string;
amount: number;
ratio: number;
}
interface PensionAccount {
id: number;
user_id: number;
account_type: string;
account_name: string;
total_amount: number;
birth_year: number;
target_retirement_age: number;
created_at: string;
updated_at: string;
holdings: PensionHolding[];
}
interface AllocationItem {
asset_name: string;
asset_type: string;
amount: number;
ratio: number;
}
interface AllocationResult {
account_id: number;
account_type: string;
total_amount: number;
risky_limit_pct: number;
safe_min_pct: number;
glide_path_equity_pct: number;
glide_path_bond_pct: number;
current_age: number;
years_to_retirement: number;
allocations: AllocationItem[];
}
interface RecommendationItem {
asset_name: string;
asset_type: string;
category: string;
ratio: number;
reason: string;
}
interface RecommendationResult {
account_id: number;
birth_year: number;
current_age: number;
target_retirement_age: number;
years_to_retirement: number;
glide_path_equity_pct: number;
glide_path_bond_pct: number;
recommendations: RecommendationItem[];
}
const formatAmount = (value: number) => {
return new Intl.NumberFormat('ko-KR').format(value);
};
const accountTypeLabel: Record<string, string> = {
dc: 'DC형',
irp: 'IRP',
personal: '개인연금',
};
const COLORS_RISKY = ['#ef4444', '#f97316'];
const COLORS_SAFE = ['#22c55e', '#3b82f6', '#a855f7'];
export default function PensionDetailPage() {
const router = useRouter();
const params = useParams();
const accountId = params.id as string;
const [loading, setLoading] = useState(true);
const [allocating, setAllocating] = useState(false);
const [account, setAccount] = useState<PensionAccount | null>(null);
const [allocation, setAllocation] = useState<AllocationResult | null>(null);
const [recommendation, setRecommendation] = useState<RecommendationResult | null>(null);
const fetchAccount = async () => {
try {
const data = await api.get<PensionAccount>(`/api/pension/accounts/${accountId}`);
setAccount(data);
} catch {
toast.error('계좌 정보를 불러오는데 실패했습니다.');
router.push('/pension');
}
};
const fetchRecommendation = async () => {
try {
const data = await api.get<RecommendationResult>(
`/api/pension/accounts/${accountId}/recommendation`
);
setRecommendation(data);
} catch {
// non-critical
}
};
useEffect(() => {
const init = async () => {
try {
await api.getCurrentUser();
await Promise.all([fetchAccount(), fetchRecommendation()]);
} catch {
router.push('/login');
} finally {
setLoading(false);
}
};
init();
}, [accountId, router]);
const handleAllocate = async () => {
setAllocating(true);
try {
const result = await api.post<AllocationResult>(
`/api/pension/accounts/${accountId}/allocate`
);
setAllocation(result);
await fetchAccount(); // refresh holdings
toast.success('자산 배분이 완료되었습니다.');
} catch {
toast.error('자산 배분에 실패했습니다.');
} finally {
setAllocating(false);
}
};
if (loading) {
return (
<DashboardLayout>
<div className="space-y-6">
<Skeleton className="h-8 w-48" />
<Skeleton className="h-[400px]" />
</div>
</DashboardLayout>
);
}
if (!account) return null;
const pieData = account.holdings.map((h, i) => ({
name: h.asset_name,
value: h.amount,
ratio: h.ratio,
type: h.asset_type,
}));
const getColor = (index: number, type: string) => {
if (type === 'risky') return COLORS_RISKY[index % COLORS_RISKY.length];
return COLORS_SAFE[index % COLORS_SAFE.length];
};
let riskyIdx = 0;
let safeIdx = 0;
const colors = pieData.map((d) => {
if (d.type === 'risky') return getColor(riskyIdx++, 'risky');
return getColor(safeIdx++, 'safe');
});
const displayAlloc = allocation || (account.holdings.length > 0 ? null : null);
return (
<DashboardLayout>
<div className="mb-6">
<Link href="/pension" className="inline-flex items-center text-sm text-muted-foreground hover:text-foreground mb-4">
<ArrowLeft className="mr-1 h-4 w-4" />
</Link>
<div className="flex items-center justify-between">
<div>
<h1 className="text-2xl font-bold text-foreground">{account.account_name}</h1>
<div className="flex items-center gap-2 mt-1">
<Badge>{accountTypeLabel[account.account_type] || account.account_type}</Badge>
<span className="text-sm text-muted-foreground">
{account.birth_year} / {account.target_retirement_age}
</span>
</div>
</div>
<Button onClick={handleAllocate} disabled={allocating}>
<RefreshCw className={`mr-2 h-4 w-4 ${allocating ? 'animate-spin' : ''}`} />
{allocating ? '배분 중...' : '자산 배분 실행'}
</Button>
</div>
</div>
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
{/* Pie Chart */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
{pieData.length > 0 ? (
<div className="h-80">
<ResponsiveContainer width="100%" height="100%">
<PieChart>
<Pie
data={pieData}
dataKey="value"
nameKey="name"
cx="50%"
cy="50%"
outerRadius={100}
label={({ name, payload }) => `${name} (${(payload?.ratio ?? 0).toFixed(1)}%)`}
labelLine
>
{pieData.map((_, index) => (
<Cell key={`cell-${index}`} fill={colors[index]} />
))}
</Pie>
<Tooltip
formatter={(value) => [`${formatAmount(value as number)}`, '금액']}
/>
<Legend />
</PieChart>
</ResponsiveContainer>
</div>
) : (
<div className="h-80 flex items-center justify-center text-muted-foreground">
</div>
)}
<div className="mt-4 text-center">
<p className="text-2xl font-bold">{formatAmount(account.total_amount)}</p>
<p className="text-sm text-muted-foreground"> </p>
</div>
</CardContent>
</Card>
{/* Holdings Table */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent className="p-0">
{account.holdings.length > 0 ? (
<table className="w-full">
<thead className="bg-muted">
<tr>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
</tr>
</thead>
<tbody className="divide-y divide-border">
{account.holdings.map((h) => (
<tr key={h.id}>
<td className="px-4 py-3 text-sm">{h.asset_name}</td>
<td className="px-4 py-3 text-center">
<Badge className={h.asset_type === 'risky'
? 'bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200'
: 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200'
}>
{h.asset_type === 'risky' ? '위험' : '안전'}
</Badge>
</td>
<td className="px-4 py-3 text-sm text-right font-mono">{formatAmount(h.amount)}</td>
<td className="px-4 py-3 text-sm text-right font-mono">{h.ratio.toFixed(1)}%</td>
</tr>
))}
</tbody>
</table>
) : (
<div className="p-8 text-center text-muted-foreground">
. .
</div>
)}
</CardContent>
</Card>
{/* Allocation Info */}
{(allocation || account.holdings.length > 0) && (
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent className="space-y-3">
<div className="flex justify-between">
<span className="text-muted-foreground"> </span>
<span className="font-mono font-medium">70%</span>
</div>
<div className="flex justify-between">
<span className="text-muted-foreground"> </span>
<span className="font-mono font-medium">30%</span>
</div>
{allocation && (
<>
<div className="border-t pt-3 flex justify-between">
<span className="text-muted-foreground"> </span>
<span className="font-mono font-medium">{allocation.current_age}</span>
</div>
<div className="flex justify-between">
<span className="text-muted-foreground"></span>
<span className="font-mono font-medium">{allocation.years_to_retirement}</span>
</div>
<div className="flex justify-between">
<span className="text-muted-foreground"> </span>
<span className="font-mono font-medium text-red-600">{allocation.glide_path_equity_pct.toFixed(1)}%</span>
</div>
<div className="flex justify-between">
<span className="text-muted-foreground"> </span>
<span className="font-mono font-medium text-green-600">{allocation.glide_path_bond_pct.toFixed(1)}%</span>
</div>
</>
)}
</CardContent>
</Card>
)}
{/* Recommendations */}
{recommendation && (
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<Lightbulb className="h-5 w-5 text-yellow-500" />
TDF/ETF
</CardTitle>
</CardHeader>
<CardContent className="space-y-3">
{recommendation.recommendations.map((rec, i) => (
<div key={i} className="border rounded-lg p-3">
<div className="flex items-center justify-between mb-1">
<span className="font-medium">{rec.asset_name}</span>
<div className="flex items-center gap-2">
<Badge className={rec.asset_type === 'risky'
? 'bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200'
: 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200'
}>
{rec.asset_type === 'risky' ? '위험' : '안전'}
</Badge>
<span className="text-sm font-mono">{rec.ratio.toFixed(1)}%</span>
</div>
</div>
<p className="text-sm text-muted-foreground">{rec.reason}</p>
</div>
))}
</CardContent>
</Card>
)}
</div>
</DashboardLayout>
);
}

View File

@ -0,0 +1,157 @@
'use client';
import { useState } from 'react';
import { useRouter } from 'next/navigation';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import { ArrowLeft } from 'lucide-react';
import Link from 'next/link';
export default function NewPensionAccountPage() {
const router = useRouter();
const [submitting, setSubmitting] = useState(false);
const [form, setForm] = useState({
account_type: 'dc',
account_name: '',
total_amount: '',
birth_year: '',
target_retirement_age: '60',
});
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
if (!form.account_name || !form.total_amount || !form.birth_year) {
toast.error('필수 항목을 모두 입력해주세요.');
return;
}
setSubmitting(true);
try {
await api.post('/api/pension/accounts', {
account_type: form.account_type,
account_name: form.account_name,
total_amount: Number(form.total_amount),
birth_year: Number(form.birth_year),
target_retirement_age: Number(form.target_retirement_age),
});
toast.success('연금 계좌가 등록되었습니다.');
router.push('/pension');
} catch {
toast.error('계좌 등록에 실패했습니다.');
} finally {
setSubmitting(false);
}
};
return (
<DashboardLayout>
<div className="mb-6">
<Link href="/pension" className="inline-flex items-center text-sm text-muted-foreground hover:text-foreground mb-4">
<ArrowLeft className="mr-1 h-4 w-4" />
</Link>
<h1 className="text-2xl font-bold text-foreground"> </h1>
<p className="mt-1 text-muted-foreground">
DC형/IRP/
</p>
</div>
<Card className="max-w-lg">
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<form onSubmit={handleSubmit} className="space-y-4">
<div className="space-y-2">
<Label> *</Label>
<Select value={form.account_type} onValueChange={(v) => setForm({ ...form, account_type: v })}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="dc">DC형 ()</SelectItem>
<SelectItem value="irp">IRP ( )</SelectItem>
<SelectItem value="personal"></SelectItem>
</SelectContent>
</Select>
</div>
<div className="space-y-2">
<Label> *</Label>
<Input
type="text"
placeholder="예: 삼성생명 DC"
value={form.account_name}
onChange={(e) => setForm({ ...form, account_name: e.target.value })}
/>
</div>
<div className="space-y-2">
<Label> () *</Label>
<Input
type="number"
placeholder="예: 10000000"
value={form.total_amount}
onChange={(e) => setForm({ ...form, total_amount: e.target.value })}
min={0}
/>
</div>
<div className="space-y-2">
<Label> *</Label>
<Input
type="number"
placeholder="예: 1990"
value={form.birth_year}
onChange={(e) => setForm({ ...form, birth_year: e.target.value })}
min={1940}
max={2010}
/>
</div>
<div className="space-y-2">
<Label> </Label>
<Input
type="number"
value={form.target_retirement_age}
onChange={(e) => setForm({ ...form, target_retirement_age: e.target.value })}
min={50}
max={70}
/>
</div>
<div className="bg-muted/50 rounded-lg p-3 text-sm text-muted-foreground">
<p className="font-medium text-foreground mb-1"> </p>
<ul className="list-disc list-inside space-y-1">
<li>DC형/IRP: 위험자산 70% , 30% </li>
<li> </li>
<li>안전자산: 채권형 , , TDF, </li>
</ul>
</div>
<div className="flex gap-3">
<Button type="submit" disabled={submitting}>
{submitting ? '등록 중...' : '계좌 등록'}
</Button>
<Link href="/pension">
<Button type="button" variant="outline"></Button>
</Link>
</div>
</form>
</CardContent>
</Card>
</DashboardLayout>
);
}

View File

@ -0,0 +1,187 @@
'use client';
import { useEffect, useState } from 'react';
import { useRouter } from 'next/navigation';
import Link from 'next/link';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Badge } from '@/components/ui/badge';
import { Skeleton } from '@/components/ui/skeleton';
import { Button } from '@/components/ui/button';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import { Plus, Wallet, PiggyBank } from 'lucide-react';
interface PensionHolding {
id: number;
account_id: number;
asset_name: string;
asset_type: string;
amount: number;
ratio: number;
}
interface PensionAccount {
id: number;
user_id: number;
account_type: string;
account_name: string;
total_amount: number;
birth_year: number;
target_retirement_age: number;
created_at: string;
updated_at: string;
holdings: PensionHolding[];
}
const formatAmount = (value: number) => {
return new Intl.NumberFormat('ko-KR').format(value);
};
const accountTypeLabel: Record<string, string> = {
dc: 'DC형',
irp: 'IRP',
personal: '개인연금',
};
export default function PensionPage() {
const router = useRouter();
const [loading, setLoading] = useState(true);
const [accounts, setAccounts] = useState<PensionAccount[]>([]);
useEffect(() => {
const init = async () => {
try {
await api.getCurrentUser();
const data = await api.get<PensionAccount[]>('/api/pension/accounts');
setAccounts(data);
} catch {
router.push('/login');
} finally {
setLoading(false);
}
};
init();
}, [router]);
if (loading) {
return (
<DashboardLayout>
<div className="space-y-6">
<Skeleton className="h-8 w-48" />
<Skeleton className="h-[400px]" />
</div>
</DashboardLayout>
);
}
const totalAmount = accounts.reduce((sum, a) => sum + a.total_amount, 0);
return (
<DashboardLayout>
<div className="mb-6 flex items-center justify-between">
<div>
<h1 className="text-2xl font-bold text-foreground"> </h1>
<p className="mt-1 text-muted-foreground">
DC형/IRP/
</p>
</div>
<Link href="/pension/new">
<Button>
<Plus className="mr-2 h-4 w-4" />
</Button>
</Link>
</div>
{/* Summary */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4 mb-6">
<Card>
<CardContent className="p-4">
<div className="flex items-center gap-2 text-muted-foreground mb-1">
<Wallet className="h-4 w-4" />
<span className="text-xs font-medium"> </span>
</div>
<p className="text-2xl font-bold">{accounts.length}</p>
</CardContent>
</Card>
<Card>
<CardContent className="p-4">
<div className="flex items-center gap-2 text-muted-foreground mb-1">
<PiggyBank className="h-4 w-4" />
<span className="text-xs font-medium"> </span>
</div>
<p className="text-2xl font-bold">{formatAmount(totalAmount)}</p>
</CardContent>
</Card>
</div>
{/* Account List */}
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
{accounts.map((account) => {
const riskyAmount = account.holdings
.filter((h) => h.asset_type === 'risky')
.reduce((sum, h) => sum + h.amount, 0);
const safeAmount = account.holdings
.filter((h) => h.asset_type === 'safe')
.reduce((sum, h) => sum + h.amount, 0);
const riskyPct = account.total_amount > 0
? ((riskyAmount / account.total_amount) * 100).toFixed(1)
: '0.0';
return (
<Card
key={account.id}
className="cursor-pointer hover:bg-muted/50 transition-colors"
onClick={() => router.push(`/pension/${account.id}`)}
>
<CardHeader className="pb-2">
<div className="flex items-center justify-between">
<CardTitle className="text-lg">{account.account_name}</CardTitle>
<Badge>{accountTypeLabel[account.account_type] || account.account_type}</Badge>
</div>
</CardHeader>
<CardContent>
<p className="text-2xl font-bold mb-3">
{formatAmount(account.total_amount)}
</p>
{account.holdings.length > 0 ? (
<div className="space-y-2">
<div className="flex justify-between text-sm">
<span className="text-muted-foreground"></span>
<span className="text-red-600">{formatAmount(riskyAmount)} ({riskyPct}%)</span>
</div>
<div className="flex justify-between text-sm">
<span className="text-muted-foreground"></span>
<span className="text-green-600">{formatAmount(safeAmount)}</span>
</div>
<div className="h-2 bg-muted rounded-full overflow-hidden">
<div
className="h-full bg-red-500 rounded-full"
style={{ width: `${riskyPct}%` }}
/>
</div>
</div>
) : (
<p className="text-sm text-muted-foreground">
</p>
)}
<p className="text-xs text-muted-foreground mt-3">
: {account.birth_year} / : {account.target_retirement_age}
</p>
</CardContent>
</Card>
);
})}
{accounts.length === 0 && (
<Card className="col-span-full">
<CardContent className="p-8 text-center text-muted-foreground">
. .
</CardContent>
</Card>
)}
</div>
</DashboardLayout>
);
}

View File

@ -0,0 +1,420 @@
'use client';
import { useState } from 'react';
import {
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
ResponsiveContainer,
Legend,
} from 'recharts';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { api } from '@/lib/api';
interface TaxDeductionResult {
annual_income: number;
contribution: number;
account_type: string;
deduction_rate: number;
irp_limit: number;
deductible_contribution: number;
tax_deduction: number;
}
interface PensionTaxResult {
withdrawal_amount: number;
withdrawal_type: string;
age: number;
pension_tax_rate: number;
pension_tax: number;
lump_sum_tax_rate: number;
lump_sum_tax: number;
tax_saving: number;
}
interface YearlyData {
year: number;
contribution: number;
cumulative_contribution: number;
investment_value: number;
tax_deduction: number;
cumulative_tax_deduction: number;
}
interface AccumulationResult {
monthly_contribution: number;
years: number;
annual_return: number;
tax_deduction_rate: number;
total_contribution: number;
final_value: number;
total_return: number;
total_tax_deduction: number;
yearly_data: YearlyData[];
}
function formatKRW(value: number): string {
if (value >= 100_000_000) {
return `${(value / 100_000_000).toFixed(1)}억원`;
}
if (value >= 10_000) {
return `${Math.round(value / 10_000).toLocaleString()}만원`;
}
return `${value.toLocaleString()}`;
}
export default function TaxSimulatorPage() {
// 세액공제 계산기
const [annualIncome, setAnnualIncome] = useState(50_000_000);
const [contribution, setContribution] = useState(9_000_000);
const [accountType, setAccountType] = useState<'irp' | 'dc'>('irp');
const [deductionResult, setDeductionResult] = useState<TaxDeductionResult | null>(null);
// 수령 방식 비교
const [withdrawalAmount, setWithdrawalAmount] = useState(100_000_000);
const [age, setAge] = useState(65);
const [pensionTaxResult, setPensionTaxResult] = useState<PensionTaxResult | null>(null);
// 적립 시뮬레이션
const [monthlyContribution, setMonthlyContribution] = useState(500_000);
const [years, setYears] = useState(20);
const [annualReturn, setAnnualReturn] = useState(7);
const [accumulationResult, setAccumulationResult] = useState<AccumulationResult | null>(null);
const [loading, setLoading] = useState({ deduction: false, pension: false, accumulation: false });
const calculateDeduction = async () => {
setLoading(prev => ({ ...prev, deduction: true }));
try {
const result = await api.post<TaxDeductionResult>('/api/tax/deduction', {
annual_income: annualIncome,
contribution,
account_type: accountType,
});
setDeductionResult(result);
} catch {
// ignore
} finally {
setLoading(prev => ({ ...prev, deduction: false }));
}
};
const calculatePensionTax = async () => {
setLoading(prev => ({ ...prev, pension: true }));
try {
const result = await api.post<PensionTaxResult>('/api/tax/pension-tax', {
withdrawal_amount: withdrawalAmount,
withdrawal_type: 'pension',
age,
});
setPensionTaxResult(result);
} catch {
// ignore
} finally {
setLoading(prev => ({ ...prev, pension: false }));
}
};
const calculateAccumulation = async () => {
setLoading(prev => ({ ...prev, accumulation: true }));
try {
const deductionRate = annualIncome <= 55_000_000 ? 16.5 : 13.2;
const result = await api.post<AccumulationResult>('/api/tax/accumulation', {
monthly_contribution: monthlyContribution,
years,
annual_return: annualReturn,
tax_deduction_rate: deductionRate,
});
setAccumulationResult(result);
} catch {
// ignore
} finally {
setLoading(prev => ({ ...prev, accumulation: false }));
}
};
return (
<DashboardLayout>
<div className="space-y-6">
<div>
<h1 className="text-2xl font-bold"> </h1>
<p className="text-sm text-muted-foreground mt-1">
, ,
</p>
</div>
{/* 세액공제 계산기 */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<div className="grid grid-cols-1 md:grid-cols-3 gap-4 mb-4">
<div>
<label className="block text-sm font-medium mb-1"> ()</label>
<input
type="number"
value={annualIncome}
onChange={e => setAnnualIncome(Number(e.target.value))}
className="w-full rounded-md border px-3 py-2 text-sm"
step={1_000_000}
/>
<p className="text-xs text-muted-foreground mt-1">
{annualIncome <= 55_000_000 ? '공제율 16.5%' : '공제율 13.2%'}
</p>
</div>
<div>
<label className="block text-sm font-medium mb-1"> ()</label>
<input
type="number"
value={contribution}
onChange={e => setContribution(Number(e.target.value))}
className="w-full rounded-md border px-3 py-2 text-sm"
step={100_000}
/>
<p className="text-xs text-muted-foreground mt-1">한도: 900만원</p>
</div>
<div>
<label className="block text-sm font-medium mb-1"> </label>
<select
value={accountType}
onChange={e => setAccountType(e.target.value as 'irp' | 'dc')}
className="w-full rounded-md border px-3 py-2 text-sm"
>
<option value="irp">IRP</option>
<option value="dc">DC ()</option>
</select>
</div>
</div>
<Button onClick={calculateDeduction} disabled={loading.deduction}>
{loading.deduction ? '계산 중...' : '세액공제 계산'}
</Button>
{deductionResult && (
<div className="mt-4 grid grid-cols-2 md:grid-cols-4 gap-4">
<div className="rounded-lg border p-4">
<p className="text-sm text-muted-foreground"></p>
<p className="text-2xl font-bold">{deductionResult.deduction_rate}%</p>
</div>
<div className="rounded-lg border p-4">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-2xl font-bold">{formatKRW(deductionResult.deductible_contribution)}</p>
</div>
<div className="rounded-lg border p-4 border-green-200 bg-green-50 dark:border-green-800 dark:bg-green-950">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-2xl font-bold text-green-600 dark:text-green-400">
{formatKRW(deductionResult.tax_deduction)}
</p>
</div>
<div className="rounded-lg border p-4">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-2xl font-bold">{formatKRW(deductionResult.irp_limit)}</p>
</div>
</div>
)}
</CardContent>
</Card>
{/* 수령 방식 비교 */}
<Card>
<CardHeader>
<CardTitle> ( vs )</CardTitle>
</CardHeader>
<CardContent>
<div className="grid grid-cols-1 md:grid-cols-2 gap-4 mb-4">
<div>
<label className="block text-sm font-medium mb-1"> ()</label>
<input
type="number"
value={withdrawalAmount}
onChange={e => setWithdrawalAmount(Number(e.target.value))}
className="w-full rounded-md border px-3 py-2 text-sm"
step={10_000_000}
/>
</div>
<div>
<label className="block text-sm font-medium mb-1"> </label>
<input
type="number"
value={age}
onChange={e => setAge(Number(e.target.value))}
className="w-full rounded-md border px-3 py-2 text-sm"
min={55}
max={100}
/>
<p className="text-xs text-muted-foreground mt-1">
{age < 70 ? '5.5%' : age < 80 ? '4.4%' : '3.3%'}
</p>
</div>
</div>
<Button onClick={calculatePensionTax} disabled={loading.pension}>
{loading.pension ? '계산 중...' : '세금 비교'}
</Button>
{pensionTaxResult && (
<div className="mt-4">
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<div className="rounded-lg border p-4">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-xl font-bold">{formatKRW(pensionTaxResult.pension_tax)}</p>
<p className="text-xs text-muted-foreground"> {pensionTaxResult.pension_tax_rate}%</p>
</div>
<div className="rounded-lg border p-4">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-xl font-bold text-red-500">{formatKRW(pensionTaxResult.lump_sum_tax)}</p>
<p className="text-xs text-muted-foreground"> {pensionTaxResult.lump_sum_tax_rate}%</p>
</div>
<div className="rounded-lg border p-4 border-green-200 bg-green-50 dark:border-green-800 dark:bg-green-950">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-xl font-bold text-green-600 dark:text-green-400">
{formatKRW(pensionTaxResult.tax_saving)}
</p>
</div>
</div>
</div>
)}
</CardContent>
</Card>
{/* 적립 시뮬레이션 */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<div className="grid grid-cols-1 md:grid-cols-3 gap-4 mb-4">
<div>
<label className="block text-sm font-medium mb-1"> ()</label>
<input
type="number"
value={monthlyContribution}
onChange={e => setMonthlyContribution(Number(e.target.value))}
className="w-full rounded-md border px-3 py-2 text-sm"
step={100_000}
/>
</div>
<div>
<label className="block text-sm font-medium mb-1"> ()</label>
<input
type="number"
value={years}
onChange={e => setYears(Number(e.target.value))}
className="w-full rounded-md border px-3 py-2 text-sm"
min={1}
max={50}
/>
</div>
<div>
<label className="block text-sm font-medium mb-1"> (%)</label>
<input
type="number"
value={annualReturn}
onChange={e => setAnnualReturn(Number(e.target.value))}
className="w-full rounded-md border px-3 py-2 text-sm"
step={0.5}
min={0}
max={30}
/>
</div>
</div>
<Button onClick={calculateAccumulation} disabled={loading.accumulation}>
{loading.accumulation ? '계산 중...' : '시뮬레이션 실행'}
</Button>
{accumulationResult && (
<div className="mt-4 space-y-4">
<div className="grid grid-cols-2 md:grid-cols-4 gap-4">
<div className="rounded-lg border p-4">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-xl font-bold">{formatKRW(accumulationResult.total_contribution)}</p>
</div>
<div className="rounded-lg border p-4">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-xl font-bold text-blue-600 dark:text-blue-400">
{formatKRW(accumulationResult.final_value)}
</p>
</div>
<div className="rounded-lg border p-4">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-xl font-bold text-red-500">{formatKRW(accumulationResult.total_return)}</p>
</div>
<div className="rounded-lg border p-4 border-green-200 bg-green-50 dark:border-green-800 dark:bg-green-950">
<p className="text-sm text-muted-foreground"> </p>
<p className="text-xl font-bold text-green-600 dark:text-green-400">
{formatKRW(accumulationResult.total_tax_deduction)}
</p>
</div>
</div>
<div className="h-[400px]">
<ResponsiveContainer width="100%" height="100%">
<LineChart data={accumulationResult.yearly_data}>
<CartesianGrid strokeDasharray="3 3" className="stroke-muted" />
<XAxis
dataKey="year"
tick={{ fontSize: 12 }}
tickFormatter={(v: number) => `${v}`}
/>
<YAxis
tick={{ fontSize: 12 }}
tickFormatter={(v: number) => formatKRW(v)}
/>
<Tooltip
formatter={(value: number | undefined, name: string | undefined) => {
const labels: Record<string, string> = {
investment_value: '적립금',
cumulative_contribution: '납입 원금',
cumulative_tax_deduction: '누적 세액공제',
};
return [formatKRW(value ?? 0), labels[name ?? ''] || name || ''];
}}
labelFormatter={(label) => `${label}년차`}
/>
<Legend
formatter={(value: string) => {
const labels: Record<string, string> = {
investment_value: '적립금',
cumulative_contribution: '납입 원금',
cumulative_tax_deduction: '누적 세액공제',
};
return labels[value] || value;
}}
/>
<Line
type="monotone"
dataKey="investment_value"
stroke="#2563eb"
strokeWidth={2}
dot={false}
/>
<Line
type="monotone"
dataKey="cumulative_contribution"
stroke="#9ca3af"
strokeWidth={1.5}
strokeDasharray="5 5"
dot={false}
/>
<Line
type="monotone"
dataKey="cumulative_tax_deduction"
stroke="#16a34a"
strokeWidth={1.5}
dot={false}
/>
</LineChart>
</ResponsiveContainer>
</div>
</div>
)}
</CardContent>
</Card>
</div>
</DashboardLayout>
);
}

View File

@ -0,0 +1,325 @@
'use client';
import { useEffect, useState } from 'react';
import { useParams, useRouter } from 'next/navigation';
import Link from 'next/link';
import {
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
ResponsiveContainer,
Legend,
} from 'recharts';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Skeleton } from '@/components/ui/skeleton';
import { api } from '@/lib/api';
interface TimeSeriesPoint {
date: string;
portfolio_return: number | null;
benchmark_return: number | null;
deposit_return: number | null;
}
interface PerformanceMetrics {
cumulative_return: number;
annualized_return: number;
sharpe_ratio: number | null;
max_drawdown: number;
}
interface BenchmarkCompareData {
portfolio_name: string;
benchmark_type: string;
period: string;
start_date: string;
end_date: string;
time_series: TimeSeriesPoint[];
portfolio_metrics: PerformanceMetrics;
benchmark_metrics: PerformanceMetrics;
deposit_metrics: PerformanceMetrics;
alpha: number;
information_ratio: number | null;
}
const PERIODS = [
{ value: '1m', label: '1개월' },
{ value: '3m', label: '3개월' },
{ value: '6m', label: '6개월' },
{ value: '1y', label: '1년' },
{ value: 'all', label: '전체' },
] as const;
function formatPercent(value: number | null): string {
if (value === null || value === undefined) return '-';
return `${value >= 0 ? '+' : ''}${value.toFixed(2)}%`;
}
function formatRatio(value: number | null): string {
if (value === null || value === undefined) return '-';
return value.toFixed(4);
}
export default function BenchmarkPage() {
const params = useParams();
const router = useRouter();
const portfolioId = params.id as string;
const [data, setData] = useState<BenchmarkCompareData | null>(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [period, setPeriod] = useState<string>('1y');
const fetchData = async (selectedPeriod: string) => {
setLoading(true);
setError(null);
try {
const result = await api.get<BenchmarkCompareData>(
`/api/benchmark/compare/${portfolioId}?benchmark=kospi&period=${selectedPeriod}`
);
setData(result);
} catch (err) {
if (err instanceof Error && err.message === 'API request failed') {
router.push('/login');
return;
}
setError(err instanceof Error ? err.message : '데이터를 불러올 수 없습니다');
} finally {
setLoading(false);
}
};
useEffect(() => {
fetchData(period);
}, [portfolioId]); // eslint-disable-line react-hooks/exhaustive-deps
const handlePeriodChange = (newPeriod: string) => {
setPeriod(newPeriod);
fetchData(newPeriod);
};
return (
<DashboardLayout>
<div className="space-y-6">
{/* Header */}
<div className="flex items-center justify-between">
<div>
<div className="flex items-center gap-2 text-sm text-muted-foreground mb-1">
<Link href="/portfolio" className="hover:text-foreground">
</Link>
<span>/</span>
<Link href={`/portfolio/${portfolioId}`} className="hover:text-foreground">
{data?.portfolio_name || '...'}
</Link>
<span>/</span>
<span> </span>
</div>
<h1 className="text-2xl font-bold"> </h1>
</div>
<Link href={`/portfolio/${portfolioId}`}>
<Button variant="outline"></Button>
</Link>
</div>
{/* Period selector */}
<div className="flex gap-2">
{PERIODS.map((p) => (
<Button
key={p.value}
variant={period === p.value ? 'default' : 'outline'}
size="sm"
onClick={() => handlePeriodChange(p.value)}
>
{p.label}
</Button>
))}
</div>
{error && (
<Card>
<CardContent className="py-8 text-center text-red-500">
{error}
</CardContent>
</Card>
)}
{loading && (
<div className="space-y-4">
<Skeleton className="h-[400px] w-full" />
<Skeleton className="h-[200px] w-full" />
</div>
)}
{!loading && !error && data && (
<>
{/* Chart */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<div className="h-[400px]">
<ResponsiveContainer width="100%" height="100%">
<LineChart data={data.time_series}>
<CartesianGrid strokeDasharray="3 3" className="stroke-muted" />
<XAxis
dataKey="date"
tick={{ fontSize: 12 }}
tickFormatter={(v: string) => {
const d = new Date(v);
return `${d.getMonth() + 1}/${d.getDate()}`;
}}
/>
<YAxis
tick={{ fontSize: 12 }}
tickFormatter={(v: number) => `${v.toFixed(1)}%`}
/>
<Tooltip
formatter={(value: number | undefined, name: string) => {
const labels: Record<string, string> = {
portfolio_return: '포트폴리오',
benchmark_return: 'KOSPI',
deposit_return: '정기예금',
};
return [`${value != null ? value.toFixed(2) : '-'}%`, labels[name] || name];
}}
labelFormatter={(label) => {
return new Date(String(label)).toLocaleDateString('ko-KR');
}}
/>
<Legend
formatter={(value: string) => {
const labels: Record<string, string> = {
portfolio_return: '포트폴리오',
benchmark_return: 'KOSPI',
deposit_return: '정기예금',
};
return labels[value] || value;
}}
/>
<Line
type="monotone"
dataKey="portfolio_return"
stroke="#2563eb"
strokeWidth={2}
dot={false}
connectNulls
/>
<Line
type="monotone"
dataKey="benchmark_return"
stroke="#dc2626"
strokeWidth={2}
dot={false}
connectNulls
/>
<Line
type="monotone"
dataKey="deposit_return"
stroke="#16a34a"
strokeWidth={1.5}
strokeDasharray="5 5"
dot={false}
connectNulls
/>
</LineChart>
</ResponsiveContainer>
</div>
</CardContent>
</Card>
{/* Performance comparison table */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<div className="overflow-x-auto">
<table className="w-full text-sm">
<thead>
<tr className="border-b">
<th className="text-left py-3 px-4 font-medium"></th>
<th className="text-right py-3 px-4 font-medium"></th>
<th className="text-right py-3 px-4 font-medium">KOSPI</th>
<th className="text-right py-3 px-4 font-medium"></th>
</tr>
</thead>
<tbody>
<tr className="border-b">
<td className="py-3 px-4"> </td>
<td className={`text-right py-3 px-4 font-mono ${data.portfolio_metrics.cumulative_return >= 0 ? 'text-red-500' : 'text-blue-500'}`}>
{formatPercent(data.portfolio_metrics.cumulative_return)}
</td>
<td className={`text-right py-3 px-4 font-mono ${data.benchmark_metrics.cumulative_return >= 0 ? 'text-red-500' : 'text-blue-500'}`}>
{formatPercent(data.benchmark_metrics.cumulative_return)}
</td>
<td className="text-right py-3 px-4 font-mono text-red-500">
{formatPercent(data.deposit_metrics.cumulative_return)}
</td>
</tr>
<tr className="border-b">
<td className="py-3 px-4"> </td>
<td className="text-right py-3 px-4 font-mono">
{formatPercent(data.portfolio_metrics.annualized_return)}
</td>
<td className="text-right py-3 px-4 font-mono">
{formatPercent(data.benchmark_metrics.annualized_return)}
</td>
<td className="text-right py-3 px-4 font-mono">
{formatPercent(data.deposit_metrics.annualized_return)}
</td>
</tr>
<tr className="border-b">
<td className="py-3 px-4"> ()</td>
<td colSpan={3} className={`text-right py-3 px-4 font-mono font-bold ${data.alpha >= 0 ? 'text-red-500' : 'text-blue-500'}`}>
{formatPercent(data.alpha)}
</td>
</tr>
<tr className="border-b">
<td className="py-3 px-4"> </td>
<td className="text-right py-3 px-4 font-mono">
{formatRatio(data.portfolio_metrics.sharpe_ratio)}
</td>
<td className="text-right py-3 px-4 font-mono">
{formatRatio(data.benchmark_metrics.sharpe_ratio)}
</td>
<td className="text-right py-3 px-4 font-mono">-</td>
</tr>
<tr className="border-b">
<td className="py-3 px-4"> </td>
<td className="text-right py-3 px-4 font-mono text-blue-500">
{formatPercent(data.portfolio_metrics.max_drawdown)}
</td>
<td className="text-right py-3 px-4 font-mono text-blue-500">
{formatPercent(data.benchmark_metrics.max_drawdown)}
</td>
<td className="text-right py-3 px-4 font-mono">-</td>
</tr>
<tr>
<td className="py-3 px-4"> </td>
<td colSpan={3} className="text-right py-3 px-4 font-mono">
{formatRatio(data.information_ratio)}
</td>
</tr>
</tbody>
</table>
</div>
</CardContent>
</Card>
{/* Period info */}
<p className="text-sm text-muted-foreground text-center">
: {new Date(data.start_date).toLocaleDateString('ko-KR')} ~ {new Date(data.end_date).toLocaleDateString('ko-KR')}
</p>
</>
)}
</div>
</DashboardLayout>
);
}

View File

@ -0,0 +1,364 @@
'use client';
import { useEffect, useState } from 'react';
import { useParams, useRouter } from 'next/navigation';
import Link from 'next/link';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Skeleton } from '@/components/ui/skeleton';
import { api } from '@/lib/api';
interface HighCorrelationPair {
stock_a: string;
stock_b: string;
correlation: number;
}
interface CorrelationMatrixData {
stock_codes: string[];
matrix: (number | null)[][];
high_correlation_pairs: HighCorrelationPair[];
}
interface DiversificationData {
portfolio_id: number;
diversification_score: number;
stock_count: number;
high_correlation_pairs: HighCorrelationPair[];
}
interface PortfolioHolding {
ticker: string;
quantity: number;
avg_price: number;
current_price?: number;
name?: string;
}
interface PortfolioDetail {
id: number;
name: string;
holdings: PortfolioHolding[];
}
const PERIODS = [
{ value: 30, label: '1개월' },
{ value: 90, label: '3개월' },
{ value: 180, label: '6개월' },
] as const;
function getCorrelationColor(value: number | null): string {
if (value === null) return 'bg-gray-100 dark:bg-gray-800';
if (value >= 0) {
// 0 (white) -> 1 (blue)
const intensity = Math.round(value * 255);
return `rgb(${255 - intensity}, ${255 - intensity}, 255)`;
} else {
// 0 (white) -> -1 (red)
const intensity = Math.round(Math.abs(value) * 255);
return `rgb(255, ${255 - intensity}, ${255 - intensity})`;
}
}
function getCorrelationStyle(value: number | null): React.CSSProperties {
if (value === null) return { backgroundColor: '#f3f4f6' };
if (value >= 0) {
const intensity = Math.round(value * 255);
return { backgroundColor: `rgb(${255 - intensity}, ${255 - intensity}, 255)` };
} else {
const intensity = Math.round(Math.abs(value) * 255);
return { backgroundColor: `rgb(255, ${255 - intensity}, ${255 - intensity})` };
}
}
function getTextColor(value: number | null): string {
if (value === null) return 'text-gray-400';
if (Math.abs(value) > 0.6) return 'text-white';
return 'text-gray-900 dark:text-gray-100';
}
export default function CorrelationPage() {
const params = useParams();
const router = useRouter();
const portfolioId = params.id as string;
const [portfolio, setPortfolio] = useState<PortfolioDetail | null>(null);
const [correlationData, setCorrelationData] = useState<CorrelationMatrixData | null>(null);
const [diversification, setDiversification] = useState<DiversificationData | null>(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [periodDays, setPeriodDays] = useState<number>(90);
const fetchData = async (days: number) => {
setLoading(true);
setError(null);
try {
// Fetch portfolio details first to get holdings
const portfolioData = await api.get<PortfolioDetail>(
`/api/portfolio/${portfolioId}`
);
setPortfolio(portfolioData);
if (!portfolioData.holdings || portfolioData.holdings.length === 0) {
setCorrelationData(null);
setDiversification(null);
setLoading(false);
return;
}
const tickers = portfolioData.holdings.map((h) => h.ticker);
// Fetch correlation matrix and diversification in parallel
const [corrResult, divResult] = await Promise.all([
api.post<CorrelationMatrixData>('/api/correlation/matrix', {
stock_codes: tickers,
period_days: days,
}),
api.get<DiversificationData>(
`/api/correlation/portfolio/${portfolioId}`
),
]);
setCorrelationData(corrResult);
setDiversification(divResult);
} catch (err) {
if (err instanceof Error && err.message === 'API request failed') {
router.push('/login');
return;
}
setError(err instanceof Error ? err.message : '데이터를 불러올 수 없습니다');
} finally {
setLoading(false);
}
};
useEffect(() => {
fetchData(periodDays);
}, [portfolioId]); // eslint-disable-line react-hooks/exhaustive-deps
const handlePeriodChange = (days: number) => {
setPeriodDays(days);
fetchData(days);
};
const scoreColor = (score: number) => {
if (score >= 0.7) return 'text-green-600';
if (score >= 0.4) return 'text-yellow-600';
return 'text-red-600';
};
const scoreLabel = (score: number) => {
if (score >= 0.7) return '우수';
if (score >= 0.4) return '보통';
return '미흡';
};
return (
<DashboardLayout>
<div className="space-y-6">
{/* Header */}
<div className="flex items-center justify-between">
<div>
<div className="flex items-center gap-2 text-sm text-muted-foreground mb-1">
<Link href="/portfolio" className="hover:text-foreground">
</Link>
<span>/</span>
<Link href={`/portfolio/${portfolioId}`} className="hover:text-foreground">
{portfolio?.name || '...'}
</Link>
<span>/</span>
<span> </span>
</div>
<h1 className="text-2xl font-bold"> </h1>
</div>
<Link href={`/portfolio/${portfolioId}`}>
<Button variant="outline"></Button>
</Link>
</div>
{/* Period selector */}
<div className="flex gap-2">
{PERIODS.map((p) => (
<Button
key={p.value}
variant={periodDays === p.value ? 'default' : 'outline'}
size="sm"
onClick={() => handlePeriodChange(p.value)}
>
{p.label}
</Button>
))}
</div>
{error && (
<Card>
<CardContent className="py-8 text-center text-red-500">
{error}
</CardContent>
</Card>
)}
{loading && (
<div className="space-y-4">
<Skeleton className="h-[100px] w-full" />
<Skeleton className="h-[400px] w-full" />
</div>
)}
{!loading && !error && (
<>
{/* Diversification Score Card */}
{diversification && (
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<div className="flex items-center gap-6">
<div className="text-center">
<div className={`text-4xl font-bold ${scoreColor(diversification.diversification_score)}`}>
{(diversification.diversification_score * 100).toFixed(1)}%
</div>
<div className={`text-sm font-medium ${scoreColor(diversification.diversification_score)}`}>
{scoreLabel(diversification.diversification_score)}
</div>
</div>
<div className="text-sm text-muted-foreground">
<p> : {diversification.stock_count}</p>
<p className="mt-1">
0% = ( )
</p>
<p>100% = ( )</p>
</div>
</div>
</CardContent>
</Card>
)}
{/* Correlation Heatmap */}
{correlationData && correlationData.stock_codes.length > 0 && (
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<div className="overflow-x-auto">
<div
className="inline-grid gap-px"
style={{
gridTemplateColumns: `80px repeat(${correlationData.stock_codes.length}, minmax(60px, 1fr))`,
}}
>
{/* Header row */}
<div className="p-2" />
{correlationData.stock_codes.map((code) => (
<div
key={`header-${code}`}
className="p-2 text-center text-xs font-medium truncate"
title={code}
>
{code}
</div>
))}
{/* Data rows */}
{correlationData.stock_codes.map((rowCode, i) => (
<>
<div
key={`label-${rowCode}`}
className="p-2 text-xs font-medium truncate flex items-center"
title={rowCode}
>
{rowCode}
</div>
{correlationData.stock_codes.map((colCode, j) => {
const value = correlationData.matrix[i]?.[j] ?? null;
return (
<div
key={`cell-${rowCode}-${colCode}`}
className={`p-2 text-center text-xs font-mono ${getTextColor(value)}`}
style={getCorrelationStyle(value)}
title={`${rowCode} vs ${colCode}: ${value !== null ? value.toFixed(4) : 'N/A'}`}
>
{value !== null ? value.toFixed(2) : '-'}
</div>
);
})}
</>
))}
</div>
{/* Legend */}
<div className="flex items-center justify-center gap-2 mt-4 text-xs text-muted-foreground">
<div className="flex items-center gap-1">
<div className="w-4 h-4 rounded" style={{ backgroundColor: 'rgb(255, 0, 0)' }} />
<span>-1 ()</span>
</div>
<div className="flex items-center gap-1">
<div className="w-4 h-4 rounded border" style={{ backgroundColor: 'rgb(255, 255, 255)' }} />
<span>0 ()</span>
</div>
<div className="flex items-center gap-1">
<div className="w-4 h-4 rounded" style={{ backgroundColor: 'rgb(0, 0, 255)' }} />
<span>+1 ()</span>
</div>
</div>
</div>
</CardContent>
</Card>
)}
{/* High Correlation Warning */}
{correlationData && correlationData.high_correlation_pairs.length > 0 && (
<Card className="border-yellow-300 dark:border-yellow-700">
<CardHeader>
<CardTitle className="text-yellow-600 dark:text-yellow-400">
</CardTitle>
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground mb-3">
|r| &gt; 0.7 .
</p>
<div className="space-y-2">
{correlationData.high_correlation_pairs.map((pair) => (
<div
key={`${pair.stock_a}-${pair.stock_b}`}
className="flex items-center justify-between p-3 rounded-lg bg-yellow-50 dark:bg-yellow-900/20"
>
<span className="font-mono text-sm">
{pair.stock_a} - {pair.stock_b}
</span>
<span
className={`font-mono text-sm font-bold ${
pair.correlation > 0 ? 'text-blue-600' : 'text-red-600'
}`}
>
{pair.correlation > 0 ? '+' : ''}
{pair.correlation.toFixed(4)}
</span>
</div>
))}
</div>
</CardContent>
</Card>
)}
{/* No holdings message */}
{!correlationData && !diversification && (
<Card>
<CardContent className="py-8 text-center text-muted-foreground">
.
</CardContent>
</Card>
)}
</>
)}
</div>
</DashboardLayout>
);
}

View File

@ -0,0 +1,333 @@
'use client';
import { useEffect, useState } from 'react';
import { useParams, useRouter } from 'next/navigation';
import Link from 'next/link';
import {
AreaChart,
Area,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
ResponsiveContainer,
ReferenceLine,
} from 'recharts';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { api } from '@/lib/api';
interface DrawdownData {
portfolio_id: number;
current_drawdown_pct: number;
max_drawdown_pct: number;
max_drawdown_date: string | null;
peak_value: number | null;
peak_date: string | null;
trough_value: number | null;
trough_date: string | null;
alert_threshold_pct: number;
}
interface DrawdownDataPoint {
date: string;
total_value: number;
peak: number;
drawdown_pct: number;
}
interface DrawdownHistoryData {
portfolio_id: number;
data: DrawdownDataPoint[];
max_drawdown_pct: number;
current_drawdown_pct: number;
}
export default function DrawdownPage() {
const params = useParams();
const router = useRouter();
const portfolioId = params.id as string;
const [drawdown, setDrawdown] = useState<DrawdownData | null>(null);
const [history, setHistory] = useState<DrawdownHistoryData | null>(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [threshold, setThreshold] = useState<string>('');
const [saving, setSaving] = useState(false);
const fetchData = async () => {
try {
const [dd, hist] = await Promise.all([
api.get<DrawdownData>(`/api/drawdown/${portfolioId}`),
api.get<DrawdownHistoryData>(`/api/drawdown/${portfolioId}/history`),
]);
setDrawdown(dd);
setHistory(hist);
setThreshold(String(dd.alert_threshold_pct));
} catch (err) {
if (err instanceof Error && err.message === 'API request failed') {
router.push('/login');
return;
}
setError(err instanceof Error ? err.message : 'An error occurred');
} finally {
setLoading(false);
}
};
useEffect(() => {
fetchData();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [portfolioId]);
const handleSaveThreshold = async () => {
const value = parseFloat(threshold);
if (isNaN(value) || value <= 0 || value > 100) {
setError('한도는 0~100 사이 값이어야 합니다.');
return;
}
setSaving(true);
setError(null);
try {
await api.put(`/api/drawdown/settings/${portfolioId}`, {
alert_threshold_pct: value,
});
await fetchData();
} catch (err) {
setError(err instanceof Error ? err.message : 'An error occurred');
} finally {
setSaving(false);
}
};
const formatCurrency = (value: number) => {
return new Intl.NumberFormat('ko-KR', {
style: 'currency',
currency: 'KRW',
maximumFractionDigits: 0,
}).format(value);
};
const formatDate = (dateStr: string) => {
return new Date(dateStr).toLocaleDateString('ko-KR', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
});
};
const getGaugeColor = (pct: number, threshold: number) => {
if (pct >= threshold) return 'text-red-600';
if (pct >= threshold * 0.7) return 'text-amber-500';
return 'text-green-600';
};
const getGaugeBgColor = (pct: number, threshold: number) => {
if (pct >= threshold) return 'bg-red-500';
if (pct >= threshold * 0.7) return 'bg-amber-500';
return 'bg-green-500';
};
if (loading) return null;
return (
<DashboardLayout>
{/* Header */}
<div className="mb-6">
<Link
href={`/portfolio/${portfolioId}`}
className="text-primary hover:underline text-sm"
>
</Link>
<h1 className="text-2xl font-bold text-foreground mt-2">
(Drawdown)
</h1>
</div>
{error && (
<div className="bg-destructive/10 border border-destructive text-destructive px-4 py-3 rounded mb-6">
{error}
</div>
)}
{drawdown && (
<>
{/* Gauge + Summary Cards */}
<div className="grid grid-cols-1 md:grid-cols-4 gap-4 mb-6">
{/* Gauge Card */}
<Card className="md:col-span-1">
<CardHeader>
<CardTitle className="text-sm"> </CardTitle>
</CardHeader>
<CardContent className="flex flex-col items-center">
<div className="relative w-32 h-32">
{/* Background circle */}
<svg className="w-32 h-32 transform -rotate-90" viewBox="0 0 120 120">
<circle
cx="60" cy="60" r="50"
fill="none"
stroke="currentColor"
strokeWidth="10"
className="text-muted/30"
/>
<circle
cx="60" cy="60" r="50"
fill="none"
stroke="currentColor"
strokeWidth="10"
strokeDasharray={`${Math.min(drawdown.current_drawdown_pct / (drawdown.alert_threshold_pct || 20) * 100, 100) * 3.14} 314`}
strokeLinecap="round"
className={getGaugeBgColor(drawdown.current_drawdown_pct, drawdown.alert_threshold_pct)}
/>
</svg>
<div className="absolute inset-0 flex items-center justify-center">
<span className={`text-2xl font-bold ${getGaugeColor(drawdown.current_drawdown_pct, drawdown.alert_threshold_pct)}`}>
{drawdown.current_drawdown_pct.toFixed(1)}%
</span>
</div>
</div>
<div className="text-xs text-muted-foreground mt-2">
: {drawdown.alert_threshold_pct}%
</div>
</CardContent>
</Card>
{/* Summary Cards */}
<Card>
<CardContent className="pt-6">
<div className="text-sm text-muted-foreground"> (MDD)</div>
<div className="text-2xl font-bold text-red-600">
-{drawdown.max_drawdown_pct.toFixed(2)}%
</div>
{drawdown.max_drawdown_date && (
<div className="text-xs text-muted-foreground mt-1">
{formatDate(drawdown.max_drawdown_date)}
</div>
)}
</CardContent>
</Card>
<Card>
<CardContent className="pt-6">
<div className="text-sm text-muted-foreground"> (Peak)</div>
<div className="text-xl font-bold text-foreground">
{drawdown.peak_value !== null ? formatCurrency(drawdown.peak_value) : '-'}
</div>
{drawdown.peak_date && (
<div className="text-xs text-muted-foreground mt-1">
{formatDate(drawdown.peak_date)}
</div>
)}
</CardContent>
</Card>
<Card>
<CardContent className="pt-6">
<div className="text-sm text-muted-foreground"> (Trough)</div>
<div className="text-xl font-bold text-foreground">
{drawdown.trough_value !== null ? formatCurrency(drawdown.trough_value) : '-'}
</div>
{drawdown.trough_date && (
<div className="text-xs text-muted-foreground mt-1">
{formatDate(drawdown.trough_date)}
</div>
)}
</CardContent>
</Card>
</div>
{/* Drawdown Chart */}
{history && history.data.length > 0 && (
<Card className="mb-6">
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<div className="h-80">
<ResponsiveContainer width="100%" height="100%">
<AreaChart data={history.data.map(d => ({
...d,
drawdown_pct: -d.drawdown_pct,
date: formatDate(d.date),
}))}>
<CartesianGrid strokeDasharray="3 3" />
<XAxis
dataKey="date"
tick={{ fontSize: 12 }}
interval="preserveStartEnd"
/>
<YAxis
tick={{ fontSize: 12 }}
tickFormatter={(v) => `${v}%`}
domain={['dataMin', 0]}
/>
<Tooltip
formatter={(value: number) => [`${value.toFixed(2)}%`, '낙폭']}
labelFormatter={(label) => `날짜: ${label}`}
/>
<ReferenceLine
y={-drawdown.alert_threshold_pct}
stroke="#ef4444"
strokeDasharray="5 5"
label={{ value: `한도 -${drawdown.alert_threshold_pct}%`, fill: '#ef4444', fontSize: 12 }}
/>
<Area
type="monotone"
dataKey="drawdown_pct"
stroke="#ef4444"
fill="#fecaca"
fillOpacity={0.5}
/>
</AreaChart>
</ResponsiveContainer>
</div>
</CardContent>
</Card>
)}
{/* Settings */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<div className="flex items-end gap-4">
<div>
<label className="block text-sm text-muted-foreground mb-1">
(%)
</label>
<input
type="number"
min="1"
max="100"
step="0.1"
value={threshold}
onChange={(e) => setThreshold(e.target.value)}
className="w-32 px-3 py-2 border border-border rounded bg-background text-foreground"
/>
</div>
<Button onClick={handleSaveThreshold} disabled={saving}>
{saving ? '저장 중...' : '저장'}
</Button>
</div>
<p className="text-xs text-muted-foreground mt-2">
.
</p>
</CardContent>
</Card>
</>
)}
{!drawdown && !error && (
<Card>
<CardContent className="py-12 text-center text-muted-foreground">
. .
</CardContent>
</Card>
)}
</DashboardLayout>
);
}

View File

@ -0,0 +1,53 @@
'use client';
import Link from 'next/link';
import { usePathname } from 'next/navigation';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { cn } from '@/lib/utils';
import { Bell } from 'lucide-react';
const settingsNav = [
{ href: '/settings/notifications', label: '알림', icon: Bell },
];
export default function SettingsLayout({ children }: { children: React.ReactNode }) {
const pathname = usePathname();
return (
<DashboardLayout>
<div className="mb-6">
<h1 className="text-2xl font-bold text-foreground"></h1>
<p className="mt-1 text-muted-foreground"> </p>
</div>
<div className="flex gap-6">
<nav className="hidden md:block w-48 shrink-0">
<ul className="space-y-1">
{settingsNav.map((item) => {
const Icon = item.icon;
const isActive = pathname === item.href;
return (
<li key={item.href}>
<Link
href={item.href}
className={cn(
'flex items-center gap-2 rounded-lg px-3 py-2 text-sm transition-colors',
isActive
? 'bg-primary text-primary-foreground'
: 'text-muted-foreground hover:bg-accent hover:text-accent-foreground'
)}
>
<Icon className="h-4 w-4" />
{item.label}
</Link>
</li>
);
})}
</ul>
</nav>
<div className="flex-1 min-w-0">{children}</div>
</div>
</DashboardLayout>
);
}

View File

@ -0,0 +1,227 @@
'use client';
import { useEffect, useState } from 'react';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import { Button } from '@/components/ui/button';
import { Switch } from '@/components/ui/switch';
import { Badge } from '@/components/ui/badge';
import { Skeleton } from '@/components/ui/skeleton';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import { Save, MessageSquare, Send } from 'lucide-react';
interface NotificationSettings {
discord_webhook_url: string;
discord_enabled: boolean;
telegram_webhook_url: string;
telegram_enabled: boolean;
}
interface NotificationHistory {
id: number;
channel: string;
title: string;
message: string;
status: string;
error_message: string | null;
created_at: string;
}
export default function NotificationsSettingsPage() {
const [loading, setLoading] = useState(true);
const [saving, setSaving] = useState(false);
const [settings, setSettings] = useState<NotificationSettings>({
discord_webhook_url: '',
discord_enabled: false,
telegram_webhook_url: '',
telegram_enabled: false,
});
const [history, setHistory] = useState<NotificationHistory[]>([]);
useEffect(() => {
const fetchData = async () => {
try {
const [settingsData, historyData] = await Promise.all([
api.get<NotificationSettings>('/api/notifications/settings').catch(() => null),
api.get<NotificationHistory[]>('/api/notifications/history?limit=20').catch(() => []),
]);
if (settingsData) setSettings(settingsData);
if (historyData) setHistory(historyData);
} finally {
setLoading(false);
}
};
fetchData();
}, []);
const handleSave = async () => {
setSaving(true);
try {
await api.put('/api/notifications/settings', settings);
toast.success('알림 설정이 저장되었습니다.');
} catch (err) {
toast.error(err instanceof Error ? err.message : '설정 저장에 실패했습니다.');
} finally {
setSaving(false);
}
};
const formatDate = (dateStr: string) => {
const date = new Date(dateStr);
return date.toLocaleString('ko-KR', {
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
});
};
if (loading) {
return (
<div className="space-y-6">
<Skeleton className="h-[200px]" />
<Skeleton className="h-[200px]" />
<Skeleton className="h-[300px]" />
</div>
);
}
return (
<div className="space-y-6">
{/* Discord 설정 */}
<Card>
<CardHeader>
<div className="flex items-center justify-between">
<div className="flex items-center gap-2">
<MessageSquare className="h-5 w-5 text-[#5865F2]" />
<CardTitle>Discord</CardTitle>
</div>
<Switch
checked={settings.discord_enabled}
onCheckedChange={(checked) =>
setSettings((prev) => ({ ...prev, discord_enabled: checked }))
}
/>
</div>
<CardDescription>Discord </CardDescription>
</CardHeader>
<CardContent>
<div className="space-y-2">
<Label htmlFor="discord-webhook"> URL</Label>
<Input
id="discord-webhook"
type="url"
placeholder="https://discord.com/api/webhooks/..."
value={settings.discord_webhook_url}
onChange={(e) =>
setSettings((prev) => ({ ...prev, discord_webhook_url: e.target.value }))
}
disabled={!settings.discord_enabled}
/>
</div>
</CardContent>
</Card>
{/* Telegram 설정 */}
<Card>
<CardHeader>
<div className="flex items-center justify-between">
<div className="flex items-center gap-2">
<Send className="h-5 w-5 text-[#0088cc]" />
<CardTitle>Telegram</CardTitle>
</div>
<Switch
checked={settings.telegram_enabled}
onCheckedChange={(checked) =>
setSettings((prev) => ({ ...prev, telegram_enabled: checked }))
}
/>
</div>
<CardDescription>Telegram </CardDescription>
</CardHeader>
<CardContent>
<div className="space-y-2">
<Label htmlFor="telegram-webhook"> URL</Label>
<Input
id="telegram-webhook"
type="url"
placeholder="https://api.telegram.org/bot..."
value={settings.telegram_webhook_url}
onChange={(e) =>
setSettings((prev) => ({ ...prev, telegram_webhook_url: e.target.value }))
}
disabled={!settings.telegram_enabled}
/>
</div>
</CardContent>
</Card>
{/* 저장 버튼 */}
<div className="flex justify-end">
<Button onClick={handleSave} disabled={saving}>
<Save className="h-4 w-4 mr-2" />
{saving ? '저장 중...' : '설정 저장'}
</Button>
</div>
{/* 알림 이력 */}
<Card>
<CardHeader>
<CardTitle> </CardTitle>
<CardDescription> 20 </CardDescription>
</CardHeader>
<CardContent className="p-0">
<div className="overflow-x-auto">
<table className="w-full">
<thead className="bg-muted">
<tr>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground"></th>
</tr>
</thead>
<tbody className="divide-y divide-border">
{history.map((item) => (
<tr key={item.id} className="hover:bg-muted/50">
<td className="px-4 py-3 text-sm whitespace-nowrap">{formatDate(item.created_at)}</td>
<td className="px-4 py-3 text-sm">
<Badge variant="outline">
{item.channel === 'discord' ? 'Discord' : 'Telegram'}
</Badge>
</td>
<td className="px-4 py-3 text-sm font-medium">{item.title}</td>
<td className="px-4 py-3 text-sm max-w-xs truncate text-muted-foreground" title={item.message}>
{item.message}
</td>
<td className="px-4 py-3 text-center">
<Badge
className={
item.status === 'sent'
? 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200'
: 'bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200'
}
>
{item.status === 'sent' ? '성공' : '실패'}
</Badge>
</td>
</tr>
))}
{history.length === 0 && (
<tr>
<td colSpan={5} className="px-4 py-8 text-center text-muted-foreground">
.
</td>
</tr>
)}
</tbody>
</table>
</div>
</CardContent>
</Card>
</div>
);
}

View File

@ -0,0 +1,332 @@
"use client";
import { useState, useEffect } from "react";
import { api } from "@/lib/api";
interface ParamGrid {
[key: string]: number[];
}
interface Presets {
[strategyType: string]: ParamGrid;
}
interface ResultItem {
rank: number;
params: Record<string, number>;
total_return: number;
cagr: number;
mdd: number;
sharpe_ratio: number;
volatility: number;
benchmark_return: number;
excess_return: number;
}
interface OptimizeResponse {
strategy_type: string;
total_combinations: number;
results: ResultItem[];
best_params: Record<string, number>;
}
const STRATEGY_LABELS: Record<string, string> = {
kjb: "KJB (김종봉)",
multi_factor: "멀티팩터",
quality: "퀄리티",
value_momentum: "가치+모멘텀",
};
const RANK_BY_OPTIONS = [
{ value: "sharpe_ratio", label: "샤프 비율" },
{ value: "cagr", label: "CAGR" },
{ value: "total_return", label: "총 수익률" },
{ value: "mdd", label: "MDD" },
];
export default function OptimizerPage() {
const [strategyType, setStrategyType] = useState("kjb");
const [startDate, setStartDate] = useState("2024-01-01");
const [endDate, setEndDate] = useState("2024-12-31");
const [rankBy, setRankBy] = useState("sharpe_ratio");
const [presets, setPresets] = useState<Presets>({});
const [paramGrid, setParamGrid] = useState<ParamGrid>({});
const [useCustomGrid, setUseCustomGrid] = useState(false);
const [results, setResults] = useState<OptimizeResponse | null>(null);
const [loading, setLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
useEffect(() => {
api
.get<{ presets: Presets }>("/api/optimizer/presets")
.then((data) => {
setPresets(data.presets);
if (data.presets[strategyType]) {
setParamGrid(data.presets[strategyType]);
}
})
.catch(() => {});
}, []);
useEffect(() => {
if (!useCustomGrid && presets[strategyType]) {
setParamGrid(presets[strategyType]);
}
}, [strategyType, useCustomGrid, presets]);
const totalCombinations = Object.values(paramGrid).reduce(
(acc, vals) => acc * vals.length,
1
);
const handleGridChange = (key: string, value: string) => {
const nums = value
.split(",")
.map((s) => parseFloat(s.trim()))
.filter((n) => !isNaN(n));
setParamGrid((prev) => ({ ...prev, [key]: nums }));
};
const handleSubmit = async () => {
setLoading(true);
setError(null);
setResults(null);
try {
const data = await api.post<OptimizeResponse>("/api/optimizer", {
strategy_type: strategyType,
param_grid: paramGrid,
start_date: startDate,
end_date: endDate,
rank_by: rankBy,
});
setResults(data);
} catch (err: unknown) {
const message = err instanceof Error ? err.message : "최적화 실행 중 오류 발생";
setError(message);
} finally {
setLoading(false);
}
};
return (
<div className="space-y-6">
<div>
<h1 className="text-2xl font-bold"> </h1>
<p className="text-muted-foreground mt-1">
.
</p>
</div>
{/* Settings */}
<div className="grid grid-cols-1 md:grid-cols-2 gap-4 rounded-lg border p-4">
<div>
<label className="block text-sm font-medium mb-1"> </label>
<select
className="w-full rounded-md border px-3 py-2 text-sm"
value={strategyType}
onChange={(e) => setStrategyType(e.target.value)}
>
{Object.entries(STRATEGY_LABELS).map(([key, label]) => (
<option key={key} value={key}>
{label}
</option>
))}
</select>
</div>
<div>
<label className="block text-sm font-medium mb-1"> </label>
<select
className="w-full rounded-md border px-3 py-2 text-sm"
value={rankBy}
onChange={(e) => setRankBy(e.target.value)}
>
{RANK_BY_OPTIONS.map((opt) => (
<option key={opt.value} value={opt.value}>
{opt.label}
</option>
))}
</select>
</div>
<div>
<label className="block text-sm font-medium mb-1"></label>
<input
type="date"
className="w-full rounded-md border px-3 py-2 text-sm"
value={startDate}
onChange={(e) => setStartDate(e.target.value)}
/>
</div>
<div>
<label className="block text-sm font-medium mb-1"></label>
<input
type="date"
className="w-full rounded-md border px-3 py-2 text-sm"
value={endDate}
onChange={(e) => setEndDate(e.target.value)}
/>
</div>
</div>
{/* Parameter Grid */}
<div className="rounded-lg border p-4">
<div className="flex items-center justify-between mb-3">
<h2 className="text-lg font-semibold"> </h2>
<label className="flex items-center gap-2 text-sm">
<input
type="checkbox"
checked={useCustomGrid}
onChange={(e) => setUseCustomGrid(e.target.checked)}
/>
</label>
</div>
<div className="space-y-3">
{Object.entries(paramGrid).map(([key, values]) => (
<div key={key} className="flex items-center gap-3">
<label className="w-40 text-sm font-mono">{key}</label>
<input
type="text"
className="flex-1 rounded-md border px-3 py-2 text-sm font-mono"
value={values.join(", ")}
onChange={(e) => handleGridChange(key, e.target.value)}
disabled={!useCustomGrid}
/>
</div>
))}
</div>
<p className="mt-3 text-sm text-muted-foreground">
<span className="font-semibold">{totalCombinations}</span>
</p>
</div>
{/* Run button */}
<button
className="rounded-md bg-primary px-6 py-2 text-sm font-medium text-primary-foreground hover:bg-primary/90 disabled:opacity-50"
onClick={handleSubmit}
disabled={loading || totalCombinations === 0}
>
{loading ? "최적화 실행 중..." : "최적화 실행"}
</button>
{/* Error */}
{error && (
<div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
{/* Results */}
{results && (
<div className="space-y-4">
<div className="flex items-center justify-between">
<h2 className="text-lg font-semibold"> </h2>
<span className="text-sm text-muted-foreground">
{results.results.length} / {results.total_combinations}
</span>
</div>
{/* Best params */}
{results.best_params &&
Object.keys(results.best_params).length > 0 && (
<div className="rounded-md bg-green-50 dark:bg-green-950 border border-green-200 dark:border-green-800 p-4">
<h3 className="text-sm font-semibold text-green-800 dark:text-green-200 mb-2">
</h3>
<div className="flex flex-wrap gap-3">
{Object.entries(results.best_params).map(([key, value]) => (
<span
key={key}
className="rounded-full bg-green-100 dark:bg-green-900 px-3 py-1 text-xs font-mono"
>
{key}: {value}
</span>
))}
</div>
</div>
)}
{/* Results table */}
<div className="overflow-x-auto rounded-lg border">
<table className="w-full text-sm">
<thead>
<tr className="border-b bg-muted/50">
<th className="px-3 py-2 text-left">#</th>
<th className="px-3 py-2 text-left"></th>
<th className="px-3 py-2 text-right"> </th>
<th className="px-3 py-2 text-right">CAGR</th>
<th className="px-3 py-2 text-right">MDD</th>
<th className="px-3 py-2 text-right"></th>
<th className="px-3 py-2 text-right"></th>
<th className="px-3 py-2 text-right"></th>
</tr>
</thead>
<tbody>
{results.results.map((item) => (
<tr
key={item.rank}
className={`border-b hover:bg-muted/30 ${
item.rank === 1 ? "bg-green-50 dark:bg-green-950/30" : ""
}`}
>
<td className="px-3 py-2 font-medium">{item.rank}</td>
<td className="px-3 py-2 font-mono text-xs">
{Object.entries(item.params)
.map(([k, v]) => `${k}=${v}`)
.join(", ")}
</td>
<td className="px-3 py-2 text-right">
<span
className={
item.total_return >= 0
? "text-green-600"
: "text-red-600"
}
>
{item.total_return.toFixed(2)}%
</span>
</td>
<td className="px-3 py-2 text-right">
<span
className={
item.cagr >= 0 ? "text-green-600" : "text-red-600"
}
>
{item.cagr.toFixed(2)}%
</span>
</td>
<td className="px-3 py-2 text-right text-red-600">
{item.mdd.toFixed(2)}%
</td>
<td className="px-3 py-2 text-right font-medium">
{item.sharpe_ratio.toFixed(2)}
</td>
<td className="px-3 py-2 text-right">
{item.volatility.toFixed(2)}%
</td>
<td className="px-3 py-2 text-right">
<span
className={
item.excess_return >= 0
? "text-green-600"
: "text-red-600"
}
>
{item.excess_return.toFixed(2)}%
</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
)}
</div>
);
}

View File

@ -0,0 +1,366 @@
'use client';
import { useState } from 'react';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import { Calculator, HelpCircle, DollarSign, BarChart3, Activity } from 'lucide-react';
interface PositionSizeResult {
method: string;
position_size: number;
shares: number;
risk_amount: number;
notes: string;
}
const formatKRW = (value: number) =>
new Intl.NumberFormat('ko-KR').format(Math.round(value));
const methodHelp = {
fixed: {
title: 'Fixed Ratio (균등 배분)',
description:
'총 자본을 종목 수로 균등 분배하는 방식입니다. 현금 비중을 설정하여 일정 비율은 현금으로 보유합니다. 가장 단순하고 직관적인 방법으로, 초보 투자자에게 적합합니다.',
icon: BarChart3,
},
kelly: {
title: 'Kelly Criterion (켈리 공식)',
description:
'승률과 평균 손익비를 기반으로 수학적 최적 베팅 비율을 계산합니다. Kelly% = W - (1-W)/R (W=승률, R=평균이익/평균손실). 장기적으로 자본 성장률을 극대화하지만, 실전에서는 Half Kelly(절반)를 사용하는 것이 일반적입니다.',
icon: Calculator,
},
atr: {
title: 'ATR Based (변동성 기반)',
description:
'ATR(Average True Range)을 사용하여 변동성에 맞는 포지션 크기를 결정합니다. 위험 허용 비율 대비 ATR로 1주당 위험 금액을 산출하여 적정 수량을 계산합니다. 변동성이 큰 종목은 적게, 작은 종목은 많이 매수하게 됩니다.',
icon: Activity,
},
};
export default function PositionSizingPage() {
const [activeTab, setActiveTab] = useState('fixed');
const [loading, setLoading] = useState(false);
const [result, setResult] = useState<PositionSizeResult | null>(null);
// Common
const [capital, setCapital] = useState('');
// Fixed
const [numPositions, setNumPositions] = useState('10');
const [cashRatio, setCashRatio] = useState('0.3');
// Kelly
const [winRate, setWinRate] = useState('');
const [avgWin, setAvgWin] = useState('');
const [avgLoss, setAvgLoss] = useState('');
// ATR
const [atr, setAtr] = useState('');
const [riskPct, setRiskPct] = useState('0.02');
const handleCalculate = async () => {
if (!capital || parseFloat(capital) <= 0) {
toast.error('총 자본을 입력해주세요.');
return;
}
const body: Record<string, unknown> = {
capital: parseFloat(capital),
method: activeTab,
};
if (activeTab === 'fixed') {
body.num_positions = parseInt(numPositions);
body.cash_ratio = parseFloat(cashRatio);
} else if (activeTab === 'kelly') {
if (!winRate || !avgWin || !avgLoss) {
toast.error('모든 Kelly 파라미터를 입력해주세요.');
return;
}
body.win_rate = parseFloat(winRate);
body.avg_win = parseFloat(avgWin);
body.avg_loss = parseFloat(avgLoss);
} else if (activeTab === 'atr') {
if (!atr) {
toast.error('ATR 값을 입력해주세요.');
return;
}
body.atr = parseFloat(atr);
body.risk_pct = parseFloat(riskPct);
}
setLoading(true);
try {
const data = await api.post<PositionSizeResult>(
'/api/position-sizing/calculate',
body
);
setResult(data);
} catch (err) {
toast.error(
err instanceof Error ? err.message : '계산에 실패했습니다.'
);
} finally {
setLoading(false);
}
};
const HelpCard = ({ method }: { method: 'fixed' | 'kelly' | 'atr' }) => {
const info = methodHelp[method];
const Icon = info.icon;
return (
<div className="rounded-md border border-blue-200 bg-blue-50 dark:border-blue-800 dark:bg-blue-950 p-4 space-y-2">
<div className="flex items-center gap-2 text-blue-800 dark:text-blue-200">
<HelpCircle className="h-4 w-4" />
<span className="font-medium text-sm">{info.title}</span>
</div>
<div className="flex items-start gap-3">
<Icon className="h-5 w-5 text-blue-600 dark:text-blue-400 mt-0.5 shrink-0" />
<p className="text-sm text-blue-700 dark:text-blue-300">
{info.description}
</p>
</div>
</div>
);
};
return (
<DashboardLayout>
<div className="mb-6">
<h1 className="text-2xl font-bold text-foreground"> </h1>
<p className="mt-1 text-muted-foreground">
</p>
</div>
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
{/* Input Section */}
<div className="lg:col-span-2 space-y-6">
{/* Common Input */}
<Card>
<CardContent className="p-4">
<div className="space-y-2">
<Label htmlFor="capital"> ()</Label>
<Input
id="capital"
type="number"
min="0"
value={capital}
onChange={(e) => setCapital(e.target.value)}
placeholder="예: 50000000"
/>
</div>
</CardContent>
</Card>
{/* Method Tabs */}
<Tabs value={activeTab} onValueChange={(v) => { setActiveTab(v); setResult(null); }}>
<TabsList className="grid w-full grid-cols-3">
<TabsTrigger value="fixed">Fixed Ratio</TabsTrigger>
<TabsTrigger value="kelly">Kelly Criterion</TabsTrigger>
<TabsTrigger value="atr">ATR Based</TabsTrigger>
</TabsList>
<TabsContent value="fixed">
<Card>
<CardHeader>
<CardTitle className="text-lg">Fixed Ratio </CardTitle>
</CardHeader>
<CardContent className="space-y-4">
<HelpCard method="fixed" />
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="numPositions"> </Label>
<Input
id="numPositions"
type="number"
min="1"
max="50"
value={numPositions}
onChange={(e) => setNumPositions(e.target.value)}
/>
</div>
<div className="space-y-2">
<Label htmlFor="cashRatio"> (0~1)</Label>
<Input
id="cashRatio"
type="number"
min="0"
max="0.99"
step="0.05"
value={cashRatio}
onChange={(e) => setCashRatio(e.target.value)}
/>
</div>
</div>
<Button onClick={handleCalculate} disabled={loading} className="w-full">
<Calculator className="mr-2 h-4 w-4" />
{loading ? '계산 중...' : '계산하기'}
</Button>
</CardContent>
</Card>
</TabsContent>
<TabsContent value="kelly">
<Card>
<CardHeader>
<CardTitle className="text-lg">Kelly Criterion </CardTitle>
</CardHeader>
<CardContent className="space-y-4">
<HelpCard method="kelly" />
<div className="grid grid-cols-1 sm:grid-cols-3 gap-4">
<div className="space-y-2">
<Label htmlFor="winRate"> (0~1)</Label>
<Input
id="winRate"
type="number"
min="0.01"
max="1"
step="0.01"
value={winRate}
onChange={(e) => setWinRate(e.target.value)}
placeholder="예: 0.55"
/>
</div>
<div className="space-y-2">
<Label htmlFor="avgWin"> (%)</Label>
<Input
id="avgWin"
type="number"
min="0.01"
step="0.01"
value={avgWin}
onChange={(e) => setAvgWin(e.target.value)}
placeholder="예: 15"
/>
</div>
<div className="space-y-2">
<Label htmlFor="avgLoss"> (%)</Label>
<Input
id="avgLoss"
type="number"
min="0.01"
step="0.01"
value={avgLoss}
onChange={(e) => setAvgLoss(e.target.value)}
placeholder="예: 7"
/>
</div>
</div>
<Button onClick={handleCalculate} disabled={loading} className="w-full">
<Calculator className="mr-2 h-4 w-4" />
{loading ? '계산 중...' : '계산하기'}
</Button>
</CardContent>
</Card>
</TabsContent>
<TabsContent value="atr">
<Card>
<CardHeader>
<CardTitle className="text-lg">ATR Based </CardTitle>
</CardHeader>
<CardContent className="space-y-4">
<HelpCard method="atr" />
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="atr">ATR ()</Label>
<Input
id="atr"
type="number"
min="0.01"
step="1"
value={atr}
onChange={(e) => setAtr(e.target.value)}
placeholder="예: 1500"
/>
</div>
<div className="space-y-2">
<Label htmlFor="riskPct"> (0~1)</Label>
<Input
id="riskPct"
type="number"
min="0.001"
max="1"
step="0.005"
value={riskPct}
onChange={(e) => setRiskPct(e.target.value)}
/>
</div>
</div>
<Button onClick={handleCalculate} disabled={loading} className="w-full">
<Calculator className="mr-2 h-4 w-4" />
{loading ? '계산 중...' : '계산하기'}
</Button>
</CardContent>
</Card>
</TabsContent>
</Tabs>
</div>
{/* Result Section */}
<div>
<Card>
<CardHeader>
<CardTitle className="text-lg flex items-center gap-2">
<DollarSign className="h-5 w-5" />
</CardTitle>
</CardHeader>
<CardContent>
{result ? (
<div className="space-y-4">
<div className="rounded-md bg-green-50 dark:bg-green-950 border border-green-200 dark:border-green-800 p-4 text-center">
<p className="text-xs text-green-600 dark:text-green-400 mb-1">
</p>
<p className="text-2xl font-bold text-green-800 dark:text-green-200">
{formatKRW(result.position_size)}
</p>
</div>
<div className="grid grid-cols-2 gap-3">
<div className="rounded-md border p-3 text-center">
<p className="text-xs text-muted-foreground mb-1"> </p>
<p className="text-xl font-bold">{result.shares.toLocaleString()}</p>
</div>
<div className="rounded-md border p-3 text-center">
<p className="text-xs text-muted-foreground mb-1"> </p>
<p className="text-xl font-bold text-red-600 dark:text-red-400">
{formatKRW(result.risk_amount)}
</p>
</div>
</div>
<div className="rounded-md border p-3">
<p className="text-xs text-muted-foreground mb-1"></p>
<p className="text-sm font-medium">
{methodHelp[result.method as keyof typeof methodHelp]?.title || result.method}
</p>
</div>
{result.notes && (
<div className="rounded-md bg-muted p-3">
<p className="text-xs text-muted-foreground mb-1"></p>
<p className="text-sm">{result.notes}</p>
</div>
)}
</div>
) : (
<div className="text-center py-8 text-muted-foreground">
<Calculator className="h-10 w-10 mx-auto mb-3 opacity-30" />
<p className="text-sm">
</p>
</div>
)}
</CardContent>
</Card>
</div>
</div>
</DashboardLayout>
);
}

View File

@ -11,6 +11,10 @@ import {
Database,
Search,
Radio,
BookOpen,
PiggyBank,
Ruler,
Settings,
ChevronLeft,
ChevronRight,
} from 'lucide-react';
@ -30,8 +34,12 @@ const navItems = [
{ href: '/strategy', label: '전략', icon: TrendingUp },
{ href: '/backtest', label: '백테스트', icon: FlaskConical },
{ href: '/signals', label: '매매 신호', icon: Radio },
{ href: '/journal', label: '트레이딩 저널', icon: BookOpen },
{ href: '/pension', label: '퇴직연금', icon: PiggyBank },
{ href: '/tools/position-sizing', label: '포지션 사이징', icon: Ruler },
{ href: '/admin/data', label: '데이터 수집', icon: Database },
{ href: '/admin/data/explorer', label: '데이터 탐색', icon: Search },
{ href: '/settings/notifications', label: '설정', icon: Settings },
];
interface NewSidebarProps {

View File

@ -0,0 +1,29 @@
"use client"
import * as React from "react"
import * as SwitchPrimitives from "@radix-ui/react-switch"
import { cn } from "@/lib/utils"
const Switch = React.forwardRef<
React.ElementRef<typeof SwitchPrimitives.Root>,
React.ComponentPropsWithoutRef<typeof SwitchPrimitives.Root>
>(({ className, ...props }, ref) => (
<SwitchPrimitives.Root
className={cn(
"peer inline-flex h-6 w-11 shrink-0 cursor-pointer items-center rounded-full border-2 border-transparent transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 focus-visible:ring-offset-background disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-input",
className
)}
{...props}
ref={ref}
>
<SwitchPrimitives.Thumb
className={cn(
"pointer-events-none block h-5 w-5 rounded-full bg-background shadow-lg ring-0 transition-transform data-[state=checked]:translate-x-5 data-[state=unchecked]:translate-x-0"
)}
/>
</SwitchPrimitives.Root>
))
Switch.displayName = SwitchPrimitives.Root.displayName
export { Switch }