Compare commits

...

21 Commits

Author SHA1 Message Date
머니페니
2b1e7cefbe docs: update CLAUDE.md with galaxis-agent reference and strengthened rules
Some checks are pending
Deploy to Production / deploy (push) Waiting to run
2026-03-20 19:03:35 +09:00
머니페니
149560c083 docs: add Phase 1 implementation plan for galaxis-agent
17-task plan covering: open-swe fork setup, code cleanup,
git_utils extraction from github.py, config module, Docker
sandbox/compose setup, ARM64 compatibility validation.
Two review iterations applied.
2026-03-20 14:28:14 +09:00
머니페니
43ff569aa3 docs: add galaxis-agent autonomous SWE agent design spec
Design for an autonomous development agent that forks open-swe
(LangGraph + Deep Agents) with Gitea webhook + Discord bot triggers,
Docker sandbox execution on Oracle VM A1 (ARM64), and Claude API.
Includes phased implementation roadmap from conservative (PR-only)
to autonomous (auto-merge with E2E gates).
2026-03-20 14:06:53 +09:00
머니페니
f6db08c9bd feat: improve security, performance, and add missing features
- Remove hardcoded database_url/jwt_secret defaults, require env vars
- Add DB indexes for stocks.market, market_cap, backtests.user_id
- Optimize backtest engine: preload all prices, move stock_names out of loop
- Fix backtest API auth: filter by user_id at query level (6 endpoints)
- Add manual transaction entry modal on portfolio detail page
- Replace console.error with toast.error in signals, backtest, data explorer
- Add backtest delete button with confirmation dialog
- Replace simulated sine chart with real snapshot data
- Add strategy-to-portfolio apply flow with dialog
- Add DC pension risk asset ratio >70% warning on rebalance page
- Add backtest comparison page with metrics table and overlay chart
2026-03-20 12:27:05 +09:00
머니페니
49bd0d8745 chore: add frontend/test-results to .gitignore
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 23:19:41 +09:00
머니페니
2ad2f56d31 docs: add project config docs, analysis report, and e2e signal cancel test
Add CLAUDE.md and AGENTS.md for AI-assisted development guidance,
analysis report with screenshots, and Playwright-based e2e test for
signal cancellation flow.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 23:18:16 +09:00
머니페니
f818bd3290 feat: add walk-forward analysis for backtests
- Add WalkForwardResult model with train/test window metrics
- Create WalkForwardEngine that reuses existing BacktestEngine
  with rolling train/test window splits
- Add POST/GET /api/backtest/{id}/walkforward endpoints
- Add Walk-forward tab to backtest detail page with parameter
  controls, cumulative return chart, and window results table
- Add Alembic migration for walkforward_results table
- Add 8 unit tests for window generation logic (100 total passed)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:33:41 +09:00
머니페니
741b7fa7dd feat: add skip/limit pagination to prices, snapshots, and transactions APIs
Add paginated responses (items/total/skip/limit) to:
- GET /api/data/stocks/{ticker}/prices (default limit=365)
- GET /api/data/etfs/{ticker}/prices (default limit=365)
- GET /api/portfolios/{id}/snapshots (default limit=100)
- GET /api/portfolios/{id}/transactions (default limit=50)

Frontend: update snapshot/transaction consumers to handle new response
shape, add "Load more" button to transaction table.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:32:34 +09:00
머니페니
98a161574e security: migrate JWT from localStorage to httpOnly cookie
Eliminates XSS token theft by storing JWT in httpOnly Secure cookie
instead of localStorage. Backend sets cookie on login and clears on
logout. Token extraction uses cookie-first with Authorization header
fallback for backward compatibility with existing tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:30:47 +09:00
머니페니
60d2221edc feat: add manual transaction entry UI with modal dialog
Add "거래 추가" button to the transactions tab with a modal dialog for
manually entering buy/sell transactions (ticker, type, quantity, price, memo).
Refreshes portfolio and transaction list after successful submission.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:22:56 +09:00
머니페니
4ea744ce62 feat: replace simulated sine wave chart with real snapshot data
Portfolio value chart now uses actual snapshot API data instead of
generated simulation. Shows empty state message when no snapshots exist.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:22:44 +09:00
머니페니
ee0de0504c feat: add portfolio edit/delete UI with confirmation dialogs
Add hover-visible edit (rename) and delete buttons to portfolio cards
on the list page, with modal dialogs for name editing and delete
confirmation. Uses existing PUT/DELETE API endpoints.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:22:37 +09:00
머니페니
815f255ff5 feat: add React ErrorBoundary with retry and reload UI
Add class-based ErrorBoundary component that catches rendering errors
and shows a user-friendly fallback with retry/reload buttons. Wrap in
root layout to protect against full-page crashes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:22:29 +09:00
머니페니
f12709ea79 fix: add toast error feedback for API failures in signals and data explorer pages
Replace silent console.error-only handling with user-visible toast notifications
using sonner, while keeping console.error for debugging.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:22:22 +09:00
머니페니
b80feb7176 perf: add DB performance indexes and fix N+1 query in backtest listing
Add 10 indexes across prices, etf_prices, financials, valuations,
holdings, transactions, signals, portfolio_snapshots, and etfs tables.
Fix N+1 query in list_backtests by eager-loading backtest results
with joinedload.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:20:29 +09:00
머니페니
4483f6e4ba feat: add DC pension ETF-only filter to strategy API
Add dc_only parameter to all strategy endpoints. When true, filters
results to include only tickers present in the ETF table, supporting
DC pension investment constraints where only ETFs are allowed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 20:57:47 +09:00
머니페니
62ac92eaaf feat: add minimum trade amount filter to rebalancing calculation
Add min_trade_amount parameter (default 10,000 KRW) to rebalance/calculate
endpoint. Trades below this threshold are converted to hold actions to avoid
inefficient micro-trades during rebalancing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 20:57:43 +09:00
머니페니
01f86298c4 feat: add strategy comparison UI for side-by-side multi-strategy analysis
Add a compare page at /strategy/compare that runs MultiFactor, Quality,
and ValueMomentum strategies simultaneously and displays results side-by-side
with common ticker highlighting and factor score comparison table.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 20:57:38 +09:00
머니페니
9249821a25 feat: add realized/unrealized PnL tracking and position sizing guide
- Add realized_pnl column to transactions table with alembic migration
- Calculate realized PnL on sell transactions: (sell_price - avg_price) * quantity
- Show total realized/unrealized PnL in portfolio detail summary cards
- Show per-transaction realized PnL in transaction history table
- Add position sizing API endpoint (GET /portfolios/{id}/position-size)
- Show position sizing guide in signal execution modal for buy signals
- 8 new E2E tests, all 88 tests passing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 19:04:36 +09:00
머니페니
65618cd957 feat: add signal execution cancel with transaction rollback and holding restore 2026-03-18 18:56:29 +09:00
머니페니
213f03a8e5 fix: replace deprecated datetime.utcnow() and SQLAlchemy Query.get() 2026-03-18 18:53:29 +09:00
66 changed files with 6655 additions and 353 deletions

View File

@ -2,9 +2,7 @@
# Copy this file to .env and fill in the values
# Database
DB_USER=galaxy
DB_PASSWORD=your_secure_password_here
DB_NAME=galaxy_po
DATABASE_URL=postgresql://galaxy:your_secure_password_here@localhost:5432/galaxy_po
# JWT Authentication
JWT_SECRET=your_jwt_secret_key_here_at_least_32_characters

1
.gitignore vendored
View File

@ -61,6 +61,7 @@ data/
.coverage
htmlcov/
.pytest_cache/
frontend/test-results/
# Worktrees
.worktrees/

67
AGENTS.md Normal file
View File

@ -0,0 +1,67 @@
# AGENTS.md - galaxis-po 개발 에이전트 가이드
## 프로젝트 개요
퀀트 & 퇴직연금 포트폴리오 관리 앱.
김종봉 전략 기반 백테스팅, 신호 생성, 포트폴리오 관리 기능 제공.
## 기술 스택
- **Backend**: FastAPI, Python 3.12, SQLAlchemy, PostgreSQL, uv
- **Frontend**: Next.js 15, React 19, TypeScript, Tailwind CSS
- **Infra**: Docker, Docker Compose
## 디렉토리 구조
```
galaxis-po/
├── backend/
│ ├── app/ # FastAPI 앱 (main.py, api/, core/, models/)
│ ├── jobs/ # 스케줄러, 데이터 수집 잡
│ ├── alembic/ # DB 마이그레이션
│ └── requirements.txt / pyproject.toml
├── frontend/ # Next.js 앱
├── docs/plans/ # 설계 문서 (구현 전 반드시 확인)
└── quant.md # 김종봉 전략 상세 가이드
```
## 개발 원칙
### 코드 작성 시
1. `docs/plans/` 의 관련 설계 문서를 먼저 확인할 것
2. `quant.md` 에 전략 로직이 정의되어 있음 — 임의 변경 금지
3. 기존 코드 스타일 유지 (Python: snake_case, TS: camelCase)
4. 모든 API 엔드포인트는 `backend/app/api/` 하위에 router로 추가
5. DB 스키마 변경 시 alembic migration 파일 함께 생성
### 금지 사항
- `.env` 파일 수정 금지 (`.env.example` 참고만 가능)
- `docker-compose.prod.yml` 임의 수정 금지
- 테스트 없는 비즈니스 로직 추가 금지
### 작업 완료 조건
- [ ] 기능 구현
- [ ] 관련 테스트 추가 또는 기존 테스트 통과 확인
- [ ] 타입 에러 없음 (Python: mypy / TS: tsc --noEmit)
- [ ] 작업 내용 요약 보고
## 자주 쓰는 명령
```bash
# 백엔드 개발 서버
cd backend && uv run uvicorn app.main:app --reload
# 프론트엔드 개발 서버
cd frontend && npm run dev
# DB 마이그레이션
cd backend && uv run alembic upgrade head
# 테스트 실행
cd backend && uv run pytest
```
## 보고 형식
작업 완료 시:
```
완료: [작업명]
변경 파일: [파일 목록]
주요 내용: [한 줄 요약]
주의사항: [있을 경우만]
```

99
CLAUDE.md Normal file
View File

@ -0,0 +1,99 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
Galaxis-Po is a quant portfolio management application for DC pension (퇴직연금) investing. It implements the Kim Jong-bong (김종봉) strategy for backtesting, signal generation, and portfolio management. The strategy logic is defined in `quant.md` — do not modify it without explicit instruction.
## Tech Stack
- **Backend:** FastAPI, Python 3.12, SQLAlchemy, PostgreSQL, uv (package manager)
- **Frontend:** Next.js 15 (App Router), React 19, TypeScript, Tailwind CSS v4, shadcn/ui (Radix primitives)
- **Infrastructure:** Docker Compose, PostgreSQL 18
## Common Commands
```bash
# Backend dev server (from repo root)
cd backend && uv run uvicorn app.main:app --reload
# Frontend dev server
cd frontend && npm run dev
# Run all backend tests
cd backend && uv run pytest
# Run a single test file
cd backend && uv run pytest tests/unit/test_kjb_signal.py -v
# Run e2e tests
cd backend && uv run pytest tests/e2e/ -v
# DB migration
cd backend && uv run alembic upgrade head
# Create new migration
cd backend && uv run alembic revision --autogenerate -m "description"
# Frontend lint
cd frontend && npm run lint
# Frontend type check
cd frontend && npx tsc --noEmit
# Start all services via Docker
docker-compose up -d
```
## Architecture
### Backend (`backend/`)
- `app/main.py` — FastAPI app with lifespan manager (seeds admin user, starts APScheduler)
- `app/api/` — Route handlers (routers): auth, admin, portfolio, strategy, market, backtest, snapshot, data_explorer, signal
- `app/models/` — SQLAlchemy ORM models: user, stock, portfolio, signal, backtest
- `app/schemas/` — Pydantic request/response schemas
- `app/services/` — Business logic layer:
- `collectors/` — Market data collectors (pykrx for Korean stock data, DART API for financials)
- `strategy/` — Kim Jong-bong strategy implementation (signal generation, factor calculation)
- `backtest/` — Backtesting engine
- `rebalance.py` — Portfolio rebalancing logic
- `price_service.py`, `factor_calculator.py`, `returns_calculator.py` — Quant utilities
- `app/core/` — Config (pydantic-settings from `.env`), database (SQLAlchemy), security (JWT/bcrypt)
- `jobs/` — APScheduler background jobs: data collection, signal generation, portfolio snapshots
- `alembic/` — Database migrations
- `tests/``unit/` and `e2e/` test directories
### Frontend (`frontend/`)
- Next.js App Router at `src/app/` with pages: portfolio, strategy, signals, backtest, admin, login
- `src/components/` — UI components organized by domain (portfolio, strategy, charts, layout, ui)
- `src/lib/api.ts` — Backend API client
- Uses lightweight-charts for financial charts, recharts for other visualizations
### galaxis-agent (`~/workspace/quant/galaxis-agent/`)
galaxis-po를 자율적으로 개발하는 SWE 에이전트 (별도 Gitea 리포: `quant/galaxis-agent`).
- `agent/` — 핵심 모듈: dispatcher, task_queue, cost_guard, task_history, recovery, auto_merge
- `agent/integrations/` — Discord bot, sandbox backends
- `agent/tools/` — 에이전트 도구 (gitea_comment, discord_reply)
- `agent/utils/` — 유틸리티 (gitea_client, discord_client, git_utils)
- `tests/` — 테스트 (139개, Phase 1-4)
- 설계 스펙: `docs/superpowers/specs/`, 구현 플랜: `docs/superpowers/plans/`
## Development Rules
1. Check `docs/plans/` and `docs/superpowers/plans/` for relevant design documents before implementing features
2. All API endpoints go under `backend/app/api/` as routers
3. DB schema changes require an alembic migration — autogenerate 후 반드시 리뷰하고 즉시 `alembic upgrade head`
4. Do not modify `.env` or `docker-compose.prod.yml` (`.env` 설정 안내는 허용, 자동 수정은 금지)
5. Python: snake_case; TypeScript: camelCase
6. External APIs: pykrx (한국 거래소 데이터, 백테스트/시그널 주력), KIS (실시간 매매), DART (재무제표)
7. 커밋은 논리 단위별로 개별 생성. 커밋 전 관련 테스트 실행 필수
8. Frontend 변경 후 `cd frontend && npx tsc --noEmit` 필수
## Environment
Backend config is loaded via pydantic-settings from environment variables / `.env` file. Key variables: `DATABASE_URL`, `JWT_SECRET`, `KIS_APP_KEY`, `KIS_APP_SECRET`, `KIS_ACCOUNT_NO`, `DART_API_KEY`. See `.env.example` for reference.

View File

