Compare commits
8 Commits
140fbd17ff
...
f63499a1c3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f63499a1c3 | ||
|
|
b2c52abf06 | ||
|
|
c0cb4b7499 | ||
|
|
e82dfe18f9 | ||
|
|
3f0d021b02 | ||
|
|
edeb336cb8 | ||
|
|
0c4c22be5a | ||
|
|
db6e9b4a41 |
34
HANDOFF.md
34
HANDOFF.md
@ -72,12 +72,21 @@ galaxis-agent 리포에 16개 커밋, 40개 테스트 전부 통과, Gitea에 pu
|
|||||||
| Task 7: Health Check | ✅ COMPLETE | `d35efae` | /health, /health/gitea, /health/discord, /health/queue, 8 테스트 |
|
| Task 7: Health Check | ✅ COMPLETE | `d35efae` | /health, /health/gitea, /health/discord, /health/queue, 8 테스트 |
|
||||||
| Task 8: 전체 검증 | ✅ COMPLETE | `a58bbca` | 통합 테스트 (webhook→queue→dispatcher), 107 테스트 통과, import 확인 |
|
| Task 8: 전체 검증 | ✅ COMPLETE | `a58bbca` | 통합 테스트 (webhook→queue→dispatcher), 107 테스트 통과, import 확인 |
|
||||||
|
|
||||||
### Phase 4: 미작성
|
### Phase 4: 안정화 & 자율 모드 — COMPLETE
|
||||||
|
|
||||||
- CostGuard (일일/작업당 API 비용 제한)
|
**실행 방식**: Subagent-Driven Development (5개 독립 Task 병렬 → 2개 순차)
|
||||||
- 복구 메커니즘 (서버 재시작 시 미완료 작업 복구, 좀비 컨테이너 정리)
|
|
||||||
- 자동 머지 모드 (autonomous 설정 + E2E 통과 조건)
|
7개 커밋, **139개 테스트 통과** (107 Phase1-3 + 32 Phase4), Gitea에 push 완료.
|
||||||
- 구조화 로깅, 작업 이력 DB, 스모크 테스트
|
|
||||||
|
| Task | 상태 | 커밋 | 설명 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| Task 1: CostGuard | ✅ COMPLETE | `edeb336` | API 비용 추적/제한, 일일/작업당 한도, 8 테스트 |
|
||||||
|
| Task 2: TaskHistory | ✅ COMPLETE | `db6e9b4` | 완료 작업 이력 DB (SQLite), 4 테스트 |
|
||||||
|
| Task 3: Dispatcher 연동 | ✅ COMPLETE | `c0cb4b7` | CostGuard+TaskHistory를 Dispatcher에 통합, 2 테스트 |
|
||||||
|
| Task 4: JSON 로깅 | ✅ COMPLETE | `3f0d021` | 구조화 JSON 로깅, LOG_FORMAT 설정, 5 테스트 |
|
||||||
|
| Task 5: Recovery | ✅ COMPLETE | `e82dfe1` | 서버 시작 시 복구, ContainerCleaner (30분 주기), 4 테스트 |
|
||||||
|
| Task 6: AutoMerge | ✅ COMPLETE | `0c4c22b` | E2E 조건부 자동 머지, blocked_paths 보호, 7 테스트 |
|
||||||
|
| Task 7: webapp 통합 | ✅ COMPLETE | `b2c52ab` | Lifespan에 전 컴포넌트 통합, /health/costs 엔드포인트, 2 테스트 |
|
||||||
|
|
||||||
## What Worked
|
## What Worked
|
||||||
|
|
||||||
@ -121,18 +130,19 @@ result = await loop.run_in_executor(None, sandbox_backend.execute, cmd)
|
|||||||
|
|
||||||
## Next Steps
|
## Next Steps
|
||||||
|
|
||||||
### Phase 4 플랜 작성 필요
|
### Phase 5: 배포 & 모니터링
|
||||||
|
|
||||||
Phase 3이 완료되었으므로, Phase 4 플랜을 작성해야 한다:
|
Phase 4 완료로 프로덕션 안정성 확보. Phase 5에서는:
|
||||||
- CostGuard (일일/작업당 API 비용 제한)
|
- Oracle VM 배포 자동화 (Ansible/Docker Compose)
|
||||||
- 복구 메커니즘 (서버 재시작 시 미완료 작업 복구, 좀비 컨테이너 정리)
|
- 모니터링 대시보드 (Grafana + SQLite → metrics)
|
||||||
- 자동 머지 모드 (autonomous 설정 + E2E 통과 조건)
|
- 알림 고도화 (Gitea PR 코멘트에 비용/소요시간 포함)
|
||||||
- 구조화 로깅 (JSON 포맷), 작업 이력 DB, 스모크 테스트
|
- 멀티 리포 지원 (galaxis-po 외 다른 리포)
|
||||||
|
- 1주일 conservative 운영 후 autonomous 전환
|
||||||
|
|
||||||
### 실행 방법
|
### 실행 방법
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd ~/workspace/quant/galaxis-agent
|
cd ~/workspace/quant/galaxis-agent
|
||||||
git log --oneline
|
git log --oneline
|
||||||
uv run pytest tests/ -v # 107 테스트 통과 확인
|
uv run pytest tests/ -v # 139 테스트 통과 확인
|
||||||
```
|
```
|
||||||
|
|||||||
61
agent/auto_merge.py
Normal file
61
agent/auto_merge.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""E2E 조건부 자동 머지 로직.
|
||||||
|
|
||||||
|
autonomous 모드에서 조건을 검증하여 PR을 자동 머지한다.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def should_auto_merge(
|
||||||
|
auto_merge: bool, require_e2e: bool, max_files_changed: int,
|
||||||
|
blocked_paths: list[str], changed_files: list[str],
|
||||||
|
tests_passed: bool, e2e_passed: bool,
|
||||||
|
) -> bool:
|
||||||
|
if not auto_merge:
|
||||||
|
return False
|
||||||
|
if not tests_passed:
|
||||||
|
return False
|
||||||
|
if require_e2e and not e2e_passed:
|
||||||
|
return False
|
||||||
|
if len(changed_files) > max_files_changed:
|
||||||
|
return False
|
||||||
|
for f in changed_files:
|
||||||
|
for blocked in blocked_paths:
|
||||||
|
if f == blocked or f.endswith("/" + blocked):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class AutoMergeChecker:
|
||||||
|
def __init__(
|
||||||
|
self, auto_merge: bool = False, require_e2e: bool = False,
|
||||||
|
max_files_changed: int = 10, blocked_paths: list[str] | None = None,
|
||||||
|
):
|
||||||
|
self._auto_merge = auto_merge
|
||||||
|
self._require_e2e = require_e2e
|
||||||
|
self._max_files_changed = max_files_changed
|
||||||
|
self._blocked_paths = blocked_paths or []
|
||||||
|
|
||||||
|
async def try_merge(
|
||||||
|
self, owner: str, repo: str, pr_number: int,
|
||||||
|
changed_files: list[str], tests_passed: bool, e2e_passed: bool,
|
||||||
|
) -> dict:
|
||||||
|
can_merge = should_auto_merge(
|
||||||
|
auto_merge=self._auto_merge, require_e2e=self._require_e2e,
|
||||||
|
max_files_changed=self._max_files_changed, blocked_paths=self._blocked_paths,
|
||||||
|
changed_files=changed_files, tests_passed=tests_passed, e2e_passed=e2e_passed,
|
||||||
|
)
|
||||||
|
if not can_merge:
|
||||||
|
return {"merged": False, "reason": "conditions not met"}
|
||||||
|
try:
|
||||||
|
from agent.utils.gitea_client import get_gitea_client
|
||||||
|
client = get_gitea_client()
|
||||||
|
await client.merge_pull_request(owner=owner, repo=repo, pr_number=pr_number)
|
||||||
|
logger.info("Auto-merged PR #%d on %s/%s", pr_number, owner, repo)
|
||||||
|
return {"merged": True, "reason": "all conditions met"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Auto-merge failed for PR #%d", pr_number)
|
||||||
|
return {"merged": False, "reason": f"merge failed: {e}"}
|
||||||
@ -62,4 +62,7 @@ class Settings(BaseSettings):
|
|||||||
DAILY_COST_LIMIT_USD: float = 10.0
|
DAILY_COST_LIMIT_USD: float = 10.0
|
||||||
PER_TASK_COST_LIMIT_USD: float = 3.0
|
PER_TASK_COST_LIMIT_USD: float = 3.0
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
LOG_FORMAT: str = "json"
|
||||||
|
|
||||||
model_config = {"env_file": ".env", "extra": "ignore"}
|
model_config = {"env_file": ".env", "extra": "ignore"}
|
||||||
|
|||||||
125
agent/cost_guard.py
Normal file
125
agent/cost_guard.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
"""API 비용 추적 및 제한.
|
||||||
|
|
||||||
|
Anthropic API 응답의 usage 필드에서 토큰 비용을 추적하고,
|
||||||
|
일일/작업당 한도를 초과하면 작업을 차단한다.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone, date
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_CREATE_TABLE = """
|
||||||
|
CREATE TABLE IF NOT EXISTS cost_records (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
task_id TEXT NOT NULL,
|
||||||
|
tokens_input INTEGER NOT NULL,
|
||||||
|
tokens_output INTEGER NOT NULL,
|
||||||
|
cost_usd REAL NOT NULL,
|
||||||
|
recorded_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_INPUT_COST_PER_TOKEN = 3.0 / 1_000_000
|
||||||
|
DEFAULT_OUTPUT_COST_PER_TOKEN = 15.0 / 1_000_000
|
||||||
|
|
||||||
|
|
||||||
|
class CostGuard:
|
||||||
|
"""API 비용 추적 및 제한."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str = "/data/cost_guard.db",
|
||||||
|
daily_limit: float = 10.0,
|
||||||
|
per_task_limit: float = 3.0,
|
||||||
|
input_cost_per_token: float = DEFAULT_INPUT_COST_PER_TOKEN,
|
||||||
|
output_cost_per_token: float = DEFAULT_OUTPUT_COST_PER_TOKEN,
|
||||||
|
):
|
||||||
|
self._db_path = db_path
|
||||||
|
self._daily_limit = daily_limit
|
||||||
|
self._per_task_limit = per_task_limit
|
||||||
|
self._input_cost = input_cost_per_token
|
||||||
|
self._output_cost = output_cost_per_token
|
||||||
|
self._db: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
await self._db.execute(_CREATE_TABLE)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._db:
|
||||||
|
await self._db.close()
|
||||||
|
|
||||||
|
def calculate_cost(self, tokens_input: int, tokens_output: int) -> float:
|
||||||
|
return tokens_input * self._input_cost + tokens_output * self._output_cost
|
||||||
|
|
||||||
|
async def record_usage(self, task_id: str, tokens_input: int, tokens_output: int) -> float:
|
||||||
|
cost = self.calculate_cost(tokens_input, tokens_output)
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
await self._db.execute(
|
||||||
|
"INSERT INTO cost_records (task_id, tokens_input, tokens_output, cost_usd, recorded_at) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(task_id, tokens_input, tokens_output, cost, now),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("Recorded cost $%.4f for task %s (in=%d, out=%d)", cost, task_id, tokens_input, tokens_output)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
async def get_daily_cost(self) -> float:
|
||||||
|
today = date.today().isoformat()
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE recorded_at >= ?", (today,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
async def get_task_cost(self, task_id: str) -> float:
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE task_id = ?", (task_id,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
async def check_daily_limit(self) -> bool:
|
||||||
|
daily = await self.get_daily_cost()
|
||||||
|
return daily < self._daily_limit
|
||||||
|
|
||||||
|
async def check_task_limit(self, task_id: str) -> bool:
|
||||||
|
task_cost = await self.get_task_cost(task_id)
|
||||||
|
return task_cost < self._per_task_limit
|
||||||
|
|
||||||
|
async def get_daily_summary(self) -> dict:
|
||||||
|
today = date.today().isoformat()
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"SELECT COUNT(*), COALESCE(SUM(cost_usd), 0), COALESCE(SUM(tokens_input), 0), COALESCE(SUM(tokens_output), 0) FROM cost_records WHERE recorded_at >= ?",
|
||||||
|
(today,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
total_cost = row[1]
|
||||||
|
return {
|
||||||
|
"record_count": row[0],
|
||||||
|
"total_cost_usd": round(total_cost, 4),
|
||||||
|
"daily_limit_usd": self._daily_limit,
|
||||||
|
"remaining_usd": round(max(0, self._daily_limit - total_cost), 4),
|
||||||
|
"total_tokens_input": row[2],
|
||||||
|
"total_tokens_output": row[3],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_guard: CostGuard | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cost_guard() -> CostGuard:
|
||||||
|
global _guard
|
||||||
|
if _guard is None:
|
||||||
|
db_path = os.environ.get("COST_GUARD_DB", "/data/cost_guard.db")
|
||||||
|
daily_limit = float(os.environ.get("DAILY_COST_LIMIT_USD", "10.0"))
|
||||||
|
per_task_limit = float(os.environ.get("PER_TASK_COST_LIMIT_USD", "3.0"))
|
||||||
|
_guard = CostGuard(db_path=db_path, daily_limit=daily_limit, per_task_limit=per_task_limit)
|
||||||
|
await _guard.initialize()
|
||||||
|
return _guard
|
||||||
@ -8,6 +8,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agent.task_queue import PersistentTaskQueue
|
from agent.task_queue import PersistentTaskQueue
|
||||||
@ -22,9 +23,13 @@ class Dispatcher:
|
|||||||
self,
|
self,
|
||||||
task_queue: PersistentTaskQueue,
|
task_queue: PersistentTaskQueue,
|
||||||
poll_interval: float = 2.0,
|
poll_interval: float = 2.0,
|
||||||
|
cost_guard: "CostGuard | None" = None,
|
||||||
|
task_history: "TaskHistory | None" = None,
|
||||||
):
|
):
|
||||||
self._queue = task_queue
|
self._queue = task_queue
|
||||||
self._poll_interval = poll_interval
|
self._poll_interval = poll_interval
|
||||||
|
self._cost_guard = cost_guard
|
||||||
|
self._task_history = task_history
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task: asyncio.Task | None = None
|
self._task: asyncio.Task | None = None
|
||||||
|
|
||||||
@ -56,21 +61,83 @@ class Dispatcher:
|
|||||||
|
|
||||||
async def _poll_once(self) -> None:
|
async def _poll_once(self) -> None:
|
||||||
"""큐에서 작업을 하나 꺼내 처리한다."""
|
"""큐에서 작업을 하나 꺼내 처리한다."""
|
||||||
|
# Check daily limit before dequeuing
|
||||||
|
if self._cost_guard:
|
||||||
|
if not await self._cost_guard.check_daily_limit():
|
||||||
|
logger.warning("Daily cost limit exceeded, skipping task dequeue")
|
||||||
|
return
|
||||||
|
|
||||||
task = await self._queue.dequeue()
|
task = await self._queue.dequeue()
|
||||||
if not task:
|
if not task:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"])
|
logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"])
|
||||||
|
|
||||||
|
start_time = datetime.now(timezone.utc)
|
||||||
try:
|
try:
|
||||||
result = await self._run_agent_for_task(task)
|
result = await self._run_agent_for_task(task)
|
||||||
await self._queue.mark_completed(task["id"], result=result)
|
await self._queue.mark_completed(task["id"], result=result)
|
||||||
logger.info("Task %s completed successfully", task["id"])
|
logger.info("Task %s completed successfully", task["id"])
|
||||||
|
|
||||||
|
# Record cost and history after successful completion
|
||||||
|
tokens_input = result.get("tokens_input", 0)
|
||||||
|
tokens_output = result.get("tokens_output", 0)
|
||||||
|
|
||||||
|
if self._cost_guard:
|
||||||
|
await self._cost_guard.record_usage(
|
||||||
|
task["id"],
|
||||||
|
tokens_input=tokens_input,
|
||||||
|
tokens_output=tokens_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._task_history:
|
||||||
|
end_time = datetime.now(timezone.utc)
|
||||||
|
duration_seconds = (end_time - start_time).total_seconds()
|
||||||
|
cost_usd = self._cost_guard.calculate_cost(tokens_input, tokens_output) if self._cost_guard else 0.0
|
||||||
|
payload = task["payload"]
|
||||||
|
|
||||||
|
await self._task_history.record(
|
||||||
|
task_id=task["id"],
|
||||||
|
thread_id=task["thread_id"],
|
||||||
|
issue_number=payload.get("issue_number", 0),
|
||||||
|
repo_name=payload.get("repo_name", ""),
|
||||||
|
source=task["source"],
|
||||||
|
status="completed",
|
||||||
|
created_at=task["created_at"],
|
||||||
|
completed_at=end_time.isoformat(),
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
tokens_input=tokens_input,
|
||||||
|
tokens_output=tokens_output,
|
||||||
|
cost_usd=cost_usd,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Task %s failed", task["id"])
|
logger.exception("Task %s failed", task["id"])
|
||||||
await self._queue.mark_failed(task["id"], error=str(e))
|
await self._queue.mark_failed(task["id"], error=str(e))
|
||||||
await self._notify_failure(task, str(e))
|
await self._notify_failure(task, str(e))
|
||||||
|
|
||||||
|
# Record history after failure
|
||||||
|
if self._task_history:
|
||||||
|
end_time = datetime.now(timezone.utc)
|
||||||
|
duration_seconds = (end_time - start_time).total_seconds()
|
||||||
|
payload = task["payload"]
|
||||||
|
|
||||||
|
await self._task_history.record(
|
||||||
|
task_id=task["id"],
|
||||||
|
thread_id=task["thread_id"],
|
||||||
|
issue_number=payload.get("issue_number", 0),
|
||||||
|
repo_name=payload.get("repo_name", ""),
|
||||||
|
source=task["source"],
|
||||||
|
status="failed",
|
||||||
|
created_at=task["created_at"],
|
||||||
|
completed_at=end_time.isoformat(),
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
tokens_input=0,
|
||||||
|
tokens_output=0,
|
||||||
|
cost_usd=0.0,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
async def _run_agent_for_task(self, task: dict) -> dict[str, Any]:
|
async def _run_agent_for_task(self, task: dict) -> dict[str, Any]:
|
||||||
"""작업에 대해 에이전트를 실행한다."""
|
"""작업에 대해 에이전트를 실행한다."""
|
||||||
from agent.server import get_agent
|
from agent.server import get_agent
|
||||||
|
|||||||
55
agent/json_logging.py
Normal file
55
agent/json_logging.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""구조화 JSON 로깅.
|
||||||
|
|
||||||
|
LOG_FORMAT 환경변수로 json | text 선택 가능.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
_BUILTIN_ATTRS = {
|
||||||
|
"args", "created", "exc_info", "exc_text", "filename", "funcName",
|
||||||
|
"levelname", "levelno", "lineno", "module", "msecs", "message", "msg",
|
||||||
|
"name", "pathname", "process", "processName", "relativeCreated",
|
||||||
|
"stack_info", "thread", "threadName", "taskName",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class JsonFormatter(logging.Formatter):
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
log_data = {
|
||||||
|
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||||
|
"level": record.levelname,
|
||||||
|
"logger": record.name,
|
||||||
|
"message": record.getMessage(),
|
||||||
|
}
|
||||||
|
for key, value in record.__dict__.items():
|
||||||
|
if key not in _BUILTIN_ATTRS and not key.startswith("_"):
|
||||||
|
try:
|
||||||
|
json.dumps(value)
|
||||||
|
log_data[key] = value
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
log_data[key] = str(value)
|
||||||
|
if record.exc_info and record.exc_info[0] is not None:
|
||||||
|
log_data["exception"] = "".join(traceback.format_exception(*record.exc_info))
|
||||||
|
return json.dumps(log_data, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(
|
||||||
|
log_format: str = "json",
|
||||||
|
level: int = logging.INFO,
|
||||||
|
logger: logging.Logger | None = None,
|
||||||
|
) -> None:
|
||||||
|
target = logger or logging.getLogger()
|
||||||
|
if log_format == "json":
|
||||||
|
formatter = JsonFormatter()
|
||||||
|
else:
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||||
|
)
|
||||||
|
for handler in target.handlers:
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
target.setLevel(level)
|
||||||
90
agent/recovery.py
Normal file
90
agent/recovery.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""서버 시작 시 복구 + 좀비 컨테이너 정리."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from agent.task_queue import PersistentTaskQueue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def recover_on_startup(task_queue: PersistentTaskQueue) -> None:
|
||||||
|
reset_count = await task_queue.reset_running_to_pending()
|
||||||
|
if reset_count:
|
||||||
|
logger.info("Recovery: reset %d running task(s) to pending", reset_count)
|
||||||
|
await _cleanup_zombie_containers()
|
||||||
|
|
||||||
|
|
||||||
|
async def _cleanup_zombie_containers() -> int:
|
||||||
|
try:
|
||||||
|
import docker
|
||||||
|
client = docker.from_env()
|
||||||
|
containers = client.containers.list(
|
||||||
|
filters={"label": "galaxis-agent-sandbox"}, all=True,
|
||||||
|
)
|
||||||
|
cleaned = 0
|
||||||
|
for container in containers:
|
||||||
|
try:
|
||||||
|
container.stop(timeout=10)
|
||||||
|
container.remove(force=True)
|
||||||
|
cleaned += 1
|
||||||
|
logger.info("Recovery: removed zombie container %s", container.name)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Recovery: failed to remove container %s", container.name)
|
||||||
|
return cleaned
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Recovery: Docker not available, skipping container cleanup")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerCleaner:
|
||||||
|
def __init__(self, docker_client=None, max_age_seconds: int = 1200, interval_seconds: int = 1800):
|
||||||
|
self._docker = docker_client
|
||||||
|
self._max_age = max_age_seconds
|
||||||
|
self._interval = interval_seconds
|
||||||
|
self._running = False
|
||||||
|
self._task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self._running = True
|
||||||
|
self._task = asyncio.create_task(self._loop())
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self._running = False
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _loop(self) -> None:
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await self.cleanup_once()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("ContainerCleaner error")
|
||||||
|
await asyncio.sleep(self._interval)
|
||||||
|
|
||||||
|
async def cleanup_once(self) -> int:
|
||||||
|
if not self._docker:
|
||||||
|
return 0
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
containers = self._docker.containers.list(
|
||||||
|
filters={"label": "galaxis-agent-sandbox"}, all=True,
|
||||||
|
)
|
||||||
|
removed = 0
|
||||||
|
for container in containers:
|
||||||
|
created_str = container.attrs.get("Created", "")
|
||||||
|
try:
|
||||||
|
created = datetime.fromisoformat(created_str.replace("Z", "+00:00"))
|
||||||
|
age = (now - created).total_seconds()
|
||||||
|
if age > self._max_age:
|
||||||
|
container.stop(timeout=10)
|
||||||
|
container.remove(force=True)
|
||||||
|
removed += 1
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to check/remove container %s", getattr(container, "name", "unknown"))
|
||||||
|
return removed
|
||||||
84
agent/task_history.py
Normal file
84
agent/task_history.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
"""완료 작업 이력 DB.
|
||||||
|
|
||||||
|
작업의 비용, 소요시간, 토큰 사용량을 SQLite에 기록한다.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_CREATE_TABLE = """
|
||||||
|
CREATE TABLE IF NOT EXISTS task_history (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
task_id TEXT UNIQUE NOT NULL,
|
||||||
|
thread_id TEXT NOT NULL,
|
||||||
|
issue_number INTEGER NOT NULL DEFAULT 0,
|
||||||
|
repo_name TEXT NOT NULL DEFAULT '',
|
||||||
|
source TEXT NOT NULL DEFAULT '',
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
completed_at TEXT NOT NULL,
|
||||||
|
duration_seconds REAL NOT NULL DEFAULT 0,
|
||||||
|
tokens_input INTEGER NOT NULL DEFAULT 0,
|
||||||
|
tokens_output INTEGER NOT NULL DEFAULT 0,
|
||||||
|
cost_usd REAL NOT NULL DEFAULT 0,
|
||||||
|
error_message TEXT NOT NULL DEFAULT ''
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TaskHistory:
|
||||||
|
def __init__(self, db_path: str = "/data/task_history.db"):
|
||||||
|
self._db_path = db_path
|
||||||
|
self._db: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
self._db.row_factory = aiosqlite.Row
|
||||||
|
await self._db.execute(_CREATE_TABLE)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._db:
|
||||||
|
await self._db.close()
|
||||||
|
|
||||||
|
async def record(
|
||||||
|
self, task_id: str, thread_id: str, issue_number: int, repo_name: str,
|
||||||
|
source: str, status: str, created_at: str, completed_at: str,
|
||||||
|
duration_seconds: float, tokens_input: int, tokens_output: int,
|
||||||
|
cost_usd: float, error_message: str = "",
|
||||||
|
) -> None:
|
||||||
|
await self._db.execute(
|
||||||
|
"INSERT OR REPLACE INTO task_history "
|
||||||
|
"(task_id, thread_id, issue_number, repo_name, source, status, "
|
||||||
|
"created_at, completed_at, duration_seconds, tokens_input, tokens_output, "
|
||||||
|
"cost_usd, error_message) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(task_id, thread_id, issue_number, repo_name, source, status,
|
||||||
|
created_at, completed_at, duration_seconds, tokens_input, tokens_output,
|
||||||
|
cost_usd, error_message),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("Recorded history: task=%s status=%s cost=$%.4f", task_id, status, cost_usd)
|
||||||
|
|
||||||
|
async def get_recent(self, limit: int = 20) -> list[dict]:
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"SELECT * FROM task_history ORDER BY completed_at DESC LIMIT ?", (limit,),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [dict(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
_history: TaskHistory | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task_history() -> TaskHistory:
|
||||||
|
global _history
|
||||||
|
if _history is None:
|
||||||
|
db_path = os.environ.get("TASK_HISTORY_DB", "/data/task_history.db")
|
||||||
|
_history = TaskHistory(db_path=db_path)
|
||||||
|
await _history.initialize()
|
||||||
|
return _history
|
||||||
@ -133,6 +133,24 @@ class PersistentTaskQueue:
|
|||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
return row["cnt"] > 0
|
return row["cnt"] > 0
|
||||||
|
|
||||||
|
async def reset_running_to_pending(self) -> int:
|
||||||
|
"""running 상태 작업을 pending으로 리셋한다 (복구용).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
리셋된 작업 수.
|
||||||
|
"""
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"SELECT COUNT(*) as cnt FROM tasks WHERE status = 'running'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
count = row["cnt"]
|
||||||
|
if count:
|
||||||
|
await self._db.execute(
|
||||||
|
"UPDATE tasks SET status = 'pending', started_at = NULL WHERE status = 'running'"
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
# 지연 초기화 싱글턴
|
# 지연 초기화 싱글턴
|
||||||
_queue: PersistentTaskQueue | None = None
|
_queue: PersistentTaskQueue | None = None
|
||||||
|
|||||||
@ -24,14 +24,43 @@ async def lifespan(app: FastAPI):
|
|||||||
from agent.message_store import get_message_store
|
from agent.message_store import get_message_store
|
||||||
from agent.dispatcher import Dispatcher
|
from agent.dispatcher import Dispatcher
|
||||||
from agent.integrations.discord_handler import DiscordHandler
|
from agent.integrations.discord_handler import DiscordHandler
|
||||||
|
from agent.json_logging import setup_logging
|
||||||
|
from agent.recovery import recover_on_startup, ContainerCleaner
|
||||||
|
from agent.cost_guard import get_cost_guard
|
||||||
|
from agent.task_history import get_task_history
|
||||||
|
|
||||||
|
# 구조화 로깅 설정
|
||||||
|
setup_logging(log_format=os.environ.get("LOG_FORMAT", "json"))
|
||||||
|
|
||||||
task_queue = await get_task_queue()
|
task_queue = await get_task_queue()
|
||||||
message_store = await get_message_store()
|
message_store = await get_message_store()
|
||||||
|
|
||||||
dispatcher = Dispatcher(task_queue=task_queue)
|
# 서버 시작 시 복구
|
||||||
|
await recover_on_startup(task_queue)
|
||||||
|
|
||||||
|
# CostGuard + TaskHistory 초기화
|
||||||
|
cost_guard = await get_cost_guard()
|
||||||
|
task_history = await get_task_history()
|
||||||
|
|
||||||
|
# Dispatcher에 CostGuard + TaskHistory 주입
|
||||||
|
dispatcher = Dispatcher(task_queue=task_queue, cost_guard=cost_guard, task_history=task_history)
|
||||||
await dispatcher.start()
|
await dispatcher.start()
|
||||||
app.state.dispatcher = dispatcher
|
app.state.dispatcher = dispatcher
|
||||||
|
|
||||||
|
# ContainerCleaner 시작
|
||||||
|
container_cleaner = None
|
||||||
|
try:
|
||||||
|
import docker
|
||||||
|
docker_client = docker.from_env()
|
||||||
|
sandbox_timeout = int(os.environ.get("SANDBOX_TIMEOUT", "600"))
|
||||||
|
container_cleaner = ContainerCleaner(
|
||||||
|
docker_client=docker_client,
|
||||||
|
max_age_seconds=sandbox_timeout * 2,
|
||||||
|
)
|
||||||
|
await container_cleaner.start()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Docker not available, container cleanup disabled")
|
||||||
|
|
||||||
discord_token = os.environ.get("DISCORD_TOKEN", "")
|
discord_token = os.environ.get("DISCORD_TOKEN", "")
|
||||||
discord_handler = None
|
discord_handler = None
|
||||||
if discord_token:
|
if discord_token:
|
||||||
@ -43,8 +72,12 @@ async def lifespan(app: FastAPI):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
await dispatcher.stop()
|
await dispatcher.stop()
|
||||||
|
if container_cleaner:
|
||||||
|
await container_cleaner.stop()
|
||||||
if discord_handler:
|
if discord_handler:
|
||||||
await discord_handler.close()
|
await discord_handler.close()
|
||||||
|
await cost_guard.close()
|
||||||
|
await task_history.close()
|
||||||
await task_queue.close()
|
await task_queue.close()
|
||||||
await message_store.close()
|
await message_store.close()
|
||||||
logger.info("Application shutdown complete")
|
logger.info("Application shutdown complete")
|
||||||
@ -170,6 +203,14 @@ async def health_queue():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health/costs")
|
||||||
|
async def health_costs():
|
||||||
|
"""API 비용 현황을 반환한다."""
|
||||||
|
from agent.cost_guard import get_cost_guard
|
||||||
|
guard = await get_cost_guard()
|
||||||
|
return await guard.get_daily_summary()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/webhooks/gitea")
|
@app.post("/webhooks/gitea")
|
||||||
@limiter.limit("10/minute")
|
@limiter.limit("10/minute")
|
||||||
async def gitea_webhook(request: Request):
|
async def gitea_webhook(request: Request):
|
||||||
|
|||||||
70
tests/test_auto_merge.py
Normal file
70
tests/test_auto_merge.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from agent.auto_merge import should_auto_merge, AutoMergeChecker
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_not_merge_when_disabled():
|
||||||
|
result = should_auto_merge(
|
||||||
|
auto_merge=False, require_e2e=False, max_files_changed=10,
|
||||||
|
blocked_paths=[".env"], changed_files=["backend/app/main.py"],
|
||||||
|
tests_passed=True, e2e_passed=True,
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_merge_when_all_conditions_met():
|
||||||
|
result = should_auto_merge(
|
||||||
|
auto_merge=True, require_e2e=True, max_files_changed=10,
|
||||||
|
blocked_paths=[".env", "quant.md"],
|
||||||
|
changed_files=["backend/app/main.py", "backend/tests/test_main.py"],
|
||||||
|
tests_passed=True, e2e_passed=True,
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_not_merge_when_tests_fail():
|
||||||
|
result = should_auto_merge(
|
||||||
|
auto_merge=True, require_e2e=False, max_files_changed=10,
|
||||||
|
blocked_paths=[], changed_files=["a.py"],
|
||||||
|
tests_passed=False, e2e_passed=False,
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_not_merge_when_e2e_required_but_failed():
|
||||||
|
result = should_auto_merge(
|
||||||
|
auto_merge=True, require_e2e=True, max_files_changed=10,
|
||||||
|
blocked_paths=[], changed_files=["a.py"],
|
||||||
|
tests_passed=True, e2e_passed=False,
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_not_merge_when_too_many_files():
|
||||||
|
files = [f"file{i}.py" for i in range(15)]
|
||||||
|
result = should_auto_merge(
|
||||||
|
auto_merge=True, require_e2e=False, max_files_changed=10,
|
||||||
|
blocked_paths=[], changed_files=files,
|
||||||
|
tests_passed=True, e2e_passed=True,
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_not_merge_when_blocked_path_modified():
|
||||||
|
result = should_auto_merge(
|
||||||
|
auto_merge=True, require_e2e=False, max_files_changed=10,
|
||||||
|
blocked_paths=[".env", "quant.md"],
|
||||||
|
changed_files=["backend/app/main.py", ".env"],
|
||||||
|
tests_passed=True, e2e_passed=True,
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_merge_without_e2e_when_not_required():
|
||||||
|
result = should_auto_merge(
|
||||||
|
auto_merge=True, require_e2e=False, max_files_changed=10,
|
||||||
|
blocked_paths=[], changed_files=["a.py"],
|
||||||
|
tests_passed=True, e2e_passed=False,
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
91
tests/test_cost_guard.py
Normal file
91
tests/test_cost_guard.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from agent.cost_guard import CostGuard
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def cost_guard():
|
||||||
|
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
guard = CostGuard(db_path=db_path, daily_limit=10.0, per_task_limit=3.0)
|
||||||
|
await guard.initialize()
|
||||||
|
yield guard
|
||||||
|
await guard.close()
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_usage(cost_guard):
|
||||||
|
"""API 사용량을 기록한다."""
|
||||||
|
await cost_guard.record_usage(
|
||||||
|
task_id="task-1",
|
||||||
|
tokens_input=1000,
|
||||||
|
tokens_output=500,
|
||||||
|
)
|
||||||
|
daily = await cost_guard.get_daily_cost()
|
||||||
|
assert daily > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calculate_cost(cost_guard):
|
||||||
|
"""토큰에서 비용을 계산한다."""
|
||||||
|
cost = cost_guard.calculate_cost(tokens_input=1_000_000, tokens_output=1_000_000)
|
||||||
|
# input: $3/MTok + output: $15/MTok = $18
|
||||||
|
assert cost == pytest.approx(18.0, rel=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_daily_limit_ok(cost_guard):
|
||||||
|
"""일일 한도 내에서 True를 반환한다."""
|
||||||
|
result = await cost_guard.check_daily_limit()
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_daily_limit_exceeded(cost_guard):
|
||||||
|
"""일일 한도 초과 시 False를 반환한다."""
|
||||||
|
for i in range(5):
|
||||||
|
await cost_guard.record_usage(f"task-{i}", tokens_input=1_000_000, tokens_output=200_000)
|
||||||
|
result = await cost_guard.check_daily_limit()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_task_limit_ok(cost_guard):
|
||||||
|
"""작업당 한도 내에서 True를 반환한다."""
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=100, tokens_output=50)
|
||||||
|
result = await cost_guard.check_task_limit("task-1")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_task_limit_exceeded(cost_guard):
|
||||||
|
"""작업당 한도 초과 시 False를 반환한다."""
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=1_000_000, tokens_output=100_000)
|
||||||
|
result = await cost_guard.check_task_limit("task-1")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_task_cost(cost_guard):
|
||||||
|
"""특정 작업의 누적 비용을 반환한다."""
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=1000, tokens_output=500)
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=2000, tokens_output=1000)
|
||||||
|
cost = await cost_guard.get_task_cost("task-1")
|
||||||
|
assert cost > 0
|
||||||
|
other_cost = await cost_guard.get_task_cost("task-2")
|
||||||
|
assert other_cost == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_daily_summary(cost_guard):
|
||||||
|
"""일일 요약 정보를 반환한다."""
|
||||||
|
await cost_guard.record_usage("task-1", tokens_input=1000, tokens_output=500)
|
||||||
|
summary = await cost_guard.get_daily_summary()
|
||||||
|
assert "total_cost_usd" in summary
|
||||||
|
assert "daily_limit_usd" in summary
|
||||||
|
assert "remaining_usd" in summary
|
||||||
|
assert summary["record_count"] == 1
|
||||||
80
tests/test_dispatcher_cost.py
Normal file
80
tests/test_dispatcher_cost.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from agent.task_queue import PersistentTaskQueue
|
||||||
|
from agent.cost_guard import CostGuard
|
||||||
|
from agent.task_history import TaskHistory
|
||||||
|
from agent.dispatcher import Dispatcher
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def resources():
|
||||||
|
paths = []
|
||||||
|
for _ in range(3):
|
||||||
|
fd, p = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
paths.append(p)
|
||||||
|
|
||||||
|
queue = PersistentTaskQueue(db_path=paths[0])
|
||||||
|
await queue.initialize()
|
||||||
|
guard = CostGuard(db_path=paths[1], daily_limit=10.0, per_task_limit=3.0)
|
||||||
|
await guard.initialize()
|
||||||
|
history = TaskHistory(db_path=paths[2])
|
||||||
|
await history.initialize()
|
||||||
|
|
||||||
|
yield queue, guard, history
|
||||||
|
|
||||||
|
await queue.close()
|
||||||
|
await guard.close()
|
||||||
|
await history.close()
|
||||||
|
for p in paths:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatcher_records_cost(resources):
|
||||||
|
queue, guard, history = resources
|
||||||
|
|
||||||
|
await queue.enqueue("thread-1", "gitea", {
|
||||||
|
"issue_number": 42, "repo_owner": "quant",
|
||||||
|
"repo_name": "galaxis-po", "message": "Fix",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_run = AsyncMock(return_value={
|
||||||
|
"status": "completed", "tokens_input": 5000, "tokens_output": 2000,
|
||||||
|
})
|
||||||
|
|
||||||
|
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
|
||||||
|
dispatcher._run_agent_for_task = mock_run
|
||||||
|
|
||||||
|
await dispatcher._poll_once()
|
||||||
|
|
||||||
|
daily = await guard.get_daily_cost()
|
||||||
|
assert daily > 0
|
||||||
|
|
||||||
|
records = await history.get_recent()
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0]["status"] == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatcher_blocks_when_daily_limit_exceeded(resources):
|
||||||
|
queue, guard, history = resources
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
await guard.record_usage(f"prev-{i}", tokens_input=1_000_000, tokens_output=200_000)
|
||||||
|
|
||||||
|
await queue.enqueue("thread-1", "gitea", {"message": "Should be blocked"})
|
||||||
|
|
||||||
|
mock_run = AsyncMock()
|
||||||
|
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
|
||||||
|
dispatcher._run_agent_for_task = mock_run
|
||||||
|
|
||||||
|
await dispatcher._poll_once()
|
||||||
|
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
pending = await queue.get_pending()
|
||||||
|
assert len(pending) == 1
|
||||||
80
tests/test_json_logging.py
Normal file
80
tests/test_json_logging.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import io
|
||||||
|
|
||||||
|
from agent.json_logging import JsonFormatter, setup_logging
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_formatter_basic():
|
||||||
|
formatter = JsonFormatter()
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test", level=logging.INFO, pathname="test.py",
|
||||||
|
lineno=1, msg="테스트 메시지", args=(), exc_info=None,
|
||||||
|
)
|
||||||
|
output = formatter.format(record)
|
||||||
|
parsed = json.loads(output)
|
||||||
|
assert parsed["message"] == "테스트 메시지"
|
||||||
|
assert parsed["level"] == "INFO"
|
||||||
|
assert "timestamp" in parsed
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_formatter_with_extra():
|
||||||
|
formatter = JsonFormatter()
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test", level=logging.INFO, pathname="test.py",
|
||||||
|
lineno=1, msg="작업 시작", args=(), exc_info=None,
|
||||||
|
)
|
||||||
|
record.thread_id = "uuid-123"
|
||||||
|
record.issue = 42
|
||||||
|
output = formatter.format(record)
|
||||||
|
parsed = json.loads(output)
|
||||||
|
assert parsed["thread_id"] == "uuid-123"
|
||||||
|
assert parsed["issue"] == 42
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_formatter_with_exception():
|
||||||
|
formatter = JsonFormatter()
|
||||||
|
try:
|
||||||
|
raise ValueError("test error")
|
||||||
|
except ValueError:
|
||||||
|
import sys
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test", level=logging.ERROR, pathname="test.py",
|
||||||
|
lineno=1, msg="에러 발생", args=(), exc_info=sys.exc_info(),
|
||||||
|
)
|
||||||
|
output = formatter.format(record)
|
||||||
|
parsed = json.loads(output)
|
||||||
|
assert "exception" in parsed
|
||||||
|
assert "ValueError" in parsed["exception"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_logging_json():
|
||||||
|
test_logger = logging.getLogger("test_json_setup")
|
||||||
|
test_logger.handlers.clear()
|
||||||
|
test_logger.setLevel(logging.DEBUG)
|
||||||
|
stream = io.StringIO()
|
||||||
|
handler = logging.StreamHandler(stream)
|
||||||
|
test_logger.addHandler(handler)
|
||||||
|
setup_logging(log_format="json", logger=test_logger)
|
||||||
|
test_logger.info("hello")
|
||||||
|
output = stream.getvalue().strip()
|
||||||
|
parsed = json.loads(output)
|
||||||
|
assert parsed["message"] == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_logging_text():
|
||||||
|
test_logger = logging.getLogger("test_text_setup")
|
||||||
|
test_logger.handlers.clear()
|
||||||
|
test_logger.setLevel(logging.DEBUG)
|
||||||
|
stream = io.StringIO()
|
||||||
|
handler = logging.StreamHandler(stream)
|
||||||
|
test_logger.addHandler(handler)
|
||||||
|
setup_logging(log_format="text", logger=test_logger)
|
||||||
|
test_logger.info("hello")
|
||||||
|
output = stream.getvalue().strip()
|
||||||
|
assert "hello" in output
|
||||||
|
try:
|
||||||
|
json.loads(output)
|
||||||
|
assert False, "Should not be valid JSON"
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
79
tests/test_recovery.py
Normal file
79
tests/test_recovery.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from agent.task_queue import PersistentTaskQueue
|
||||||
|
from agent.recovery import recover_on_startup, ContainerCleaner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def task_queue():
|
||||||
|
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
queue = PersistentTaskQueue(db_path=db_path)
|
||||||
|
await queue.initialize()
|
||||||
|
yield queue
|
||||||
|
await queue.close()
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recover_resets_running_to_pending(task_queue):
|
||||||
|
await task_queue.enqueue("thread-1", "gitea", {"msg": "interrupted"})
|
||||||
|
await task_queue.dequeue() # → running
|
||||||
|
assert await task_queue.has_running_task("thread-1") is True
|
||||||
|
|
||||||
|
with patch("agent.recovery._cleanup_zombie_containers", new_callable=AsyncMock):
|
||||||
|
await recover_on_startup(task_queue)
|
||||||
|
|
||||||
|
assert await task_queue.has_running_task("thread-1") is False
|
||||||
|
pending = await task_queue.get_pending()
|
||||||
|
assert len(pending) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recover_no_running_tasks(task_queue):
|
||||||
|
with patch("agent.recovery._cleanup_zombie_containers", new_callable=AsyncMock):
|
||||||
|
await recover_on_startup(task_queue)
|
||||||
|
pending = await task_queue.get_pending()
|
||||||
|
assert len(pending) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_container_cleaner_removes_old():
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.name = "galaxis-sandbox-old"
|
||||||
|
mock_container.labels = {"galaxis-agent-sandbox": "true"}
|
||||||
|
mock_container.attrs = {"Created": "2026-03-19T00:00:00Z"}
|
||||||
|
mock_container.stop = MagicMock()
|
||||||
|
mock_container.remove = MagicMock()
|
||||||
|
|
||||||
|
mock_docker = MagicMock()
|
||||||
|
mock_docker.containers.list.return_value = [mock_container]
|
||||||
|
|
||||||
|
cleaner = ContainerCleaner(docker_client=mock_docker, max_age_seconds=600)
|
||||||
|
removed = await cleaner.cleanup_once()
|
||||||
|
|
||||||
|
assert removed == 1
|
||||||
|
mock_container.stop.assert_called_once()
|
||||||
|
mock_container.remove.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_container_cleaner_keeps_recent():
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.labels = {"galaxis-agent-sandbox": "true"}
|
||||||
|
mock_container.attrs = {"Created": now}
|
||||||
|
|
||||||
|
mock_docker = MagicMock()
|
||||||
|
mock_docker.containers.list.return_value = [mock_container]
|
||||||
|
|
||||||
|
cleaner = ContainerCleaner(docker_client=mock_docker, max_age_seconds=3600)
|
||||||
|
removed = await cleaner.cleanup_once()
|
||||||
|
|
||||||
|
assert removed == 0
|
||||||
|
mock_container.stop.assert_not_called()
|
||||||
44
tests/test_smoke.py
Normal file
44
tests/test_smoke.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import pytest
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def mock_lifespan(app):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smoke_health():
|
||||||
|
from agent.webapp import app
|
||||||
|
app.router.lifespan_context = mock_lifespan
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
resp = await client.get("/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smoke_health_costs():
|
||||||
|
from agent.webapp import app
|
||||||
|
app.router.lifespan_context = mock_lifespan
|
||||||
|
|
||||||
|
mock_guard = MagicMock()
|
||||||
|
mock_guard.get_daily_summary = AsyncMock(return_value={
|
||||||
|
"total_cost_usd": 1.5,
|
||||||
|
"daily_limit_usd": 10.0,
|
||||||
|
"remaining_usd": 8.5,
|
||||||
|
"record_count": 3,
|
||||||
|
"total_tokens_input": 50000,
|
||||||
|
"total_tokens_output": 20000,
|
||||||
|
})
|
||||||
|
|
||||||
|
with patch("agent.cost_guard.get_cost_guard", new_callable=AsyncMock, return_value=mock_guard):
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
resp = await client.get("/health/costs")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["total_cost_usd"] == 1.5
|
||||||
|
assert data["daily_limit_usd"] == 10.0
|
||||||
69
tests/test_task_history.py
Normal file
69
tests/test_task_history.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from agent.task_history import TaskHistory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def history():
|
||||||
|
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
h = TaskHistory(db_path=db_path)
|
||||||
|
await h.initialize()
|
||||||
|
yield h
|
||||||
|
await h.close()
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_completed(history):
|
||||||
|
await history.record(
|
||||||
|
task_id="task-1", thread_id="thread-1", issue_number=42,
|
||||||
|
repo_name="galaxis-po", source="gitea", status="completed",
|
||||||
|
created_at="2026-03-20T10:00:00Z", completed_at="2026-03-20T10:05:00Z",
|
||||||
|
duration_seconds=300.0, tokens_input=5000, tokens_output=2000, cost_usd=0.045,
|
||||||
|
)
|
||||||
|
records = await history.get_recent(limit=10)
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0]["task_id"] == "task-1"
|
||||||
|
assert records[0]["status"] == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_failed(history):
|
||||||
|
await history.record(
|
||||||
|
task_id="task-2", thread_id="thread-2", issue_number=10,
|
||||||
|
repo_name="galaxis-po", source="discord", status="failed",
|
||||||
|
created_at="2026-03-20T11:00:00Z", completed_at="2026-03-20T11:01:00Z",
|
||||||
|
duration_seconds=60.0, tokens_input=1000, tokens_output=500, cost_usd=0.01,
|
||||||
|
error_message="Agent crashed",
|
||||||
|
)
|
||||||
|
records = await history.get_recent(limit=10)
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0]["error_message"] == "Agent crashed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_recent_ordered(history):
|
||||||
|
await history.record(
|
||||||
|
task_id="task-1", thread_id="t1", issue_number=1, repo_name="r",
|
||||||
|
source="gitea", status="completed", created_at="2026-03-20T10:00:00Z",
|
||||||
|
completed_at="2026-03-20T10:05:00Z", duration_seconds=300,
|
||||||
|
tokens_input=100, tokens_output=50, cost_usd=0.001,
|
||||||
|
)
|
||||||
|
await history.record(
|
||||||
|
task_id="task-2", thread_id="t2", issue_number=2, repo_name="r",
|
||||||
|
source="gitea", status="completed", created_at="2026-03-20T11:00:00Z",
|
||||||
|
completed_at="2026-03-20T11:05:00Z", duration_seconds=300,
|
||||||
|
tokens_input=200, tokens_output=100, cost_usd=0.002,
|
||||||
|
)
|
||||||
|
records = await history.get_recent(limit=10)
|
||||||
|
assert len(records) == 2
|
||||||
|
assert records[0]["task_id"] == "task-2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_history(history):
|
||||||
|
records = await history.get_recent(limit=10)
|
||||||
|
assert records == []
|
||||||
Loading…
x
Reference in New Issue
Block a user