From edeb336cb854ae571a2f77963a5efe67a370b560 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EB=A8=B8=EB=8B=88=ED=8E=98=EB=8B=88?= Date: Fri, 20 Mar 2026 18:40:53 +0900 Subject: [PATCH] feat: add CostGuard for API cost tracking and limiting --- agent/cost_guard.py | 125 +++++++++++++++++++++++++++++++++++++++ tests/test_cost_guard.py | 91 ++++++++++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 agent/cost_guard.py create mode 100644 tests/test_cost_guard.py diff --git a/agent/cost_guard.py b/agent/cost_guard.py new file mode 100644 index 0000000..ea0e030 --- /dev/null +++ b/agent/cost_guard.py @@ -0,0 +1,125 @@ +"""API 비용 추적 및 제한. + +Anthropic API 응답의 usage 필드에서 토큰 비용을 추적하고, +일일/작업당 한도를 초과하면 작업을 차단한다. +""" +from __future__ import annotations + +import json +import logging +import os +from datetime import datetime, timezone, date + +import aiosqlite + +logger = logging.getLogger(__name__) + +_CREATE_TABLE = """ +CREATE TABLE IF NOT EXISTS cost_records ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id TEXT NOT NULL, + tokens_input INTEGER NOT NULL, + tokens_output INTEGER NOT NULL, + cost_usd REAL NOT NULL, + recorded_at TEXT NOT NULL +) +""" + +DEFAULT_INPUT_COST_PER_TOKEN = 3.0 / 1_000_000 +DEFAULT_OUTPUT_COST_PER_TOKEN = 15.0 / 1_000_000 + + +class CostGuard: + """API 비용 추적 및 제한.""" + + def __init__( + self, + db_path: str = "/data/cost_guard.db", + daily_limit: float = 10.0, + per_task_limit: float = 3.0, + input_cost_per_token: float = DEFAULT_INPUT_COST_PER_TOKEN, + output_cost_per_token: float = DEFAULT_OUTPUT_COST_PER_TOKEN, + ): + self._db_path = db_path + self._daily_limit = daily_limit + self._per_task_limit = per_task_limit + self._input_cost = input_cost_per_token + self._output_cost = output_cost_per_token + self._db: aiosqlite.Connection | None = None + + async def initialize(self) -> None: + self._db = await aiosqlite.connect(self._db_path) + await self._db.execute(_CREATE_TABLE) + await self._db.commit() + + async def close(self) -> None: + if self._db: + await self._db.close() + + def calculate_cost(self, tokens_input: int, tokens_output: int) -> float: + return tokens_input * self._input_cost + tokens_output * self._output_cost + + async def record_usage(self, task_id: str, tokens_input: int, tokens_output: int) -> float: + cost = self.calculate_cost(tokens_input, tokens_output) + now = datetime.now(timezone.utc).isoformat() + await self._db.execute( + "INSERT INTO cost_records (task_id, tokens_input, tokens_output, cost_usd, recorded_at) VALUES (?, ?, ?, ?, ?)", + (task_id, tokens_input, tokens_output, cost, now), + ) + await self._db.commit() + logger.info("Recorded cost $%.4f for task %s (in=%d, out=%d)", cost, task_id, tokens_input, tokens_output) + return cost + + async def get_daily_cost(self) -> float: + today = date.today().isoformat() + cursor = await self._db.execute( + "SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE recorded_at >= ?", (today,), + ) + row = await cursor.fetchone() + return row[0] + + async def get_task_cost(self, task_id: str) -> float: + cursor = await self._db.execute( + "SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE task_id = ?", (task_id,), + ) + row = await cursor.fetchone() + return row[0] + + async def check_daily_limit(self) -> bool: + daily = await self.get_daily_cost() + return daily < self._daily_limit + + async def check_task_limit(self, task_id: str) -> bool: + task_cost = await self.get_task_cost(task_id) + return task_cost < self._per_task_limit + + async def get_daily_summary(self) -> dict: + today = date.today().isoformat() + cursor = await self._db.execute( + "SELECT COUNT(*), COALESCE(SUM(cost_usd), 0), COALESCE(SUM(tokens_input), 0), COALESCE(SUM(tokens_output), 0) FROM cost_records WHERE recorded_at >= ?", + (today,), + ) + row = await cursor.fetchone() + total_cost = row[1] + return { + "record_count": row[0], + "total_cost_usd": round(total_cost, 4), + "daily_limit_usd": self._daily_limit, + "remaining_usd": round(max(0, self._daily_limit - total_cost), 4), + "total_tokens_input": row[2], + "total_tokens_output": row[3], + } + + +_guard: CostGuard | None = None + + +async def get_cost_guard() -> CostGuard: + global _guard + if _guard is None: + db_path = os.environ.get("COST_GUARD_DB", "/data/cost_guard.db") + daily_limit = float(os.environ.get("DAILY_COST_LIMIT_USD", "10.0")) + per_task_limit = float(os.environ.get("PER_TASK_COST_LIMIT_USD", "3.0")) + _guard = CostGuard(db_path=db_path, daily_limit=daily_limit, per_task_limit=per_task_limit) + await _guard.initialize() + return _guard diff --git a/tests/test_cost_guard.py b/tests/test_cost_guard.py new file mode 100644 index 0000000..633b01a --- /dev/null +++ b/tests/test_cost_guard.py @@ -0,0 +1,91 @@ +import pytest +import os +import tempfile +from datetime import datetime, timezone + +from agent.cost_guard import CostGuard + + +@pytest.fixture +async def cost_guard(): + fd, db_path = tempfile.mkstemp(suffix=".db") + os.close(fd) + guard = CostGuard(db_path=db_path, daily_limit=10.0, per_task_limit=3.0) + await guard.initialize() + yield guard + await guard.close() + os.unlink(db_path) + + +@pytest.mark.asyncio +async def test_record_usage(cost_guard): + """API 사용량을 기록한다.""" + await cost_guard.record_usage( + task_id="task-1", + tokens_input=1000, + tokens_output=500, + ) + daily = await cost_guard.get_daily_cost() + assert daily > 0 + + +@pytest.mark.asyncio +async def test_calculate_cost(cost_guard): + """토큰에서 비용을 계산한다.""" + cost = cost_guard.calculate_cost(tokens_input=1_000_000, tokens_output=1_000_000) + # input: $3/MTok + output: $15/MTok = $18 + assert cost == pytest.approx(18.0, rel=0.01) + + +@pytest.mark.asyncio +async def test_check_daily_limit_ok(cost_guard): + """일일 한도 내에서 True를 반환한다.""" + result = await cost_guard.check_daily_limit() + assert result is True + + +@pytest.mark.asyncio +async def test_check_daily_limit_exceeded(cost_guard): + """일일 한도 초과 시 False를 반환한다.""" + for i in range(5): + await cost_guard.record_usage(f"task-{i}", tokens_input=1_000_000, tokens_output=200_000) + result = await cost_guard.check_daily_limit() + assert result is False + + +@pytest.mark.asyncio +async def test_check_task_limit_ok(cost_guard): + """작업당 한도 내에서 True를 반환한다.""" + await cost_guard.record_usage("task-1", tokens_input=100, tokens_output=50) + result = await cost_guard.check_task_limit("task-1") + assert result is True + + +@pytest.mark.asyncio +async def test_check_task_limit_exceeded(cost_guard): + """작업당 한도 초과 시 False를 반환한다.""" + await cost_guard.record_usage("task-1", tokens_input=1_000_000, tokens_output=100_000) + result = await cost_guard.check_task_limit("task-1") + assert result is False + + +@pytest.mark.asyncio +async def test_get_task_cost(cost_guard): + """특정 작업의 누적 비용을 반환한다.""" + await cost_guard.record_usage("task-1", tokens_input=1000, tokens_output=500) + await cost_guard.record_usage("task-1", tokens_input=2000, tokens_output=1000) + cost = await cost_guard.get_task_cost("task-1") + assert cost > 0 + other_cost = await cost_guard.get_task_cost("task-2") + assert other_cost == 0.0 + + +@pytest.mark.asyncio +async def test_get_daily_summary(cost_guard): + """일일 요약 정보를 반환한다.""" + await cost_guard.record_usage("task-1", tokens_input=1000, tokens_output=500) + summary = await cost_guard.get_daily_summary() + assert "total_cost_usd" in summary + assert "daily_limit_usd" in summary + assert "remaining_usd" in summary + assert summary["record_count"] == 1