feat: add CostGuard for API cost tracking and limiting
This commit is contained in:
parent
0c4c22be5a
commit
edeb336cb8
125
agent/cost_guard.py
Normal file
125
agent/cost_guard.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
"""API 비용 추적 및 제한.
|
||||||
|
|
||||||
|
Anthropic API 응답의 usage 필드에서 토큰 비용을 추적하고,
|
||||||
|
일일/작업당 한도를 초과하면 작업을 차단한다.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone, date
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_CREATE_TABLE = """
|
||||||
|
CREATE TABLE IF NOT EXISTS cost_records (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
task_id TEXT NOT NULL,
|
||||||
|
tokens_input INTEGER NOT NULL,
|
||||||
|
tokens_output INTEGER NOT NULL,
|
||||||
|
cost_usd REAL NOT NULL,
|
||||||
|
recorded_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_INPUT_COST_PER_TOKEN = 3.0 / 1_000_000
|
||||||
|
DEFAULT_OUTPUT_COST_PER_TOKEN = 15.0 / 1_000_000
|
||||||
|
|
||||||
|
|
||||||
|
class CostGuard:
|
||||||
|
"""API 비용 추적 및 제한."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str = "/data/cost_guard.db",
|
||||||
|
daily_limit: float = 10.0,
|
||||||
|
per_task_limit: float = 3.0,
|
||||||
|
input_cost_per_token: float = DEFAULT_INPUT_COST_PER_TOKEN,
|
||||||
|
output_cost_per_token: float = DEFAULT_OUTPUT_COST_PER_TOKEN,
|
||||||
|
):
|
||||||
|
self._db_path = db_path
|
||||||
|
self._daily_limit = daily_limit
|
||||||
|
self._per_task_limit = per_task_limit
|
||||||
|
self._input_cost = input_cost_per_token
|
||||||
|
self._output_cost = output_cost_per_token
|
||||||
|
self._db: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
await self._db.execute(_CREATE_TABLE)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._db:
|
||||||
|
await self._db.close()
|
||||||
|
|
||||||
|
def calculate_cost(self, tokens_input: int, tokens_output: int) -> float:
|
||||||
|
return tokens_input * self._input_cost + tokens_output * self._output_cost
|
||||||
|
|
||||||
|
async def record_usage(self, task_id: str, tokens_input: int, tokens_output: int) -> float:
|
||||||
|
cost = self.calculate_cost(tokens_input, tokens_output)
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
await self._db.execute(
|
||||||
|
"INSERT INTO cost_records (task_id, tokens_input, tokens_output, cost_usd, recorded_at) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(task_id, tokens_input, tokens_output, cost, now),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("Recorded cost $%.4f for task %s (in=%d, out=%d)", cost, task_id, tokens_input, tokens_output)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
async def get_daily_cost(self) -> float:
|
||||||
|
today = date.today().isoformat()
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE recorded_at >= ?", (today,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
async def get_task_cost(self, task_id: str) -> float:
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE task_id = ?", (task_id,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
async def check_daily_limit(self) -> bool:
|
||||||
|
daily = await self.get_daily_cost()
|
||||||
|
return daily < self._daily_limit
|
||||||
|
|
||||||
|
async def check_task_limit(self, task_id: str) -> bool:
|
||||||
|
task_cost = await self.get_task_cost(task_id)
|
||||||
|
return task_cost < self._per_task_limit
|
||||||
|
|
||||||
|
async def get_daily_summary(self) -> dict:
|
||||||
|
today = date.today().isoformat()
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"SELECT COUNT(*), COALESCE(SUM(cost_usd), 0), COALESCE(SUM(tokens_input), 0), COALESCE(SUM(tokens_output), 0) FROM cost_records WHERE recorded_at >= ?",
|
||||||
|
(today,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
total_cost = row[1]
|
||||||
|
return {
|
||||||
|
"record_count": row[0],
|
||||||
|
"total_cost_usd": round(total_cost, 4),
|
||||||
|
"daily_limit_usd": self._daily_limit,
|
||||||
|
"remaining_usd": round(max(0, self._daily_limit - total_cost), 4),
|
||||||
|
"total_tokens_input": row[2],
|
||||||
|
"total_tokens_output": row[3],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_guard: CostGuard | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cost_guard() -> CostGuard:
|
||||||
|
global _guard
|
||||||
|
if _guard is None:
|
||||||
|
db_path = os.environ.get("COST_GUARD_DB", "/data/cost_guard.db")
|
||||||
|
daily_limit = float(os.environ.get("DAILY_COST_LIMIT_USD", "10.0"))
|
||||||
|
per_task_limit = float(os.environ.get("PER_TASK_COST_LIMIT_USD", "3.0"))
|
||||||
|
_guard = CostGuard(db_path=db_path, daily_limit=daily_limit, per_task_limit=per_task_limit)
|
||||||
|
await _guard.initialize()
|
||||||
|
return _guard
|
||||||
91
tests/test_cost_guard.py
Normal file
91
tests/test_cost_guard.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from agent.cost_guard import CostGuard
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def cost_guard():
|
||||||
|
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
guard = CostGuard(db_path=db_path, daily_limit=10.0, per_task_limit=3.0)
|
||||||
|
await guard.initialize()
|
||||||
|
yield guard
|
||||||
|
await guard.close()
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_usage(cost_guard):
|
||||||
|
"""API 사용량을 기록한다."""
|
||||||
|
await cost_guard.record_usage(
|
||||||
|
task_id="task-1",
|
||||||
|
tokens_input=1000,
|
||||||
|
tokens_output=500,
|
||||||
|
)
|
||||||
|
daily = await cost_guard.get_daily_cost()
|
||||||
|
assert daily > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calculate_cost(cost_guard):
|
||||||
|
"""토큰에서 비용을 계산한다."""
|
||||||
|
cost = cost_guard.calculate_cost(tokens_input=1_000_000, tokens_output=1_000_000)
|
||||||
|
# input: $3/MTok + output: $15/MTok = $18
|
||||||
|
assert cost == pytest.approx(18.0, rel=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_daily_limit_ok(cost_guard):
|
||||||
|
"""일일 한도 내에서 True를 반환한다."""
|
||||||
|
result = await cost_guard.check_daily_limit()
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_daily_limit_exceeded(cost_guard):
|
||||||
|
"""일일 한도 초과 시 False를 반환한다."""
|
||||||
|
for i in range(5):
|
||||||
|
await cost_guard.record_usage(f"task-{i}", tokens_input=1_000_000, tokens_output=200_000)
|
||||||
|
result = await cost_guard.check_daily_limit()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_task_limit_ok(cost_guard):
|
||||||
|
"""작업당 한도 내에서 True를 반환한다."""
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=100, tokens_output=50)
|
||||||
|
result = await cost_guard.check_task_limit("task-1")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_task_limit_exceeded(cost_guard):
|
||||||
|
"""작업당 한도 초과 시 False를 반환한다."""
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=1_000_000, tokens_output=100_000)
|
||||||
|
result = await cost_guard.check_task_limit("task-1")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_task_cost(cost_guard):
|
||||||
|
"""특정 작업의 누적 비용을 반환한다."""
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=1000, tokens_output=500)
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=2000, tokens_output=1000)
|
||||||
|
cost = await cost_guard.get_task_cost("task-1")
|
||||||
|
assert cost > 0
|
||||||
|
other_cost = await cost_guard.get_task_cost("task-2")
|
||||||
|
assert other_cost == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_daily_summary(cost_guard):
|
||||||
|
"""일일 요약 정보를 반환한다."""
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=1000, tokens_output=500)
|
||||||
|
summary = await cost_guard.get_daily_summary()
|
||||||
|
assert "total_cost_usd" in summary
|
||||||
|
assert "daily_limit_usd" in summary
|
||||||
|
assert "remaining_usd" in summary
|
||||||
|
assert summary["record_count"] == 1
|
||||||
Loading…
x
Reference in New Issue
Block a user