galaxis-agent/agent/cost_guard.py

126 lines
4.4 KiB
Python
Raw Normal View History

"""API 비용 추적 및 제한.
Anthropic API 응답의 usage 필드에서 토큰 비용을 추적하고,
일일/작업당 한도를 초과하면 작업을 차단한다.
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timezone, date
import aiosqlite
logger = logging.getLogger(__name__)
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS cost_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
tokens_input INTEGER NOT NULL,
tokens_output INTEGER NOT NULL,
cost_usd REAL NOT NULL,
recorded_at TEXT NOT NULL
)
"""
DEFAULT_INPUT_COST_PER_TOKEN = 3.0 / 1_000_000
DEFAULT_OUTPUT_COST_PER_TOKEN = 15.0 / 1_000_000
class CostGuard:
"""API 비용 추적 및 제한."""
def __init__(
self,
db_path: str = "/data/cost_guard.db",
daily_limit: float = 10.0,
per_task_limit: float = 3.0,
input_cost_per_token: float = DEFAULT_INPUT_COST_PER_TOKEN,
output_cost_per_token: float = DEFAULT_OUTPUT_COST_PER_TOKEN,
):
self._db_path = db_path
self._daily_limit = daily_limit
self._per_task_limit = per_task_limit
self._input_cost = input_cost_per_token
self._output_cost = output_cost_per_token
self._db: aiosqlite.Connection | None = None
async def initialize(self) -> None:
self._db = await aiosqlite.connect(self._db_path)
await self._db.execute(_CREATE_TABLE)
await self._db.commit()
async def close(self) -> None:
if self._db:
await self._db.close()
def calculate_cost(self, tokens_input: int, tokens_output: int) -> float:
return tokens_input * self._input_cost + tokens_output * self._output_cost
async def record_usage(self, task_id: str, tokens_input: int, tokens_output: int) -> float:
cost = self.calculate_cost(tokens_input, tokens_output)
now = datetime.now(timezone.utc).isoformat()
await self._db.execute(
"INSERT INTO cost_records (task_id, tokens_input, tokens_output, cost_usd, recorded_at) VALUES (?, ?, ?, ?, ?)",
(task_id, tokens_input, tokens_output, cost, now),
)
await self._db.commit()
logger.info("Recorded cost $%.4f for task %s (in=%d, out=%d)", cost, task_id, tokens_input, tokens_output)
return cost
async def get_daily_cost(self) -> float:
today = date.today().isoformat()
cursor = await self._db.execute(
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE recorded_at >= ?", (today,),
)
row = await cursor.fetchone()
return row[0]
async def get_task_cost(self, task_id: str) -> float:
cursor = await self._db.execute(
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE task_id = ?", (task_id,),
)
row = await cursor.fetchone()
return row[0]
async def check_daily_limit(self) -> bool:
daily = await self.get_daily_cost()
return daily < self._daily_limit
async def check_task_limit(self, task_id: str) -> bool:
task_cost = await self.get_task_cost(task_id)
return task_cost < self._per_task_limit
async def get_daily_summary(self) -> dict:
today = date.today().isoformat()
cursor = await self._db.execute(
"SELECT COUNT(*), COALESCE(SUM(cost_usd), 0), COALESCE(SUM(tokens_input), 0), COALESCE(SUM(tokens_output), 0) FROM cost_records WHERE recorded_at >= ?",
(today,),
)
row = await cursor.fetchone()
total_cost = row[1]
return {
"record_count": row[0],
"total_cost_usd": round(total_cost, 4),
"daily_limit_usd": self._daily_limit,
"remaining_usd": round(max(0, self._daily_limit - total_cost), 4),
"total_tokens_input": row[2],
"total_tokens_output": row[3],
}
_guard: CostGuard | None = None
async def get_cost_guard() -> CostGuard:
global _guard
if _guard is None:
db_path = os.environ.get("COST_GUARD_DB", "/data/cost_guard.db")
daily_limit = float(os.environ.get("DAILY_COST_LIMIT_USD", "10.0"))
per_task_limit = float(os.environ.get("PER_TASK_COST_LIMIT_USD", "3.0"))
_guard = CostGuard(db_path=db_path, daily_limit=daily_limit, per_task_limit=per_task_limit)
await _guard.initialize()
return _guard