feat: add prompt loading pipeline and path validator
- Add read_repo_instructions() to read both AGENTS.md and CLAUDE.md - Add path_validator.validate_paths() for writable/blocked path enforcement - Add 10 passing tests (test_prompt_loading.py, test_path_validator.py) - All 71 tests pass
This commit is contained in:
parent
e8983d8534
commit
af7bd2cdc3
@ -32,3 +32,41 @@ async def read_agents_md_in_sandbox(
|
|||||||
content = result.output or ""
|
content = result.output or ""
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
return content or None
|
return content or None
|
||||||
|
|
||||||
|
|
||||||
|
async def read_repo_instructions(
|
||||||
|
sandbox_backend: SandboxBackendProtocol,
|
||||||
|
repo_dir: str,
|
||||||
|
) -> str:
|
||||||
|
"""AGENTS.md와 CLAUDE.md를 모두 읽어서 프롬프트에 주입할 문자열을 반환한다."""
|
||||||
|
sections = []
|
||||||
|
|
||||||
|
agents_md = await _read_file_if_exists(sandbox_backend, f"{repo_dir}/AGENTS.md")
|
||||||
|
if agents_md:
|
||||||
|
sections.append(f"## Repository Agent Rules\n{agents_md}")
|
||||||
|
|
||||||
|
claude_md = await _read_file_if_exists(sandbox_backend, f"{repo_dir}/CLAUDE.md")
|
||||||
|
if claude_md:
|
||||||
|
sections.append(f"## Project Conventions\n{claude_md}")
|
||||||
|
|
||||||
|
return "\n\n".join(sections)
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_file_if_exists(
|
||||||
|
sandbox_backend: SandboxBackendProtocol,
|
||||||
|
file_path: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""파일이 존재하면 내용을 읽고, 없으면 None을 반환한다."""
|
||||||
|
safe_path = shlex.quote(file_path)
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
sandbox_backend.execute,
|
||||||
|
f"test -f {safe_path} && cat {safe_path}",
|
||||||
|
)
|
||||||
|
if result.exit_code == 0 and result.output.strip():
|
||||||
|
return result.output.strip()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to read %s", file_path)
|
||||||
|
return None
|
||||||
|
|||||||
40
agent/utils/path_validator.py
Normal file
40
agent/utils/path_validator.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
"""파일 경로 접근 제어."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def validate_paths(
|
||||||
|
changed_paths: list[str],
|
||||||
|
writable: list[str],
|
||||||
|
blocked: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""변경된 파일 경로들을 검증한다.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
changed_paths: 변경된 파일 경로 목록.
|
||||||
|
writable: 쓰기 허용 경로 접두사 목록.
|
||||||
|
blocked: 차단 경로 목록 (정확 일치 또는 접두사).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
위반 사항 메시지 목록. 빈 리스트이면 모든 경로 통과.
|
||||||
|
"""
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for path in changed_paths:
|
||||||
|
normalized = path.lstrip("./")
|
||||||
|
|
||||||
|
blocked_match = False
|
||||||
|
for b in blocked:
|
||||||
|
b_normalized = b.lstrip("./")
|
||||||
|
if normalized == b_normalized or normalized.startswith(b_normalized + "/"):
|
||||||
|
errors.append(f"BLOCKED: '{path}' is in blocked_paths ({b})")
|
||||||
|
blocked_match = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if blocked_match:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_writable = any(normalized.startswith(w.lstrip("./")) for w in writable)
|
||||||
|
if not is_writable:
|
||||||
|
errors.append(f"NOT_WRITABLE: '{path}' is not in writable_paths")
|
||||||
|
|
||||||
|
return errors
|
||||||
68
tests/test_path_validator.py
Normal file
68
tests/test_path_validator.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import pytest
|
||||||
|
from agent.utils.path_validator import validate_paths
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_backend_path():
|
||||||
|
errors = validate_paths(
|
||||||
|
["backend/app/services/rebalance.py"],
|
||||||
|
writable=["backend/app/", "backend/tests/"],
|
||||||
|
blocked=[".env", "quant.md"],
|
||||||
|
)
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_multiple_paths():
|
||||||
|
errors = validate_paths(
|
||||||
|
[
|
||||||
|
"backend/app/api/signal.py",
|
||||||
|
"backend/tests/unit/test_signal.py",
|
||||||
|
"frontend/src/app/page.tsx",
|
||||||
|
"docs/README.md",
|
||||||
|
],
|
||||||
|
writable=["backend/app/", "backend/tests/", "frontend/src/", "docs/"],
|
||||||
|
blocked=[".env"],
|
||||||
|
)
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_blocked_path_rejected():
|
||||||
|
errors = validate_paths(
|
||||||
|
[".env", "backend/app/main.py"],
|
||||||
|
writable=["backend/app/"],
|
||||||
|
blocked=[".env", "quant.md"],
|
||||||
|
)
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert ".env" in errors[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_writable_path_rejected():
|
||||||
|
errors = validate_paths(
|
||||||
|
["docker-compose.prod.yml"],
|
||||||
|
writable=["backend/app/"],
|
||||||
|
blocked=[],
|
||||||
|
)
|
||||||
|
assert len(errors) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_quant_md_blocked():
|
||||||
|
errors = validate_paths(
|
||||||
|
["quant.md"],
|
||||||
|
writable=["backend/app/", "docs/"],
|
||||||
|
blocked=["quant.md"],
|
||||||
|
)
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "quant.md" in errors[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_paths_ok():
|
||||||
|
errors = validate_paths([], writable=["backend/app/"], blocked=[".env"])
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_alembic_versions_writable():
|
||||||
|
errors = validate_paths(
|
||||||
|
["backend/alembic/versions/001_add_table.py"],
|
||||||
|
writable=["backend/alembic/versions/"],
|
||||||
|
blocked=[],
|
||||||
|
)
|
||||||
|
assert errors == []
|
||||||
58
tests/test_prompt_loading.py
Normal file
58
tests/test_prompt_loading.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeExecuteResponse:
|
||||||
|
output: str
|
||||||
|
exit_code: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_sandbox():
|
||||||
|
sandbox = MagicMock()
|
||||||
|
return sandbox
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reads_agents_md_and_claude_md(mock_sandbox):
|
||||||
|
def fake_execute(cmd, **kwargs):
|
||||||
|
if "AGENTS.md" in cmd:
|
||||||
|
return FakeExecuteResponse(output="# AGENTS.md\n## Rules\n- rule 1", exit_code=0)
|
||||||
|
if "CLAUDE.md" in cmd:
|
||||||
|
return FakeExecuteResponse(output="# CLAUDE.md\n## Overview\n- info 1", exit_code=0)
|
||||||
|
return FakeExecuteResponse(output="", exit_code=1)
|
||||||
|
|
||||||
|
mock_sandbox.execute = MagicMock(side_effect=fake_execute)
|
||||||
|
|
||||||
|
from agent.utils.agents_md import read_repo_instructions
|
||||||
|
result = await read_repo_instructions(mock_sandbox, "/workspace/galaxis-po")
|
||||||
|
assert "Rules" in result
|
||||||
|
assert "Overview" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agents_md_only(mock_sandbox):
|
||||||
|
def fake_execute(cmd, **kwargs):
|
||||||
|
if "AGENTS.md" in cmd:
|
||||||
|
return FakeExecuteResponse(output="# AGENTS rules", exit_code=0)
|
||||||
|
return FakeExecuteResponse(output="", exit_code=1)
|
||||||
|
|
||||||
|
mock_sandbox.execute = MagicMock(side_effect=fake_execute)
|
||||||
|
|
||||||
|
from agent.utils.agents_md import read_repo_instructions
|
||||||
|
result = await read_repo_instructions(mock_sandbox, "/workspace/galaxis-po")
|
||||||
|
assert result is not None
|
||||||
|
assert "AGENTS" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_instruction_files(mock_sandbox):
|
||||||
|
mock_sandbox.execute = MagicMock(
|
||||||
|
return_value=FakeExecuteResponse(output="", exit_code=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
from agent.utils.agents_md import read_repo_instructions
|
||||||
|
result = await read_repo_instructions(mock_sandbox, "/workspace/galaxis-po")
|
||||||
|
assert result == ""
|
||||||
Loading…
x
Reference in New Issue
Block a user