Compare commits
6 Commits
bb2a47157e
...
94edb45c86
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
94edb45c86 | ||
|
|
816415dd24 | ||
|
|
af7bd2cdc3 | ||
|
|
e8983d8534 | ||
|
|
b2ad726fc4 | ||
|
|
5d44c2e7e2 |
@ -1,15 +1,176 @@
|
||||
"""Docker container-based sandbox backend. Phase 2 implementation."""
|
||||
"""Docker container sandbox backend.
|
||||
execute() is synchronous. server.py calls it via loop.run_in_executor().
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import docker
|
||||
from deepagents.backends.protocol import ExecuteResponse, FileDownloadResponse, FileUploadResponse
|
||||
from deepagents.backends.sandbox import BaseSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DockerSandbox:
|
||||
async def execute(self, command: str, timeout: int = 300):
|
||||
raise NotImplementedError("Phase 2")
|
||||
class DockerSandbox(BaseSandbox):
|
||||
"""Docker container-based sandbox implementation.
|
||||
|
||||
async def read_file(self, path: str) -> str:
|
||||
raise NotImplementedError("Phase 2")
|
||||
Extends BaseSandbox, which auto-implements file I/O (read/write/ls/grep)
|
||||
by delegating to execute(). Only need to implement: id property, execute(),
|
||||
upload_files(), download_files(), and container lifecycle.
|
||||
"""
|
||||
|
||||
async def write_file(self, path: str, content: str) -> None:
|
||||
raise NotImplementedError("Phase 2")
|
||||
def __init__(
|
||||
self,
|
||||
container_id: str | None = None,
|
||||
*,
|
||||
image: str | None = None,
|
||||
network: str = "galaxis-net",
|
||||
mem_limit: str = "4g",
|
||||
cpu_count: int = 2,
|
||||
pids_limit: int = 256,
|
||||
environment: dict | None = None,
|
||||
default_timeout: int = 300,
|
||||
):
|
||||
self._docker = docker.DockerClient(
|
||||
base_url=os.environ.get("DOCKER_HOST", "unix:///var/run/docker.sock")
|
||||
)
|
||||
self._default_timeout = default_timeout
|
||||
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError("Phase 2")
|
||||
if container_id:
|
||||
# Connect to existing container
|
||||
self._container = self._docker.containers.get(container_id)
|
||||
else:
|
||||
# Create new container
|
||||
resolved_image = image or os.environ.get("SANDBOX_IMAGE", "galaxis-sandbox:latest")
|
||||
self._container = self._docker.containers.run(
|
||||
image=resolved_image,
|
||||
detach=True,
|
||||
network=network,
|
||||
mem_limit=mem_limit,
|
||||
cpu_count=cpu_count,
|
||||
pids_limit=pids_limit,
|
||||
environment=environment or {},
|
||||
labels={"galaxis-agent-sandbox": "true"},
|
||||
working_dir="/workspace",
|
||||
)
|
||||
self._id = self._container.id
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
def execute(
|
||||
self,
|
||||
command: str,
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
) -> ExecuteResponse:
|
||||
"""Execute a shell command in the container.
|
||||
|
||||
Synchronous method - server.py calls via loop.run_in_executor().
|
||||
"""
|
||||
effective_timeout = timeout if timeout is not None else self._default_timeout
|
||||
|
||||
# Wrap command with timeout if specified
|
||||
if effective_timeout and effective_timeout > 0:
|
||||
cmd = ["timeout", str(effective_timeout), "sh", "-c", command]
|
||||
else:
|
||||
cmd = ["sh", "-c", command]
|
||||
|
||||
# Execute command in container with demux=True to separate stdout/stderr
|
||||
result = self._container.exec_run(cmd=cmd, demux=True, workdir="/workspace")
|
||||
|
||||
# Decode output
|
||||
stdout = (result.output[0] or b"").decode("utf-8", errors="replace")
|
||||
stderr = (result.output[1] or b"").decode("utf-8", errors="replace")
|
||||
output = stdout + stderr
|
||||
|
||||
exit_code = result.exit_code
|
||||
|
||||
# Handle timeout exit code
|
||||
if exit_code == 124:
|
||||
output += f"\n[TIMEOUT] Command timed out after {effective_timeout}s"
|
||||
|
||||
return ExecuteResponse(output=output, exit_code=exit_code, truncated=False)
|
||||
|
||||
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||
"""Upload multiple files to the sandbox.
|
||||
|
||||
Supports partial success - returns errors per-file rather than raising.
|
||||
"""
|
||||
responses = []
|
||||
for path, content in files:
|
||||
try:
|
||||
# Use tar to upload file to container
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
# Create tar archive in memory
|
||||
tar_stream = io.BytesIO()
|
||||
tar = tarfile.open(fileobj=tar_stream, mode="w")
|
||||
|
||||
# Add file to archive
|
||||
tarinfo = tarfile.TarInfo(name=os.path.basename(path))
|
||||
tarinfo.size = len(content)
|
||||
tar.addfile(tarinfo, io.BytesIO(content))
|
||||
tar.close()
|
||||
|
||||
# Upload to container
|
||||
tar_stream.seek(0)
|
||||
self._container.put_archive(os.path.dirname(path) or "/workspace", tar_stream)
|
||||
|
||||
responses.append(FileUploadResponse(path=path, error=None))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to upload file: %s", path)
|
||||
responses.append(FileUploadResponse(path=path, error=str(e)))
|
||||
|
||||
return responses
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
"""Download multiple files from the sandbox.
|
||||
|
||||
Supports partial success - returns errors per-file rather than raising.
|
||||
"""
|
||||
responses = []
|
||||
for path in paths:
|
||||
try:
|
||||
# Get file from container as tar archive
|
||||
bits, stat = self._container.get_archive(path)
|
||||
|
||||
# Extract content from tar
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
tar_stream = io.BytesIO(b"".join(bits))
|
||||
tar = tarfile.open(fileobj=tar_stream)
|
||||
|
||||
# Get first file from archive
|
||||
member = tar.next()
|
||||
if member:
|
||||
f = tar.extractfile(member)
|
||||
if f:
|
||||
content = f.read()
|
||||
responses.append(FileDownloadResponse(path=path, content=content, error=None))
|
||||
else:
|
||||
responses.append(FileDownloadResponse(path=path, content=None, error="Not a file"))
|
||||
else:
|
||||
responses.append(FileDownloadResponse(path=path, content=None, error="Empty archive"))
|
||||
|
||||
tar.close()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to download file: %s", path)
|
||||
responses.append(FileDownloadResponse(path=path, content=None, error=str(e)))
|
||||
|
||||
return responses
|
||||
|
||||
def close(self):
|
||||
"""Stop and remove the container."""
|
||||
try:
|
||||
self._container.stop(timeout=10)
|
||||
self._container.remove(force=True)
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Failed to remove container: %s", self._id)
|
||||
|
||||
@ -16,6 +16,7 @@ from langchain.agents.middleware import AgentState, after_agent
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from ..utils.gitea_client import get_gitea_client
|
||||
from ..utils.git_utils import (
|
||||
git_add_all,
|
||||
git_checkout_branch,
|
||||
@ -135,8 +136,21 @@ async def open_pr_if_needed(
|
||||
git_push, sandbox_backend, repo_dir, target_branch, gitea_token
|
||||
)
|
||||
|
||||
# TODO: Phase 2 - use GiteaClient to create PR via Gitea API
|
||||
logger.info("Pushed to branch %s, PR creation pending Gitea integration", target_branch)
|
||||
# --- PR 생성 (GiteaClient) ---
|
||||
default_branch = os.environ.get("DEFAULT_BRANCH", "main")
|
||||
client = get_gitea_client()
|
||||
try:
|
||||
pr_result = await client.create_pull_request(
|
||||
owner=repo_owner,
|
||||
repo=repo_name,
|
||||
title=pr_title,
|
||||
head=target_branch,
|
||||
base=default_branch,
|
||||
body=pr_body,
|
||||
)
|
||||
logger.info("Safety net PR created: %s", pr_result.get("html_url"))
|
||||
except Exception:
|
||||
logger.exception("Safety net PR creation failed (changes were pushed)")
|
||||
|
||||
logger.info("After-agent middleware completed successfully")
|
||||
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from langgraph.config import get_config
|
||||
|
||||
from agent.utils.gitea_client import get_gitea_client
|
||||
|
||||
from ..utils.git_utils import (
|
||||
git_add_all,
|
||||
git_checkout_branch,
|
||||
@ -95,7 +98,6 @@ def commit_and_open_pr(
|
||||
"pr_url": None,
|
||||
}
|
||||
|
||||
import os
|
||||
gitea_token = os.environ.get("GITEA_TOKEN", "")
|
||||
if not gitea_token:
|
||||
logger.error("commit_and_open_pr missing Gitea token for thread %s", thread_id)
|
||||
@ -113,8 +115,39 @@ def commit_and_open_pr(
|
||||
"pr_url": None,
|
||||
}
|
||||
|
||||
# TODO: Phase 2 - use GiteaClient to create PR
|
||||
return {"success": True, "pr_url": "pending-gitea-implementation"}
|
||||
# --- PR 생성 (GiteaClient) ---
|
||||
gitea_external_url = os.environ.get("GITEA_EXTERNAL_URL", "")
|
||||
gitea_internal_url = os.environ.get("GITEA_URL", "http://gitea:3000")
|
||||
default_branch = os.environ.get("DEFAULT_BRANCH", "main")
|
||||
client = get_gitea_client()
|
||||
|
||||
try:
|
||||
pr_result = asyncio.run(
|
||||
client.create_pull_request(
|
||||
owner=repo_owner,
|
||||
repo=repo_name,
|
||||
title=title,
|
||||
head=target_branch,
|
||||
base=default_branch,
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
pr_url = pr_result.get("html_url", "")
|
||||
if gitea_external_url and pr_url:
|
||||
pr_url = pr_url.replace(gitea_internal_url, gitea_external_url)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"pr_url": pr_url,
|
||||
"pr_number": pr_result.get("number"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create PR (push succeeded)")
|
||||
return {
|
||||
"success": True,
|
||||
"pr_url": "",
|
||||
"error": f"Push succeeded but PR creation failed: {e}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("commit_and_open_pr failed")
|
||||
|
||||
@ -1,5 +1,32 @@
|
||||
"""Discord message tool. Phase 2 implementation."""
|
||||
"""Discord 채널/스레드 메시지 전송 도구."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agent.utils.discord_client import get_discord_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discord_reply(message: str) -> dict:
|
||||
raise NotImplementedError("Phase 2")
|
||||
def discord_reply(message: str) -> dict[str, Any]:
|
||||
if not message.strip():
|
||||
return {"success": False, "error": "빈 메시지는 전송할 수 없습니다."}
|
||||
|
||||
channel_id = os.environ.get("DISCORD_CHANNEL_ID", "")
|
||||
if not channel_id:
|
||||
return {"success": False, "error": "DISCORD_CHANNEL_ID가 설정되지 않았습니다."}
|
||||
|
||||
client = get_discord_client()
|
||||
|
||||
try:
|
||||
result = asyncio.run(
|
||||
client.send_message(channel_id=channel_id, content=message)
|
||||
)
|
||||
logger.info("Sent Discord message to channel %s", channel_id)
|
||||
return {"success": True, "message_id": result.get("id")}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to send Discord message")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@ -1,5 +1,40 @@
|
||||
"""Gitea issue/PR comment tool. Phase 2 implementation."""
|
||||
"""Gitea 이슈/PR 코멘트 작성 도구."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agent.utils.gitea_client import get_gitea_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gitea_comment(message: str, issue_number: int) -> dict:
|
||||
raise NotImplementedError("Phase 2")
|
||||
def _get_repo_info() -> tuple[str, str]:
|
||||
owner = os.environ.get("DEFAULT_REPO_OWNER", "quant")
|
||||
repo = os.environ.get("DEFAULT_REPO_NAME", "galaxis-po")
|
||||
return owner, repo
|
||||
|
||||
|
||||
def gitea_comment(message: str, issue_number: int) -> dict[str, Any]:
|
||||
if not issue_number or issue_number <= 0:
|
||||
return {"success": False, "error": "유효한 issue_number가 필요합니다."}
|
||||
|
||||
if not message.strip():
|
||||
return {"success": False, "error": "빈 메시지는 작성할 수 없습니다."}
|
||||
|
||||
owner, repo = _get_repo_info()
|
||||
client = get_gitea_client()
|
||||
|
||||
try:
|
||||
result = asyncio.run(
|
||||
client.create_issue_comment(
|
||||
owner=owner, repo=repo, issue_number=issue_number, body=message
|
||||
)
|
||||
)
|
||||
logger.info("Posted comment on %s/%s#%d", owner, repo, issue_number)
|
||||
return {"success": True, "comment_id": result.get("id")}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to post comment on %s/%s#%d", owner, repo, issue_number)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@ -32,3 +32,41 @@ async def read_agents_md_in_sandbox(
|
||||
content = result.output or ""
|
||||
content = content.strip()
|
||||
return content or None
|
||||
|
||||
|
||||
async def read_repo_instructions(
|
||||
sandbox_backend: SandboxBackendProtocol,
|
||||
repo_dir: str,
|
||||
) -> str:
|
||||
"""AGENTS.md와 CLAUDE.md를 모두 읽어서 프롬프트에 주입할 문자열을 반환한다."""
|
||||
sections = []
|
||||
|
||||
agents_md = await _read_file_if_exists(sandbox_backend, f"{repo_dir}/AGENTS.md")
|
||||
if agents_md:
|
||||
sections.append(f"## Repository Agent Rules\n{agents_md}")
|
||||
|
||||
claude_md = await _read_file_if_exists(sandbox_backend, f"{repo_dir}/CLAUDE.md")
|
||||
if claude_md:
|
||||
sections.append(f"## Project Conventions\n{claude_md}")
|
||||
|
||||
return "\n\n".join(sections)
|
||||
|
||||
|
||||
async def _read_file_if_exists(
|
||||
sandbox_backend: SandboxBackendProtocol,
|
||||
file_path: str,
|
||||
) -> str | None:
|
||||
"""파일이 존재하면 내용을 읽고, 없으면 None을 반환한다."""
|
||||
safe_path = shlex.quote(file_path)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
sandbox_backend.execute,
|
||||
f"test -f {safe_path} && cat {safe_path}",
|
||||
)
|
||||
if result.exit_code == 0 and result.output.strip():
|
||||
return result.output.strip()
|
||||
except Exception:
|
||||
logger.debug("Failed to read %s", file_path)
|
||||
return None
|
||||
|
||||
@ -1,9 +1,45 @@
|
||||
"""Discord bot integration. Phase 2 implementation."""
|
||||
"""Discord REST API 클라이언트."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
|
||||
|
||||
class DiscordClient:
|
||||
async def send_message(self, channel_id: str, content: str) -> dict:
|
||||
raise NotImplementedError("Phase 2")
|
||||
def __init__(self, token: str):
|
||||
self.token = token
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=DISCORD_API_BASE,
|
||||
headers={"Authorization": f"Bot {self.token}"},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
async def send_thread_reply(self, channel_id, thread_id, content) -> dict:
|
||||
raise NotImplementedError("Phase 2")
|
||||
async def send_message(self, channel_id: str, content: str) -> dict:
|
||||
resp = await self._client.post(
|
||||
f"/channels/{channel_id}/messages",
|
||||
json={"content": content},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def send_thread_reply(self, channel_id: str, thread_id: str, content: str) -> dict:
|
||||
return await self.send_message(channel_id=thread_id, content=content)
|
||||
|
||||
async def close(self):
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
_client: DiscordClient | None = None
|
||||
|
||||
|
||||
def get_discord_client() -> DiscordClient:
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = DiscordClient(token=os.environ.get("DISCORD_TOKEN", ""))
|
||||
return _client
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Gitea REST API v1 client. Phase 2 implementation."""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
|
||||
|
||||
@ -13,22 +14,133 @@ class GiteaClient:
|
||||
)
|
||||
|
||||
async def create_pull_request(self, owner, repo, title, head, base, body) -> dict:
|
||||
raise NotImplementedError("Phase 2")
|
||||
"""Create a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
title: PR title
|
||||
head: Head branch name
|
||||
base: Base branch name
|
||||
body: PR body/description
|
||||
|
||||
Returns:
|
||||
dict: Created PR data (number, html_url, etc.)
|
||||
"""
|
||||
resp = await self._client.post(
|
||||
f"/repos/{owner}/{repo}/pulls",
|
||||
json={"title": title, "head": head, "base": base, "body": body},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def merge_pull_request(self, owner, repo, pr_number, merge_type="merge") -> dict:
|
||||
raise NotImplementedError("Phase 2")
|
||||
"""Merge a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: PR number
|
||||
merge_type: Merge type ("merge", "rebase", "squash")
|
||||
|
||||
Returns:
|
||||
dict: Merge result
|
||||
"""
|
||||
resp = await self._client.post(
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
|
||||
json={"Do": merge_type},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def create_issue_comment(self, owner, repo, issue_number, body) -> dict:
|
||||
raise NotImplementedError("Phase 2")
|
||||
"""Create a comment on an issue or PR.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
issue_number: Issue or PR number
|
||||
body: Comment body
|
||||
|
||||
Returns:
|
||||
dict: Created comment data (id, body, etc.)
|
||||
"""
|
||||
resp = await self._client.post(
|
||||
f"/repos/{owner}/{repo}/issues/{issue_number}/comments",
|
||||
json={"body": body},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def get_issue(self, owner, repo, issue_number) -> dict:
|
||||
raise NotImplementedError("Phase 2")
|
||||
"""Get issue or PR details.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
issue_number: Issue or PR number
|
||||
|
||||
Returns:
|
||||
dict: Issue/PR data (number, title, body, etc.)
|
||||
"""
|
||||
resp = await self._client.get(f"/repos/{owner}/{repo}/issues/{issue_number}")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def get_issue_comments(self, owner, repo, issue_number) -> list:
|
||||
raise NotImplementedError("Phase 2")
|
||||
"""Get all comments on an issue or PR.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
issue_number: Issue or PR number
|
||||
|
||||
Returns:
|
||||
list: List of comment dicts
|
||||
"""
|
||||
resp = await self._client.get(
|
||||
f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def create_branch(self, owner, repo, branch_name, old_branch) -> dict:
|
||||
raise NotImplementedError("Phase 2")
|
||||
"""Create a new branch.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
branch_name: New branch name
|
||||
old_branch: Source branch name
|
||||
|
||||
Returns:
|
||||
dict: Created branch data
|
||||
"""
|
||||
resp = await self._client.post(
|
||||
f"/repos/{owner}/{repo}/branches",
|
||||
json={"new_branch_name": branch_name, "old_branch_name": old_branch},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def close(self):
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
# Lazy singleton
|
||||
_client: GiteaClient | None = None
|
||||
|
||||
|
||||
def get_gitea_client() -> GiteaClient:
|
||||
"""Get or create the singleton GiteaClient instance.
|
||||
|
||||
Returns:
|
||||
GiteaClient: The singleton instance
|
||||
"""
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = GiteaClient(
|
||||
base_url=os.environ.get("GITEA_URL", "http://gitea:3000"),
|
||||
token=os.environ.get("GITEA_TOKEN", ""),
|
||||
)
|
||||
return _client
|
||||
|
||||
40
agent/utils/path_validator.py
Normal file
40
agent/utils/path_validator.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""파일 경로 접근 제어."""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def validate_paths(
|
||||
changed_paths: list[str],
|
||||
writable: list[str],
|
||||
blocked: list[str],
|
||||
) -> list[str]:
|
||||
"""변경된 파일 경로들을 검증한다.
|
||||
|
||||
Args:
|
||||
changed_paths: 변경된 파일 경로 목록.
|
||||
writable: 쓰기 허용 경로 접두사 목록.
|
||||
blocked: 차단 경로 목록 (정확 일치 또는 접두사).
|
||||
|
||||
Returns:
|
||||
위반 사항 메시지 목록. 빈 리스트이면 모든 경로 통과.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
for path in changed_paths:
|
||||
normalized = path.lstrip("./")
|
||||
|
||||
blocked_match = False
|
||||
for b in blocked:
|
||||
b_normalized = b.lstrip("./")
|
||||
if normalized == b_normalized or normalized.startswith(b_normalized + "/"):
|
||||
errors.append(f"BLOCKED: '{path}' is in blocked_paths ({b})")
|
||||
blocked_match = True
|
||||
break
|
||||
|
||||
if blocked_match:
|
||||
continue
|
||||
|
||||
is_writable = any(normalized.startswith(w.lstrip("./")) for w in writable)
|
||||
if not is_writable:
|
||||
errors.append(f"NOT_WRITABLE: '{path}' is not in writable_paths")
|
||||
|
||||
return errors
|
||||
@ -1,5 +1,35 @@
|
||||
import os
|
||||
|
||||
from agent.integrations.docker_sandbox import DockerSandbox
|
||||
|
||||
|
||||
def create_sandbox(sandbox_id: str | None = None) -> DockerSandbox:
|
||||
return DockerSandbox() # Phase 2 implementation
|
||||
"""Factory function for creating DockerSandbox instances.
|
||||
|
||||
Args:
|
||||
sandbox_id: Optional container ID to connect to existing container.
|
||||
If None, creates a new container.
|
||||
|
||||
Returns:
|
||||
DockerSandbox instance configured from environment variables.
|
||||
"""
|
||||
# Build environment variables for the container
|
||||
env = {}
|
||||
test_db_url = os.environ.get("TEST_DATABASE_URL", "")
|
||||
if test_db_url:
|
||||
env["DATABASE_URL"] = test_db_url
|
||||
|
||||
# Connect to existing container if ID provided
|
||||
if sandbox_id:
|
||||
return DockerSandbox(container_id=sandbox_id)
|
||||
|
||||
# Create new container with environment configuration
|
||||
return DockerSandbox(
|
||||
image=os.environ.get("SANDBOX_IMAGE", "galaxis-sandbox:latest"),
|
||||
network=os.environ.get("SANDBOX_NETWORK", "galaxis-net"),
|
||||
mem_limit=os.environ.get("SANDBOX_MEM_LIMIT", "4g"),
|
||||
cpu_count=int(os.environ.get("SANDBOX_CPU_COUNT", "2")),
|
||||
pids_limit=int(os.environ.get("SANDBOX_PIDS_LIMIT", "256")),
|
||||
environment=env,
|
||||
default_timeout=int(os.environ.get("SANDBOX_TIMEOUT", "300")),
|
||||
)
|
||||
|
||||
@ -34,6 +34,21 @@ services:
|
||||
depends_on:
|
||||
- docker-socket-proxy
|
||||
|
||||
langgraph-server:
|
||||
image: langchain/langgraph-api:3.11
|
||||
environment:
|
||||
- DATABASE_URI=sqlite:///data/langgraph.db
|
||||
volumes:
|
||||
- langgraph-data:/data
|
||||
networks:
|
||||
- galaxis-net
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: "0.5"
|
||||
memory: 1G
|
||||
|
||||
networks:
|
||||
galaxis-net:
|
||||
external: true
|
||||
@ -42,3 +57,4 @@ volumes:
|
||||
uv-cache:
|
||||
npm-cache:
|
||||
agent-data:
|
||||
langgraph-data:
|
||||
|
||||
129
tests/test_commit_and_open_pr.py
Normal file
129
tests/test_commit_and_open_pr.py
Normal file
@ -0,0 +1,129 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def test_pr_creation_after_push():
|
||||
"""push 성공 후 GiteaClient로 PR을 생성한다."""
|
||||
mock_gitea = MagicMock()
|
||||
mock_gitea.create_pull_request = AsyncMock(
|
||||
return_value={
|
||||
"number": 1,
|
||||
"html_url": "http://gitea:3000/quant/galaxis-po/pulls/1",
|
||||
}
|
||||
)
|
||||
mock_sandbox = MagicMock()
|
||||
mock_result = MagicMock(exit_code=0, output="")
|
||||
|
||||
with patch(
|
||||
"agent.tools.commit_and_open_pr.get_gitea_client", return_value=mock_gitea
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.get_sandbox_backend_sync",
|
||||
return_value=mock_sandbox,
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.get_config",
|
||||
return_value={
|
||||
"configurable": {
|
||||
"thread_id": "test-thread",
|
||||
"repo": {"owner": "quant", "name": "galaxis-po"},
|
||||
}
|
||||
},
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.resolve_repo_dir",
|
||||
return_value="/workspace/galaxis-po",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_has_uncommitted_changes",
|
||||
return_value=True,
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_fetch_origin",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_has_unpushed_commits",
|
||||
return_value=False,
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_current_branch",
|
||||
return_value="galaxis-agent/test-thread",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_checkout_branch",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_config_user",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_add_all",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_commit",
|
||||
return_value=mock_result,
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_push",
|
||||
return_value=mock_result,
|
||||
), patch.dict(
|
||||
"os.environ", {"GITEA_TOKEN": "test-token"},
|
||||
):
|
||||
from agent.tools.commit_and_open_pr import commit_and_open_pr
|
||||
result = commit_and_open_pr(title="feat: add feature", body="PR description")
|
||||
assert result["success"] is True
|
||||
assert "pulls/1" in result["pr_url"]
|
||||
mock_gitea.create_pull_request.assert_called_once()
|
||||
|
||||
|
||||
def test_pr_creation_converts_internal_to_external_url():
|
||||
"""PR URL이 내부 URL에서 외부 URL로 변환된다."""
|
||||
mock_gitea = MagicMock()
|
||||
mock_gitea.create_pull_request = AsyncMock(
|
||||
return_value={
|
||||
"number": 5,
|
||||
"html_url": "http://gitea:3000/quant/galaxis-po/pulls/5",
|
||||
}
|
||||
)
|
||||
mock_sandbox = MagicMock()
|
||||
mock_result = MagicMock(exit_code=0, output="")
|
||||
|
||||
with patch(
|
||||
"agent.tools.commit_and_open_pr.get_gitea_client", return_value=mock_gitea
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.get_sandbox_backend_sync",
|
||||
return_value=mock_sandbox,
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.get_config",
|
||||
return_value={
|
||||
"configurable": {
|
||||
"thread_id": "test-thread",
|
||||
"repo": {"owner": "quant", "name": "galaxis-po"},
|
||||
}
|
||||
},
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.resolve_repo_dir",
|
||||
return_value="/workspace/galaxis-po",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_has_uncommitted_changes",
|
||||
return_value=True,
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_fetch_origin",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_has_unpushed_commits",
|
||||
return_value=False,
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_current_branch",
|
||||
return_value="galaxis-agent/test-thread",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_checkout_branch",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_config_user",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_add_all",
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_commit",
|
||||
return_value=mock_result,
|
||||
), patch(
|
||||
"agent.tools.commit_and_open_pr.git_push",
|
||||
return_value=mock_result,
|
||||
), patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"GITEA_TOKEN": "test-token",
|
||||
"GITEA_EXTERNAL_URL": "https://ayuriel.duckdns.org",
|
||||
"GITEA_URL": "http://gitea:3000",
|
||||
},
|
||||
):
|
||||
from agent.tools.commit_and_open_pr import commit_and_open_pr
|
||||
result = commit_and_open_pr(title="feat: test", body="body")
|
||||
assert result["success"] is True
|
||||
assert "ayuriel.duckdns.org" in result["pr_url"]
|
||||
assert "gitea:3000" not in result["pr_url"]
|
||||
48
tests/test_discord_reply.py
Normal file
48
tests/test_discord_reply.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def test_discord_reply_success():
|
||||
mock_client = MagicMock()
|
||||
mock_client.send_message = AsyncMock(
|
||||
return_value={"id": "123456", "content": "test message"}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.tools.discord_reply.get_discord_client", return_value=mock_client
|
||||
), patch.dict("os.environ", {"DISCORD_CHANNEL_ID": "999"}):
|
||||
from agent.tools.discord_reply import discord_reply
|
||||
result = discord_reply(message="test message")
|
||||
assert result["success"] is True
|
||||
mock_client.send_message.assert_called_once()
|
||||
|
||||
|
||||
def test_discord_reply_empty_message():
|
||||
from agent.tools.discord_reply import discord_reply
|
||||
result = discord_reply(message="")
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
def test_discord_reply_no_channel_configured():
|
||||
with patch.dict("os.environ", {"DISCORD_CHANNEL_ID": ""}, clear=False):
|
||||
from agent.tools.discord_reply import discord_reply
|
||||
result = discord_reply(message="test")
|
||||
assert result["success"] is False
|
||||
assert "DISCORD" in result.get("error", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_client_send_message():
|
||||
import httpx
|
||||
mock_resp = MagicMock(spec=httpx.Response)
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"id": "msg123", "content": "hello"}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
||||
from agent.utils.discord_client import DiscordClient
|
||||
client = DiscordClient(token="test-token")
|
||||
client._client.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
result = await client.send_message(channel_id="999", content="hello")
|
||||
assert result["id"] == "msg123"
|
||||
client._client.post.assert_called_once()
|
||||
81
tests/test_docker_sandbox.py
Normal file
81
tests/test_docker_sandbox.py
Normal file
@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_docker_client():
|
||||
client = MagicMock()
|
||||
container = MagicMock()
|
||||
container.id = "abc123def456"
|
||||
container.status = "running"
|
||||
client.containers.run.return_value = container
|
||||
client.containers.get.return_value = container
|
||||
return client, container
|
||||
|
||||
|
||||
def test_sandbox_create_new_container(mock_docker_client):
|
||||
client, container = mock_docker_client
|
||||
with patch("agent.integrations.docker_sandbox.docker.DockerClient", return_value=client):
|
||||
from agent.integrations.docker_sandbox import DockerSandbox
|
||||
sandbox = DockerSandbox()
|
||||
assert sandbox.id == "abc123def456"
|
||||
client.containers.run.assert_called_once()
|
||||
call_kwargs = client.containers.run.call_args.kwargs
|
||||
assert call_kwargs["detach"] is True
|
||||
assert call_kwargs["labels"] == {"galaxis-agent-sandbox": "true"}
|
||||
|
||||
|
||||
def test_sandbox_connect_existing_container(mock_docker_client):
|
||||
client, container = mock_docker_client
|
||||
with patch("agent.integrations.docker_sandbox.docker.DockerClient", return_value=client):
|
||||
from agent.integrations.docker_sandbox import DockerSandbox
|
||||
sandbox = DockerSandbox(container_id="abc123def456")
|
||||
assert sandbox.id == "abc123def456"
|
||||
client.containers.get.assert_called_once_with("abc123def456")
|
||||
client.containers.run.assert_not_called()
|
||||
|
||||
|
||||
def test_sandbox_execute_success(mock_docker_client):
|
||||
client, container = mock_docker_client
|
||||
container.exec_run.return_value = MagicMock(exit_code=0, output=(b"hello world\n", None))
|
||||
with patch("agent.integrations.docker_sandbox.docker.DockerClient", return_value=client):
|
||||
from agent.integrations.docker_sandbox import DockerSandbox
|
||||
sandbox = DockerSandbox(container_id="abc123def456")
|
||||
result = sandbox.execute("echo hello world")
|
||||
assert "hello world" in result.output
|
||||
assert result.exit_code == 0
|
||||
assert result.truncated is False
|
||||
|
||||
|
||||
def test_sandbox_execute_with_stderr(mock_docker_client):
|
||||
client, container = mock_docker_client
|
||||
container.exec_run.return_value = MagicMock(exit_code=1, output=(b"", b"error: not found\n"))
|
||||
with patch("agent.integrations.docker_sandbox.docker.DockerClient", return_value=client):
|
||||
from agent.integrations.docker_sandbox import DockerSandbox
|
||||
sandbox = DockerSandbox(container_id="abc123def456")
|
||||
result = sandbox.execute("cat missing.txt")
|
||||
assert "error: not found" in result.output
|
||||
assert result.exit_code == 1
|
||||
|
||||
|
||||
def test_sandbox_execute_with_timeout(mock_docker_client):
|
||||
client, container = mock_docker_client
|
||||
container.exec_run.return_value = MagicMock(exit_code=0, output=(b"ok\n", None))
|
||||
with patch("agent.integrations.docker_sandbox.docker.DockerClient", return_value=client):
|
||||
from agent.integrations.docker_sandbox import DockerSandbox
|
||||
sandbox = DockerSandbox(container_id="abc123def456")
|
||||
sandbox.execute("sleep 1", timeout=30)
|
||||
call_args = container.exec_run.call_args
|
||||
cmd = call_args.kwargs.get("cmd", call_args.args[0] if call_args.args else [])
|
||||
assert cmd[0] == "timeout"
|
||||
assert cmd[1] == "30"
|
||||
|
||||
|
||||
def test_sandbox_close(mock_docker_client):
|
||||
client, container = mock_docker_client
|
||||
with patch("agent.integrations.docker_sandbox.docker.DockerClient", return_value=client):
|
||||
from agent.integrations.docker_sandbox import DockerSandbox
|
||||
sandbox = DockerSandbox(container_id="abc123def456")
|
||||
sandbox.close()
|
||||
container.stop.assert_called_once()
|
||||
container.remove.assert_called_once()
|
||||
87
tests/test_gitea_client.py
Normal file
87
tests/test_gitea_client.py
Normal file
@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
import httpx
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from agent.utils.gitea_client import GiteaClient
|
||||
|
||||
@pytest.fixture
|
||||
def gitea_client():
|
||||
return GiteaClient(base_url="http://gitea:3000", token="test-token")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
def _make(status_code=200, json_data=None):
|
||||
resp = MagicMock(spec=httpx.Response)
|
||||
resp.status_code = status_code
|
||||
resp.json.return_value = json_data or {}
|
||||
resp.raise_for_status = MagicMock()
|
||||
if status_code >= 400:
|
||||
resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"error", request=MagicMock(), response=resp
|
||||
)
|
||||
return resp
|
||||
return _make
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pull_request(gitea_client, mock_response):
|
||||
pr_data = {"number": 1, "html_url": "http://gitea:3000/quant/galaxis-po/pulls/1"}
|
||||
gitea_client._client.post = AsyncMock(return_value=mock_response(201, pr_data))
|
||||
result = await gitea_client.create_pull_request(
|
||||
owner="quant", repo="galaxis-po", title="feat: add feature",
|
||||
head="galaxis-agent/abc123", base="main", body="PR body",
|
||||
)
|
||||
assert result["number"] == 1
|
||||
call_url = gitea_client._client.post.call_args[0][0]
|
||||
assert "/repos/quant/galaxis-po/pulls" in call_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_issue_comment(gitea_client, mock_response):
|
||||
comment_data = {"id": 42, "body": "작업을 시작합니다."}
|
||||
gitea_client._client.post = AsyncMock(return_value=mock_response(201, comment_data))
|
||||
result = await gitea_client.create_issue_comment(
|
||||
owner="quant", repo="galaxis-po", issue_number=1, body="작업을 시작합니다."
|
||||
)
|
||||
assert result["id"] == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue(gitea_client, mock_response):
|
||||
issue_data = {"number": 1, "title": "Fix bug", "body": "Bug description"}
|
||||
gitea_client._client.get = AsyncMock(return_value=mock_response(200, issue_data))
|
||||
result = await gitea_client.get_issue(owner="quant", repo="galaxis-po", issue_number=1)
|
||||
assert result["title"] == "Fix bug"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue_comments(gitea_client, mock_response):
|
||||
comments = [{"id": 1, "body": "comment1"}, {"id": 2, "body": "comment2"}]
|
||||
gitea_client._client.get = AsyncMock(return_value=mock_response(200, comments))
|
||||
result = await gitea_client.get_issue_comments(owner="quant", repo="galaxis-po", issue_number=1)
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pull_request(gitea_client, mock_response):
|
||||
gitea_client._client.post = AsyncMock(return_value=mock_response(200, {}))
|
||||
await gitea_client.merge_pull_request(owner="quant", repo="galaxis-po", pr_number=1, merge_type="merge")
|
||||
call_url = gitea_client._client.post.call_args[0][0]
|
||||
assert "/pulls/1/merge" in call_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch(gitea_client, mock_response):
|
||||
branch_data = {"name": "galaxis-agent/abc123"}
|
||||
gitea_client._client.post = AsyncMock(return_value=mock_response(201, branch_data))
|
||||
result = await gitea_client.create_branch(
|
||||
owner="quant", repo="galaxis-po", branch_name="galaxis-agent/abc123", old_branch="main",
|
||||
)
|
||||
assert result["name"] == "galaxis-agent/abc123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_raises_exception(gitea_client, mock_response):
|
||||
gitea_client._client.post = AsyncMock(return_value=mock_response(404))
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await gitea_client.create_pull_request(
|
||||
owner="quant", repo="galaxis-po", title="t", head="h", base="b", body=""
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_close(gitea_client):
|
||||
gitea_client._client.aclose = AsyncMock()
|
||||
await gitea_client.close()
|
||||
gitea_client._client.aclose.assert_called_once()
|
||||
48
tests/test_gitea_comment.py
Normal file
48
tests/test_gitea_comment.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
|
||||
def test_gitea_comment_success():
|
||||
mock_client = MagicMock()
|
||||
mock_client.create_issue_comment = AsyncMock(
|
||||
return_value={"id": 42, "body": "test comment"}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.tools.gitea_comment.get_gitea_client", return_value=mock_client
|
||||
), patch(
|
||||
"agent.tools.gitea_comment._get_repo_info",
|
||||
return_value=("quant", "galaxis-po"),
|
||||
):
|
||||
from agent.tools.gitea_comment import gitea_comment
|
||||
result = gitea_comment(message="test comment", issue_number=1)
|
||||
assert result["success"] is True
|
||||
assert result["comment_id"] == 42
|
||||
|
||||
|
||||
def test_gitea_comment_missing_issue_number():
|
||||
from agent.tools.gitea_comment import gitea_comment
|
||||
result = gitea_comment(message="test", issue_number=0)
|
||||
assert result["success"] is False
|
||||
assert "issue_number" in result["error"]
|
||||
|
||||
|
||||
def test_gitea_comment_api_error():
|
||||
import httpx
|
||||
mock_client = MagicMock()
|
||||
mock_client.create_issue_comment = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"404", request=MagicMock(), response=MagicMock(status_code=404)
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.tools.gitea_comment.get_gitea_client", return_value=mock_client
|
||||
), patch(
|
||||
"agent.tools.gitea_comment._get_repo_info",
|
||||
return_value=("quant", "galaxis-po"),
|
||||
):
|
||||
from agent.tools.gitea_comment import gitea_comment
|
||||
result = gitea_comment(message="test", issue_number=999)
|
||||
assert result["success"] is False
|
||||
assert "error" in result
|
||||
68
tests/test_path_validator.py
Normal file
68
tests/test_path_validator.py
Normal file
@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
from agent.utils.path_validator import validate_paths
|
||||
|
||||
|
||||
def test_valid_backend_path():
|
||||
errors = validate_paths(
|
||||
["backend/app/services/rebalance.py"],
|
||||
writable=["backend/app/", "backend/tests/"],
|
||||
blocked=[".env", "quant.md"],
|
||||
)
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_valid_multiple_paths():
|
||||
errors = validate_paths(
|
||||
[
|
||||
"backend/app/api/signal.py",
|
||||
"backend/tests/unit/test_signal.py",
|
||||
"frontend/src/app/page.tsx",
|
||||
"docs/README.md",
|
||||
],
|
||||
writable=["backend/app/", "backend/tests/", "frontend/src/", "docs/"],
|
||||
blocked=[".env"],
|
||||
)
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_blocked_path_rejected():
|
||||
errors = validate_paths(
|
||||
[".env", "backend/app/main.py"],
|
||||
writable=["backend/app/"],
|
||||
blocked=[".env", "quant.md"],
|
||||
)
|
||||
assert len(errors) == 1
|
||||
assert ".env" in errors[0]
|
||||
|
||||
|
||||
def test_non_writable_path_rejected():
|
||||
errors = validate_paths(
|
||||
["docker-compose.prod.yml"],
|
||||
writable=["backend/app/"],
|
||||
blocked=[],
|
||||
)
|
||||
assert len(errors) == 1
|
||||
|
||||
|
||||
def test_quant_md_blocked():
|
||||
errors = validate_paths(
|
||||
["quant.md"],
|
||||
writable=["backend/app/", "docs/"],
|
||||
blocked=["quant.md"],
|
||||
)
|
||||
assert len(errors) == 1
|
||||
assert "quant.md" in errors[0]
|
||||
|
||||
|
||||
def test_empty_paths_ok():
|
||||
errors = validate_paths([], writable=["backend/app/"], blocked=[".env"])
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_alembic_versions_writable():
|
||||
errors = validate_paths(
|
||||
["backend/alembic/versions/001_add_table.py"],
|
||||
writable=["backend/alembic/versions/"],
|
||||
blocked=[],
|
||||
)
|
||||
assert errors == []
|
||||
58
tests/test_prompt_loading.py
Normal file
58
tests/test_prompt_loading.py
Normal file
@ -0,0 +1,58 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeExecuteResponse:
|
||||
output: str
|
||||
exit_code: int = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sandbox():
|
||||
sandbox = MagicMock()
|
||||
return sandbox
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reads_agents_md_and_claude_md(mock_sandbox):
|
||||
def fake_execute(cmd, **kwargs):
|
||||
if "AGENTS.md" in cmd:
|
||||
return FakeExecuteResponse(output="# AGENTS.md\n## Rules\n- rule 1", exit_code=0)
|
||||
if "CLAUDE.md" in cmd:
|
||||
return FakeExecuteResponse(output="# CLAUDE.md\n## Overview\n- info 1", exit_code=0)
|
||||
return FakeExecuteResponse(output="", exit_code=1)
|
||||
|
||||
mock_sandbox.execute = MagicMock(side_effect=fake_execute)
|
||||
|
||||
from agent.utils.agents_md import read_repo_instructions
|
||||
result = await read_repo_instructions(mock_sandbox, "/workspace/galaxis-po")
|
||||
assert "Rules" in result
|
||||
assert "Overview" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agents_md_only(mock_sandbox):
|
||||
def fake_execute(cmd, **kwargs):
|
||||
if "AGENTS.md" in cmd:
|
||||
return FakeExecuteResponse(output="# AGENTS rules", exit_code=0)
|
||||
return FakeExecuteResponse(output="", exit_code=1)
|
||||
|
||||
mock_sandbox.execute = MagicMock(side_effect=fake_execute)
|
||||
|
||||
from agent.utils.agents_md import read_repo_instructions
|
||||
result = await read_repo_instructions(mock_sandbox, "/workspace/galaxis-po")
|
||||
assert result is not None
|
||||
assert "AGENTS" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_instruction_files(mock_sandbox):
|
||||
mock_sandbox.execute = MagicMock(
|
||||
return_value=FakeExecuteResponse(output="", exit_code=1)
|
||||
)
|
||||
|
||||
from agent.utils.agents_md import read_repo_instructions
|
||||
result = await read_repo_instructions(mock_sandbox, "/workspace/galaxis-po")
|
||||
assert result == ""
|
||||
Loading…
x
Reference in New Issue
Block a user