@ -0,0 +1,45 @@
"""add walkforward_results table
Revision ID: 59807c4e84ee
Revises: b7c8d9e0f1a2
Create Date: 2026-03-18 22:28:53.955519
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '59807c4e84ee'
down_revision: Union[str, None] = 'b7c8d9e0f1a2'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('walkforward_results',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('backtest_id', sa.Integer(), nullable=False),
sa.Column('window_index', sa.Integer(), nullable=False),
sa.Column('train_start', sa.Date(), nullable=False),
sa.Column('train_end', sa.Date(), nullable=False),
sa.Column('test_start', sa.Date(), nullable=False),
sa.Column('test_end', sa.Date(), nullable=False),
sa.Column('test_return', sa.Numeric(precision=10, scale=4), nullable=True),
sa.Column('test_sharpe', sa.Numeric(precision=10, scale=4), nullable=True),
sa.Column('test_mdd', sa.Numeric(precision=10, scale=4), nullable=True),
sa.ForeignKeyConstraint(['backtest_id'], ['backtests.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_walkforward_results_id'), 'walkforward_results', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_walkforward_results_id'), table_name='walkforward_results')
op.drop_table('walkforward_results')
# ### end Alembic commands ###

View File

@ -0,0 +1,30 @@
"""add realized_pnl to transactions
Revision ID: 606a5011f84f
Revises: a1b2c3d4e5f6
Create Date: 2026-03-18 19:00:02.245720
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '606a5011f84f'
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('transactions', sa.Column('realized_pnl', sa.Numeric(precision=15, scale=2), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('transactions', 'realized_pnl')
# ### end Alembic commands ###

View File

@ -0,0 +1,42 @@
"""add missing performance indexes
Revision ID: c3d4e5f6a7b8
Revises: b7c8d9e0f1a2
Create Date: 2026-03-19 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c3d4e5f6a7b8"
down_revision: Union[str, None] = "59807c4e84ee"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Stock universe filtering (strategy engine uses market + market_cap frequently)
op.create_index("idx_stocks_market", "stocks", ["market"])
op.create_index(
"idx_stocks_market_cap", "stocks", [sa.text("market_cap DESC NULLS LAST")]
)
# Backtest listing by user (always filtered by user_id + ordered by created_at)
op.create_index(
"idx_backtests_user_created",
"backtests",
["user_id", sa.text("created_at DESC")],
)
op.create_index("idx_backtests_status", "backtests", ["status"])
def downgrade() -> None:
op.drop_index("idx_backtests_status", table_name="backtests")
op.drop_index("idx_backtests_user_created", table_name="backtests")
op.drop_index("idx_stocks_market_cap", table_name="stocks")
op.drop_index("idx_stocks_market", table_name="stocks")

View File

@ -0,0 +1,49 @@
"""add performance indexes
Revision ID: b7c8d9e0f1a2
Revises: 606a5011f84f
Create Date: 2026-03-18 22:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b7c8d9e0f1a2'
down_revision: Union[str, None] = '606a5011f84f'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Tier 1: backtest/strategy performance
op.create_index('idx_prices_ticker_date', 'prices', ['ticker', sa.text('date DESC')])
op.create_index('idx_etf_prices_ticker_date', 'etf_prices', ['ticker', sa.text('date DESC')])
op.create_index('idx_financials_ticker_base_date', 'financials', ['ticker', sa.text('base_date DESC')])
op.create_index('idx_valuations_ticker_base_date', 'valuations', ['ticker', sa.text('base_date DESC')])
# Tier 2: portfolio queries
op.create_index('idx_holdings_portfolio_id', 'holdings', ['portfolio_id'])
op.create_index('idx_transactions_portfolio_id_executed_at', 'transactions', ['portfolio_id', sa.text('executed_at DESC')])
op.create_index('idx_signals_date_status', 'signals', [sa.text('date DESC'), 'status'])
op.create_index('idx_snapshots_portfolio_date', 'portfolio_snapshots', ['portfolio_id', sa.text('snapshot_date DESC')])
# Tier 3: ETF filters
op.create_index('idx_etf_asset_class', 'etfs', ['asset_class'])
op.create_index('idx_etf_price_date', 'etf_prices', [sa.text('date DESC')])
def downgrade() -> None:
op.drop_index('idx_etf_price_date', table_name='etf_prices')
op.drop_index('idx_etf_asset_class', table_name='etfs')
op.drop_index('idx_snapshots_portfolio_date', table_name='portfolio_snapshots')
op.drop_index('idx_signals_date_status', table_name='signals')
op.drop_index('idx_transactions_portfolio_id_executed_at', table_name='transactions')
op.drop_index('idx_holdings_portfolio_id', table_name='holdings')
op.drop_index('idx_valuations_ticker_base_date', table_name='valuations')
op.drop_index('idx_financials_ticker_base_date', table_name='financials')
op.drop_index('idx_etf_prices_ticker_date', table_name='etf_prices')
op.drop_index('idx_prices_ticker_date', table_name='prices')

View File

@ -5,6 +5,7 @@ from datetime import timedelta
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from app.core.database import get_db
@ -22,7 +23,7 @@ router = APIRouter(prefix="/api/auth", tags=["auth"])
settings = get_settings()
@router.post("/login", response_model=Token)
@router.post("/login")
async def login(
login_data: LoginRequest,
db: Annotated[Session, Depends(get_db)],
@ -42,7 +43,19 @@ async def login(
expires_delta=timedelta(minutes=settings.access_token_expire_minutes),
)
return Token(access_token=access_token)
response = JSONResponse(
content={"access_token": access_token, "token_type": "bearer"},
)
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
samesite="lax",
secure=False, # Set True in production behind HTTPS
path="/",
max_age=settings.access_token_expire_minutes * 60,
)
return response
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@ -65,5 +78,13 @@ async def get_current_user_info(current_user: CurrentUser):
@router.post("/logout")
async def logout():
"""Logout (client should discard token)."""
return {"message": "Successfully logged out"}
"""Logout by clearing the access_token cookie."""
response = JSONResponse(content={"message": "Successfully logged out"})
response.delete_cookie(
key="access_token",
httponly=True,
samesite="lax",
secure=False,
path="/",
)
return response

View File

@ -1,22 +1,38 @@
"""
Backtest API endpoints.
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.backtest import (
Backtest, BacktestResult, BacktestEquityCurve,
BacktestHolding, BacktestTransaction, BacktestStatus,
Backtest,
BacktestResult,
BacktestEquityCurve,
BacktestHolding,
BacktestTransaction,
BacktestStatus,
WalkForwardResult,
)
from app.schemas.backtest import (
BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics,
EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem,
BacktestCreate,
BacktestResponse,
BacktestListItem,
BacktestMetrics,
EquityCurvePoint,
RebalanceHoldings,
HoldingItem,
TransactionItem,
WalkForwardRequest,
WalkForwardWindowResult,
WalkForwardResponse,
)
from app.services.backtest import submit_backtest
from app.services.backtest.walkforward_engine import WalkForwardEngine
from app.services.rebalance import RebalanceService
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
@ -62,6 +78,7 @@ async def list_backtests(
"""List all backtests for current user."""
backtests = (
db.query(Backtest)
.options(joinedload(Backtest.result))
.filter(Backtest.user_id == current_user.id)
.order_by(Backtest.created_at.desc())
.all()
@ -93,14 +110,15 @@ async def get_backtest(
db: Session = Depends(get_db),
):
"""Get backtest details and results."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
result_metrics = None
if backtest.result:
result_metrics = BacktestMetrics(
@ -140,14 +158,15 @@ async def get_equity_curve(
db: Session = Depends(get_db),
):
"""Get equity curve data for chart."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
curve = (
db.query(BacktestEquityCurve)
.filter(BacktestEquityCurve.backtest_id == backtest_id)
@ -173,14 +192,15 @@ async def get_holdings(
db: Session = Depends(get_db),
):
"""Get holdings at each rebalance date."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
holdings = (
db.query(BacktestHolding)
.filter(BacktestHolding.backtest_id == backtest_id)
@ -193,13 +213,15 @@ async def get_holdings(
for h in holdings:
if h.rebalance_date not in grouped:
grouped[h.rebalance_date] = []
grouped[h.rebalance_date].append(HoldingItem(
ticker=h.ticker,
name=h.name,
weight=h.weight,
shares=h.shares,
price=h.price,
))
grouped[h.rebalance_date].append(
HoldingItem(
ticker=h.ticker,
name=h.name,
weight=h.weight,
shares=h.shares,
price=h.price,
)
)
return [
RebalanceHoldings(rebalance_date=date, holdings=items)
@ -214,14 +236,15 @@ async def get_transactions(
db: Session = Depends(get_db),
):
"""Get all transactions."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
transactions = (
db.query(BacktestTransaction)
.filter(BacktestTransaction.backtest_id == backtest_id)
@ -249,6 +272,84 @@ async def get_transactions(
]
@router.post("/{backtest_id}/walkforward", response_model=dict)
async def run_walkforward(
backtest_id: int,
request: WalkForwardRequest,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Run walk-forward analysis on a completed backtest."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.status != BacktestStatus.COMPLETED:
raise HTTPException(
status_code=400,
detail="백테스트가 완료된 상태에서만 walk-forward 분석이 가능합니다",
)
engine = WalkForwardEngine(db)
try:
engine.run(
backtest_id=backtest_id,
train_months=request.train_months,
test_months=request.test_months,
step_months=request.step_months,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"status": "completed", "backtest_id": backtest_id}
@router.get("/{backtest_id}/walkforward", response_model=WalkForwardResponse)
async def get_walkforward(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get walk-forward analysis results."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
results = (
db.query(WalkForwardResult)
.filter(WalkForwardResult.backtest_id == backtest_id)
.order_by(WalkForwardResult.window_index)
.all()
)
return WalkForwardResponse(
backtest_id=backtest_id,
windows=[
WalkForwardWindowResult(
window_index=r.window_index,
train_start=r.train_start,
train_end=r.train_end,
test_start=r.test_start,
test_end=r.test_end,
test_return=r.test_return,
test_sharpe=r.test_sharpe,
test_mdd=r.test_mdd,
)
for r in results
],
)
@router.delete("/{backtest_id}")
async def delete_backtest(
backtest_id: int,
@ -256,15 +357,19 @@ async def delete_backtest(
db: Session = Depends(get_db),
):
"""Delete a backtest and all its data."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
# Delete related data
db.query(WalkForwardResult).filter(
WalkForwardResult.backtest_id == backtest_id
).delete()
db.query(BacktestTransaction).filter(
BacktestTransaction.backtest_id == backtest_id
).delete()
@ -274,9 +379,7 @@ async def delete_backtest(
db.query(BacktestEquityCurve).filter(
BacktestEquityCurve.backtest_id == backtest_id
).delete()
db.query(BacktestResult).filter(
BacktestResult.backtest_id == backtest_id
).delete()
db.query(BacktestResult).filter(BacktestResult.backtest_id == backtest_id).delete()
db.delete(backtest)
db.commit()

View File

@ -129,15 +129,25 @@ async def get_stock_prices(
ticker: str,
current_user: CurrentUser,
db: Session = Depends(get_db),
skip: int = Query(0, ge=0),
limit: int = Query(365, ge=1, le=3000),
):
"""Get daily prices for a stock."""
"""Get daily prices for a stock with pagination."""
base_query = db.query(Price).filter(Price.ticker == ticker)
total = base_query.count()
prices = (
db.query(Price)
.filter(Price.ticker == ticker)
.order_by(Price.date.asc())
base_query
.order_by(Price.date.desc())
.offset(skip)
.limit(limit)
.all()
)
return [PriceItem.model_validate(p) for p in prices]
return {
"items": [PriceItem.model_validate(p) for p in prices],
"total": total,
"skip": skip,
"limit": limit,
}
@router.get("/etfs")
@ -171,15 +181,25 @@ async def get_etf_prices(
ticker: str,
current_user: CurrentUser,
db: Session = Depends(get_db),
skip: int = Query(0, ge=0),
limit: int = Query(365, ge=1, le=3000),
):
"""Get daily prices for an ETF."""
"""Get daily prices for an ETF with pagination."""
base_query = db.query(ETFPrice).filter(ETFPrice.ticker == ticker)
total = base_query.count()
prices = (
db.query(ETFPrice)
.filter(ETFPrice.ticker == ticker)
.order_by(ETFPrice.date.asc())
base_query
.order_by(ETFPrice.date.desc())
.offset(skip)
.limit(limit)
.all()
)
return [ETFPriceItem.model_validate(p) for p in prices]
return {
"items": [ETFPriceItem.model_validate(p) for p in prices],
"total": total,
"skip": skip,
"limit": limit,
}
@router.get("/sectors")

View File

@ -1,9 +1,9 @@
"""
API dependencies.
"""
from typing import Annotated
from typing import Annotated, Optional
from fastapi import Depends, HTTPException, status
from fastapi import Cookie, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
@ -11,20 +11,29 @@ from app.core.database import get_db
from app.core.security import decode_access_token
from app.models.user import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login", auto_error=False)
async def get_current_user(
request: Request,
db: Annotated[Session, Depends(get_db)],
token: Annotated[str, Depends(oauth2_scheme)],
bearer_token: Annotated[Optional[str], Depends(oauth2_scheme)] = None,
) -> User:
"""Get the current authenticated user."""
"""Get the current authenticated user.
Token extraction order: httpOnly cookie first, then Authorization header fallback.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Cookie first, then Authorization header fallback
token = request.cookies.get("access_token") or bearer_token
if token is None:
raise credentials_exception
payload = decode_access_token(token)
if payload is None:
raise credentials_exception

View File

@ -4,7 +4,7 @@ Portfolio management API endpoints.
from decimal import Decimal
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.core.database import get_db
@ -19,6 +19,7 @@ from app.schemas.portfolio import (
RebalanceResponse, RebalanceSimulationRequest, RebalanceSimulationResponse,
RebalanceCalculateRequest, RebalanceCalculateResponse,
RebalanceApplyRequest, RebalanceApplyResponse,
PositionSizeResponse,
)
from app.services.rebalance import RebalanceService
@ -217,19 +218,26 @@ async def set_holdings(
return new_holdings
@router.get("/{portfolio_id}/transactions", response_model=List[TransactionResponse])
@router.get("/{portfolio_id}/transactions")
async def get_transactions(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
limit: int = 50,
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=500),
):
"""Get transaction history for a portfolio."""
"""Get transaction history for a portfolio with pagination."""
_get_portfolio(db, portfolio_id, current_user.id)
transactions = (
base_query = (
db.query(Transaction)
.filter(Transaction.portfolio_id == portfolio_id)
)
total = base_query.count()
transactions = (
base_query
.order_by(Transaction.executed_at.desc())
.offset(skip)
.limit(limit)
.all()
)
@ -239,19 +247,25 @@ async def get_transactions(
service = RebalanceService(db)
names = service.get_stock_names(tickers)
return [
TransactionResponse(
id=tx.id,
ticker=tx.ticker,
name=names.get(tx.ticker),
tx_type=tx.tx_type.value,
quantity=tx.quantity,
price=tx.price,
executed_at=tx.executed_at,
memo=tx.memo,
)
for tx in transactions
]
return {
"items": [
TransactionResponse(
id=tx.id,
ticker=tx.ticker,
name=names.get(tx.ticker),
tx_type=tx.tx_type.value,
quantity=tx.quantity,
price=tx.price,
executed_at=tx.executed_at,
memo=tx.memo,
realized_pnl=tx.realized_pnl,
)
for tx in transactions
],
"total": total,
"skip": skip,
"limit": limit,
}
@router.post("/{portfolio_id}/transactions", response_model=TransactionResponse, status_code=status.HTTP_201_CREATED)
@ -306,6 +320,8 @@ async def add_transaction(
status_code=400,
detail=f"Insufficient quantity for {data.ticker}"
)
# Calculate realized PnL: (sell_price - avg_price) * quantity
transaction.realized_pnl = (data.price - holding.avg_price) * data.quantity
holding.quantity -= data.quantity
if holding.quantity == 0:
db.delete(holding)
@ -362,6 +378,7 @@ async def calculate_rebalance_manual(
strategy=data.strategy,
manual_prices=data.prices,
additional_amount=data.additional_amount,
min_trade_amount=data.min_trade_amount,
)
@ -373,7 +390,7 @@ async def apply_rebalance(
db: Session = Depends(get_db),
):
"""리밸런싱 결과를 적용하여 거래를 일괄 생성한다."""
from datetime import datetime
from datetime import datetime, timezone
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
transactions = []
@ -387,7 +404,7 @@ async def apply_rebalance(
tx_type=tx_type,
quantity=item.quantity,
price=item.price,
executed_at=datetime.utcnow(),
executed_at=datetime.now(timezone.utc),
memo="리밸런싱 적용",
)
db.add(transaction)
@ -415,6 +432,7 @@ async def apply_rebalance(
elif tx_type == TransactionType.SELL:
if not holding or holding.quantity < item.quantity:
raise HTTPException(status_code=400, detail=f"Insufficient quantity for {item.ticker}")
transaction.realized_pnl = (item.price - holding.avg_price) * item.quantity
holding.quantity -= item.quantity
if holding.quantity == 0:
db.delete(holding)
@ -439,6 +457,7 @@ async def apply_rebalance(
price=tx.price,
executed_at=tx.executed_at,
memo=tx.memo,
realized_pnl=tx.realized_pnl,
)
for tx in transactions
]
@ -496,6 +515,19 @@ async def get_portfolio_detail(
if total_value > 0:
h.current_ratio = (h.value / total_value * 100).quantize(Decimal("0.01"))
# Calculate realized PnL (sum of all sell transactions with realized_pnl)
from sqlalchemy import func
total_realized_pnl_result = (
db.query(func.coalesce(func.sum(Transaction.realized_pnl), 0))
.filter(
Transaction.portfolio_id == portfolio_id,
Transaction.realized_pnl.isnot(None),
)
.scalar()
)
total_realized_pnl = Decimal(str(total_realized_pnl_result))
total_unrealized_pnl = (total_value - total_invested)
# Calculate risk asset ratio for pension portfolios
risk_asset_ratio = None
if portfolio.portfolio_type == PortfolioType.PENSION and total_value > 0:
@ -525,5 +557,75 @@ async def get_portfolio_detail(
total_value=total_value,
total_invested=total_invested,
total_profit_loss=total_value - total_invested,
total_realized_pnl=total_realized_pnl,
total_unrealized_pnl=total_unrealized_pnl,
risk_asset_ratio=risk_asset_ratio,
)
@router.get("/{portfolio_id}/position-size", response_model=PositionSizeResponse)
async def get_position_size(
portfolio_id: int,
ticker: str,
price: Decimal,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""포지션 사이징 가이드: 추천 수량과 최대 수량을 계산한다."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
service = RebalanceService(db)
# Calculate total portfolio value
holding_tickers = [h.ticker for h in portfolio.holdings]
prices = service.get_current_prices(holding_tickers)
total_value = Decimal("0")
for holding in portfolio.holdings:
cp = prices.get(holding.ticker, Decimal("0"))
total_value += cp * holding.quantity
# Current holding for this ticker
current_holding = db.query(Holding).filter(
Holding.portfolio_id == portfolio_id,
Holding.ticker == ticker,
).first()
current_qty = current_holding.quantity if current_holding else 0
current_value = price * current_qty
# Current ratio
current_ratio = (current_value / total_value * 100) if total_value > 0 else Decimal("0")
# Target ratio from portfolio targets
target = db.query(Target).filter(
Target.portfolio_id == portfolio_id,
Target.ticker == ticker,
).first()
target_ratio = Decimal(str(target.target_ratio)) if target else None
# Max position: 20% of portfolio (or target ratio if set)
max_ratio = target_ratio if target_ratio else Decimal("20")
max_value = total_value * max_ratio / 100
max_additional_value = max(max_value - current_value, Decimal("0"))
max_quantity = int(max_additional_value / price) if price > 0 else 0
# Recommended: equal-weight across targets, or 10% if no targets
num_targets = len(portfolio.targets) or 1
equal_ratio = Decimal("100") / num_targets
rec_ratio = target_ratio if target_ratio else min(equal_ratio, Decimal("10"))
rec_value = total_value * rec_ratio / 100
rec_additional_value = max(rec_value - current_value, Decimal("0"))
recommended_quantity = int(rec_additional_value / price) if price > 0 else 0
return PositionSizeResponse(
ticker=ticker,
price=price,
total_portfolio_value=total_value,
current_holding_quantity=current_qty,
current_holding_value=current_value,
current_ratio=current_ratio.quantize(Decimal("0.01")) if isinstance(current_ratio, Decimal) else current_ratio,
target_ratio=target_ratio,
recommended_quantity=recommended_quantity,
max_quantity=max_quantity,
recommended_value=rec_additional_value,
max_value=max_additional_value,
)

View File

@ -1,7 +1,7 @@
"""
KJB Signal API endpoints.
"""
from datetime import date, datetime
from datetime import date, datetime, timezone
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
@ -91,7 +91,7 @@ async def execute_signal(
tx_type=tx_type,
quantity=data.quantity,
price=data.price,
executed_at=datetime.utcnow(),
executed_at=datetime.now(timezone.utc),
memo=f"KJB signal #{signal.id}: {signal.signal_type.value}",
)
db.add(transaction)
@ -122,6 +122,8 @@ async def execute_signal(
status_code=400,
detail=f"Insufficient quantity for {signal.ticker}"
)
# Calculate realized PnL: (sell_price - avg_price) * quantity
transaction.realized_pnl = (data.price - holding.avg_price) * data.quantity
holding.quantity -= data.quantity
if holding.quantity == 0:
db.delete(holding)
@ -130,7 +132,7 @@ async def execute_signal(
signal.status = SignalStatus.EXECUTED
signal.executed_price = data.price
signal.executed_quantity = data.quantity
signal.executed_at = datetime.utcnow()
signal.executed_at = datetime.now(timezone.utc)
db.commit()
db.refresh(transaction)
@ -150,3 +152,85 @@ async def execute_signal(
"status": signal.status.value,
},
}
@router.delete("/{signal_id}/cancel", response_model=dict)
async def cancel_signal(
signal_id: int,
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""실행된 신호를 취소한다. 연결된 거래를 삭제하고 보유량을 복원하며 신호를 ACTIVE로 되돌린다."""
from app.api.portfolio import _get_portfolio
from decimal import Decimal
# 1. 신호 조회 및 상태 확인
signal = db.query(Signal).filter(Signal.id == signal_id).first()
if not signal:
raise HTTPException(status_code=404, detail="Signal not found")
if signal.status != SignalStatus.EXECUTED:
raise HTTPException(status_code=400, detail="Signal is not in EXECUTED status")
# 2. 포트폴리오 소유권 확인
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
# 3. 연결된 거래 조회 (신호 메모 기준)
memo_prefix = f"KJB signal #{signal_id}:"
transaction = (
db.query(Transaction)
.filter(
Transaction.portfolio_id == portfolio_id,
Transaction.ticker == signal.ticker,
Transaction.memo.like(f"{memo_prefix}%"),
)
.order_by(Transaction.executed_at.desc())
.first()
)
if not transaction:
raise HTTPException(status_code=404, detail="Related transaction not found")
# 4. 보유량 복원 (거래 역방향)
holding = db.query(Holding).filter(
Holding.portfolio_id == portfolio_id,
Holding.ticker == signal.ticker,
).first()
if transaction.tx_type == TransactionType.BUY:
# 매수 취소 → 보유량 감소
if holding:
holding.quantity -= transaction.quantity
if holding.quantity <= 0:
db.delete(holding)
elif transaction.tx_type == TransactionType.SELL:
# 매도 취소 → 보유량 복원
if holding:
# 평균단가 재계산 (역산 불가이므로 수량만 복원)
holding.quantity += transaction.quantity
else:
holding = Holding(
portfolio_id=portfolio_id,
ticker=signal.ticker,
quantity=transaction.quantity,
avg_price=transaction.price,
)
db.add(holding)
# 5. 거래 삭제
db.delete(transaction)
# 6. 신호 상태 복원
signal.status = SignalStatus.ACTIVE
signal.executed_price = None
signal.executed_quantity = None
signal.executed_at = None
db.commit()
return {
"signal_id": signal_id,
"signal_status": signal.status.value,
"transaction_deleted": True,
"ticker": signal.ticker,
}

View File

@ -5,7 +5,7 @@ from datetime import date
from decimal import Decimal
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.core.database import get_db
@ -33,23 +33,36 @@ def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio:
return portfolio
@router.get("/{portfolio_id}/snapshots", response_model=List[SnapshotListItem])
@router.get("/{portfolio_id}/snapshots")
async def list_snapshots(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
):
"""Get all snapshots for a portfolio."""
"""Get snapshots for a portfolio with pagination."""
_get_portfolio(db, portfolio_id, current_user.id)
snapshots = (
base_query = (
db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
)
total = base_query.count()
snapshots = (
base_query
.order_by(PortfolioSnapshot.snapshot_date.desc())
.offset(skip)
.limit(limit)
.all()
)
return snapshots
return {
"items": [SnapshotListItem.model_validate(s) for s in snapshots],
"total": total,
"skip": skip,
"limit": limit,
}
@router.post("/{portfolio_id}/snapshots", response_model=SnapshotResponse, status_code=status.HTTP_201_CREATED)

View File

@ -1,11 +1,14 @@
"""
Quant strategy API endpoints.
"""
from typing import Set
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.stock import ETF
from app.schemas.strategy import (
MultiFactorRequest, QualityRequest, ValueMomentumRequest, KJBRequest, StrategyResult,
)
@ -14,6 +17,27 @@ from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMom
router = APIRouter(prefix="/api/strategy", tags=["strategy"])
def _filter_dc_only(result: StrategyResult, db: Session) -> StrategyResult:
"""Filter strategy result to include only ETFs (DC pension investable)."""
tickers = [s.ticker for s in result.stocks]
etf_tickers: Set[str] = set(
row[0] for row in db.query(ETF.ticker).filter(ETF.ticker.in_(tickers)).all()
) if tickers else set()
filtered = [s for s in result.stocks if s.ticker in etf_tickers]
# Re-rank
for i, stock in enumerate(filtered, 1):
stock.rank = i
return StrategyResult(
strategy_name=result.strategy_name,
base_date=result.base_date,
universe_count=result.universe_count,
result_count=len(filtered),
stocks=filtered,
)
@router.post("/multi-factor", response_model=StrategyResult)
async def run_multi_factor(
request: MultiFactorRequest,
@ -22,12 +46,13 @@ async def run_multi_factor(
):
"""Run multi-factor strategy."""
strategy = MultiFactorStrategy(db)
return strategy.run(
result = strategy.run(
universe_filter=request.universe,
top_n=request.top_n,
base_date=request.base_date,
weights=request.weights,
)
return _filter_dc_only(result, db) if request.dc_only else result
@router.post("/quality", response_model=StrategyResult)
@ -38,12 +63,13 @@ async def run_quality(
):
"""Run super quality strategy."""
strategy = QualityStrategy(db)
return strategy.run(
result = strategy.run(
universe_filter=request.universe,
top_n=request.top_n,
base_date=request.base_date,
min_fscore=request.min_fscore,
)
return _filter_dc_only(result, db) if request.dc_only else result
@router.post("/value-momentum", response_model=StrategyResult)
@ -54,13 +80,14 @@ async def run_value_momentum(
):
"""Run value-momentum strategy."""
strategy = ValueMomentumStrategy(db)
return strategy.run(
result = strategy.run(
universe_filter=request.universe,
top_n=request.top_n,
base_date=request.base_date,
value_weight=request.value_weight,
momentum_weight=request.momentum_weight,
)
return _filter_dc_only(result, db) if request.dc_only else result
@router.post("/kjb", response_model=StrategyResult)
@ -71,8 +98,9 @@ async def run_kjb(
):
"""Run KJB strategy."""
strategy = KJBStrategy(db)
return strategy.run(
result = strategy.run(
universe_filter=request.universe,
top_n=request.top_n,
base_date=request.base_date,
)
return _filter_dc_only(result, db) if request.dc_only else result

View File

@ -1,6 +1,7 @@
"""
Application configuration using Pydantic Settings.
"""
from pydantic_settings import BaseSettings
from functools import lru_cache
@ -11,10 +12,10 @@ class Settings(BaseSettings):
debug: bool = False
# Database
database_url: str = "postgresql://galaxy:devpassword@localhost:5432/galaxy_po"
database_url: str
# JWT
jwt_secret: str = "dev-jwt-secret-change-in-production"
jwt_secret: str
jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 60 * 24 # 24 hours

View File

@ -1,18 +1,18 @@
"""
Database connection and session management.
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from app.core.config import get_settings
settings = get_settings()
engine = create_engine(
settings.database_url,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
)
_engine_kwargs = {"pool_pre_ping": True}
if settings.database_url.startswith("postgresql"):
_engine_kwargs.update(pool_size=10, max_overflow=20)
engine = create_engine(settings.database_url, **_engine_kwargs)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@ -51,6 +51,7 @@ class Backtest(Base):
equity_curve = relationship("BacktestEquityCurve", back_populates="backtest")
holdings = relationship("BacktestHolding", back_populates="backtest")
transactions = relationship("BacktestTransaction", back_populates="backtest")
walkforward_results = relationship("WalkForwardResult", back_populates="backtest")
class BacktestResult(Base):
@ -107,3 +108,20 @@ class BacktestTransaction(Base):
commission = Column(Numeric(12, 2), nullable=False)
backtest = relationship("Backtest", back_populates="transactions")
class WalkForwardResult(Base):
__tablename__ = "walkforward_results"
id = Column(Integer, primary_key=True, index=True)
backtest_id = Column(Integer, ForeignKey("backtests.id"), nullable=False)
window_index = Column(Integer, nullable=False)
train_start = Column(Date, nullable=False)
train_end = Column(Date, nullable=False)
test_start = Column(Date, nullable=False)
test_end = Column(Date, nullable=False)
test_return = Column(Numeric(10, 4), nullable=True)
test_sharpe = Column(Numeric(10, 4), nullable=True)
test_mdd = Column(Numeric(10, 4), nullable=True)
backtest = relationship("Backtest", back_populates="walkforward_results")

View File

@ -70,6 +70,7 @@ class Transaction(Base):
price = Column(Numeric(12, 2), nullable=False)
executed_at = Column(DateTime, nullable=False)
memo = Column(Text, nullable=True)
realized_pnl = Column(Numeric(15, 2), nullable=True)
portfolio = relationship("Portfolio", back_populates="transactions")

View File

@ -129,3 +129,31 @@ class TransactionItem(BaseModel):
class Config:
from_attributes = True
class WalkForwardRequest(BaseModel):
"""Request to run walk-forward analysis."""
train_months: int = Field(default=12, ge=3, le=60)
test_months: int = Field(default=3, ge=1, le=24)
step_months: int = Field(default=3, ge=1, le=24)
class WalkForwardWindowResult(BaseModel):
"""Single walk-forward window result."""
window_index: int
train_start: date
train_end: date
test_start: date
test_end: date
test_return: FloatDecimal | None = None
test_sharpe: FloatDecimal | None = None
test_mdd: FloatDecimal | None = None
class Config:
from_attributes = True
class WalkForwardResponse(BaseModel):
"""Walk-forward analysis results."""
backtest_id: int
windows: List[WalkForwardWindowResult]

View File

@ -72,6 +72,7 @@ class TransactionCreate(TransactionBase):
class TransactionResponse(TransactionBase):
id: int
name: str | None = None
realized_pnl: FloatDecimal | None = None
class Config:
from_attributes = True
@ -109,6 +110,8 @@ class PortfolioDetail(PortfolioResponse):
total_value: FloatDecimal | None = None
total_invested: FloatDecimal | None = None
total_profit_loss: FloatDecimal | None = None
total_realized_pnl: FloatDecimal | None = None
total_unrealized_pnl: FloatDecimal | None = None
risk_asset_ratio: FloatDecimal | None = None
@ -205,6 +208,7 @@ class RebalanceCalculateRequest(BaseModel):
strategy: str = Field(..., pattern="^(full_rebalance|additional_buy)$")
prices: Optional[dict[str, Decimal]] = None
additional_amount: Optional[Decimal] = Field(None, ge=0)
min_trade_amount: Optional[Decimal] = Field(default=Decimal("10000"), ge=0)
class RebalanceCalculateItem(BaseModel):
@ -247,3 +251,17 @@ class RebalanceApplyRequest(BaseModel):
class RebalanceApplyResponse(BaseModel):
transactions: List[TransactionResponse]
holdings_updated: int
class PositionSizeResponse(BaseModel):
ticker: str
price: FloatDecimal
total_portfolio_value: FloatDecimal
current_holding_quantity: int
current_holding_value: FloatDecimal
current_ratio: FloatDecimal
target_ratio: FloatDecimal | None = None
recommended_quantity: int
max_quantity: int
recommended_value: FloatDecimal
max_value: FloatDecimal

View File

@ -32,6 +32,7 @@ class StrategyRequest(BaseModel):
universe: UniverseFilter = UniverseFilter()
top_n: int = Field(default=30, ge=1, le=100)
base_date: Optional[date] = None
dc_only: bool = False
class MultiFactorRequest(StrategyRequest):

View File

@ -4,6 +4,7 @@ from app.services.backtest.metrics import MetricsCalculator, BacktestMetrics
from app.services.backtest.worker import submit_backtest, get_executor_status
from app.services.backtest.daily_engine import DailyBacktestEngine
from app.services.backtest.trading_portfolio import TradingPortfolio, TradingTransaction
from app.services.backtest.walkforward_engine import WalkForwardEngine
__all__ = [
"BacktestEngine",
@ -18,4 +19,5 @@ __all__ = [
"DailyBacktestEngine",
"TradingPortfolio",
"TradingTransaction",
"WalkForwardEngine",
]

View File

@ -1,6 +1,7 @@
"""
Main backtest engine.
"""
import logging
from dataclasses import dataclass, field
from datetime import date, timedelta
@ -12,13 +13,21 @@ from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.backtest import (
Backtest, BacktestResult, BacktestEquityCurve,
BacktestHolding, BacktestTransaction, RebalancePeriod,
Backtest,
BacktestResult,
BacktestEquityCurve,
BacktestHolding,
BacktestTransaction,
RebalancePeriod,
)
from app.models.stock import Stock, Price
from app.services.backtest.portfolio import VirtualPortfolio, Transaction
from app.services.backtest.metrics import MetricsCalculator
from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMomentumStrategy
from app.services.strategy import (
MultiFactorStrategy,
QualityStrategy,
ValueMomentumStrategy,
)
from app.schemas.strategy import UniverseFilter, FactorWeights
logger = logging.getLogger(__name__)
@ -27,6 +36,7 @@ logger = logging.getLogger(__name__)
@dataclass
class DataValidationResult:
"""Result of pre-backtest data validation."""
is_valid: bool = True
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
@ -85,9 +95,7 @@ class BacktestEngine:
logger.warning(f"Backtest {backtest_id}: {warning}")
if not validation.is_valid:
raise ValueError(
"데이터 검증 실패:\n" + "\n".join(validation.errors)
)
raise ValueError("데이터 검증 실패:\n" + "\n".join(validation.errors))
# Create strategy instance
strategy = self._create_strategy(
@ -105,20 +113,24 @@ class BacktestEngine:
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
names = self._get_stock_names()
all_date_prices = self._load_all_prices_by_date(
backtest.start_date,
backtest.end_date,
)
for trading_date in trading_days:
# Get prices for this date
prices = self._get_prices_for_date(trading_date)
names = self._get_stock_names()
prices = all_date_prices.get(trading_date, {})
# Warn about holdings with missing prices
missing = [
t for t in portfolio.holdings
t
for t in portfolio.holdings
if portfolio.holdings[t] > 0 and t not in prices
]
if missing:
logger.warning(
f"{trading_date}: 보유 종목 가격 누락 {missing} "
f"(0원으로 처리됨)"
f"{trading_date}: 보유 종목 가격 누락 {missing} (0원으로 처리됨)"
)
# Rebalance if needed
@ -141,16 +153,16 @@ class BacktestEngine:
slippage_rate=backtest.slippage_rate,
)
all_transactions.extend([
(trading_date, txn) for txn in transactions
])
all_transactions.extend([(trading_date, txn) for txn in transactions])
# Record holdings
holdings = portfolio.get_holdings_with_weights(prices, names)
holdings_history.append({
'date': trading_date,
'holdings': holdings,
})
holdings_history.append(
{
"date": trading_date,
"holdings": holdings,
}
)
# Record daily value
portfolio_value = portfolio.get_value(prices)
@ -161,15 +173,21 @@ class BacktestEngine:
benchmark_value / initial_benchmark * backtest.initial_capital
)
equity_curve_data.append({
'date': trading_date,
'portfolio_value': portfolio_value,
'benchmark_value': normalized_benchmark,
})
equity_curve_data.append(
{
"date": trading_date,
"portfolio_value": portfolio_value,
"benchmark_value": normalized_benchmark,
}
)
# Calculate metrics
portfolio_values = [Decimal(str(e['portfolio_value'])) for e in equity_curve_data]
benchmark_values = [Decimal(str(e['benchmark_value'])) for e in equity_curve_data]
portfolio_values = [
Decimal(str(e["portfolio_value"])) for e in equity_curve_data
]
benchmark_values = [
Decimal(str(e["benchmark_value"])) for e in equity_curve_data
]
metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values)
drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values)
@ -221,18 +239,13 @@ class BacktestEngine:
# 2. Benchmark data coverage
benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500"
benchmark_coverage = sum(
1 for d in total_days if d in benchmark_prices
)
benchmark_coverage = sum(1 for d in total_days if d in benchmark_prices)
benchmark_pct = (
benchmark_coverage / num_trading_days * 100
if num_trading_days > 0 else 0
benchmark_coverage / num_trading_days * 100 if num_trading_days > 0 else 0
)
if benchmark_coverage == 0:
result.errors.append(
f"벤치마크({benchmark_ticker}) 가격 데이터 없음"
)
result.errors.append(f"벤치마크({benchmark_ticker}) 가격 데이터 없음")
result.is_valid = False
elif benchmark_pct < 90:
result.warnings.append(
@ -254,22 +267,17 @@ class BacktestEngine:
.scalar()
)
if ticker_count == 0:
result.errors.append(
f"{sample_date} 가격 데이터 없음 (종목 0개)"
)
result.errors.append(f"{sample_date} 가격 데이터 없음 (종목 0개)")
result.is_valid = False
elif ticker_count < 100:
result.warnings.append(
f"{sample_date} 종목 수 적음: {ticker_count}"
)
result.warnings.append(f"{sample_date} 종목 수 적음: {ticker_count}")
# 4. Large gaps in trading days (> 7 calendar days excluding normal weekends)
for i in range(1, num_trading_days):
gap = (total_days[i] - total_days[i - 1]).days
if gap > 7:
result.warnings.append(
f"거래일 갭 발견: {total_days[i-1]} ~ {total_days[i]} "
f"({gap}일)"
f"거래일 갭 발견: {total_days[i - 1]} ~ {total_days[i]} ({gap}일)"
)
if result.is_valid and not result.warnings:
@ -338,18 +346,25 @@ class BacktestEngine:
return {p.date: p.close for p in prices}
def _get_prices_for_date(self, trading_date: date) -> Dict[str, Decimal]:
"""Get all stock prices for a specific date."""
def _load_all_prices_by_date(
self,
start_date: date,
end_date: date,
) -> Dict[date, Dict[str, Decimal]]:
prices = (
self.db.query(Price)
.filter(Price.date == trading_date)
.filter(Price.date >= start_date, Price.date <= end_date)
.all()
)
return {p.ticker: p.close for p in prices}
result: Dict[date, Dict[str, Decimal]] = {}
for p in prices:
if p.date not in result:
result[p.date] = {}
result[p.date][p.ticker] = p.close
return result
def _get_stock_names(self) -> Dict[str, str]:
"""Get all stock names."""
stocks = self.db.query(Stock).all()
stocks = self.db.query(Stock.ticker, Stock.name).all()
return {s.ticker: s.name for s in stocks}
def _create_strategy(
@ -367,8 +382,12 @@ class BacktestEngine:
strategy._min_fscore = strategy_params.get("min_fscore", 7)
elif strategy_type == "value_momentum":
strategy = ValueMomentumStrategy(self.db)
strategy._value_weight = Decimal(str(strategy_params.get("value_weight", 0.5)))
strategy._momentum_weight = Decimal(str(strategy_params.get("momentum_weight", 0.5)))
strategy._value_weight = Decimal(
str(strategy_params.get("value_weight", 0.5))
)
strategy._momentum_weight = Decimal(
str(strategy_params.get("momentum_weight", 0.5))
)
else:
raise ValueError(f"Unknown strategy type: {strategy_type}")
@ -401,19 +420,19 @@ class BacktestEngine:
for i, point in enumerate(equity_curve_data):
curve_point = BacktestEquityCurve(
backtest_id=backtest_id,
date=point['date'],
portfolio_value=point['portfolio_value'],
benchmark_value=point['benchmark_value'],
date=point["date"],
portfolio_value=point["portfolio_value"],
benchmark_value=point["benchmark_value"],
drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"),
)
self.db.add(curve_point)
# Save holdings history
for record in holdings_history:
for holding in record['holdings']:
for holding in record["holdings"]:
h = BacktestHolding(
backtest_id=backtest_id,
rebalance_date=record['date'],
rebalance_date=record["date"],
ticker=holding.ticker,
name=holding.name,
weight=holding.weight,

View File

@ -0,0 +1,235 @@
"""
Walk-forward analysis engine.
Splits backtest period into rolling train/test windows and runs
the existing BacktestEngine (or DailyBacktestEngine) on each test window.
Train window results are used for validation only (no parameter optimisation yet).
"""
import logging
from dataclasses import dataclass
from datetime import date
from decimal import Decimal
from typing import List
from dateutil.relativedelta import relativedelta
from sqlalchemy.orm import Session
from app.models.backtest import (
Backtest,
BacktestResult,
BacktestEquityCurve,
WalkForwardResult,
)
from app.services.backtest.engine import BacktestEngine
from app.services.backtest.metrics import MetricsCalculator
logger = logging.getLogger(__name__)
@dataclass
class Window:
index: int
train_start: date
train_end: date
test_start: date
test_end: date
class WalkForwardEngine:
"""
Walk-forward analysis using existing BacktestEngine.
Parameters
----------
train_months : int length of in-sample (training) window
test_months : int length of out-of-sample (test) window
step_months : int how far the window slides each iteration
"""
def __init__(self, db: Session):
self.db = db
# ------------------------------------------------------------------
# public
# ------------------------------------------------------------------
def run(
self,
backtest_id: int,
train_months: int = 12,
test_months: int = 3,
step_months: int = 3,
) -> List[WalkForwardResult]:
backtest = self.db.query(Backtest).get(backtest_id)
if not backtest:
raise ValueError(f"Backtest {backtest_id} not found")
windows = self._generate_windows(
start=backtest.start_date,
end=backtest.end_date,
train_months=train_months,
test_months=test_months,
step_months=step_months,
)
if not windows:
raise ValueError(
"기간이 너무 짧아 walk-forward 윈도우를 생성할 수 없습니다. "
f"최소 {train_months + test_months}개월 필요"
)
# Delete previous walk-forward results for this backtest
self.db.query(WalkForwardResult).filter(
WalkForwardResult.backtest_id == backtest_id
).delete()
self.db.flush()
engine = BacktestEngine(self.db)
results: List[WalkForwardResult] = []
for win in windows:
logger.info(
f"Walk-forward window {win.index}: "
f"test {win.test_start} ~ {win.test_end}"
)
test_return, test_sharpe, test_mdd = self._run_window(
engine, backtest, win
)
wf = WalkForwardResult(
backtest_id=backtest_id,
window_index=win.index,
train_start=win.train_start,
train_end=win.train_end,
test_start=win.test_start,
test_end=win.test_end,
test_return=test_return,
test_sharpe=test_sharpe,
test_mdd=test_mdd,
)
self.db.add(wf)
results.append(wf)
self.db.commit()
return results
# ------------------------------------------------------------------
# window generation
# ------------------------------------------------------------------
@staticmethod
def _generate_windows(
start: date,
end: date,
train_months: int,
test_months: int,
step_months: int,
) -> List[Window]:
windows: List[Window] = []
idx = 0
cursor = start
while True:
train_start = cursor
train_end = train_start + relativedelta(months=train_months) - relativedelta(days=1)
test_start = train_end + relativedelta(days=1)
test_end = test_start + relativedelta(months=test_months) - relativedelta(days=1)
if test_end > end:
# Allow partial last window if test_start is before end
if test_start <= end:
test_end = end
else:
break
windows.append(Window(
index=idx,
train_start=train_start,
train_end=train_end,
test_start=test_start,
test_end=test_end,
))
idx += 1
cursor += relativedelta(months=step_months)
return windows
# ------------------------------------------------------------------
# single window execution
# ------------------------------------------------------------------
def _run_window(
self,
engine: BacktestEngine,
backtest: Backtest,
win: Window,
) -> tuple:
"""Run backtest on the test window and return (return, sharpe, mdd)."""
try:
trading_days = engine._get_trading_days(win.test_start, win.test_end)
if not trading_days:
return (Decimal("0"), Decimal("0"), Decimal("0"))
benchmark_prices = engine._load_benchmark_prices(
backtest.benchmark, win.test_start, win.test_end
)
strategy = engine._create_strategy(
backtest.strategy_type,
backtest.strategy_params or {},
backtest.top_n,
)
from app.services.backtest.portfolio import VirtualPortfolio
from app.schemas.strategy import UniverseFilter
portfolio = VirtualPortfolio(backtest.initial_capital)
rebalance_dates = engine._generate_rebalance_dates(
win.test_start, win.test_end, backtest.rebalance_period,
)
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
equity_curve: List[Decimal] = []
benchmark_curve: List[Decimal] = []
for trading_date in trading_days:
prices = engine._get_prices_for_date(trading_date)
names = engine._get_stock_names()
if trading_date in rebalance_dates:
target_stocks = strategy.run(
universe_filter=UniverseFilter(),
top_n=backtest.top_n,
base_date=trading_date,
)
target_tickers = [s.ticker for s in target_stocks.stocks]
portfolio.rebalance(
target_tickers=target_tickers,
prices=prices,
names=names,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
)
portfolio_value = portfolio.get_value(prices)
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
normalized_benchmark = (
benchmark_value / initial_benchmark * backtest.initial_capital
)
equity_curve.append(Decimal(str(portfolio_value)))
benchmark_curve.append(Decimal(str(normalized_benchmark)))
if len(equity_curve) < 2:
return (Decimal("0"), Decimal("0"), Decimal("0"))
metrics = MetricsCalculator.calculate_all(equity_curve, benchmark_curve)
return (metrics.total_return, metrics.sharpe_ratio, metrics.mdd)
except Exception as e:
logger.warning(f"Walk-forward window {win.index} failed: {e}")
return (Decimal("0"), Decimal("0"), Decimal("0"))

View File

@ -1,7 +1,7 @@
"""
Background worker for backtest execution.
"""
from datetime import datetime
from datetime import datetime, timezone
from concurrent.futures import ThreadPoolExecutor
import logging
@ -35,7 +35,7 @@ def _run_backtest_job(backtest_id: int) -> None:
try:
# Update status to running
backtest = db.query(Backtest).get(backtest_id)
backtest = db.get(Backtest, backtest_id)
if not backtest:
logger.error(f"Backtest {backtest_id} not found")
return
@ -54,7 +54,7 @@ def _run_backtest_job(backtest_id: int) -> None:
# Update status to completed
backtest.status = BacktestStatus.COMPLETED
backtest.completed_at = datetime.utcnow()
backtest.completed_at = datetime.now(timezone.utc)
db.commit()
logger.info(f"Backtest {backtest_id} completed successfully")
@ -63,11 +63,11 @@ def _run_backtest_job(backtest_id: int) -> None:
# Update status to failed
try:
backtest = db.query(Backtest).get(backtest_id)
backtest = db.get(Backtest, backtest_id)
if backtest:
backtest.status = BacktestStatus.FAILED
backtest.error_message = str(e)[:1000] # Limit error message length
backtest.completed_at = datetime.utcnow()
backtest.completed_at = datetime.now(timezone.utc)
db.commit()
except Exception as commit_error:
logger.exception(f"Failed to update backtest status: {commit_error}")

View File

@ -4,7 +4,7 @@ Base collector class for data collection jobs.
import logging
import re
from abc import ABC, abstractmethod
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional
import requests
@ -44,7 +44,7 @@ class BaseCollector(ABC):
self.job_log = JobLog(
job_name=self.job_name,
status="running",
started_at=datetime.utcnow(),
started_at=datetime.now(timezone.utc),
)
self.db.add(self.job_log)
self.db.commit()
@ -55,7 +55,7 @@ class BaseCollector(ABC):
if self.job_log:
try:
self.job_log.status = "success"
self.job_log.finished_at = datetime.utcnow()
self.job_log.finished_at = datetime.now(timezone.utc)
self.job_log.records_count = records_count
self.db.commit()
except Exception:
@ -67,7 +67,7 @@ class BaseCollector(ABC):
if self.job_log:
try:
self.job_log.status = "failed"
self.job_log.finished_at = datetime.utcnow()
self.job_log.finished_at = datetime.now(timezone.utc)
self.job_log.error_msg = error_msg
self.db.commit()
except Exception:

View File

@ -193,6 +193,7 @@ class RebalanceService:
strategy: str,
manual_prices: Optional[Dict[str, Decimal]] = None,
additional_amount: Optional[Decimal] = None,
min_trade_amount: Optional[Decimal] = None,
):
"""Calculate rebalance with optional manual prices and strategy selection."""
from app.schemas.portfolio import RebalanceCalculateItem, RebalanceCalculateResponse
@ -228,17 +229,29 @@ class RebalanceService:
current_values, total_assets, stock_names,
prev_prices, start_prices,
)
return RebalanceCalculateResponse(
portfolio_id=portfolio.id,
total_assets=total_assets,
items=items,
)
else: # additional_buy
items = self._calc_additional_buy(
all_tickers, targets, holdings, current_prices,
current_values, total_assets, additional_amount,
stock_names, prev_prices, start_prices,
)
# Filter out trades below min_trade_amount
if min_trade_amount and min_trade_amount > 0:
for item in items:
if item.action != "hold":
trade_value = abs(item.diff_quantity) * item.current_price
if trade_value < min_trade_amount:
item.diff_quantity = 0
item.action = "hold"
if strategy == "full_rebalance":
return RebalanceCalculateResponse(
portfolio_id=portfolio.id,
total_assets=total_assets,
items=items,
)
else:
return RebalanceCalculateResponse(
portfolio_id=portfolio.id,
total_assets=total_assets,

View File

@ -1,10 +1,14 @@
"""
Pytest configuration and fixtures for E2E tests.
"""
import os
import pytest
from typing import Generator
os.environ.setdefault("DATABASE_URL", "sqlite:///:memory:")
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-pytest-only")
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session

View File

@ -62,7 +62,8 @@ def test_stock_prices(client: TestClient, auth_headers, db: Session):
resp = client.get("/api/data/stocks/005930/prices", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert len(data["items"]) == 2
assert data["total"] == 2
def test_list_etfs(client: TestClient, auth_headers, db: Session):
@ -76,7 +77,9 @@ def test_etf_prices(client: TestClient, auth_headers, db: Session):
_seed_etf(db)
resp = client.get("/api/data/etfs/069500/prices", headers=auth_headers)
assert resp.status_code == 200
assert len(resp.json()) == 2
data = resp.json()
assert len(data["items"]) == 2
assert data["total"] == 2
def test_list_sectors(client: TestClient, auth_headers, db: Session):

View File

@ -72,3 +72,56 @@ def test_signal_requires_auth(client: TestClient):
"""Test that signal endpoints require authentication."""
response = client.get("/api/signal/kjb/today")
assert response.status_code == 401
def test_cancel_executed_signal(client: TestClient, auth_headers):
"""실행된 신호를 취소하면 거래가 삭제되고 보유량이 복원된다."""
# 1. 포트폴리오 생성
resp = client.post(
"/api/portfolios",
json={"name": "Signal Cancel Test", "portfolio_type": "general"},
headers=auth_headers,
)
assert resp.status_code == 201
portfolio_id = resp.json()["id"]
# 2. 신호 생성 (직접 DB 경유 없이 API로 생성 불가 → 신호 실행 취소는 EXECUTED 신호에만 작동)
# 오늘 날짜로 신호 조회해서 없으면 스킵
today_resp = client.get("/api/signal/kjb/today", headers=auth_headers)
signals = today_resp.json()
if not signals:
# 신호가 없으면 엔드포인트 존재만 검증 (portfolio_id 포함)
resp = client.delete("/api/signal/9999/cancel", params={"portfolio_id": 9999}, headers=auth_headers)
assert resp.status_code in [404, 400]
return
# 3. ACTIVE 신호에 보유 종목 세팅 후 신호 실행
signal = signals[0]
ticker = signal["ticker"]
client.put(
f"/api/portfolios/{portfolio_id}/holdings",
json=[{"ticker": ticker, "quantity": 100, "avg_price": 10000}],
headers=auth_headers,
)
exec_resp = client.post(
f"/api/signal/{signal['id']}/execute",
json={"portfolio_id": portfolio_id, "quantity": 10, "price": 10000},
headers=auth_headers,
)
# 신호 타입에 따라 실패할 수 있음
if exec_resp.status_code != 200:
return
# 4. 취소 요청
cancel_resp = client.delete(
f"/api/signal/{signal['id']}/cancel",
params={"portfolio_id": portfolio_id},
headers=auth_headers,
)
assert cancel_resp.status_code == 200
data = cancel_resp.json()
assert data["signal_status"] == "active"
assert data["transaction_deleted"] is True

View File

@ -198,7 +198,7 @@ def test_transaction_flow(client: TestClient, auth_headers):
headers=auth_headers,
)
assert response.status_code == 200
txs = response.json()
txs = response.json()["items"]
assert len(txs) == 2

View File

@ -0,0 +1,265 @@
"""
E2E tests for realized/unrealized PnL tracking and position sizing.
"""
import pytest
from fastapi.testclient import TestClient
def _create_portfolio(client: TestClient, auth_headers: dict, name: str = "PnL Test") -> int:
"""Helper to create a portfolio and return its ID."""
resp = client.post(
"/api/portfolios",
json={"name": name, "portfolio_type": "general"},
headers=auth_headers,
)
assert resp.status_code == 201
return resp.json()["id"]
def test_sell_transaction_records_realized_pnl(client: TestClient, auth_headers):
"""매도 거래 시 realized_pnl이 계산되어 저장된다."""
pid = _create_portfolio(client, auth_headers)
# Buy 10 shares at 70,000
client.post(
f"/api/portfolios/{pid}/transactions",
json={
"ticker": "005930",
"tx_type": "buy",
"quantity": 10,
"price": 70000,
"executed_at": "2024-01-15T10:00:00",
},
headers=auth_headers,
)
# Sell 5 shares at 80,000 → realized_pnl = (80000 - 70000) * 5 = 50,000
resp = client.post(
f"/api/portfolios/{pid}/transactions",
json={
"ticker": "005930",
"tx_type": "sell",
"quantity": 5,
"price": 80000,
"executed_at": "2024-01-20T10:00:00",
},
headers=auth_headers,
)
assert resp.status_code == 201
tx = resp.json()
assert tx["realized_pnl"] == 50000.0
def test_sell_transaction_loss_realized_pnl(client: TestClient, auth_headers):
"""매도 손실 시 음수 realized_pnl이 기록된다."""
pid = _create_portfolio(client, auth_headers)
# Buy 10 shares at 70,000
client.post(
f"/api/portfolios/{pid}/transactions",
json={
"ticker": "005930",
"tx_type": "buy",
"quantity": 10,
"price": 70000,
"executed_at": "2024-01-15T10:00:00",
},
headers=auth_headers,
)
# Sell 5 shares at 60,000 → realized_pnl = (60000 - 70000) * 5 = -50,000
resp = client.post(
f"/api/portfolios/{pid}/transactions",
json={
"ticker": "005930",
"tx_type": "sell",
"quantity": 5,
"price": 60000,
"executed_at": "2024-01-20T10:00:00",
},
headers=auth_headers,
)
assert resp.status_code == 201
tx = resp.json()
assert tx["realized_pnl"] == -50000.0
def test_buy_transaction_no_realized_pnl(client: TestClient, auth_headers):
"""매수 거래에는 realized_pnl이 없다."""
pid = _create_portfolio(client, auth_headers)
resp = client.post(
f"/api/portfolios/{pid}/transactions",
json={
"ticker": "005930",
"tx_type": "buy",
"quantity": 10,
"price": 70000,
"executed_at": "2024-01-15T10:00:00",
},
headers=auth_headers,
)
assert resp.status_code == 201
tx = resp.json()
assert tx["realized_pnl"] is None
def test_transaction_list_includes_realized_pnl(client: TestClient, auth_headers):
"""거래 목록 조회 시 realized_pnl이 포함된다."""
pid = _create_portfolio(client, auth_headers)
# Buy then sell
client.post(
f"/api/portfolios/{pid}/transactions",
json={
"ticker": "005930",
"tx_type": "buy",
"quantity": 10,
"price": 70000,
"executed_at": "2024-01-15T10:00:00",
},
headers=auth_headers,
)
client.post(
f"/api/portfolios/{pid}/transactions",
json={
"ticker": "005930",
"tx_type": "sell",
"quantity": 5,
"price": 75000,
"executed_at": "2024-01-20T10:00:00",
},
headers=auth_headers,
)
resp = client.get(f"/api/portfolios/{pid}/transactions", headers=auth_headers)
assert resp.status_code == 200
txs = resp.json()["items"]
assert len(txs) == 2
# Most recent first (sell)
sell_tx = next(t for t in txs if t["tx_type"] == "sell")
buy_tx = next(t for t in txs if t["tx_type"] == "buy")
assert sell_tx["realized_pnl"] == 25000.0
assert buy_tx["realized_pnl"] is None
def test_portfolio_detail_includes_realized_unrealized_pnl(client: TestClient, auth_headers):
"""포트폴리오 상세에 실현/미실현 수익이 포함된다."""
pid = _create_portfolio(client, auth_headers)
# Buy 10 shares at 70,000
client.post(
f"/api/portfolios/{pid}/transactions",
json={
"ticker": "005930",
"tx_type": "buy",
"quantity": 10,
"price": 70000,
"executed_at": "2024-01-15T10:00:00",
},
headers=auth_headers,
)
# Sell 5 shares at 80,000
client.post(
f"/api/portfolios/{pid}/transactions",
json={
"ticker": "005930",
"tx_type": "sell",
"quantity": 5,
"price": 80000,
"executed_at": "2024-01-20T10:00:00",
},
headers=auth_headers,
)
resp = client.get(f"/api/portfolios/{pid}/detail", headers=auth_headers)
assert resp.status_code == 200
detail = resp.json()
assert detail["total_realized_pnl"] == 50000.0
# unrealized_pnl depends on current prices but should be present
assert "total_unrealized_pnl" in detail
def test_rebalance_apply_records_realized_pnl(client: TestClient, auth_headers):
"""리밸런싱 적용 시 매도 거래에 realized_pnl이 기록된다."""
pid = _create_portfolio(client, auth_headers)
# Setup initial holdings
client.put(
f"/api/portfolios/{pid}/holdings",
json=[{"ticker": "005930", "quantity": 10, "avg_price": 70000}],
headers=auth_headers,
)
# Apply rebalance with a sell
resp = client.post(
f"/api/portfolios/{pid}/rebalance/apply",
json={
"items": [
{"ticker": "005930", "action": "sell", "quantity": 3, "price": 75000},
]
},
headers=auth_headers,
)
assert resp.status_code == 201
data = resp.json()
sell_tx = data["transactions"][0]
assert sell_tx["realized_pnl"] == 15000.0 # (75000 - 70000) * 3
def test_position_size_endpoint(client: TestClient, auth_headers):
"""포지션 사이징 가이드 API가 올바르게 동작한다."""
pid = _create_portfolio(client, auth_headers)
# Set holdings and targets
client.put(
f"/api/portfolios/{pid}/holdings",
json=[
{"ticker": "005930", "quantity": 10, "avg_price": 70000},
],
headers=auth_headers,
)
client.put(
f"/api/portfolios/{pid}/targets",
json=[
{"ticker": "005930", "target_ratio": 50},
{"ticker": "000660", "target_ratio": 50},
],
headers=auth_headers,
)
# Get position size for a new ticker
resp = client.get(
f"/api/portfolios/{pid}/position-size?ticker=000660&price=150000",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["ticker"] == "000660"
assert data["current_holding_quantity"] == 0
assert data["target_ratio"] == 50.0
assert data["recommended_quantity"] >= 0
assert data["max_quantity"] >= 0
def test_position_size_no_targets(client: TestClient, auth_headers):
"""목표 비중 없을 때 포지션 사이징이 기본값으로 동작한다."""
pid = _create_portfolio(client, auth_headers)
client.put(
f"/api/portfolios/{pid}/holdings",
json=[{"ticker": "005930", "quantity": 10, "avg_price": 70000}],
headers=auth_headers,
)
resp = client.get(
f"/api/portfolios/{pid}/position-size?ticker=000660&price=150000",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["ticker"] == "000660"
assert data["target_ratio"] is None
# Without targets, max should use 20% default
assert data["max_quantity"] >= 0

View File

@ -158,3 +158,48 @@ def test_apply_rebalance_insufficient_quantity(client: TestClient, auth_headers)
headers=auth_headers,
)
assert response.status_code == 400
def test_min_trade_amount_filters_small_trades(client: TestClient, auth_headers):
"""min_trade_amount 미만 거래는 hold로 변경된다."""
pid = _setup_portfolio_with_holdings(client, auth_headers)
# With very high min_trade_amount, all trades should become hold
response = client.post(
f"/api/portfolios/{pid}/rebalance/calculate",
json={
"strategy": "full_rebalance",
"prices": {"069500": 50000, "148070": 110000},
"min_trade_amount": 99999999,
},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
for item in data["items"]:
assert item["action"] == "hold"
assert item["diff_quantity"] == 0
def test_min_trade_amount_allows_large_trades(client: TestClient, auth_headers):
"""min_trade_amount 이상 거래는 정상 처리된다."""
pid = _setup_portfolio_with_holdings(client, auth_headers)
# Use skewed prices to create a meaningful rebalancing diff
# 069500: 10 * 30000 = 300,000 / 148070: 5 * 200000 = 1,000,000
# total = 1,300,000, target each 50% = 650,000
# 069500 needs buy: (650000-300000)/30000 = 11 shares => 330,000 trade value
response = client.post(
f"/api/portfolios/{pid}/rebalance/calculate",
json={
"strategy": "full_rebalance",
"prices": {"069500": 30000, "148070": 200000},
"min_trade_amount": 1,
},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
# At least one item should have buy or sell action
actions = [item["action"] for item in data["items"]]
assert "buy" in actions or "sell" in actions

View File

@ -40,7 +40,9 @@ def test_snapshot_list_empty(client: TestClient, auth_headers):
headers=auth_headers,
)
assert response.status_code == 200
assert response.json() == []
data = response.json()
assert data["items"] == []
assert data["total"] == 0
def test_returns_empty(client: TestClient, auth_headers):

View File

@ -79,6 +79,50 @@ def test_value_momentum_strategy(client: TestClient, auth_headers):
assert data["strategy_name"] == "value_momentum"
def test_dc_only_filter(client: TestClient, auth_headers):
"""Test dc_only parameter filters to ETFs only."""
response = client.post(
"/api/strategy/multi-factor",
json={
"universe": {
"markets": ["KOSPI"],
},
"top_n": 20,
"dc_only": True,
"weights": {
"value": 0.3,
"quality": 0.3,
"momentum": 0.2,
"low_vol": 0.2,
},
},
headers=auth_headers,
)
# May fail if no data, just check it accepts the parameter
assert response.status_code in [200, 400, 500]
if response.status_code == 200:
data = response.json()
assert "stocks" in data
# All returned stocks should be ETFs (or empty if no ETFs in universe)
def test_dc_only_false_returns_all(client: TestClient, auth_headers):
"""Test dc_only=false returns all stocks (default behavior)."""
response = client.post(
"/api/strategy/multi-factor",
json={
"universe": {
"markets": ["KOSPI"],
},
"top_n": 20,
"dc_only": False,
},
headers=auth_headers,
)
assert response.status_code in [200, 400, 500]
def test_strategy_requires_auth(client: TestClient):
"""Test that strategy endpoints require authentication."""
response = client.post(

View File

@ -0,0 +1,115 @@
"""
Unit tests for WalkForwardEngine window generation logic.
"""
from datetime import date
import pytest
from app.services.backtest.walkforward_engine import WalkForwardEngine, Window
class TestGenerateWindows:
"""Test _generate_windows static method."""
def test_basic_windows(self):
"""2-year period with 12m train, 3m test, 3m step -> 4 windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2021, 12, 31),
train_months=12,
test_months=3,
step_months=3,
)
assert len(windows) == 4
assert windows[0].index == 0
assert windows[0].train_start == date(2020, 1, 1)
assert windows[0].train_end == date(2020, 12, 31)
assert windows[0].test_start == date(2021, 1, 1)
assert windows[0].test_end == date(2021, 3, 31)
def test_single_window(self):
"""Exactly 15 months -> 1 window with 12m train + 3m test."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2021, 3, 31),
train_months=12,
test_months=3,
step_months=3,
)
assert len(windows) == 1
assert windows[0].train_start == date(2020, 1, 1)
assert windows[0].test_end == date(2021, 3, 31)
def test_no_windows_period_too_short(self):
"""Period shorter than train + test -> 0 windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2020, 12, 31),
train_months=12,
test_months=3,
step_months=3,
)
assert len(windows) == 0
def test_partial_last_window(self):
"""Last window with partial test period is included."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2021, 2, 15),
train_months=12,
test_months=3,
step_months=3,
)
assert len(windows) == 1
assert windows[0].test_end == date(2021, 2, 15)
def test_step_larger_than_test(self):
"""step_months > test_months creates non-overlapping test windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2019, 1, 1),
end=date(2022, 12, 31),
train_months=12,
test_months=3,
step_months=6,
)
assert len(windows) >= 2
# test windows should not overlap
for i in range(1, len(windows)):
assert windows[i].test_start > windows[i - 1].test_end
def test_monthly_step(self):
"""step_months=1 creates many overlapping windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2021, 6, 30),
train_months=6,
test_months=3,
step_months=1,
)
assert len(windows) >= 9
def test_window_indices_sequential(self):
"""Window indices should be sequential starting from 0."""
windows = WalkForwardEngine._generate_windows(
start=date(2019, 1, 1),
end=date(2022, 12, 31),
train_months=12,
test_months=3,
step_months=3,
)
for i, w in enumerate(windows):
assert w.index == i
def test_window_dates_consistent(self):
"""train_end < test_start and test_start <= test_end for all windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2019, 1, 1),
end=date(2023, 12, 31),
train_months=12,
test_months=6,
step_months=3,
)
for w in windows:
assert w.train_start < w.train_end
assert w.train_end < w.test_start
assert w.test_start <= w.test_end

View File

@ -0,0 +1,378 @@
# Galaxis-Po 코드 품질 및 완성도 심층 분석 보고서
**분석일:** 2026-03-18
**분석 범위:** Backend (FastAPI), Frontend (Next.js), DB 모델, 설계 문서
---
## 목차
1. [미구현/불완전 기능 분석](#1-미구현불완전-기능-분석)
2. [코드 품질 이슈](#2-코드-품질-이슈)
3. [보안/안정성 이슈](#3-보안안정성-이슈)
4. [성능 이슈](#4-성능-이슈)
5. [Walk-forward 분석 구현 가능성 평가](#5-walk-forward-분석-구현-가능성-평가)
6. [종합 평가 및 권고사항](#6-종합-평가-및-권고사항)
---
## 1. 미구현/불완전 기능 분석
### 1.1 Backend API 엔드포인트 전수 현황
**9개 라우터**, **44개 엔드포인트** + 헬스체크 1개
| 라우터 | 경로 접두사 | 엔드포인트 수 | 인증 필요 |
|--------|------------|:------------:|:---------:|
| auth | `/api/auth` | 4 | 부분적 |
| admin | `/api/admin` | 8 | 전체 |
| portfolio | `/api/portfolios` | 16 | 전체 |
| strategy | `/api/strategy` | 4 | 전체 |
| market | `/api/market` | 3 | 전체 |
| backtest | `/api/backtest` | 7 | 전체 |
| snapshot | `/api/portfolios/{id}` | 5 | 전체 |
| data_explorer | `/api/data` | 6 | 전체 |
| signal | `/api/signal` | 4 | 전체 |
### 1.2 Frontend 페이지 현황
**15개 페이지**, **32개 API 호출**
| 페이지 | 경로 | 주요 기능 |
|--------|------|----------|
| 로그인 | `/login` | SHA-256 해시 비밀번호 로그인 |
| 대시보드 | `/` | 총자산, 수익률, 자산배분 차트 |
| 포트폴리오 목록 | `/portfolio` | 포트폴리오 카드 그리드 |
| 포트폴리오 생성 | `/portfolio/new` | 이름, 유형 선택 |
| 포트폴리오 상세 | `/portfolio/[id]` | 보유종목, 거래내역, 분석 탭 |
| 리밸런싱 | `/portfolio/[id]/rebalance` | 전략 선택, 수동 가격 입력, 적용 |
| 포트폴리오 이력 | `/portfolio/[id]/history` | 스냅샷, 수익률 추이 |
| 전략 목록 | `/strategy` | 4개 전략 카드 |
| KJB 전략 | `/strategy/kjb` | 김종봉 전략 실행 |
| 멀티팩터 전략 | `/strategy/multi-factor` | 가중치 설정, 실행 |
| 퀄리티 전략 | `/strategy/quality` | F-Score 기반 |
| 가치모멘텀 전략 | `/strategy/value-momentum` | 가치+모멘텀 가중치 |
| 전략 비교 | `/strategy/compare` | 3개 전략 병렬 비교 |
| 백테스트 | `/backtest` | 생성, 결과 조회 |
| 백테스트 상세 | `/backtest/[id]` | 수익곡선, 드로다운, 거래내역 |
| 관리자 데이터 | `/admin/data` | 6개 수집기 실행/상태 |
| 데이터 탐색기 | `/admin/data/explorer` | 종목/ETF/섹터/밸류에이션 조회 |
### 1.3 API-UI 매핑 갭 분석
#### API 있으나 UI 없는 항목
| API 엔드포인트 | 상태 | 설명 |
|---------------|------|------|
| `PUT /api/portfolios/{id}` | UI 없음 | 포트폴리오 이름/유형 수정 기능 |
| `DELETE /api/portfolios/{id}` | UI 없음 | 포트폴리오 삭제 기능 |
| `PUT /api/portfolios/{id}/targets` | UI 없음 | 목표 비중 설정 (독립 UI 없음, 리밸런싱에서 간접 사용) |
| `PUT /api/portfolios/{id}/holdings` | UI 없음 | 보유종목 직접 설정 |
| `POST /api/portfolios/{id}/transactions` | UI 없음 | 수동 거래 추가 (신호 실행 외) |
| `GET /api/portfolios/{id}/rebalance` | UI 없음 | 자동 리밸런싱 계산 (수동 계산만 UI에 있음) |
| `POST /api/portfolios/{id}/rebalance/simulate` | UI 없음 | 추가 투자 시뮬레이션 |
| `GET /api/market/stocks/{ticker}` | UI 없음 | 개별 종목 상세 정보 |
| `GET /api/market/stocks/{ticker}/prices` | UI 없음 | 개별 종목 가격 차트 (data_explorer와 중복) |
| `GET /api/market/search` | UI 없음 | 종목 검색 (독립 UI 없음) |
| `DELETE /api/backtest/{id}` | UI 없음 | 백테스트 삭제 |
| `POST /api/auth/register` | 비활성 | 코드에서 403 반환으로 비활성화됨 |
| `POST /api/admin/collect/backfill` | UI 없음 | 과거 데이터 백필 기능 |
#### UI 있으나 실제 데이터가 아닌 항목
| UI 요소 | 위치 | 설명 |
|---------|------|------|
| 포트폴리오 가치 차트 | `/portfolio/[id]` | 90일 시뮬레이션 사인파 데이터 사용 (실제 스냅샷 기반 아님) |
### 1.4 설계 문서 vs 구현 대조
| 설계 문서 | 구현 상태 | 미구현 항목 |
|----------|:---------:|-----------|
| Phase 1: Foundation | 완료 | - |
| Phase 2: Data Collection | 완료 | Financial collector 설계는 있으나 UI에서 직접 트리거 불가 |
| Phase 3: Portfolio Management | 대부분 완료 | 포트폴리오 수정/삭제 UI 없음, 수동 거래 입력 UI 없음 |
| Phase 4: Quant Strategy | 완료 | - |
| Phase 5: Backtest Engine | 완료 | 백테스트 삭제/비교 UI 없음 |
| Phase 6: Finishing | 부분 완료 | 종합 테스트 미흡, 프로덕션 배포 문서화 미완 |
| KJB 전략 설계 | 완료 | DailyBacktestEngine, TradingPortfolio 모두 구현 |
| 리밸런싱 설계 | 완료 | 추가 투자 시뮬레이션 UI 없음 |
| DC 시나리오 갭 분석 | 문서만 존재 | 6개 시나리오 검증 미구현 |
| 프로덕션 배포 설계 | 문서만 존재 | Nginx, SSL, 모니터링 등 미구현 |
---
## 2. 코드 품질 이슈
### 2.1 TODO/FIXME/HACK 검색 결과
**결과: 0건** - Backend/Frontend 모두 TODO, FIXME, HACK, XXX, NotImplementedError 없음.
### 2.2 에러 핸들링 이슈
#### Frontend - console.error만 있고 사용자 피드백 없는 경우
| 파일 | 위치 | 설명 |
|------|------|------|
| `frontend/src/app/signals/page.tsx` | L161, L176, L185 | `fetchTodaySignals`, `fetchHistorySignals`, `fetchPortfolios` 에러 시 console만 |
| `frontend/src/app/signals/page.tsx` | L246-254 | 포지션 사이징 API 에러 무시 |
| `frontend/src/app/admin/data/explorer/page.tsx` | L121-135 | ETF 가격 조회 에러 시 빈 배열로 대체 |
| `frontend/src/app/backtest/page.tsx` | 중첩 fetch | 내부 상세 조회 실패 시 UI 피드백 없음 |
#### Frontend - Error Boundary 부재
- 전체 프론트엔드에 React Error Boundary가 없음
- 컴포넌트 렌더링 에러 시 전체 페이지 크래시 가능
### 2.3 타입 안전성 이슈
| 위치 | 설명 |
|------|------|
| `frontend/src/app/portfolio/[id]/rebalance/page.tsx` L313 | `as { data: Portfolio[] }` 강제 캐스팅 |
| `frontend/src/app/admin/data/explorer/page.tsx` 다수 | 데이터 아이템 `as` 타입 단언 다수 사용 |
---
## 3. 보안/안정성 이슈
### 3.1 심각도별 분류
#### CRITICAL - 하드코딩된 시크릿
| 파일 | 라인 | 내용 |
|------|------|------|
| `backend/app/core/config.py` | L14 | `database_url` 기본값에 비밀번호 포함: `postgresql://galaxy:devpassword@localhost:5432/galaxy_po` |
| `backend/app/core/config.py` | L17 | `jwt_secret` 기본값: `"dev-jwt-secret-change-in-production"` |
> pydantic-settings로 환경변수 오버라이드 가능하나, 소스코드에 기본값이 남아 있음
#### HIGH - JWT 토큰 관리
| 이슈 | 설명 |
|------|------|
| 토큰 무효화 불가 | `POST /api/auth/logout`이 서버 측 토큰 무효화 없이 클라이언트에 의존 |
| localStorage 저장 | `frontend/src/lib/api.ts` L18,25,32 - XSS 공격 시 토큰 탈취 가능 |
#### MEDIUM - 인증 패턴
| 파일 | 설명 |
|------|------|
| `backend/app/api/backtest.py` L96,143,176,217,259 | 백테스트 조회 시 `user_id` 필터 없이 전체 조회 후 소유권 확인 - 비효율적이며 정보 노출 가능 |
### 3.2 SQL 인젝션
**위험도: 없음** - 모든 DB 쿼리가 SQLAlchemy ORM 사용. raw SQL 쿼리 없음.
### 3.3 민감 정보 로깅
**위험도: 없음** - 비밀번호, 토큰, API 키 로깅 없음 확인.
### 3.4 인증 없이 접근 가능한 엔드포인트
| 엔드포인트 | 인증 | 비고 |
|-----------|:----:|------|
| `GET /health` | 불필요 | 정상 (헬스체크) |
| `POST /api/auth/login` | 불필요 | 정상 (로그인) |
| `POST /api/auth/register` | 불필요 | 비활성화 (403 반환) |
| `POST /api/auth/logout` | 불필요 | 클라이언트 측 처리 |
> 그 외 모든 엔드포인트는 `CurrentUser` 의존성으로 인증 필수
---
## 4. 성능 이슈
### 4.1 N+1 쿼리 패턴
| 파일 | 위치 | 심각도 | 설명 |
|------|------|:------:|------|
| `backend/app/api/backtest.py` | L71-82 | HIGH | `list_backtests`에서 모든 백테스트 순회하며 `bt.result` lazy 로딩 → N+1 |
| `frontend/src/app/portfolio/page.tsx` | L55-70 | MEDIUM | 포트폴리오 목록에서 각 포트폴리오별 detail API 개별 호출 |
### 4.2 인덱스 누락 (DB)
현재 인덱스: `users(email)`, `users(id)`, `job_logs(id)`, `signals(date, ticker)`, 각 테이블 PK만 존재
#### Tier 1 - 백테스트 성능 (긴급)
```sql
-- 가격 데이터 (백테스트에서 가장 빈번하게 조회)
CREATE INDEX idx_price_date ON prices(date);
-- 종목 유니버스 필터링
CREATE INDEX idx_stock_market ON stocks(market);
CREATE INDEX idx_stock_market_cap ON stocks(market_cap DESC);
-- 밸류에이션 스크리닝
CREATE INDEX idx_valuation_base_date ON valuations(base_date);
```
#### Tier 2 - 포트폴리오/신호 (중요)
```sql
-- 신호 조회
CREATE INDEX idx_signal_status_date ON signals(status, date);
-- 거래 내역
CREATE INDEX idx_transaction_portfolio_executed ON transactions(portfolio_id, executed_at);
-- 백테스트 목록
CREATE INDEX idx_backtest_user_created ON backtests(user_id, created_at);
CREATE INDEX idx_backtest_status ON backtests(status);
```
#### Tier 3 - 최적화
```sql
-- 재무 데이터
CREATE INDEX idx_financial_base_date ON financials(base_date);
-- ETF 필터링
CREATE INDEX idx_etf_asset_class ON etfs(asset_class);
CREATE INDEX idx_etf_price_date ON etf_prices(date);
```
### 4.3 대용량 데이터 처리 이슈
| 파일 | 위치 | 설명 |
|------|------|------|
| `backend/app/services/backtest/engine.py` | L352 | `Stock.query.all()` - 전체 종목 메모리 로드 |
| `backend/app/services/backtest/daily_engine.py` | L147-187 | 다수의 `.all()` 호출로 대용량 데이터 메모리 로드 |
| `backend/app/services/price_service.py` | L92,114,180 | 전체 가격 데이터 제한 없이 로드 |
| `backend/app/api/data_explorer.py` | L138,180 | 종목 가격 이력 페이지네이션 없이 전체 반환 |
| `backend/app/api/snapshot.py` | L49,230 | 스냅샷 전체 조회 제한 없음 |
| `backend/app/api/portfolio.py` | L39 | 포트폴리오 전체 조회 제한 없음 |
### 4.4 비동기 작업 관리
| 이슈 | 설명 |
|------|------|
| 데몬 스레드 사용 | 데이터 수집기가 daemon thread로 실행 → 앱 종료 시 작업 유실 가능 |
| 작업 큐 미사용 | Celery/RQ 등 없이 in-process thread 실행 → 재시작 시 상태 복구 불가 |
---
## 5. Walk-forward 분석 구현 가능성 평가
### 5.1 현재 백테스트 엔진 구조
```
BacktestWorker.submit_backtest()
├── strategy_type == "kjb"
│ └── DailyBacktestEngine (신호 기반 일일 매매)
│ ├── TradingPortfolio (개별 포지션 관리)
│ └── KJBSignalGenerator (매수/매도 조건 판단)
└── strategy_type != "kjb"
└── BacktestEngine (정기 리밸런싱)
├── VirtualPortfolio (균등 가중 포트폴리오)
└── MetricsCalculator (수익률, MDD, 샤프 등)
```
**핵심 클래스:**
- `BacktestEngine` (engine.py) - 주기적 리밸런싱 시뮬레이션
- `DailyBacktestEngine` (daily_engine.py) - 일일 신호 기반 매매
- `VirtualPortfolio` - 단순 포트폴리오 (리밸런싱용)
- `TradingPortfolio` - 포지션 기반 포트폴리오 (손절/익절 관리)
- `MetricsCalculator` - 독립적 성과 지표 계산
### 5.2 Walk-forward 구현 가능성: **높음 (HIGHLY FEASIBLE)**
#### 유리한 구조적 요소
| 요소 | 설명 |
|------|------|
| 데이터 분할 용이 | `_get_trading_days()`가 날짜 범위 필터링 지원 → 학습/검증 윈도우 분할 가능 |
| 전략 재사용성 | 전략 인스턴스화가 분리됨 → 동일 전략을 다른 기간에 실행 가능 |
| 메트릭스 독립성 | `MetricsCalculator`가 값 배열만 받음 → 백테스트 인스턴스와 무관 |
| 글로벌 상태 없음 | 각 `BacktestEngine` 인스턴스가 독립적 → 여러 기간 동시 실행 가능 |
| 결과 분리 저장 | metrics, equity curve, holdings, transactions 별도 테이블 → 학습/검증 비교 가능 |
#### 구현 방안
```
1. WalkForwardEngine 클래스 추가
- train_window, test_window, step_size 파라미터
- 롤링 윈도우 생성기
2. 각 윈도우에서:
- 학습 기간: 기존 BacktestEngine으로 전략 실행 → 최적 파라미터 도출
- 검증 기간: 학습된 파라미터로 별도 BacktestEngine 실행
3. 전체 검증 기간 결과 합산 → 최종 성과 지표 계산
```
#### 예상 작업량
| 작업 | 예상 규모 | 설명 |
|------|:---------:|------|
| `WalkForwardEngine` 클래스 | 중 | 윈도우 분할 + 순차 실행 로직 |
| 파라미터 최적화 모듈 | 중~대 | 전략별 파라미터 그리드 서치 또는 최적화 |
| DB 모델 추가 | 소 | `WalkForwardResult` 테이블 (윈도우별 결과 저장) |
| API 엔드포인트 | 소 | 생성/조회/결과 반환 |
| Frontend UI | 중 | 설정 폼 + 윈도우별 결과 시각화 |
| **총합** | **중~대 규모** | 약 5-8개 파일 신규/수정 |
#### 주의사항
- 백테스트 실행 시간이 윈도우 수만큼 배수로 증가 → 백그라운드 실행 필수
- 파라미터 최적화 시 과적합(overfitting) 방지 로직 필요
- 현재 daemon thread 방식으로는 장시간 실행에 부적합 → Celery 도입 검토
---
## 6. 종합 평가 및 권고사항
### 6.1 종합 스코어카드
| 항목 | 점수 | 평가 |
|------|:----:|------|
| 기능 완성도 | 8/10 | 핵심 기능 구현 완료, CRUD UI 일부 누락 |
| 코드 품질 | 8/10 | TODO/FIXME 없음, 일관된 코드 스타일 |
| 보안 | 6/10 | SQL 인젝션 안전, 하드코딩 시크릿/토큰 관리 취약 |
| 성능 | 5/10 | 인덱스 부재, N+1 쿼리, 무제한 데이터 로드 |
| 테스트 커버리지 | 미측정 | unit/e2e 구조 존재, 커버리지 분석 필요 |
| 프로덕션 준비도 | 4/10 | 작업 큐 없음, 모니터링/로깅 미흡 |
### 6.2 우선순위별 권고사항
#### 즉시 조치 (P0)
1. **하드코딩 시크릿 제거** - `config.py`의 database_url, jwt_secret 기본값을 환경변수 필수로 변경
2. **DB 인덱스 추가** - Tier 1 인덱스 마이그레이션 생성 (백테스트 성능 직결)
3. **N+1 쿼리 수정** - `backtest.py` list_backtests에 `joinedload(Backtest.result)` 추가
#### 단기 개선 (P1)
4. **포트폴리오 CRUD UI 완성** - 수정/삭제/수동 거래 입력 화면 추가
5. **에러 핸들링 통일** - Frontend Error Boundary 추가, console.error를 사용자 피드백으로 전환
6. **페이지네이션 적용** - 가격 이력, 스냅샷, 거래 내역 등 무제한 로드 수정
7. **백테스트 삭제 UI** - 불필요한 백테스트 정리 기능
#### 중기 개선 (P2)
8. **JWT httpOnly 쿠키 전환** - localStorage에서 secure cookie로 변경
9. **작업 큐 도입** - Celery/RQ로 데이터 수집 및 백테스트 실행 안정화
10. **포트폴리오 가치 차트 실데이터화** - 시뮬레이션 사인파 → 스냅샷 기반 실제 데이터
11. **React Error Boundary** 추가
#### 장기 개선 (P3)
12. **Walk-forward 분석 구현** - 위 평가 참고
13. **백테스트 비교 기능** - 전략 간 성과 비교 UI
14. **DC 시나리오 검증** - 6개 시나리오 자동 검증 파이프라인
15. **프로덕션 인프라** - Nginx, SSL, 모니터링, 로깅 체계
---
### 6.3 파일별 이슈 요약
| 파일 | 이슈 유형 | 심각도 |
|------|----------|:------:|
| `backend/app/core/config.py` | 하드코딩 시크릿 | CRITICAL |
| `backend/app/api/backtest.py` | N+1 쿼리, 비효율적 인증 패턴 | HIGH |
| `backend/app/services/backtest/engine.py` | 전체 종목 메모리 로드 | MEDIUM |
| `backend/app/services/backtest/daily_engine.py` | 다수 unbounded `.all()` | MEDIUM |
| `backend/app/services/price_service.py` | 무제한 가격 데이터 로드 | MEDIUM |
| `backend/app/api/data_explorer.py` | 가격 이력 페이지네이션 없음 | MEDIUM |
| `frontend/src/lib/api.ts` | localStorage 토큰 저장 | MEDIUM |
| `frontend/src/app/signals/page.tsx` | 에러 핸들링 누락 (console만) | LOW |
| `frontend/src/app/portfolio/[id]/page.tsx` | 시뮬레이션 차트 데이터 | LOW |
| DB migrations | 성능 인덱스 대부분 누락 | HIGH |

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,152 @@
import { test, expect, Page } from "@playwright/test";
const TEST_USER = { username: "testuser", password: "testpass123" };
async function login(page: Page) {
await page.goto("/login");
await page.locator("#username").fill(TEST_USER.username);
await page.locator("#password").fill(TEST_USER.password);
await page.locator('button[type="submit"]').click();
await page.waitForURL("**/", { timeout: 10000 });
}
test.describe("Signal Cancel & Related Pages", () => {
test.beforeEach(async ({ page }) => {
await login(page);
});
test("should access signals page and see signal table", async ({ page }) => {
await page.goto("/signals");
// Wait for page title to appear
await expect(page.getByText("KJB 매매 신호")).toBeVisible({
timeout: 10000,
});
// Signal table should be visible
await expect(page.locator("table").first()).toBeVisible();
// Summary cards should be visible
await expect(page.getByText("매수 신호")).toBeVisible();
await expect(page.getByText("매도 신호", { exact: true })).toBeVisible();
await page.screenshot({
path: "../docs/screenshots/signals-page.png",
fullPage: true,
});
});
test("should show cancel button for EXECUTED signals", async ({ page }) => {
await page.goto("/signals");
await expect(page.getByText("KJB 매매 신호")).toBeVisible({
timeout: 10000,
});
// Switch to history view for executed signals
const historyButton = page.getByText("신호 이력");
if (await historyButton.isVisible()) {
await historyButton.click();
await page.waitForTimeout(2000);
}
// Check if any executed signal rows exist
const executedBadges = page.locator('text="실행됨"');
const count = await executedBadges.count();
if (count > 0) {
// There should be a cancel button near the executed signal
const cancelButtons = page.locator('button:has-text("취소")');
const cancelCount = await cancelButtons.count();
expect(cancelCount).toBeGreaterThan(0);
await page.screenshot({
path: "../docs/screenshots/signals-executed-with-cancel.png",
fullPage: true,
});
} else {
// No executed signals - verify the page structure is correct
console.log(
"No EXECUTED signals found - cancel button test skipped (no data)"
);
await expect(page.locator("table").first()).toBeVisible();
await page.screenshot({
path: "../docs/screenshots/signals-history.png",
fullPage: true,
});
}
});
test("should show realized/unrealized PnL cards on portfolio detail page", async ({
page,
}) => {
// First check if any portfolio exists
await page.goto("/portfolio");
await page.waitForTimeout(2000);
// Try to find a portfolio link
const portfolioLinks = page.locator('a[href^="/portfolio/"]');
const linkCount = await portfolioLinks.count();
if (linkCount > 0) {
await portfolioLinks.first().click();
await page.waitForTimeout(3000);
} else {
await page.goto("/portfolio/1");
await page.waitForTimeout(3000);
}
// Check for realized/unrealized PnL cards
const realizedPnlLabel = page.getByText("실현 수익");
const unrealizedPnlLabel = page.getByText("미실현 수익");
const hasRealizedCard = await realizedPnlLabel
.isVisible()
.catch(() => false);
const hasUnrealizedCard = await unrealizedPnlLabel
.isVisible()
.catch(() => false);
if (hasRealizedCard && hasUnrealizedCard) {
await expect(realizedPnlLabel).toBeVisible();
await expect(unrealizedPnlLabel).toBeVisible();
await expect(page.getByText("매도 확정 손익")).toBeVisible();
await expect(page.getByText("보유 중 평가 손익")).toBeVisible();
} else {
console.log(
"Portfolio detail page may not have data - PnL cards not visible"
);
}
await page.screenshot({
path: "../docs/screenshots/portfolio-detail.png",
fullPage: true,
});
});
test("should render strategy compare page", async ({ page }) => {
await page.goto("/strategy/compare");
// Wait for page title
await expect(
page.getByRole("heading", { name: "전략 비교" })
).toBeVisible({
timeout: 10000,
});
// Check description text
await expect(
page.getByText("멀티팩터, 퀄리티, 밸류모멘텀")
).toBeVisible();
// Check the compare execution button
const runButton = page.getByText("전략 비교 실행");
await expect(runButton).toBeVisible();
await page.screenshot({
path: "../docs/screenshots/strategy-compare.png",
fullPage: true,
});
});
});

View File

@ -28,6 +28,7 @@
"tailwind-merge": "^3.4.0"
},
"devDependencies": {
"@playwright/test": "^1.58.2",
"@tailwindcss/postcss": "^4",
"@types/node": "^22",
"@types/react": "^19",
@ -1279,6 +1280,22 @@
"node": ">=12.4.0"
}
},
"node_modules/@playwright/test": {
"version": "1.58.2",
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.58.2.tgz",
"integrity": "sha512-akea+6bHYBBfA9uQqSYmlJXn61cTa+jbO87xVLCWbTqbWadRVmhxlXATaOjOgcBaWU4ePo0wB41KMFv3o35IXA==",
"devOptional": true,
"license": "Apache-2.0",
"dependencies": {
"playwright": "1.58.2"
},
"bin": {
"playwright": "cli.js"
},
"engines": {
"node": ">=18"
}
},
"node_modules/@radix-ui/number": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@radix-ui/number/-/number-1.1.1.tgz",
@ -4817,6 +4834,21 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/fsevents": {
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
"integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": "^8.16.0 || ^10.6.0 || >=11.0.0"
}
},
"node_modules/function-bind": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz",
@ -6574,6 +6606,38 @@
"url": "https://github.com/sponsors/jonschlinkert"
}
},
"node_modules/playwright": {
"version": "1.58.2",
"resolved": "https://registry.npmjs.org/playwright/-/playwright-1.58.2.tgz",
"integrity": "sha512-vA30H8Nvkq/cPBnNw4Q8TWz1EJyqgpuinBcHET0YVJVFldr8JDNiU9LaWAE1KqSkRYazuaBhTpB5ZzShOezQ6A==",
"devOptional": true,
"license": "Apache-2.0",
"dependencies": {
"playwright-core": "1.58.2"
},
"bin": {
"playwright": "cli.js"
},
"engines": {
"node": ">=18"
},
"optionalDependencies": {
"fsevents": "2.3.2"
}
},
"node_modules/playwright-core": {
"version": "1.58.2",
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.58.2.tgz",
"integrity": "sha512-yZkEtftgwS8CsfYo7nm0KE8jsvm6i/PTgVtB8DL726wNf6H2IMsDuxCpJj59KDaxCtSnrWan2AeDqM7JBaultg==",
"devOptional": true,
"license": "Apache-2.0",
"bin": {
"playwright-core": "cli.js"
},
"engines": {
"node": ">=18"
}
},
"node_modules/possible-typed-array-names": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/possible-typed-array-names/-/possible-typed-array-names-1.1.0.tgz",

View File

@ -29,6 +29,7 @@
"tailwind-merge": "^3.4.0"
},
"devDependencies": {
"@playwright/test": "^1.58.2",
"@tailwindcss/postcss": "^4",
"@types/node": "^22",
"@types/react": "^19",

View File

@ -7,6 +7,7 @@ import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { api } from '@/lib/api';
import { toast } from 'sonner';
type Tab = 'stocks' | 'etfs' | 'sectors' | 'valuations';
@ -97,6 +98,7 @@ export default function DataExplorerPage() {
const result = await api.get<PaginatedResponse<unknown>>(endpoint);
setData(result);
} catch {
toast.error('데이터를 불러오는데 실패했습니다.');
setData(null);
} finally {
setFetching(false);
@ -129,6 +131,7 @@ export default function DataExplorerPage() {
const result = await api.get<PricePoint[]>(endpoint);
setPrices(result);
} catch {
toast.error('가격 데이터를 불러오는데 실패했습니다.');
setPrices([]);
} finally {
setPriceLoading(false);

View File

@ -67,6 +67,22 @@ interface TransactionItem {
commission: number;
}
interface WalkForwardWindow {
window_index: number;
train_start: string;
train_end: string;
test_start: string;
test_end: string;
test_return: number | null;
test_sharpe: number | null;
test_mdd: number | null;
}
interface WalkForwardResponse {
backtest_id: number;
windows: WalkForwardWindow[];
}
const strategyLabels: Record<string, string> = {
multi_factor: '멀티 팩터',
quality: '슈퍼 퀄리티',
@ -90,8 +106,13 @@ export default function BacktestDetailPage() {
const [equityCurve, setEquityCurve] = useState<EquityCurvePoint[]>([]);
const [holdings, setHoldings] = useState<RebalanceHoldings[]>([]);
const [transactions, setTransactions] = useState<TransactionItem[]>([]);
const [activeTab, setActiveTab] = useState<'holdings' | 'transactions'>('holdings');
const [activeTab, setActiveTab] = useState<'holdings' | 'transactions' | 'walkforward'>('holdings');
const [selectedRebalance, setSelectedRebalance] = useState<string | null>(null);
const [wfWindows, setWfWindows] = useState<WalkForwardWindow[]>([]);
const [wfLoading, setWfLoading] = useState(false);
const [wfTrainMonths, setWfTrainMonths] = useState(12);
const [wfTestMonths, setWfTestMonths] = useState(3);
const [wfStepMonths, setWfStepMonths] = useState(3);
const fetchBacktest = useCallback(async () => {
try {
@ -141,6 +162,37 @@ export default function BacktestDetailPage() {
}
}, [backtest, fetchBacktest]);
const fetchWalkForward = useCallback(async () => {
try {
const data = await api.get<WalkForwardResponse>(`/api/backtest/${backtestId}/walkforward`);
setWfWindows(data.windows);
} catch {
setWfWindows([]);
}
}, [backtestId]);
useEffect(() => {
if (activeTab === 'walkforward' && backtest?.status === 'completed') {
fetchWalkForward();
}
}, [activeTab, backtest?.status, fetchWalkForward]);
const runWalkForward = async () => {
setWfLoading(true);
try {
await api.post(`/api/backtest/${backtestId}/walkforward`, {
train_months: wfTrainMonths,
test_months: wfTestMonths,
step_months: wfStepMonths,
});
await fetchWalkForward();
} catch (err) {
console.error('Walk-forward failed:', err);
} finally {
setWfLoading(false);
}
};
const formatNumber = (value: number | null | undefined, decimals: number = 2) => {
if (value === null || value === undefined) return '-';
return value.toFixed(decimals);
@ -356,6 +408,16 @@ export default function BacktestDetailPage() {
>
</button>
<button
onClick={() => setActiveTab('walkforward')}
className={`px-6 py-3 text-sm font-medium ${
activeTab === 'walkforward'
? 'border-b-2 border-primary text-primary'
: 'text-muted-foreground hover:text-foreground'
}`}
>
Walk-forward
</button>
</nav>
</div>
@ -448,6 +510,116 @@ export default function BacktestDetailPage() {
</table>
</div>
)}
{/* Walk-forward Tab */}
{activeTab === 'walkforward' && (
<CardContent className="p-4">
<div className="flex flex-wrap gap-4 items-end mb-6">
<div className="space-y-1">
<Label htmlFor="wf-train"> ()</Label>
<input
id="wf-train"
type="number"
min={3}
max={60}
value={wfTrainMonths}
onChange={(e) => setWfTrainMonths(Number(e.target.value))}
className="flex h-10 w-24 rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-1">
<Label htmlFor="wf-test"> ()</Label>
<input
id="wf-test"
type="number"
min={1}
max={24}
value={wfTestMonths}
onChange={(e) => setWfTestMonths(Number(e.target.value))}
className="flex h-10 w-24 rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-1">
<Label htmlFor="wf-step"> ()</Label>
<input
id="wf-step"
type="number"
min={1}
max={24}
value={wfStepMonths}
onChange={(e) => setWfStepMonths(Number(e.target.value))}
className="flex h-10 w-24 rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<button
onClick={runWalkForward}
disabled={wfLoading}
className="h-10 px-4 rounded-md bg-primary text-primary-foreground text-sm font-medium hover:bg-primary/90 disabled:opacity-50"
>
{wfLoading ? '분석 중...' : '분석 실행'}
</button>
</div>
{wfWindows.length > 0 && (
<>
<div className="mb-6">
<h3 className="text-sm font-medium text-muted-foreground mb-2"> </h3>
<AreaChart
data={wfWindows.map((w) => {
const cumReturn = wfWindows
.filter((x) => x.window_index <= w.window_index)
.reduce((acc, x) => acc * (1 + (x.test_return ?? 0) / 100), 1);
return {
date: w.test_end,
value: (cumReturn - 1) * 100,
};
})}
height={200}
color="#3b82f6"
showLegend={false}
formatValue={(v) => `${v.toFixed(2)}%`}
formatXAxis={(v) => v.slice(5)}
/>
</div>
<div className="overflow-x-auto">
<table className="w-full">
<thead className="bg-muted">
<tr>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground">#</th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"> </th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"> </th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground">MDD</th>
</tr>
</thead>
<tbody className="divide-y divide-border">
{wfWindows.map((w) => (
<tr key={w.window_index} className="hover:bg-muted/50">
<td className="px-4 py-3 text-sm text-center">{w.window_index + 1}</td>
<td className="px-4 py-3 text-sm">{w.train_start} ~ {w.train_end}</td>
<td className="px-4 py-3 text-sm">{w.test_start} ~ {w.test_end}</td>
<td className={`px-4 py-3 text-sm text-right font-medium ${(w.test_return ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'}`}>
{formatNumber(w.test_return)}%
</td>
<td className="px-4 py-3 text-sm text-right">{formatNumber(w.test_sharpe)}</td>
<td className="px-4 py-3 text-sm text-right text-red-600">{formatNumber(w.test_mdd)}%</td>
</tr>
))}
</tbody>
</table>
</div>
</>
)}
{wfWindows.length === 0 && !wfLoading && (
<div className="py-8 text-center text-muted-foreground">
.
</div>
)}
</CardContent>
)}
</Card>
</>
)}

View File

@ -0,0 +1,419 @@
'use client';
import { useEffect, useState } from 'react';
import { useRouter } from 'next/navigation';
import Link from 'next/link';
import {
Area,
AreaChart as RechartsAreaChart,
CartesianGrid,
Legend,
ResponsiveContainer,
Tooltip,
XAxis,
YAxis,
} from 'recharts';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Skeleton } from '@/components/ui/skeleton';
import { api } from '@/lib/api';
import { toast } from 'sonner';
interface BacktestListItem {
id: number;
strategy_type: string;
start_date: string;
end_date: string;
rebalance_period: string;
status: string;
created_at: string;
total_return: number | null;
cagr: number | null;
mdd: number | null;
}
interface BacktestDetail {
id: number;
strategy_type: string;
start_date: string;
end_date: string;
status: string;
result: {
total_return: number;
cagr: number;
mdd: number;
sharpe_ratio: number;
volatility: number;
benchmark_return: number;
excess_return: number;
} | null;
}
interface EquityCurvePoint {
date: string;
portfolio_value: number;
benchmark_value: number;
drawdown: number;
}
interface CompareData {
detail: BacktestDetail;
equityCurve: EquityCurvePoint[];
}
const STRATEGY_LABELS: Record<string, string> = {
multi_factor: '멀티 팩터',
quality: '슈퍼 퀄리티',
value_momentum: '밸류 모멘텀',
kjb: '김종봉 단기매매',
};
const COMPARE_COLORS = ['#3b82f6', '#ef4444', '#22c55e'];
const METRICS = [
{ key: 'total_return', label: '총 수익률', suffix: '%' },
{ key: 'cagr', label: 'CAGR', suffix: '%' },
{ key: 'mdd', label: 'MDD', suffix: '%' },
{ key: 'sharpe_ratio', label: '샤프 비율', suffix: '' },
{ key: 'volatility', label: '변동성', suffix: '%' },
{ key: 'benchmark_return', label: '벤치마크 수익률', suffix: '%' },
{ key: 'excess_return', label: '초과 수익률', suffix: '%' },
] as const;
export default function BacktestComparePage() {
const router = useRouter();
const [loading, setLoading] = useState(true);
const [backtests, setBacktests] = useState<BacktestListItem[]>([]);
const [selectedIds, setSelectedIds] = useState<Set<number>>(new Set());
const [compareData, setCompareData] = useState<CompareData[]>([]);
const [comparing, setComparing] = useState(false);
useEffect(() => {
const init = async () => {
try {
await api.getCurrentUser();
const data = await api.get<BacktestListItem[]>('/api/backtest');
setBacktests(data.filter((bt) => bt.status === 'completed'));
} catch {
router.push('/login');
} finally {
setLoading(false);
}
};
init();
}, [router]);
const toggleSelection = (id: number) => {
setSelectedIds((prev) => {
const next = new Set(prev);
if (next.has(id)) {
next.delete(id);
} else if (next.size < 3) {
next.add(id);
} else {
toast.error('최대 3개까지 선택할 수 있습니다.');
}
return next;
});
};
const handleCompare = async () => {
if (selectedIds.size < 2) {
toast.error('비교할 백테스트를 2개 이상 선택하세요.');
return;
}
setComparing(true);
try {
const ids = Array.from(selectedIds);
const results = await Promise.all(
ids.map(async (id) => {
const [detail, equityCurve] = await Promise.all([
api.get<BacktestDetail>(`/api/backtest/${id}`),
api.get<EquityCurvePoint[]>(`/api/backtest/${id}/equity-curve`),
]);
return { detail, equityCurve };
})
);
setCompareData(results);
} catch (err) {
toast.error(err instanceof Error ? err.message : '비교 데이터를 불러오는데 실패했습니다.');
} finally {
setComparing(false);
}
};
const getStrategyLabel = (type: string) => STRATEGY_LABELS[type] || type;
const formatNumber = (value: number | null | undefined, decimals: number = 2) => {
if (value === null || value === undefined) return '-';
return value.toFixed(decimals);
};
const formatCurrency = (value: number) => {
return new Intl.NumberFormat('ko-KR', {
style: 'currency',
currency: 'KRW',
maximumFractionDigits: 0,
}).format(value);
};
const buildChartData = () => {
if (compareData.length === 0) return [];
const dateMap = new Map<string, Record<string, number>>();
compareData.forEach((cd, idx) => {
for (const point of cd.equityCurve) {
const existing = dateMap.get(point.date) || {};
existing[`value_${idx}`] = point.portfolio_value;
dateMap.set(point.date, existing);
}
});
return Array.from(dateMap.entries())
.sort(([a], [b]) => a.localeCompare(b))
.map(([date, values]) => ({ date, ...values }));
};
const getCompareLabel = (idx: number) => {
const cd = compareData[idx];
return `${getStrategyLabel(cd.detail.strategy_type)} (${cd.detail.start_date.slice(0, 4)}~${cd.detail.end_date.slice(0, 4)})`;
};
if (loading) {
return (
<DashboardLayout>
<Skeleton className="h-8 w-48 mb-6" />
<Skeleton className="h-96 rounded-xl" />
</DashboardLayout>
);
}
const chartData = buildChartData();
return (
<DashboardLayout>
<div className="flex items-center justify-between mb-6">
<div>
<h1 className="text-2xl font-bold text-foreground"> </h1>
<p className="mt-1 text-muted-foreground">
( 3)
</p>
</div>
<Button variant="outline" asChild>
<Link href="/backtest"> </Link>
</Button>
</div>
<Card className="mb-6">
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
{backtests.length === 0 ? (
<p className="text-muted-foreground text-center py-8">
.
</p>
) : (
<>
<div className="space-y-2 max-h-64 overflow-y-auto mb-4">
{backtests.map((bt) => (
<label
key={bt.id}
className={`flex items-center gap-3 p-3 border rounded cursor-pointer transition-colors ${
selectedIds.has(bt.id)
? 'border-primary bg-primary/5'
: 'border-border hover:bg-muted/50'
}`}
>
<input
type="checkbox"
checked={selectedIds.has(bt.id)}
onChange={() => toggleSelection(bt.id)}
className="h-4 w-4 rounded border-input"
/>
<div className="flex-1 min-w-0">
<span className="font-medium text-sm">
{getStrategyLabel(bt.strategy_type)}
</span>
<span className="text-xs text-muted-foreground ml-2">
{bt.start_date} ~ {bt.end_date}
</span>
</div>
<div className="flex gap-4 text-xs text-muted-foreground">
<span>: <span className={(bt.total_return ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'}>{formatNumber(bt.total_return)}%</span></span>
<span>CAGR: <span className={(bt.cagr ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'}>{formatNumber(bt.cagr)}%</span></span>
<span>MDD: <span className="text-red-600">{formatNumber(bt.mdd)}%</span></span>
</div>
</label>
))}
</div>
<Button
onClick={handleCompare}
disabled={selectedIds.size < 2 || comparing}
>
{comparing ? '비교 중...' : `비교하기 (${selectedIds.size}개 선택)`}
</Button>
</>
)}
</CardContent>
</Card>
{compareData.length >= 2 && (
<>
<Card className="mb-6">
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent className="p-0">
<div className="overflow-x-auto">
<table className="w-full">
<thead className="bg-muted">
<tr>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground">
</th>
{compareData.map((cd, idx) => (
<th
key={cd.detail.id}
scope="col"
className="px-4 py-3 text-right text-sm font-medium"
style={{ color: COMPARE_COLORS[idx] }}
>
{getCompareLabel(idx)}
</th>
))}
</tr>
</thead>
<tbody className="divide-y divide-border">
{METRICS.map((metric) => (
<tr key={metric.key}>
<td className="px-4 py-3 text-sm font-medium text-muted-foreground">
{metric.label}
</td>
{compareData.map((cd) => {
const value = cd.detail.result
? cd.detail.result[metric.key as keyof typeof cd.detail.result]
: null;
const numValue = typeof value === 'number' ? value : null;
const isNegativeMetric = metric.key === 'mdd';
const colorClass = numValue !== null
? isNegativeMetric
? 'text-red-600'
: numValue >= 0
? 'text-green-600'
: 'text-red-600'
: '';
return (
<td
key={cd.detail.id}
className={`px-4 py-3 text-sm text-right font-medium ${colorClass}`}
>
{formatNumber(numValue)}{numValue !== null ? metric.suffix : ''}
</td>
);
})}
</tr>
))}
</tbody>
</table>
</div>
</CardContent>
</Card>
<Card>
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
{chartData.length > 0 ? (
<div style={{ height: 400 }}>
<ResponsiveContainer width="100%" height="100%">
<RechartsAreaChart
data={chartData}
margin={{ top: 10, right: 30, left: 0, bottom: 0 }}
>
<defs>
{compareData.map((_, idx) => (
<linearGradient
key={idx}
id={`compareGradient_${idx}`}
x1="0"
y1="0"
x2="0"
y2="1"
>
<stop offset="5%" stopColor={COMPARE_COLORS[idx]} stopOpacity={0.15} />
<stop offset="95%" stopColor={COMPARE_COLORS[idx]} stopOpacity={0} />
</linearGradient>
))}
</defs>
<CartesianGrid
strokeDasharray="3 3"
stroke="hsl(var(--border))"
vertical={false}
/>
<XAxis
dataKey="date"
tickFormatter={(v: string) => v.slice(5)}
stroke="hsl(var(--muted-foreground))"
fontSize={12}
tickLine={false}
axisLine={false}
/>
<YAxis
tickFormatter={(v: number) => formatCurrency(v)}
stroke="hsl(var(--muted-foreground))"
fontSize={12}
tickLine={false}
axisLine={false}
width={100}
/>
<Tooltip
contentStyle={{
backgroundColor: 'hsl(var(--popover))',
border: '1px solid hsl(var(--border))',
borderRadius: '8px',
color: 'hsl(var(--popover-foreground))',
}}
labelStyle={{ color: 'hsl(var(--popover-foreground))' }}
formatter={(value, name) => {
const idx = parseInt(String(name).replace('value_', ''));
return [formatCurrency(Number(value)), getCompareLabel(idx)];
}}
/>
<Legend
wrapperStyle={{ color: 'hsl(var(--foreground))' }}
formatter={(value: string) => {
const idx = parseInt(value.replace('value_', ''));
return getCompareLabel(idx);
}}
/>
{compareData.map((_, idx) => (
<Area
key={idx}
type="monotone"
dataKey={`value_${idx}`}
stroke={COMPARE_COLORS[idx]}
strokeWidth={2}
fill={`url(#compareGradient_${idx})`}
name={`value_${idx}`}
connectNulls
/>
))}
</RechartsAreaChart>
</ResponsiveContainer>
</div>
) : (
<div className="flex items-center justify-center h-64 bg-muted/50 rounded-lg">
<p className="text-muted-foreground"> </p>
</div>
)}
</CardContent>
</Card>
</>
)}
</DashboardLayout>
);
}

View File

@ -17,9 +17,18 @@ import {
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
DialogFooter,
} from '@/components/ui/dialog';
import { AreaChart } from '@/components/charts/area-chart';
import { api } from '@/lib/api';
import { TrendingUp, TrendingDown, Activity, Target, Calendar, Settings } from 'lucide-react';
import { toast } from 'sonner';
import { TrendingUp, TrendingDown, Activity, Target, Calendar, Settings, Trash2, GitCompareArrows } from 'lucide-react';
interface BacktestResult {
id: number;
@ -148,7 +157,7 @@ export default function BacktestPage() {
}
}
} catch (err) {
console.error('Failed to fetch backtests:', err);
toast.error(err instanceof Error ? err.message : '백테스트 목록을 불러오는데 실패했습니다.');
}
};
@ -238,6 +247,34 @@ export default function BacktestPage() {
}).format(value);
};
const [deleteConfirmOpen, setDeleteConfirmOpen] = useState(false);
const [deleteTargetId, setDeleteTargetId] = useState<number | null>(null);
const [deleting, setDeleting] = useState(false);
const handleDeleteClick = (id: number) => {
setDeleteTargetId(id);
setDeleteConfirmOpen(true);
};
const handleConfirmDelete = async () => {
if (deleteTargetId === null) return;
setDeleting(true);
try {
await api.delete(`/api/backtest/${deleteTargetId}`);
setBacktests((prev) => prev.filter((bt) => bt.id !== deleteTargetId));
if (currentResult?.id === deleteTargetId) {
setCurrentResult(null);
}
setDeleteConfirmOpen(false);
setDeleteTargetId(null);
toast.success('백테스트가 삭제되었습니다.');
} catch (err) {
toast.error(err instanceof Error ? err.message : '백테스트 삭제에 실패했습니다.');
} finally {
setDeleting(false);
}
};
const displayResult = currentResult;
if (loading) {
@ -263,13 +300,21 @@ export default function BacktestPage() {
</p>
</div>
<Button
variant="outline"
onClick={() => setShowHistory(!showHistory)}
>
<Calendar className="mr-2 h-4 w-4" />
{showHistory ? '새 백테스트' : '이전 기록'}
</Button>
<div className="flex gap-2">
<Button variant="outline" asChild>
<Link href="/backtest/compare">
<GitCompareArrows className="mr-2 h-4 w-4" />
</Link>
</Button>
<Button
variant="outline"
onClick={() => setShowHistory(!showHistory)}
>
<Calendar className="mr-2 h-4 w-4" />
{showHistory ? '새 백테스트' : '이전 기록'}
</Button>
</div>
</div>
{error && (
@ -297,6 +342,7 @@ export default function BacktestPage() {
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground">MDD</th>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground"></th>
</tr>
</thead>
<tbody className="divide-y divide-border">
@ -328,11 +374,21 @@ export default function BacktestPage() {
<td className="px-4 py-3 text-sm text-muted-foreground">
{new Date(bt.created_at).toLocaleDateString('ko-KR')}
</td>
<td className="px-4 py-3 text-center">
<Button
variant="ghost"
size="icon"
onClick={() => handleDeleteClick(bt.id)}
className="h-7 w-7 text-muted-foreground hover:text-destructive"
>
<Trash2 className="h-4 w-4" />
</Button>
</td>
</tr>
))}
{backtests.length === 0 && (
<tr>
<td colSpan={8} className="px-4 py-8 text-center text-muted-foreground">
<td colSpan={9} className="px-4 py-8 text-center text-muted-foreground">
.
</td>
</tr>
@ -771,6 +827,24 @@ export default function BacktestPage() {
</div>
</div>
)}
<Dialog open={deleteConfirmOpen} onOpenChange={setDeleteConfirmOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle> </DialogTitle>
<DialogDescription>
. ?
</DialogDescription>
</DialogHeader>
<DialogFooter>
<Button variant="outline" onClick={() => setDeleteConfirmOpen(false)} disabled={deleting}>
</Button>
<Button variant="destructive" onClick={handleConfirmDelete} disabled={deleting}>
{deleting ? '삭제 중...' : '삭제'}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</DashboardLayout>
);
}

View File

@ -3,6 +3,7 @@ import { Inter, Noto_Sans_KR } from 'next/font/google';
import './globals.css';
import { ThemeProvider } from '@/components/providers/theme-provider';
import { Toaster } from '@/components/ui/sonner';
import { ErrorBoundary } from '@/components/error-boundary';
const inter = Inter({
subsets: ['latin'],
@ -34,7 +35,9 @@ export default function RootLayout({
enableSystem
disableTransitionOnChange
>
{children}
<ErrorBoundary>
{children}
</ErrorBoundary>
<Toaster />
</ThemeProvider>
</body>

View File

@ -63,12 +63,12 @@ export default function PortfolioHistoryPage() {
const fetchData = async () => {
try {
const [snapshotsData, returnsData] = await Promise.all([
api.get<SnapshotItem[]>(`/api/portfolios/${portfolioId}/snapshots`),
const [snapshotsRes, returnsData] = await Promise.all([
api.get<{ items: SnapshotItem[]; total: number }>(`/api/portfolios/${portfolioId}/snapshots`),
api.get<ReturnsData>(`/api/portfolios/${portfolioId}/returns`),
]);
setSnapshots(snapshotsData);
setSnapshots(snapshotsRes.items);
setReturns(returnsData);
} catch (err) {
if (err instanceof Error && err.message === 'API request failed') {

View File

@ -10,6 +10,23 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
import { TradingViewChart } from '@/components/charts/trading-view-chart';
import { DonutChart } from '@/components/charts/donut-chart';
import { Skeleton } from '@/components/ui/skeleton';
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
DialogFooter,
} from '@/components/ui/dialog';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { api } from '@/lib/api';
import type { AreaData, Time } from 'lightweight-charts';
@ -38,6 +55,21 @@ interface Transaction {
quantity: number;
price: number;
executed_at: string;
realized_pnl: number | null;
}
interface SnapshotListItem {
id: number;
portfolio_id: number;
total_value: number;
snapshot_date: string;
}
interface PaginatedResponse<T> {
items: T[];
total: number;
skip: number;
limit: number;
}
interface PortfolioDetail {
@ -51,6 +83,8 @@ interface PortfolioDetail {
total_value: number | null;
total_invested: number | null;
total_profit_loss: number | null;
total_realized_pnl: number | null;
total_unrealized_pnl: number | null;
risk_asset_ratio: number | null;
}
@ -63,31 +97,13 @@ const CHART_COLORS = [
'hsl(199.4, 95.5%, 53.8%)',
];
// Generate sample chart data for portfolio value over time
function generateChartData(totalValue: number | null): AreaData<Time>[] {
if (totalValue === null || totalValue === 0) return [];
const data: AreaData<Time>[] = [];
const now = new Date();
const baseValue = totalValue * 0.85;
for (let i = 90; i >= 0; i--) {
const date = new Date(now);
date.setDate(date.getDate() - i);
const dateStr = date.toISOString().split('T')[0];
// Simulate value fluctuation
const progress = (90 - i) / 90;
const fluctuation = Math.sin(i * 0.1) * 0.05;
const value = baseValue + (totalValue - baseValue) * progress * (1 + fluctuation);
data.push({
time: dateStr as Time,
value: Math.round(value),
});
}
return data;
function snapshotsToChartData(snapshots: SnapshotListItem[]): AreaData<Time>[] {
return snapshots
.sort((a, b) => a.snapshot_date.localeCompare(b.snapshot_date))
.map((s) => ({
time: s.snapshot_date as Time,
value: Math.round(s.total_value),
}));
}
export default function PortfolioDetailPage() {
@ -98,8 +114,23 @@ export default function PortfolioDetailPage() {
const [loading, setLoading] = useState(true);
const [portfolio, setPortfolio] = useState<PortfolioDetail | null>(null);
const [transactions, setTransactions] = useState<Transaction[]>([]);
const [txTotal, setTxTotal] = useState(0);
const [txLoadingMore, setTxLoadingMore] = useState(false);
const [snapshots, setSnapshots] = useState<SnapshotListItem[]>([]);
const [error, setError] = useState<string | null>(null);
// Transaction modal state
const [txModalOpen, setTxModalOpen] = useState(false);
const [txSubmitting, setTxSubmitting] = useState(false);
const [txForm, setTxForm] = useState({
ticker: '',
tx_type: 'buy',
quantity: '',
price: '',
executed_at: '',
memo: '',
});
const fetchPortfolio = useCallback(async () => {
try {
setError(null);
@ -113,19 +144,44 @@ export default function PortfolioDetailPage() {
const fetchTransactions = useCallback(async () => {
try {
const data = await api.get<Transaction[]>(`/api/portfolios/${portfolioId}/transactions`);
setTransactions(data);
const data = await api.get<PaginatedResponse<Transaction>>(
`/api/portfolios/${portfolioId}/transactions?skip=0&limit=50`
);
setTransactions(data.items);
setTxTotal(data.total);
} catch {
// Transactions may not exist yet
setTransactions([]);
}
}, [portfolioId]);
const fetchMoreTransactions = useCallback(async (currentCount: number) => {
try {
const data = await api.get<PaginatedResponse<Transaction>>(
`/api/portfolios/${portfolioId}/transactions?skip=${currentCount}&limit=50`
);
setTransactions((prev) => [...prev, ...data.items]);
setTxTotal(data.total);
} catch {
// ignore load-more errors
}
}, [portfolioId]);
const fetchSnapshots = useCallback(async () => {
try {
const data = await api.get<PaginatedResponse<SnapshotListItem>>(
`/api/portfolios/${portfolioId}/snapshots`
);
setSnapshots(data.items);
} catch {
setSnapshots([]);
}
}, [portfolioId]);
useEffect(() => {
const init = async () => {
try {
await api.getCurrentUser();
await Promise.all([fetchPortfolio(), fetchTransactions()]);
await Promise.all([fetchPortfolio(), fetchTransactions(), fetchSnapshots()]);
} catch {
router.push('/login');
} finally {
@ -133,7 +189,7 @@ export default function PortfolioDetailPage() {
}
};
init();
}, [router, fetchPortfolio, fetchTransactions]);
}, [router, fetchPortfolio, fetchTransactions, fetchSnapshots]);
const formatCurrency = (value: number | null) => {
if (value === null) return '-';
@ -172,6 +228,38 @@ export default function PortfolioDetailPage() {
}));
};
const handleLoadMoreTransactions = async () => {
setTxLoadingMore(true);
try {
await fetchMoreTransactions(transactions.length);
} finally {
setTxLoadingMore(false);
}
};
const handleAddTransaction = async () => {
if (!txForm.ticker || !txForm.quantity || !txForm.price || !txForm.executed_at) return;
setTxSubmitting(true);
try {
await api.post(`/api/portfolios/${portfolioId}/transactions`, {
ticker: txForm.ticker,
tx_type: txForm.tx_type,
quantity: parseInt(txForm.quantity, 10),
price: parseFloat(txForm.price),
executed_at: new Date(txForm.executed_at).toISOString(),
memo: txForm.memo || null,
});
setTxModalOpen(false);
setTxForm({ ticker: '', tx_type: 'buy', quantity: '', price: '', executed_at: '', memo: '' });
await Promise.all([fetchPortfolio(), fetchTransactions()]);
} catch (err) {
const message = err instanceof Error ? err.message : '거래 추가 실패';
setError(message);
} finally {
setTxSubmitting(false);
}
};
if (loading) {
return (
<DashboardLayout>
@ -189,7 +277,7 @@ export default function PortfolioDetailPage() {
);
}
const chartData = portfolio ? generateChartData(portfolio.total_value) : [];
const chartData = snapshotsToChartData(snapshots);
const returnPercent = calculateReturnPercent();
return (
@ -237,7 +325,7 @@ export default function PortfolioDetailPage() {
)}
{/* Summary Cards */}
<div className="grid grid-cols-1 md:grid-cols-4 gap-4 mb-6">
<div className="grid grid-cols-1 md:grid-cols-3 gap-4 mb-4">
<Card>
<CardContent className="pt-6">
<p className="text-sm text-muted-foreground mb-1"> </p>
@ -254,6 +342,20 @@ export default function PortfolioDetailPage() {
</p>
</CardContent>
</Card>
<Card>
<CardContent className="pt-6">
<p className="text-sm text-muted-foreground mb-1"></p>
<p
className={`text-2xl font-bold ${
(returnPercent ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'
}`}
>
{formatPercent(returnPercent)}
</p>
</CardContent>
</Card>
</div>
<div className="grid grid-cols-1 md:grid-cols-3 gap-4 mb-6">
<Card>
<CardContent className="pt-6">
<p className="text-sm text-muted-foreground mb-1"> </p>
@ -268,29 +370,47 @@ export default function PortfolioDetailPage() {
</Card>
<Card>
<CardContent className="pt-6">
<p className="text-sm text-muted-foreground mb-1"></p>
<p className="text-sm text-muted-foreground mb-1"> </p>
<p
className={`text-2xl font-bold ${
(returnPercent ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'
(portfolio.total_realized_pnl ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'
}`}
>
{formatPercent(returnPercent)}
{formatCurrency(portfolio.total_realized_pnl)}
</p>
<p className="text-xs text-muted-foreground mt-1"> </p>
</CardContent>
</Card>
<Card>
<CardContent className="pt-6">
<p className="text-sm text-muted-foreground mb-1"> </p>
<p
className={`text-2xl font-bold ${
(portfolio.total_unrealized_pnl ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'
}`}
>
{formatCurrency(portfolio.total_unrealized_pnl)}
</p>
<p className="text-xs text-muted-foreground mt-1"> </p>
</CardContent>
</Card>
</div>
{/* Chart Section */}
{chartData.length > 0 && (
<Card className="mb-6">
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
<Card className="mb-6">
<CardHeader>
<CardTitle> </CardTitle>
</CardHeader>
<CardContent>
{chartData.length > 0 ? (
<TradingViewChart data={chartData} height={300} />
</CardContent>
</Card>
)}
) : (
<div className="flex items-center justify-center h-[300px] text-muted-foreground">
. .
</div>
)}
</CardContent>
</Card>
{/* Tabs Section */}
<Tabs defaultValue="holdings" className="space-y-4">
@ -417,8 +537,11 @@ export default function PortfolioDetailPage() {
{/* Transactions Tab */}
<TabsContent value="transactions">
<Card>
<CardHeader>
<CardHeader className="flex flex-row items-center justify-between">
<CardTitle> </CardTitle>
<Button size="sm" onClick={() => setTxModalOpen(true)}>
</Button>
</CardHeader>
<CardContent className="p-0">
<div className="overflow-x-auto">
@ -461,6 +584,12 @@ export default function PortfolioDetailPage() {
>
</th>
<th
scope="col"
className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"
>
</th>
</tr>
</thead>
<tbody className="divide-y divide-border">
@ -490,11 +619,22 @@ export default function PortfolioDetailPage() {
<td className="px-4 py-3 text-sm text-right">
{formatCurrency(tx.quantity * tx.price)}
</td>
<td
className={`px-4 py-3 text-sm text-right ${
tx.realized_pnl !== null
? tx.realized_pnl >= 0
? 'text-green-600'
: 'text-red-600'
: ''
}`}
>
{tx.realized_pnl !== null ? formatCurrency(tx.realized_pnl) : '-'}
</td>
</tr>
))}
{transactions.length === 0 && (
<tr>
<td colSpan={6} className="px-4 py-8 text-center text-muted-foreground">
<td colSpan={7} className="px-4 py-8 text-center text-muted-foreground">
.
</td>
</tr>
@ -502,6 +642,18 @@ export default function PortfolioDetailPage() {
</tbody>
</table>
</div>
{transactions.length < txTotal && (
<div className="flex justify-center py-4 border-t border-border">
<Button
variant="outline"
size="sm"
onClick={handleLoadMoreTransactions}
disabled={txLoadingMore}
>
{txLoadingMore ? '불러오는 중...' : `더 보기 (${transactions.length}/${txTotal})`}
</Button>
</div>
)}
</CardContent>
</Card>
</TabsContent>
@ -582,6 +734,95 @@ export default function PortfolioDetailPage() {
</Tabs>
</>
)}
{/* Transaction Add Modal */}
<Dialog open={txModalOpen} onOpenChange={setTxModalOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle> </DialogTitle>
<DialogDescription> / .</DialogDescription>
</DialogHeader>
<div className="space-y-4">
<div className="space-y-2">
<Label htmlFor="tx-ticker"></Label>
<Input
id="tx-ticker"
placeholder="예: 069500"
value={txForm.ticker}
onChange={(e) => setTxForm({ ...txForm, ticker: e.target.value })}
/>
</div>
<div className="space-y-2">
<Label> </Label>
<Select
value={txForm.tx_type}
onValueChange={(v) => setTxForm({ ...txForm, tx_type: v })}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="buy"></SelectItem>
<SelectItem value="sell"></SelectItem>
</SelectContent>
</Select>
</div>
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="tx-quantity"></Label>
<Input
id="tx-quantity"
type="number"
min="1"
placeholder="0"
value={txForm.quantity}
onChange={(e) => setTxForm({ ...txForm, quantity: e.target.value })}
/>
</div>
<div className="space-y-2">
<Label htmlFor="tx-price"></Label>
<Input
id="tx-price"
type="number"
min="1"
placeholder="0"
value={txForm.price}
onChange={(e) => setTxForm({ ...txForm, price: e.target.value })}
/>
</div>
</div>
<div className="space-y-2">
<Label htmlFor="tx-executed-at"> </Label>
<Input
id="tx-executed-at"
type="datetime-local"
value={txForm.executed_at}
onChange={(e) => setTxForm({ ...txForm, executed_at: e.target.value })}
/>
</div>
<div className="space-y-2">
<Label htmlFor="tx-memo"> ()</Label>
<Input
id="tx-memo"
placeholder="메모를 입력하세요"
value={txForm.memo}
onChange={(e) => setTxForm({ ...txForm, memo: e.target.value })}
/>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setTxModalOpen(false)}>
</Button>
<Button
onClick={handleAddTransaction}
disabled={txSubmitting || !txForm.ticker || !txForm.quantity || !txForm.price || !txForm.executed_at}
>
{txSubmitting ? '저장 중...' : '저장'}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</DashboardLayout>
);
}

View File

@ -64,6 +64,8 @@ export default function RebalancePage() {
const [applyPrices, setApplyPrices] = useState<Record<string, string>>({});
const [applying, setApplying] = useState(false);
const [applyError, setApplyError] = useState<string | null>(null);
const [portfolioType, setPortfolioType] = useState<string>('general');
const [currentRiskRatio, setCurrentRiskRatio] = useState<number | null>(null);
useEffect(() => {
const init = async () => {
@ -87,16 +89,21 @@ export default function RebalancePage() {
});
setPrices(initialPrices);
// Fetch stock names from portfolio detail
try {
const detail = await api.get<{ holdings: { ticker: string; name: string | null }[] }>(`/api/portfolios/${portfolioId}/detail`);
const detail = await api.get<{
portfolio_type: string;
risk_asset_ratio: number | null;
holdings: { ticker: string; name: string | null }[];
}>(`/api/portfolios/${portfolioId}/detail`);
const names: Record<string, string> = {};
for (const h of detail.holdings) {
if (h.name) names[h.ticker] = h.name;
}
setNameMap(names);
setPortfolioType(detail.portfolio_type);
setCurrentRiskRatio(detail.risk_asset_ratio);
} catch {
// Names are optional, continue without
// ignore
}
} catch {
router.push('/login');
@ -316,7 +323,19 @@ export default function RebalancePage() {
</CardContent>
</Card>
{/* Results */}
{result && portfolioType === 'pension' && currentRiskRatio !== null && currentRiskRatio > 70 && (
<div className="bg-amber-50 border border-amber-300 text-amber-800 dark:bg-amber-950 dark:border-amber-700 dark:text-amber-200 px-4 py-3 rounded mb-4 flex items-start gap-2">
<span className="text-lg">&#9888;</span>
<div>
<p className="font-medium">DC형 </p>
<p className="text-sm mt-1">
: <strong>{currentRiskRatio.toFixed(1)}%</strong> ( 한도: 70%).
/ ETF .
</p>
</div>
</div>
)}
{result && (
<>
<Card className="mb-6">

View File

@ -5,9 +5,21 @@ import { useRouter } from 'next/navigation';
import Link from 'next/link';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog';
import { PortfolioCard } from '@/components/portfolio/portfolio-card';
import { Skeleton } from '@/components/ui/skeleton';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import { Pencil, Trash2 } from 'lucide-react';
interface HoldingWithValue {
ticker: string;
@ -33,6 +45,17 @@ export default function PortfolioListPage() {
const [portfolios, setPortfolios] = useState<Portfolio[]>([]);
const [error, setError] = useState<string | null>(null);
// Edit modal state
const [editModalOpen, setEditModalOpen] = useState(false);
const [editTarget, setEditTarget] = useState<Portfolio | null>(null);
const [editName, setEditName] = useState('');
const [editSaving, setEditSaving] = useState(false);
// Delete modal state
const [deleteModalOpen, setDeleteModalOpen] = useState(false);
const [deleteTarget, setDeleteTarget] = useState<Portfolio | null>(null);
const [deleting, setDeleting] = useState(false);
useEffect(() => {
const init = async () => {
try {
@ -88,6 +111,53 @@ export default function PortfolioListPage() {
return (portfolio.total_profit_loss / portfolio.total_invested) * 100;
};
const handleOpenEdit = (e: React.MouseEvent, portfolio: Portfolio) => {
e.preventDefault();
e.stopPropagation();
setEditTarget(portfolio);
setEditName(portfolio.name);
setEditModalOpen(true);
};
const handleSaveEdit = async () => {
if (!editTarget || !editName.trim()) return;
setEditSaving(true);
try {
await api.put(`/api/portfolios/${editTarget.id}`, { name: editName.trim() });
toast.success('포트폴리오 이름이 변경되었습니다.');
setEditModalOpen(false);
await fetchPortfolios();
} catch (err) {
console.error('Failed to update portfolio:', err);
toast.error('포트폴리오 수정에 실패했습니다.');
} finally {
setEditSaving(false);
}
};
const handleOpenDelete = (e: React.MouseEvent, portfolio: Portfolio) => {
e.preventDefault();
e.stopPropagation();
setDeleteTarget(portfolio);
setDeleteModalOpen(true);
};
const handleConfirmDelete = async () => {
if (!deleteTarget) return;
setDeleting(true);
try {
await api.delete(`/api/portfolios/${deleteTarget.id}`);
toast.success('포트폴리오가 삭제되었습니다.');
setDeleteModalOpen(false);
setPortfolios((prev) => prev.filter((p) => p.id !== deleteTarget.id));
} catch (err) {
console.error('Failed to delete portfolio:', err);
toast.error('포트폴리오 삭제에 실패했습니다.');
} finally {
setDeleting(false);
}
};
if (loading) {
return (
<DashboardLayout>
@ -121,15 +191,32 @@ export default function PortfolioListPage() {
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{portfolios.map((portfolio) => (
<PortfolioCard
key={portfolio.id}
id={portfolio.id}
name={portfolio.name}
portfolioType={portfolio.portfolio_type}
totalValue={portfolio.total_value ?? null}
returnPercent={calculateReturnPercent(portfolio)}
holdings={portfolio.holdings ?? []}
/>
<div key={portfolio.id} className="relative group">
<PortfolioCard
id={portfolio.id}
name={portfolio.name}
portfolioType={portfolio.portfolio_type}
totalValue={portfolio.total_value ?? null}
returnPercent={calculateReturnPercent(portfolio)}
holdings={portfolio.holdings ?? []}
/>
<div className="absolute top-3 right-14 flex gap-1 opacity-0 group-hover:opacity-100 transition-opacity z-10">
<button
onClick={(e) => handleOpenEdit(e, portfolio)}
className="p-1.5 rounded-md bg-background/80 border border-border hover:bg-muted text-muted-foreground hover:text-foreground"
title="이름 변경"
>
<Pencil className="h-3.5 w-3.5" />
</button>
<button
onClick={(e) => handleOpenDelete(e, portfolio)}
className="p-1.5 rounded-md bg-background/80 border border-border hover:bg-destructive/10 text-muted-foreground hover:text-destructive"
title="삭제"
>
<Trash2 className="h-3.5 w-3.5" />
</button>
</div>
</div>
))}
</div>
@ -161,6 +248,57 @@ export default function PortfolioListPage() {
</Button>
</div>
)}
{/* Edit Modal */}
<Dialog open={editModalOpen} onOpenChange={setEditModalOpen}>
<DialogContent className="sm:max-w-md">
<DialogHeader>
<DialogTitle> </DialogTitle>
<DialogDescription>
.
</DialogDescription>
</DialogHeader>
<div className="space-y-4">
<div className="space-y-2">
<Label htmlFor="edit-name"></Label>
<Input
id="edit-name"
value={editName}
onChange={(e) => setEditName(e.target.value)}
placeholder="포트폴리오 이름"
onKeyDown={(e) => e.key === 'Enter' && handleSaveEdit()}
/>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setEditModalOpen(false)} disabled={editSaving}>
</Button>
<Button onClick={handleSaveEdit} disabled={editSaving || !editName.trim()}>
{editSaving ? '저장 중...' : '저장'}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* Delete Confirmation Modal */}
<Dialog open={deleteModalOpen} onOpenChange={setDeleteModalOpen}>
<DialogContent className="sm:max-w-md">
<DialogHeader>
<DialogTitle> </DialogTitle>
<DialogDescription>
&quot;{deleteTarget?.name}&quot; ? .
</DialogDescription>
</DialogHeader>
<DialogFooter>
<Button variant="outline" onClick={() => setDeleteModalOpen(false)} disabled={deleting}>
</Button>
<Button variant="destructive" onClick={handleConfirmDelete} disabled={deleting}>
{deleting ? '삭제 중...' : '삭제'}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</DashboardLayout>
);
}

View File

@ -25,6 +25,7 @@ import {
SelectValue,
} from '@/components/ui/select';
import { api } from '@/lib/api';
import { toast } from 'sonner';
import { Radio, History, RefreshCw, ArrowUpCircle, ArrowDownCircle, MinusCircle, Play } from 'lucide-react';
interface Signal {
@ -55,6 +56,20 @@ interface Holding {
avg_price: number;
}
interface PositionSize {
ticker: string;
price: number;
total_portfolio_value: number;
current_holding_quantity: number;
current_holding_value: number;
current_ratio: number;
target_ratio: number | null;
recommended_quantity: number;
max_quantity: number;
recommended_value: number;
max_value: number;
}
const signalTypeConfig: Record<string, { label: string; style: string; icon: typeof ArrowUpCircle }> = {
buy: {
label: '매수',
@ -116,6 +131,14 @@ export default function SignalsPage() {
const [executing, setExecuting] = useState(false);
const [executeError, setExecuteError] = useState('');
const [currentHoldings, setCurrentHoldings] = useState<Holding[]>([]);
const [positionSize, setPositionSize] = useState<PositionSize | null>(null);
// Cancel modal state
const [cancelModalOpen, setCancelModalOpen] = useState(false);
const [cancelSignal, setCancelSignal] = useState<Signal | null>(null);
const [cancelPortfolioId, setCancelPortfolioId] = useState('');
const [cancelling, setCancelling] = useState(false);
const [cancelError, setCancelError] = useState('');
useEffect(() => {
const init = async () => {
@ -135,8 +158,8 @@ export default function SignalsPage() {
try {
const data = await api.get<Signal[]>('/api/signal/kjb/today');
setTodaySignals(data);
} catch (err) {
console.error('Failed to fetch today signals:', err);
} catch {
toast.error('오늘의 신호를 불러오는데 실패했습니다.');
}
};
@ -150,8 +173,8 @@ export default function SignalsPage() {
const url = `/api/signal/kjb/history${query ? `?${query}` : ''}`;
const data = await api.get<Signal[]>(url);
setHistorySignals(data);
} catch (err) {
console.error('Failed to fetch signal history:', err);
} catch {
toast.error('신호 이력을 불러오는데 실패했습니다.');
}
};
@ -159,8 +182,8 @@ export default function SignalsPage() {
try {
const data = await api.get<Portfolio[]>('/api/portfolios');
setPortfolios(data);
} catch (err) {
console.error('Failed to fetch portfolios:', err);
} catch {
toast.error('포트폴리오 목록을 불러오는데 실패했습니다.');
}
};
@ -196,12 +219,14 @@ export default function SignalsPage() {
setSelectedPortfolioId('');
setExecuteError('');
setCurrentHoldings([]);
setPositionSize(null);
await fetchPortfolios();
setExecuteModalOpen(true);
};
const handlePortfolioChange = async (portfolioId: string) => {
setSelectedPortfolioId(portfolioId);
setPositionSize(null);
if (portfolioId) {
try {
const holdings = await api.get<Holding[]>(`/api/portfolios/${portfolioId}/holdings`);
@ -215,6 +240,21 @@ export default function SignalsPage() {
: Math.floor(holding.quantity / 2);
setExecuteQuantity(String(defaultQty));
}
// 매수 신호일 때 포지션 사이징 가이드 조회
if (executeSignal.signal_type === 'buy' && executeSignal.entry_price) {
try {
const ps = await api.get<PositionSize>(
`/api/portfolios/${portfolioId}/position-size?ticker=${executeSignal.ticker}&price=${executeSignal.entry_price}`
);
setPositionSize(ps);
if (ps.recommended_quantity > 0) {
setExecuteQuantity(String(ps.recommended_quantity));
}
} catch {
toast.error('포지션 사이징 정보를 불러오는데 실패했습니다.');
}
}
}
} catch {
setCurrentHoldings([]);
@ -265,6 +305,40 @@ export default function SignalsPage() {
}
};
const handleOpenCancelModal = async (signal: Signal) => {
setCancelSignal(signal);
setCancelError('');
setCancelPortfolioId('');
// 포트폴리오 목록 로드 (이미 있으면 재사용)
if (portfolios.length === 0) {
const pResp = await api.get('/api/portfolios') as { data: Portfolio[] };
setPortfolios(pResp.data);
}
setCancelModalOpen(true);
};
const handleSubmitCancel = async () => {
if (!cancelSignal || !cancelPortfolioId) {
setCancelError('포트폴리오를 선택해주세요.');
return;
}
setCancelling(true);
setCancelError('');
try {
await api.delete(`/api/signal/${cancelSignal.id}/cancel?portfolio_id=${cancelPortfolioId}`);
setCancelModalOpen(false);
if (showHistory) {
await fetchHistorySignals();
} else {
await fetchTodaySignals();
}
} catch (err) {
setCancelError(err instanceof Error ? err.message : '취소에 실패했습니다.');
} finally {
setCancelling(false);
}
};
const renderSignalTable = (signals: Signal[]) => (
<div className="overflow-x-auto">
<table className="w-full">
@ -334,6 +408,16 @@ export default function SignalsPage() {
</Button>
)}
{signal.status === 'executed' && (
<Button
variant="outline"
size="sm"
className="text-destructive border-destructive hover:bg-destructive hover:text-destructive-foreground"
onClick={() => handleOpenCancelModal(signal)}
>
</Button>
)}
</td>
</tr>
);
@ -574,6 +658,29 @@ export default function SignalsPage() {
);
})()}
{/* Position sizing guide */}
{positionSize && executeSignal?.signal_type === 'buy' && (
<div className="rounded-md border border-blue-200 bg-blue-50 dark:border-blue-800 dark:bg-blue-950 p-3 space-y-2 text-sm">
<p className="font-medium text-blue-800 dark:text-blue-200"> </p>
<div className="grid grid-cols-2 gap-x-4 gap-y-1 text-blue-700 dark:text-blue-300">
<span> </span>
<span className="text-right font-mono">{formatPrice(positionSize.total_portfolio_value)}</span>
<span> </span>
<span className="text-right font-mono">{positionSize.current_ratio.toFixed(1)}%</span>
{positionSize.target_ratio !== null && (
<>
<span> </span>
<span className="text-right font-mono">{positionSize.target_ratio.toFixed(1)}%</span>
</>
)}
<span> </span>
<span className="text-right font-mono font-medium">{positionSize.recommended_quantity.toLocaleString()}</span>
<span> </span>
<span className="text-right font-mono">{positionSize.max_quantity.toLocaleString()}</span>
</div>
</div>
)}
{/* Quantity */}
<div className="space-y-2">
<Label htmlFor="exec-quantity"> ()</Label>
@ -617,6 +724,69 @@ export default function SignalsPage() {
</DialogFooter>
</DialogContent>
</Dialog>
{/* 신호 취소 확인 모달 */}
<Dialog open={cancelModalOpen} onOpenChange={setCancelModalOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle> </DialogTitle>
<DialogDescription>
. .
</DialogDescription>
</DialogHeader>
{cancelSignal && (
<div className="space-y-4">
<div className="bg-muted/50 rounded-lg p-3 space-y-1">
<div className="flex justify-between text-sm">
<span className="text-muted-foreground"></span>
<span className="font-medium">{cancelSignal.name || cancelSignal.ticker} ({cancelSignal.ticker})</span>
</div>
<div className="flex justify-between text-sm">
<span className="text-muted-foreground"></span>
<span className="font-mono">{cancelSignal.executed_price ? cancelSignal.executed_price.toLocaleString() : '-'}</span>
</div>
<div className="flex justify-between text-sm">
<span className="text-muted-foreground"> </span>
<span className="font-mono">{cancelSignal.executed_quantity?.toLocaleString() || '-'}</span>
</div>
</div>
<div className="space-y-2">
<Label htmlFor="cancel-portfolio"> </Label>
<select
id="cancel-portfolio"
className="w-full border rounded-md px-3 py-2 text-sm bg-background"
value={cancelPortfolioId}
onChange={(e) => setCancelPortfolioId(e.target.value)}
>
<option value=""> ...</option>
{portfolios.map((p) => (
<option key={p.id} value={p.id}>{p.name}</option>
))}
</select>
</div>
{cancelError && (
<p className="text-sm text-red-600">{cancelError}</p>
)}
</div>
)}
<DialogFooter>
<Button variant="outline" onClick={() => setCancelModalOpen(false)} disabled={cancelling}>
</Button>
<Button
variant="destructive"
onClick={handleSubmitCancel}
disabled={cancelling || !cancelPortfolioId}
>
{cancelling ? '취소 중...' : '실행 취소'}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</DashboardLayout>
);
}

View File

@ -0,0 +1,387 @@
'use client';
import React, { useEffect, useState } from 'react';
import { useRouter } from 'next/navigation';
import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Skeleton } from '@/components/ui/skeleton';
import { Badge } from '@/components/ui/badge';
import { api } from '@/lib/api';
interface StockFactor {
ticker: string;
name: string;
market: string;
sector_name: string | null;
market_cap: number | null;
close_price: number | null;
per: number | null;
pbr: number | null;
value_score: number | null;
quality_score: number | null;
momentum_score: number | null;
total_score: number | null;
rank: number | null;
}
interface StrategyResult {
strategy_name: string;
base_date: string;
universe_count: number;
result_count: number;
stocks: StockFactor[];
}
const STRATEGIES = [
{
key: 'multi-factor',
label: '멀티팩터',
payload: {
universe: { markets: ['KOSPI', 'KOSDAQ'], exclude_stock_types: ['spac', 'preferred', 'reit'] },
top_n: 30,
weights: { value: 0.3, quality: 0.3, momentum: 0.2, low_vol: 0.2 },
},
},
{
key: 'quality',
label: '퀄리티',
payload: {
universe: { markets: ['KOSPI', 'KOSDAQ'], exclude_stock_types: ['spac', 'preferred', 'reit'] },
top_n: 30,
min_fscore: 7,
},
},
{
key: 'value-momentum',
label: '밸류모멘텀',
payload: {
universe: { markets: ['KOSPI', 'KOSDAQ'], exclude_stock_types: ['spac', 'preferred', 'reit'] },
top_n: 30,
value_weight: 0.5,
momentum_weight: 0.5,
},
},
] as const;
type StrategyKey = (typeof STRATEGIES)[number]['key'];
export default function StrategyComparePage() {
const router = useRouter();
const [initialLoading, setInitialLoading] = useState(true);
const [loading, setLoading] = useState(false);
const [results, setResults] = useState<Record<string, StrategyResult>>({});
const [error, setError] = useState<string | null>(null);
useEffect(() => {
const init = async () => {
try {
await api.getCurrentUser();
} catch {
router.push('/login');
} finally {
setInitialLoading(false);
}
};
init();
}, [router]);
const runAll = async () => {
setLoading(true);
setError(null);
try {
const promises = STRATEGIES.map((s) =>
api.post<StrategyResult>(`/api/strategy/${s.key}`, s.payload)
);
const responses = await Promise.all(promises);
const map: Record<string, StrategyResult> = {};
STRATEGIES.forEach((s, i) => {
map[s.key] = responses[i];
});
setResults(map);
} catch (err) {
setError(err instanceof Error ? err.message : '전략 실행 실패');
} finally {
setLoading(false);
}
};
const formatNumber = (value: number | null, decimals = 2) => {
if (value === null) return '-';
return value.toFixed(decimals);
};
const formatCurrency = (value: number | null) => {
if (value === null) return '-';
return new Intl.NumberFormat('ko-KR').format(value);
};
// Find common tickers across all results
const getCommonTickers = (): Set<string> => {
const resultKeys = Object.keys(results);
if (resultKeys.length < 2) return new Set();
const tickerSets = resultKeys.map(
(key) => new Set(results[key].stocks.map((s) => s.ticker))
);
const common = new Set<string>();
tickerSets[0].forEach((ticker) => {
if (tickerSets.every((set) => set.has(ticker))) {
common.add(ticker);
}
});
return common;
};
// Find tickers that appear in at least 2 strategies
const getOverlapTickers = (): Set<string> => {
const resultKeys = Object.keys(results);
if (resultKeys.length < 2) return new Set();
const tickerCount: Record<string, number> = {};
resultKeys.forEach((key) => {
results[key].stocks.forEach((s) => {
tickerCount[s.ticker] = (tickerCount[s.ticker] || 0) + 1;
});
});
return new Set(
Object.entries(tickerCount)
.filter(([, count]) => count >= 2)
.map(([ticker]) => ticker)
);
};
const hasResults = Object.keys(results).length === STRATEGIES.length;
const commonTickers = hasResults ? getCommonTickers() : new Set<string>();
const overlapTickers = hasResults ? getOverlapTickers() : new Set<string>();
if (initialLoading) {
return (
<DashboardLayout>
<Skeleton className="h-8 w-48 mb-6" />
<Skeleton className="h-48 rounded-xl" />
</DashboardLayout>
);
}
return (
<DashboardLayout>
<div className="mb-6">
<h1 className="text-2xl font-bold text-foreground"> </h1>
<p className="mt-1 text-muted-foreground">
, , 3
</p>
</div>
{error && (
<div className="bg-destructive/10 border border-destructive text-destructive px-4 py-3 rounded mb-4">
{error}
</div>
)}
<div className="mb-6">
<Button onClick={runAll} disabled={loading} size="lg">
{loading ? '3개 전략 실행 중...' : '전략 비교 실행'}
</Button>
</div>
{hasResults && (
<>
{/* Summary */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4 mb-6">
{STRATEGIES.map((s) => {
const r = results[s.key];
return (
<Card key={s.key}>
<CardHeader className="pb-2">
<CardTitle className="text-lg">{s.label}</CardTitle>
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground">
: {r.base_date}
</p>
<p className="text-sm text-muted-foreground">
: {r.universe_count} / : {r.result_count}
</p>
</CardContent>
</Card>
);
})}
</div>
{/* Common stocks highlight */}
{commonTickers.size > 0 && (
<Card className="mb-6">
<CardHeader>
<CardTitle className="text-lg">
({commonTickers.size})
<span className="text-sm font-normal text-muted-foreground ml-2">
3
</span>
</CardTitle>
</CardHeader>
<CardContent>
<div className="flex flex-wrap gap-2">
{Array.from(commonTickers).map((ticker) => {
const stock = results[STRATEGIES[0].key].stocks.find(
(s) => s.ticker === ticker
);
return (
<Badge key={ticker} variant="default">
{stock?.name || ticker}
</Badge>
);
})}
</div>
</CardContent>
</Card>
)}
{/* Side-by-side tables */}
<div className="grid grid-cols-1 lg:grid-cols-3 gap-4">
{STRATEGIES.map((s) => {
const r = results[s.key];
return (
<Card key={s.key}>
<CardHeader className="pb-2">
<CardTitle className="text-base">{s.label} </CardTitle>
</CardHeader>
<CardContent className="p-0">
<div className="overflow-x-auto">
<table className="w-full text-sm">
<thead className="bg-muted">
<tr>
<th scope="col" className="px-3 py-2 text-left font-medium text-muted-foreground">#</th>
<th scope="col" className="px-3 py-2 text-left font-medium text-muted-foreground"></th>
<th scope="col" className="px-3 py-2 text-right font-medium text-muted-foreground"></th>
<th scope="col" className="px-3 py-2 text-right font-medium text-muted-foreground"></th>
</tr>
</thead>
<tbody className="divide-y divide-border">
{r.stocks.map((stock) => {
const isCommon = commonTickers.has(stock.ticker);
const isOverlap = overlapTickers.has(stock.ticker);
return (
<tr
key={stock.ticker}
className={
isCommon
? 'bg-primary/10'
: isOverlap
? 'bg-accent/50'
: 'hover:bg-muted/50'
}
>
<td className="px-3 py-2 font-medium">{stock.rank}</td>
<td className="px-3 py-2">
<span className="font-medium" title={stock.ticker}>
{stock.name || stock.ticker}
</span>
{isCommon && (
<Badge variant="default" className="ml-1 text-[10px] px-1 py-0">
</Badge>
)}
</td>
<td className="px-3 py-2 text-right tabular-nums">
{formatCurrency(stock.close_price)}
</td>
<td className="px-3 py-2 text-right font-medium tabular-nums">
{formatNumber(stock.total_score)}
</td>
</tr>
);
})}
</tbody>
</table>
</div>
</CardContent>
</Card>
);
})}
</div>
{/* Detailed comparison table */}
<Card className="mt-6">
<CardHeader>
<CardTitle> </CardTitle>
<p className="text-sm text-muted-foreground">
2
</p>
</CardHeader>
<CardContent className="p-0">
<div className="overflow-x-auto">
<table className="w-full text-sm">
<thead className="bg-muted">
<tr>
<th scope="col" className="px-4 py-3 text-left font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right font-medium text-muted-foreground">()</th>
{STRATEGIES.map((s) => (
<th key={s.key} scope="col" className="px-4 py-3 text-center font-medium text-muted-foreground" colSpan={2}>
{s.label}
</th>
))}
</tr>
<tr className="border-t">
<th scope="col" className="px-4 py-1"></th>
<th scope="col" className="px-4 py-1"></th>
{STRATEGIES.map((s) => (
<React.Fragment key={s.key}>
<th scope="col" className="px-2 py-1 text-right text-xs text-muted-foreground"></th>
<th scope="col" className="px-2 py-1 text-right text-xs text-muted-foreground"></th>
</React.Fragment>
))}
</tr>
</thead>
<tbody className="divide-y divide-border">
{Array.from(overlapTickers).map((ticker) => {
const stockData = STRATEGIES.map((s) => {
return results[s.key].stocks.find((st) => st.ticker === ticker) || null;
});
const anyStock = stockData.find((s) => s !== null);
const isCommon = commonTickers.has(ticker);
return (
<tr key={ticker} className={isCommon ? 'bg-primary/10' : 'hover:bg-muted/50'}>
<td className="px-4 py-2 font-medium">
{anyStock?.name || ticker}
{isCommon && (
<Badge variant="default" className="ml-1 text-[10px] px-1 py-0">
</Badge>
)}
</td>
<td className="px-4 py-2 text-right tabular-nums">
{formatCurrency(anyStock?.market_cap ?? null)}
</td>
{stockData.map((stock, i) => (
<React.Fragment key={STRATEGIES[i].key}>
<td className="px-2 py-2 text-right tabular-nums">
{stock ? stock.rank : '-'}
</td>
<td className="px-2 py-2 text-right tabular-nums">
{stock ? formatNumber(stock.total_score) : '-'}
</td>
</React.Fragment>
))}
</tr>
);
})}
{overlapTickers.size === 0 && (
<tr>
<td colSpan={2 + STRATEGIES.length * 2} className="px-4 py-8 text-center text-muted-foreground">
2
</td>
</tr>
)}
</tbody>
</table>
</div>
</CardContent>
</Card>
</>
)}
</DashboardLayout>
);
}

View File

@ -9,6 +9,7 @@ import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import { Skeleton } from '@/components/ui/skeleton';
import { api } from '@/lib/api';
import { ApplyToPortfolio } from '@/components/strategy/apply-to-portfolio';
interface StockFactor {
ticker: string;
@ -191,6 +192,9 @@ export default function KJBStrategyPage() {
</table>
</div>
</CardContent>
<div className="px-4 pb-4">
<ApplyToPortfolio stocks={result.stocks.map((s) => ({ ticker: s.ticker, name: s.name }))} />
</div>
</Card>
)}
</DashboardLayout>

View File

@ -6,7 +6,9 @@ import { DashboardLayout } from '@/components/layout/dashboard-layout';
import { StrategyCard } from '@/components/strategy/strategy-card';
import { api } from '@/lib/api';
import { Skeleton } from '@/components/ui/skeleton';
import { BarChart3, Star, TrendingUp, Zap } from 'lucide-react';
import { BarChart3, Star, TrendingUp, Zap, GitCompareArrows } from 'lucide-react';
import Link from 'next/link';
import { Button } from '@/components/ui/button';
const strategies = [
{
@ -87,6 +89,12 @@ export default function StrategyListPage() {
<p className="mt-1 text-muted-foreground">
퀀
</p>
<Link href="/strategy/compare" className="inline-block mt-3">
<Button variant="outline" size="sm">
<GitCompareArrows className="h-4 w-4 mr-2" />
</Button>
</Link>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6">

View File

@ -0,0 +1,73 @@
'use client';
import React from 'react';
import { Button } from '@/components/ui/button';
interface ErrorBoundaryProps {
children: React.ReactNode;
}
interface ErrorBoundaryState {
hasError: boolean;
}
export class ErrorBoundary extends React.Component<ErrorBoundaryProps, ErrorBoundaryState> {
constructor(props: ErrorBoundaryProps) {
super(props);
this.state = { hasError: false };
}
static getDerivedStateFromError(): ErrorBoundaryState {
return { hasError: true };
}
componentDidCatch(error: Error, errorInfo: React.ErrorInfo) {
console.error('ErrorBoundary caught an error:', error, errorInfo);
}
handleRetry = () => {
this.setState({ hasError: false });
};
render() {
if (this.state.hasError) {
return (
<div className="flex min-h-screen items-center justify-center bg-background">
<div className="text-center space-y-4 p-8">
<div className="inline-flex items-center justify-center w-16 h-16 rounded-full bg-destructive/10 mb-2">
<svg
className="w-8 h-8 text-destructive"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-2.5L13.732 4c-.77-.833-1.964-.833-2.732 0L4.082 16.5c-.77.833.192 2.5 1.732 2.5z"
/>
</svg>
</div>
<h2 className="text-xl font-semibold text-foreground">
</h2>
<p className="text-muted-foreground">
.
</p>
<div className="flex gap-3 justify-center">
<Button onClick={this.handleRetry} variant="outline">
</Button>
<Button onClick={() => window.location.reload()}>
</Button>
</div>
</div>
</div>
);
}
return this.props.children;
}
}

View File

@ -45,8 +45,8 @@ export function NewHeader({ username, onMenuClick, showMenuButton = false }: New
const router = useRouter();
const pageTitle = getPageTitle(pathname);
const handleLogout = () => {
api.logout();
const handleLogout = async () => {
await api.logout();
router.push('/login');
};

View File

@ -1,9 +1,26 @@
'use client';
import { useEffect, useState } from 'react';
import { useRouter } from 'next/navigation';
import { Button } from '@/components/ui/button';
import { Label } from '@/components/ui/label';
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
DialogFooter,
} from '@/components/ui/dialog';
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { api } from '@/lib/api';
import { toast } from 'sonner';
interface Portfolio {
id: number;
@ -21,19 +38,18 @@ interface ApplyToPortfolioProps {
}
export function ApplyToPortfolio({ stocks }: ApplyToPortfolioProps) {
const router = useRouter();
const [portfolios, setPortfolios] = useState<Portfolio[]>([]);
const [selectedId, setSelectedId] = useState<number | null>(null);
const [showConfirm, setShowConfirm] = useState(false);
const [selectedId, setSelectedId] = useState<string>('');
const [dialogOpen, setDialogOpen] = useState(false);
const [applying, setApplying] = useState(false);
const [error, setError] = useState<string | null>(null);
const [success, setSuccess] = useState(false);
useEffect(() => {
const load = async () => {
try {
const data = await api.get<Portfolio[]>('/api/portfolios');
setPortfolios(data);
if (data.length > 0) setSelectedId(data[0].id);
if (data.length > 0) setSelectedId(String(data[0].id));
} catch {
// ignore
}
@ -41,10 +57,10 @@ export function ApplyToPortfolio({ stocks }: ApplyToPortfolioProps) {
load();
}, []);
const apply = async () => {
if (!selectedId || stocks.length === 0) return;
const handleApply = async () => {
const portfolioId = Number(selectedId);
if (!portfolioId || stocks.length === 0) return;
setApplying(true);
setError(null);
try {
const ratio = parseFloat((100 / stocks.length).toFixed(2));
const targets: TargetItem[] = stocks.map((s, i) => ({
@ -54,12 +70,16 @@ export function ApplyToPortfolio({ stocks }: ApplyToPortfolioProps) {
: ratio,
}));
await api.put(`/api/portfolios/${selectedId}/targets`, targets);
setShowConfirm(false);
setSuccess(true);
setTimeout(() => setSuccess(false), 3000);
await api.put(`/api/portfolios/${portfolioId}/targets`, targets);
setDialogOpen(false);
toast.success('목표 배분이 적용되었습니다.', {
action: {
label: '리밸런싱으로 이동',
onClick: () => router.push(`/portfolio/${portfolioId}/rebalance`),
},
});
} catch (err) {
setError(err instanceof Error ? err.message : '적용 실패');
toast.error(err instanceof Error ? err.message : '목표 배분 적용 실패했습니다.');
} finally {
setApplying(false);
}
@ -68,71 +88,58 @@ export function ApplyToPortfolio({ stocks }: ApplyToPortfolioProps) {
if (portfolios.length === 0) return null;
return (
<div className="mt-4">
<div className="flex items-end gap-3">
<div>
<Label htmlFor="portfolio-select"> </Label>
<select
id="portfolio-select"
className="mt-1 block w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
value={selectedId ?? ''}
onChange={(e) => setSelectedId(Number(e.target.value))}
>
{portfolios.map((p) => (
<option key={p.id} value={p.id}>
{p.name} ({p.portfolio_type === 'pension' ? '퇴직연금' : '일반'})
</option>
))}
</select>
</div>
<Button onClick={() => setShowConfirm(true)}>
</Button>
</div>
<>
<Button className="mt-4" onClick={() => setDialogOpen(true)}>
</Button>
{success && (
<div className="mt-2 text-sm text-green-600 dark:text-green-400">
.
</div>
)}
<Dialog open={dialogOpen} onOpenChange={setDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle> </DialogTitle>
<DialogDescription>
.
{stocks.length} ({(100 / stocks.length).toFixed(2)}%) .
</DialogDescription>
</DialogHeader>
{showConfirm && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/50">
<div className="bg-background rounded-lg shadow-lg max-w-md w-full mx-4 max-h-[80vh] overflow-y-auto">
<div className="p-6">
<h2 className="text-lg font-bold mb-2"> </h2>
<p className="text-sm text-muted-foreground mb-4">
.
{stocks.length} ({(100 / stocks.length).toFixed(2)}%) .
</p>
<div className="space-y-4">
<div className="space-y-2">
<Label> </Label>
<Select value={selectedId} onValueChange={setSelectedId}>
<SelectTrigger>
<SelectValue placeholder="포트폴리오를 선택하세요" />
</SelectTrigger>
<SelectContent>
{portfolios.map((p) => (
<SelectItem key={p.id} value={String(p.id)}>
{p.name} ({p.portfolio_type === 'pension' ? '퇴직연금' : '일반'})
</SelectItem>
))}
</SelectContent>
</Select>
</div>
{error && (
<div className="bg-destructive/10 border border-destructive text-destructive px-3 py-2 rounded mb-4 text-sm">
{error}
<div className="max-h-48 overflow-y-auto border rounded p-2">
{stocks.map((s) => (
<div key={s.ticker} className="text-sm py-1 flex justify-between">
<span>{s.name || s.ticker}</span>
<span className="text-muted-foreground">{(100 / stocks.length).toFixed(2)}%</span>
</div>
)}
<div className="max-h-48 overflow-y-auto mb-4 border rounded p-2">
{stocks.map((s) => (
<div key={s.ticker} className="text-sm py-1 flex justify-between">
<span>{s.name || s.ticker}</span>
<span className="text-muted-foreground">{(100 / stocks.length).toFixed(2)}%</span>
</div>
))}
</div>
<div className="flex justify-end gap-2">
<Button variant="outline" onClick={() => setShowConfirm(false)}>
</Button>
<Button onClick={apply} disabled={applying}>
{applying ? '적용 중...' : '적용'}
</Button>
</div>
))}
</div>
</div>
</div>
)}
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setDialogOpen(false)}>
</Button>
<Button onClick={handleApply} disabled={applying || !selectedId}>
{applying ? '적용 중...' : '적용'}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</>
);
}

View File

@ -10,27 +10,9 @@ async function hashPassword(password: string): Promise<string> {
class ApiClient {
private baseUrl: string;
private token: string | null = null;
constructor(baseUrl: string) {
this.baseUrl = baseUrl;
if (typeof window !== 'undefined') {
this.token = localStorage.getItem('token');
}
}
setToken(token: string) {
this.token = token;
if (typeof window !== 'undefined') {
localStorage.setItem('token', token);
}
}
clearToken() {
this.token = null;
if (typeof window !== 'undefined') {
localStorage.removeItem('token');
}
}
private async request<T>(
@ -42,13 +24,10 @@ class ApiClient {
...options.headers,
};
if (this.token) {
(headers as Record<string, string>)['Authorization'] = `Bearer ${this.token}`;
}
const response = await fetch(`${this.baseUrl}${endpoint}`, {
...options,
headers,
credentials: 'include',
});
if (!response.ok) {
@ -89,6 +68,7 @@ class ApiClient {
headers: {
'Content-Type': 'application/json',
},
credentials: 'include',
body: JSON.stringify({ username, password: hashedPassword }),
});
@ -97,13 +77,11 @@ class ApiClient {
throw new Error(error.detail || 'Login failed');
}
const data = await response.json();
this.setToken(data.access_token);
return data;
return response.json();
}
logout() {
this.clearToken();
async logout() {
await this.post('/api/auth/logout');
}
async getCurrentUser() {