""" Trading journal API endpoints. """ from datetime import date from decimal import Decimal from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from sqlalchemy import func from app.core.database import get_db from app.api.deps import CurrentUser from app.models.journal import TradeJournal, JournalStatus from app.schemas.journal import ( TradeJournalCreate, TradeJournalUpdate, TradeJournalResponse, TradeJournalStats, ) router = APIRouter(prefix="/api/journal", tags=["journal"]) @router.post("", response_model=TradeJournalResponse, status_code=201) async def create_journal( data: TradeJournalCreate, current_user: CurrentUser, db: Session = Depends(get_db), ): journal = TradeJournal( user_id=current_user.id, **data.model_dump(), ) db.add(journal) db.commit() db.refresh(journal) return journal @router.get("", response_model=List[TradeJournalResponse]) async def list_journals( current_user: CurrentUser, db: Session = Depends(get_db), status: Optional[str] = Query(None), stock_code: Optional[str] = Query(None), start_date: Optional[date] = Query(None), end_date: Optional[date] = Query(None), skip: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=200), ): query = db.query(TradeJournal).filter(TradeJournal.user_id == current_user.id) if status: query = query.filter(TradeJournal.status == status) if stock_code: query = query.filter(TradeJournal.stock_code == stock_code) if start_date: query = query.filter(TradeJournal.entry_date >= start_date) if end_date: query = query.filter(TradeJournal.entry_date <= end_date) journals = ( query.order_by(TradeJournal.entry_date.desc(), TradeJournal.created_at.desc()) .offset(skip) .limit(limit) .all() ) return journals @router.get("/stats", response_model=TradeJournalStats) async def get_journal_stats( current_user: CurrentUser, db: Session = Depends(get_db), ): journals = ( db.query(TradeJournal) .filter(TradeJournal.user_id == current_user.id) .all() ) total = len(journals) open_trades = sum(1 for j in journals if j.status == JournalStatus.OPEN) closed = [j for j in journals if j.status == JournalStatus.CLOSED] closed_count = len(closed) closed_with_pnl = [j for j in closed if j.profit_loss_pct is not None] win_count = sum(1 for j in closed_with_pnl if j.profit_loss_pct > 0) loss_count = sum(1 for j in closed_with_pnl if j.profit_loss_pct <= 0) win_rate = None avg_pnl_pct = None max_profit = None max_loss = None total_pnl = None if closed_with_pnl: win_rate = Decimal(win_count) / Decimal(len(closed_with_pnl)) * 100 pcts = [j.profit_loss_pct for j in closed_with_pnl] avg_pnl_pct = sum(pcts) / len(pcts) max_profit = max(pcts) max_loss = min(pcts) closed_with_pl = [j for j in closed if j.profit_loss is not None] if closed_with_pl: total_pnl = sum(j.profit_loss for j in closed_with_pl) return TradeJournalStats( total_trades=total, open_trades=open_trades, closed_trades=closed_count, win_count=win_count, loss_count=loss_count, win_rate=win_rate, avg_profit_loss_pct=avg_pnl_pct, max_profit_pct=max_profit, max_loss_pct=max_loss, total_profit_loss=total_pnl, ) @router.get("/{journal_id}", response_model=TradeJournalResponse) async def get_journal( journal_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): journal = ( db.query(TradeJournal) .filter(TradeJournal.id == journal_id, TradeJournal.user_id == current_user.id) .first() ) if not journal: raise HTTPException(status_code=404, detail="Journal not found") return journal @router.put("/{journal_id}", response_model=TradeJournalResponse) async def update_journal( journal_id: int, data: TradeJournalUpdate, current_user: CurrentUser, db: Session = Depends(get_db), ): journal = ( db.query(TradeJournal) .filter(TradeJournal.id == journal_id, TradeJournal.user_id == current_user.id) .first() ) if not journal: raise HTTPException(status_code=404, detail="Journal not found") update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(journal, field, value) # Auto-calculate profit/loss when closing if data.exit_price is not None and journal.entry_price is not None: if journal.trade_type.value == "buy": journal.profit_loss = (data.exit_price - journal.entry_price) * (journal.quantity or 1) journal.profit_loss_pct = (data.exit_price - journal.entry_price) / journal.entry_price * 100 else: journal.profit_loss = (journal.entry_price - data.exit_price) * (journal.quantity or 1) journal.profit_loss_pct = (journal.entry_price - data.exit_price) / journal.entry_price * 100 # Auto-close when exit info is provided if data.exit_price is not None and data.exit_date is not None and journal.status == JournalStatus.OPEN: journal.status = JournalStatus.CLOSED db.commit() db.refresh(journal) return journal