diff --git a/STATUS.md b/STATUS.md new file mode 100644 index 00000000000..f30249b6ca0 --- /dev/null +++ b/STATUS.md @@ -0,0 +1,135 @@ +# OpenClaw (RA 2.0) — Project Status & Installation + +## Overview + +OpenClaw is a **local-first, multi-channel AI assistant platform** that runs on +your own devices. It bridges 13+ messaging channels (WhatsApp, Telegram, Slack, +Discord, Signal, iMessage, Teams, Matrix, …) to LLM providers (Claude, GPT, +Gemini) with a single always-on gateway. + +**Version**: 2026.2.19 +**License**: MIT +**Runtime**: Node.js ≥ 22.12.0, pnpm 10+ + +--- + +## Current Status + +| Area | State | Notes | +|------|-------|-------| +| Gateway / core | Stable | WebSocket control plane on `ws://127.0.0.1:18789` | +| Messaging channels | 13+ integrations | WhatsApp, Telegram, Slack, Discord, Signal, Teams, Matrix, etc. | +| Native apps | macOS, iOS, Android | Menu-bar daemon + mobile clients | +| Voice | Active | ElevenLabs TTS + speech-to-text | +| Browser automation | Active | Playwright-based Chrome control | +| Skills | 52 bundled | GitHub, email, coding-agent, canvas, 1Password, … | +| Extensions | 37 modules | BlueBubbles, Zalo, Google Gemini CLI auth, … | +| `ra2/` context layer | Phase 1 | Python — context engine, ledger, sigil, redact, token gate | +| Test coverage | ~1 176 test files | 70 % line/function threshold; 55 % branch threshold | + +### Recent fixes (`ra2/`) + +- **Blocker list limit** — `context_engine.py` was slicing blockers by + `token_gate.MAX_TOKENS` (6 000) instead of `ledger.MAX_BLOCKERS` (10). + Fixed to use the correct constant. +- **Redact-before-compress** — `build_context` now runs + `redact.redact_messages()` *before* `_run_compression`, preventing raw + credentials from being persisted to ledger/sigil files on disk. The old flow + only redacted the final assembled prompt, leaving at-rest secret leakage via + the compression pass. +- **Ledger JSON resilience** — `ledger.load()` now catches + `JSONDecodeError`/`ValueError` and falls back to an empty ledger (matching + `sigil.load`'s existing pattern), so a corrupted file no longer permanently + breaks `build_context` for that stream. + +--- + +## Installation + +### Quick install (npm) + +```bash +npm install -g openclaw@latest # or: pnpm add -g openclaw@latest +openclaw onboard --install-daemon +``` + +### From source (development) + +```bash +git clone https://github.com/openclaw/openclaw.git +cd openclaw + +pnpm install +pnpm ui:build # builds the web UI (auto-installs UI deps) +pnpm build # compiles TypeScript → dist/ + +openclaw onboard --install-daemon +``` + +### Development mode + +```bash +pnpm dev # run via tsx (no build step) +pnpm gateway:watch # auto-reload on file changes +``` + +--- + +## Running Tests + +```bash +# Full suite +pnpm test + +# Subsets +pnpm test:fast # unit tests only +pnpm test:e2e # end-to-end +pnpm test:live # live model tests +pnpm test:coverage # unit + coverage report + +# Python (ra2 module) +cd ra2 && pytest tests/ +``` + +--- + +## Configuration + +Set environment variables in `.env` or `~/.openclaw/.env`: + +| Variable | Purpose | +|----------|---------| +| `ANTHROPIC_API_KEY` | Claude access | +| `OPENAI_API_KEY` | GPT access | +| `GEMINI_API_KEY` | Gemini access | +| `TELEGRAM_BOT_TOKEN` | Telegram channel | +| `DISCORD_BOT_TOKEN` | Discord channel | +| `SLACK_BOT_TOKEN` | Slack channel | +| `OPENCLAW_GATEWAY_TOKEN` | Gateway auth | + +Config lives in `~/.openclaw/openclaw.json` (or `OPENCLAW_CONFIG_PATH`). + +--- + +## Project Structure (abridged) + +``` +src/ TypeScript core (agents, channels, gateway, CLI, plugins) +extensions/ 37 extension packages +skills/ 52 bundled skills +ra2/ Python context-sovereignty layer +apps/ Native apps (macOS / iOS / Android) +ui/ Web dashboard + WebChat +docs/ Comprehensive documentation +test/ Integration & e2e tests +``` + +--- + +## Pre-PR Checklist + +```bash +pnpm check # format + lint + type-check +pnpm test # full test suite +pnpm build # production build +``` diff --git a/ra2/__init__.py b/ra2/__init__.py new file mode 100644 index 00000000000..0b1260348f6 --- /dev/null +++ b/ra2/__init__.py @@ -0,0 +1,21 @@ +""" +ra2 — Context Sovereignty Layer (Phase 1) + +Deterministic thin wrapper that: + - Prevents full markdown history injection into prompts + - Introduces structured ledger memory + - Introduces sigil shorthand memory + - Enforces hard token caps before provider calls + - Redacts secrets before logs and model calls + +Usage: + from ra2.context_engine import build_context + + result = build_context(stream_id="my-stream", new_messages=[...]) + prompt = result["prompt"] + tokens = result["token_estimate"] +""" + +from ra2.context_engine import build_context + +__all__ = ["build_context"] diff --git a/ra2/context_engine.py b/ra2/context_engine.py new file mode 100644 index 00000000000..1b59ad7443a --- /dev/null +++ b/ra2/context_engine.py @@ -0,0 +1,205 @@ +""" +ra2.context_engine — The single choke point for all model calls. + +All prompts must pass through build_context() before reaching any provider. + +Internal flow: + 1. Redact secrets from incoming messages + 2. Run rule-based compression pass (writes redacted data to ledger/sigils) + 3. Determine live window from redacted messages + 4. Assemble structured prompt + 5. Estimate token count + 6. If > MAX_TOKENS: shrink live window, reassemble + 7. If still > MAX_TOKENS: raise controlled exception + +Never reads full .md history. +""" + +import re +from typing import List, Optional + +from ra2 import ledger, sigil, token_gate, redact + +# ── Compression rule patterns ─────────────────────────────────────── + +_DECISION_RE = re.compile( + r"(?:we\s+will|we\s+chose|decided\s+to|going\s+to|let'?s)\s+(.{10,120})", + re.IGNORECASE, +) +_ARCHITECTURE_RE = re.compile( + r"(?:architect(?:ure)?|refactor|redesign|restructur|migrat)\w*\s+(.{10,120})", + re.IGNORECASE, +) +_COST_RE = re.compile( + r"(?:budget|cost|spend|rate[_\s]*limit|token[_\s]*cap|pricing)\s*[:=→]?\s*(.{5,120})", + re.IGNORECASE, +) +_BLOCKER_RE = re.compile( + r"(?:block(?:er|ed|ing)|stuck|cannot|can'?t\s+proceed|waiting\s+on)\s+(.{5,120})", + re.IGNORECASE, +) +_QUESTION_RE = re.compile( + r"(?:should\s+we|do\s+we|how\s+(?:do|should)|what\s+(?:if|about)|need\s+to\s+decide)\s+(.{5,120})", + re.IGNORECASE, +) + + +def _extract_content(msg: dict) -> str: + """Get text content from a message dict.""" + content = msg.get("content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + # Handle structured content blocks + parts = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + parts.append(block.get("text", "")) + elif isinstance(block, str): + parts.append(block) + return " ".join(parts) + return str(content) + + +def _run_compression(messages: list, stream_id: str) -> None: + """Rule-based compression pass over recent messages. + + Extracts decisions, architecture shifts, cost constraints, blockers, + and open questions — then updates the ledger accordingly. + """ + decisions: list[str] = [] + blockers: list[str] = [] + open_questions: list[str] = [] + latest_summary_parts: list[str] = [] + + for msg in messages: + text = _extract_content(msg) + if not text: + continue + + # Decisions + for m in _DECISION_RE.finditer(text): + decisions.append(m.group(1).strip()) + + # Architecture shifts + for m in _ARCHITECTURE_RE.finditer(text): + latest_summary_parts.append(f"arch: {m.group(1).strip()}") + + # Cost/budget + for m in _COST_RE.finditer(text): + latest_summary_parts.append(f"cost: {m.group(1).strip()}") + + # Blockers + for m in _BLOCKER_RE.finditer(text): + blockers.append(m.group(1).strip()) + + # Open questions + for m in _QUESTION_RE.finditer(text): + open_questions.append(m.group(1).strip()) + + # Sigil event generation + sigil_triple = sigil.generate_from_message(text) + if sigil_triple: + op, constraint, decision = sigil_triple + sigil.append_event(stream_id, op, constraint, decision) + + # Build delta from decisions + delta = "; ".join(decisions[-5:]) if decisions else "" + latest = "; ".join(latest_summary_parts[-5:]) if latest_summary_parts else "" + + # Update ledger (only non-empty fields) + updates = {} + if delta: + updates["delta"] = delta + if latest: + updates["latest"] = latest + if blockers: + updates["blockers"] = blockers[-ledger.MAX_BLOCKERS:] # bounded + if open_questions: + updates["open"] = open_questions[-10:] + + if updates: + ledger.update(stream_id, **updates) + + +def _assemble_prompt(stream_id: str, live_messages: list) -> str: + """Build the structured prompt from ledger + (optional sigil) + live window.""" + sections = [] + + # Sigil section — only when DEBUG_SIGIL is enabled + if sigil.DEBUG_SIGIL: + sigil_snap = sigil.snapshot(stream_id) + if sigil_snap != "(no sigils)": + sections.append( + f"=== INTERNAL SIGIL SNAPSHOT ===\n{sigil_snap}" + ) + + # Ledger section + ledger_snap = ledger.snapshot(stream_id) + sections.append(f"=== LEDGER ===\n{ledger_snap}") + + # Live window section + live_lines = [] + for msg in live_messages: + role = msg.get("role", "unknown") + content = _extract_content(msg) + live_lines.append(f"[{role}] {content}") + sections.append("=== LIVE WINDOW ===\n" + "\n".join(live_lines)) + + # Closing directive + sections.append("Respond concisely and aligned with orientation.") + + return "\n\n".join(sections) + + +def build_context(stream_id: str, new_messages: list) -> dict: + """Main entry point — the single choke point for all model calls. + + Args: + stream_id: Unique identifier for the conversation stream. + new_messages: List of message dicts with at minimum 'role' and 'content'. + + Returns: + { + "prompt": str, # The assembled, redacted prompt + "token_estimate": int # Estimated token count + } + + Raises: + token_gate.TokenBudgetExceeded: If prompt exceeds MAX_TOKENS + even after shrinking the live window to minimum. + """ + # 1. Redact secrets before any disk-persisting step (ledger/sigil writes) + safe_messages = redact.redact_messages(new_messages) + + # 2. Run compression pass on redacted messages → updates ledger + sigils + _run_compression(safe_messages, stream_id) + + # 3. Determine live window (from already-redacted messages) + window_size = token_gate.LIVE_WINDOW + live_messages = safe_messages[-window_size:] + + # 4. Assemble prompt + prompt = _assemble_prompt(stream_id, live_messages) + + # 5. Estimate tokens + estimated = token_gate.estimate_tokens(prompt) + + # 6. Shrink loop if over budget + while not token_gate.check_budget(estimated): + try: + window_size = token_gate.shrink_window(window_size) + except token_gate.TokenBudgetExceeded: + # Already at minimum window — hard fail + raise token_gate.TokenBudgetExceeded( + estimated=estimated, + limit=token_gate.MAX_TOKENS, + ) + live_messages = safe_messages[-window_size:] + prompt = _assemble_prompt(stream_id, live_messages) + estimated = token_gate.estimate_tokens(prompt) + + return { + "prompt": prompt, + "token_estimate": estimated, + } diff --git a/ra2/ledger.py b/ra2/ledger.py new file mode 100644 index 00000000000..36baa5ec131 --- /dev/null +++ b/ra2/ledger.py @@ -0,0 +1,116 @@ +""" +ra2.ledger — Structured ledger memory (one per stream). + +Each stream gets a JSON ledger file with bounded fields. +Fields are overwritten (never appended unbounded). +Only updated via the compression pass. +""" + +import json +import os +from typing import Optional + +# Configurable storage root +LEDGER_DIR: str = os.environ.get( + "RA2_LEDGER_DIR", + os.path.join(os.path.expanduser("~"), ".ra2", "ledgers"), +) + +# Hard limits +MAX_BLOCKERS = 10 +MAX_OPEN = 10 +MAX_FIELD_CHARS = 500 # per string field + +_EMPTY_LEDGER = { + "stream": "", + "orientation": "", + "latest": "", + "blockers": [], + "open": [], + "delta": "", +} + + +def _ledger_path(stream_id: str) -> str: + return os.path.join(LEDGER_DIR, f"{stream_id}.json") + + +def load(stream_id: str) -> dict: + """Load ledger for *stream_id*, returning empty template if none exists.""" + path = _ledger_path(stream_id) + if not os.path.exists(path): + ledger = dict(_EMPTY_LEDGER) + ledger["stream"] = stream_id + return ledger + with open(path, "r", encoding="utf-8") as f: + try: + data = json.load(f) + except (json.JSONDecodeError, ValueError): + ledger = dict(_EMPTY_LEDGER) + ledger["stream"] = stream_id + return ledger + # Ensure all expected keys exist + for key, default in _EMPTY_LEDGER.items(): + if key not in data: + data[key] = default if not isinstance(default, list) else list(default) + return data + + +def save(stream_id: str, ledger: dict) -> None: + """Persist ledger to disk, enforcing size limits.""" + ledger = _enforce_limits(ledger) + os.makedirs(LEDGER_DIR, exist_ok=True) + path = _ledger_path(stream_id) + with open(path, "w", encoding="utf-8") as f: + json.dump(ledger, f, indent=2, ensure_ascii=False) + + +def update(stream_id: str, **fields) -> dict: + """Load, merge fields, save, and return the updated ledger. + + Only known keys are accepted. Unknown keys are silently dropped. + """ + ledger = load(stream_id) + for key, value in fields.items(): + if key in _EMPTY_LEDGER: + ledger[key] = value + save(stream_id, ledger) + return ledger + + +def snapshot(stream_id: str) -> str: + """Return a human-readable snapshot string for prompt injection.""" + ledger = load(stream_id) + lines = [] + lines.append(f"stream: {ledger['stream']}") + lines.append(f"orientation: {ledger['orientation']}") + lines.append(f"latest: {ledger['latest']}") + if ledger["blockers"]: + lines.append("blockers:") + for b in ledger["blockers"]: + lines.append(f" - {b}") + if ledger["open"]: + lines.append("open:") + for o in ledger["open"]: + lines.append(f" - {o}") + if ledger["delta"]: + lines.append(f"delta: {ledger['delta']}") + return "\n".join(lines) + + +def _enforce_limits(ledger: dict) -> dict: + """Truncate fields and lists to hard limits.""" + for key in ("orientation", "latest", "delta", "stream"): + if isinstance(ledger.get(key), str) and len(ledger[key]) > MAX_FIELD_CHARS: + ledger[key] = ledger[key][:MAX_FIELD_CHARS] + if isinstance(ledger.get("blockers"), list): + ledger["blockers"] = [ + b[:MAX_FIELD_CHARS] if isinstance(b, str) else b + for b in ledger["blockers"][:MAX_BLOCKERS] + ] + if isinstance(ledger.get("open"), list): + ledger["open"] = [ + o[:MAX_FIELD_CHARS] if isinstance(o, str) else o + for o in ledger["open"][:MAX_OPEN] + ] + return ledger diff --git a/ra2/redact.py b/ra2/redact.py new file mode 100644 index 00000000000..93532639528 --- /dev/null +++ b/ra2/redact.py @@ -0,0 +1,88 @@ +""" +ra2.redact — Secret redaction before logging, .md writes, and model calls. + +Detects common API key patterns and replaces them with [REDACTED_SECRET]. +Must be applied before any external output path. +""" + +import re +from typing import List, Tuple + +REDACTED = "[REDACTED_SECRET]" + +# Each entry: (label, compiled regex) +_PATTERNS: List[Tuple[str, re.Pattern]] = [ + # Discord bot tokens (base64-ish, three dot-separated segments) + ("discord_token", re.compile( + r"[MN][A-Za-z0-9]{23,}\.[A-Za-z0-9_-]{6}\.[A-Za-z0-9_-]{27,}" + )), + # OpenAI keys + ("openai_key", re.compile(r"sk-[A-Za-z0-9_-]{20,}")), + # Anthropic keys + ("anthropic_key", re.compile(r"sk-ant-[A-Za-z0-9_-]{20,}")), + # Google / GCP API keys + ("google_key", re.compile(r"AIza[A-Za-z0-9_-]{35}")), + # AWS access key IDs + ("aws_access_key", re.compile(r"AKIA[A-Z0-9]{16}")), + # Generic long hex/base64 secrets (40+ chars, likely tokens) + ("generic_secret", re.compile( + r"(?:api[_-]?key|secret|token|password|credential)" + r"[\s]*[:=][\s]*['\"]?([A-Za-z0-9_/+=-]{32,})['\"]?", + re.IGNORECASE, + )), + # Bearer tokens in auth headers + ("bearer_token", re.compile( + r"Bearer\s+[A-Za-z0-9_.+/=-]{20,}", re.IGNORECASE + )), + # Slack tokens + ("slack_token", re.compile(r"xox[bpas]-[A-Za-z0-9-]{10,}")), + # GitHub tokens + ("github_token", re.compile(r"gh[ps]_[A-Za-z0-9]{36,}")), + # Telegram bot tokens + ("telegram_token", re.compile(r"\d{8,10}:[A-Za-z0-9_-]{35}")), +] + + +def redact(text: str) -> str: + """Replace all detected secret patterns in *text* with [REDACTED_SECRET].""" + for _label, pattern in _PATTERNS: + # For the generic_secret pattern that uses a capture group, + # replace only the captured secret value. + if _label == "generic_secret": + text = pattern.sub(_replace_generic, text) + else: + text = pattern.sub(REDACTED, text) + return text + + +def _replace_generic(match: re.Match) -> str: + """Replace only the secret value inside a key=value match.""" + full = match.group(0) + secret = match.group(1) + return full.replace(secret, REDACTED) + + +def redact_dict(d: dict) -> dict: + """Recursively redact all string values in a dict.""" + out = {} + for k, v in d.items(): + if isinstance(v, str): + out[k] = redact(v) + elif isinstance(v, dict): + out[k] = redact_dict(v) + elif isinstance(v, list): + out[k] = [redact(i) if isinstance(i, str) else i for i in v] + else: + out[k] = v + return out + + +def redact_messages(messages: list) -> list: + """Redact secrets from a list of message dicts (content field).""" + result = [] + for msg in messages: + copy = dict(msg) + if isinstance(copy.get("content"), str): + copy["content"] = redact(copy["content"]) + result.append(copy) + return result diff --git a/ra2/sigil.py b/ra2/sigil.py new file mode 100644 index 00000000000..850c4fece58 --- /dev/null +++ b/ra2/sigil.py @@ -0,0 +1,245 @@ +""" +ra2.sigil — Layered internal state map stored as JSON (one file per stream). + +Two layers: + EVENT — decision causality log [{operator, constraint, decision, timestamp}] + STATE — authoritative snapshot {arch, risk, mode} + +Deterministic. Bounded. Internal-only (hidden unless DEBUG_SIGIL=true). +No AI generation. No semantic expansion. No prose. +""" + +import json +import os +import re +from datetime import datetime, timezone +from typing import Dict, List, Optional, Tuple + +SIGIL_DIR: str = os.environ.get( + "RA2_SIGIL_DIR", + os.path.join(os.path.expanduser("~"), ".ra2", "sigils"), +) + +DEBUG_SIGIL: bool = os.environ.get("DEBUG_SIGIL", "false").lower() == "true" + +MAX_EVENT_ENTRIES = 15 +MAX_FIELD_CHARS = 64 +MAX_FILE_BYTES = int(os.environ.get("RA2_SIGIL_MAX_BYTES", "8192")) + +_SNAKE_RE = re.compile(r"^[a-z][a-z0-9_]*$") + + +# ── Schema ────────────────────────────────────────────────────────── + +def _empty_state() -> dict: + """Return the canonical empty sigil document.""" + return { + "event": [], + "state": { + "arch": { + "wrapper": "", + "compression": "", + "agents": "", + "router": "", + }, + "risk": { + "token_pressure": "", + "cooldown": "", + "scope_creep": "", + }, + "mode": { + "determinism": "", + "rewrite_mode": "", + "debug": False, + }, + }, + } + + +def _validate_snake(value: str) -> str: + """Validate and truncate a snake_case string field.""" + value = value.strip()[:MAX_FIELD_CHARS] + return value + + +def _validate_event(event: dict) -> bool: + """Return True if an event dict has all required keys with valid values.""" + for key in ("operator", "constraint", "decision"): + val = event.get(key) + if not isinstance(val, str) or not val: + return False + if len(val) > MAX_FIELD_CHARS: + return False + return "timestamp" in event + + +# ── File I/O ──────────────────────────────────────────────────────── + +def _sigil_path(stream_id: str) -> str: + return os.path.join(SIGIL_DIR, f"{stream_id}.json") + + +def load(stream_id: str) -> dict: + """Load the JSON sigil state for a stream.""" + path = _sigil_path(stream_id) + if not os.path.exists(path): + return _empty_state() + with open(path, "r", encoding="utf-8") as f: + try: + data = json.load(f) + except (json.JSONDecodeError, ValueError): + return _empty_state() + + # Ensure structural integrity — fill missing keys from template + template = _empty_state() + if not isinstance(data.get("event"), list): + data["event"] = template["event"] + if not isinstance(data.get("state"), dict): + data["state"] = template["state"] + for section in ("arch", "risk", "mode"): + if not isinstance(data["state"].get(section), dict): + data["state"][section] = template["state"][section] + return data + + +def save(stream_id: str, state: dict) -> None: + """Atomically persist the JSON sigil state to disk. + + Enforces EVENT cap, field lengths, and total file size. + """ + # FIFO trim events + events = state.get("event", [])[-MAX_EVENT_ENTRIES:] + state["event"] = events + + os.makedirs(SIGIL_DIR, exist_ok=True) + path = _sigil_path(stream_id) + + content = json.dumps(state, indent=2, ensure_ascii=False) + + # Enforce total file size — trim oldest events until it fits + while len(content.encode("utf-8")) > MAX_FILE_BYTES and state["event"]: + state["event"].pop(0) + content = json.dumps(state, indent=2, ensure_ascii=False) + + # Atomic write: write to temp then rename + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + f.write(content) + os.replace(tmp_path, path) + + +# ── Mutation helpers ──────────────────────────────────────────────── + +def _now_iso() -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +def append_event(stream_id: str, operator: str, constraint: str, + decision: str) -> dict: + """Add an event triple. Deduplicates and FIFO-trims. + + Rejects fields longer than MAX_FIELD_CHARS. + """ + operator = _validate_snake(operator) + constraint = _validate_snake(constraint) + decision = _validate_snake(decision) + + if not operator or not constraint or not decision: + return load(stream_id) + + state = load(stream_id) + + # Dedup on (operator, constraint, decision) + triple = (operator, constraint, decision) + for existing in state["event"]: + if (existing["operator"], existing["constraint"], + existing["decision"]) == triple: + return state + + event = { + "operator": operator, + "constraint": constraint, + "decision": decision, + "timestamp": _now_iso(), + } + + state["event"].append(event) + state["event"] = state["event"][-MAX_EVENT_ENTRIES:] + + save(stream_id, state) + return state + + +def update_state(stream_id: str, + arch: Optional[Dict[str, str]] = None, + risk: Optional[Dict[str, str]] = None, + mode: Optional[dict] = None) -> dict: + """Overwrite STATE sections. STATE is authoritative snapshot.""" + state = load(stream_id) + if arch is not None: + state["state"]["arch"] = arch + if risk is not None: + state["state"]["risk"] = risk + if mode is not None: + state["state"]["mode"] = mode + save(stream_id, state) + return state + + +# ── Snapshot ──────────────────────────────────────────────────────── + +def snapshot(stream_id: str) -> str: + """Return compacted JSON string for debug prompt injection. + + Only meaningful when DEBUG_SIGIL is true. + """ + state = load(stream_id) + if not state["event"] and not any( + v for v in state["state"]["arch"].values() if v + ): + return "(no sigils)" + return json.dumps(state, indent=2, ensure_ascii=False) + + +# ── Deterministic event generators ───────────────────────────────── + +# Each rule: (regex, (operator, constraint, decision)) +# The decision field may use {0} for first capture group. +_EVENT_RULES: List[Tuple[re.Pattern, Tuple[str, str, str]]] = [ + (re.compile(r"fork(?:ed|ing)?\s*(?:to|into|\u2192)\s*(\S+)", re.I), + ("fork", "architectural_scope", "{0}")), + (re.compile(r"token[_\s]*burn", re.I), + ("token_burn", "context_overflow", "compress_first")), + (re.compile(r"rewrite[_\s]*impulse", re.I), + ("rewrite_impulse", "determinism_requirement", "layering_not_rewrite")), + (re.compile(r"context[_\s]*sov(?:ereignty)?", re.I), + ("context_sov", "sovereignty_active", "enforce")), + (re.compile(r"budget[_\s]*cap(?:ped)?", re.I), + ("budget_cap", "cost_constraint", "enforce_limit")), + (re.compile(r"rate[_\s]*limit", re.I), + ("rate_limit", "cooldown_active", "fallback_model")), + (re.compile(r"provider[_\s]*switch(?:ed)?", re.I), + ("provider_switch", "availability", "route_alternate")), + (re.compile(r"compaction[_\s]*trigger", re.I), + ("compaction", "history_overflow", "compact_now")), + (re.compile(r"thin[_\s]*wrapper", re.I), + ("fork", "architectural_scope", "thin_wrapper")), + (re.compile(r"rule[_\s]*based[_\s]*compress", re.I), + ("compression", "method_selection", "rule_based_v1")), +] + + +def generate_from_message(content: str) -> Optional[Tuple[str, str, str]]: + """Apply deterministic rules to message content. + + Returns (operator, constraint, decision) triple or None. + """ + for pattern, (op, constraint, decision) in _EVENT_RULES: + m = pattern.search(content) + if m: + try: + filled_decision = decision.format(*m.groups()) + except (IndexError, KeyError): + filled_decision = decision + return (op, constraint, filled_decision) + return None diff --git a/ra2/tests/__init__.py b/ra2/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ra2/tests/test_context_engine.py b/ra2/tests/test_context_engine.py new file mode 100644 index 00000000000..8417d01aecd --- /dev/null +++ b/ra2/tests/test_context_engine.py @@ -0,0 +1,178 @@ +"""Tests for ra2.context_engine""" + +import json +import pytest +from ra2 import ledger, sigil, token_gate +from ra2.context_engine import build_context + + +@pytest.fixture(autouse=True) +def tmp_storage(monkeypatch, tmp_path): + """Redirect all storage to temp directories.""" + monkeypatch.setattr(ledger, "LEDGER_DIR", str(tmp_path / "ledgers")) + monkeypatch.setattr(sigil, "SIGIL_DIR", str(tmp_path / "sigils")) + # Default: sigil hidden from prompt + monkeypatch.setattr(sigil, "DEBUG_SIGIL", False) + + +class TestBuildContext: + def test_basic_output_shape(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + result = build_context("test-stream", messages) + assert "prompt" in result + assert "token_estimate" in result + assert isinstance(result["prompt"], str) + assert isinstance(result["token_estimate"], int) + + def test_prompt_structure_default(self): + messages = [ + {"role": "user", "content": "Let's build a context engine"}, + ] + result = build_context("s1", messages) + prompt = result["prompt"] + assert "=== LEDGER ===" in prompt + assert "=== LIVE WINDOW ===" in prompt + assert "Respond concisely" in prompt + # Sigil should NOT appear by default + assert "INTERNAL SIGIL SNAPSHOT" not in prompt + + def test_sigil_hidden_by_default(self): + messages = [ + {"role": "user", "content": "We forked to context_sov"}, + ] + result = build_context("s1", messages) + # Event should be recorded in JSON but not in prompt + state = sigil.load("s1") + assert len(state["event"]) > 0 + assert "INTERNAL SIGIL SNAPSHOT" not in result["prompt"] + + def test_sigil_shown_when_debug(self, monkeypatch): + monkeypatch.setattr(sigil, "DEBUG_SIGIL", True) + messages = [ + {"role": "user", "content": "We forked to context_sov"}, + ] + result = build_context("s1", messages) + assert "=== INTERNAL SIGIL SNAPSHOT ===" in result["prompt"] + + def test_live_window_content(self): + messages = [ + {"role": "user", "content": "message one"}, + {"role": "assistant", "content": "response one"}, + ] + result = build_context("s1", messages) + assert "[user] message one" in result["prompt"] + assert "[assistant] response one" in result["prompt"] + + def test_redaction_applied(self): + messages = [ + {"role": "user", "content": "my key is sk-abc123def456ghi789jklmnopqrs"}, + ] + result = build_context("s1", messages) + assert "sk-abc" not in result["prompt"] + assert "[REDACTED_SECRET]" in result["prompt"] + + def test_compression_updates_ledger(self): + messages = [ + {"role": "user", "content": "we will use deterministic compression"}, + {"role": "assistant", "content": "decided to skip AI summarization"}, + ] + build_context("s1", messages) + data = ledger.load("s1") + assert data["delta"] != "" + + def test_compression_detects_blockers(self): + messages = [ + {"role": "user", "content": "I'm blocked on rate limit issues"}, + ] + build_context("s1", messages) + data = ledger.load("s1") + assert len(data["blockers"]) > 0 + + def test_compression_detects_open_questions(self): + messages = [ + {"role": "user", "content": "should we use tiktoken for counting?"}, + ] + build_context("s1", messages) + data = ledger.load("s1") + assert len(data["open"]) > 0 + + def test_sigil_event_generation(self): + messages = [ + {"role": "user", "content": "We forked to context_sov"}, + ] + build_context("s1", messages) + state = sigil.load("s1") + assert len(state["event"]) > 0 + assert state["event"][0]["operator"] == "fork" + + def test_sigil_dedup_across_calls(self): + messages = [ + {"role": "user", "content": "We forked to context_sov"}, + ] + build_context("s1", messages) + build_context("s1", messages) + state = sigil.load("s1") + # Same triple should not be duplicated + assert len(state["event"]) == 1 + + def test_token_estimate_positive(self): + messages = [{"role": "user", "content": "hello"}] + result = build_context("s1", messages) + assert result["token_estimate"] > 0 + + def test_window_shrinks_on_large_input(self, monkeypatch): + monkeypatch.setattr(token_gate, "MAX_TOKENS", 200) + monkeypatch.setattr(token_gate, "LIVE_WINDOW", 16) + messages = [ + {"role": "user", "content": f"This is message number {i} with some content"} + for i in range(20) + ] + result = build_context("s1", messages) + assert result["token_estimate"] <= 200 + + def test_hard_fail_on_impossible_budget(self, monkeypatch): + monkeypatch.setattr(token_gate, "MAX_TOKENS", 5) + monkeypatch.setattr(token_gate, "LIVE_WINDOW", 4) + messages = [ + {"role": "user", "content": "x" * 1000}, + ] + with pytest.raises(token_gate.TokenBudgetExceeded): + build_context("s1", messages) + + def test_structured_content_blocks(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello from structured content"}, + ], + }, + ] + result = build_context("s1", messages) + assert "Hello from structured content" in result["prompt"] + + def test_no_md_history_injection(self): + messages = [{"role": "user", "content": "just this"}] + result = build_context("s1", messages) + assert "just this" in result["prompt"] + assert ".md" not in result["prompt"] + + def test_debug_sigil_snapshot_is_valid_json(self, monkeypatch): + monkeypatch.setattr(sigil, "DEBUG_SIGIL", True) + messages = [ + {"role": "user", "content": "We forked to context_sov"}, + ] + result = build_context("s1", messages) + # Extract the sigil JSON from the prompt + prompt = result["prompt"] + marker = "=== INTERNAL SIGIL SNAPSHOT ===" + assert marker in prompt + start = prompt.index(marker) + len(marker) + end = prompt.index("=== LEDGER ===") + sigil_json = prompt[start:end].strip() + data = json.loads(sigil_json) + assert "event" in data + assert "state" in data diff --git a/ra2/tests/test_ledger.py b/ra2/tests/test_ledger.py new file mode 100644 index 00000000000..ec32af7e128 --- /dev/null +++ b/ra2/tests/test_ledger.py @@ -0,0 +1,100 @@ +"""Tests for ra2.ledger""" + +import json +import os +import tempfile +import pytest +from ra2 import ledger + + +@pytest.fixture(autouse=True) +def tmp_ledger_dir(monkeypatch, tmp_path): + """Redirect ledger storage to a temp directory for each test.""" + d = str(tmp_path / "ledgers") + monkeypatch.setattr(ledger, "LEDGER_DIR", d) + return d + + +class TestLoadSave: + def test_load_empty(self): + data = ledger.load("test-stream") + assert data["stream"] == "test-stream" + assert data["orientation"] == "" + assert data["blockers"] == [] + assert data["open"] == [] + + def test_save_and_load(self): + data = { + "stream": "s1", + "orientation": "build context engine", + "latest": "implemented ledger", + "blockers": ["rate limits"], + "open": ["how to compress?"], + "delta": "added ledger module", + } + ledger.save("s1", data) + loaded = ledger.load("s1") + assert loaded == data + + def test_save_enforces_field_length(self): + data = { + "stream": "s1", + "orientation": "x" * 1000, + "latest": "", + "blockers": [], + "open": [], + "delta": "", + } + ledger.save("s1", data) + loaded = ledger.load("s1") + assert len(loaded["orientation"]) == ledger.MAX_FIELD_CHARS + + def test_save_enforces_list_length(self): + data = { + "stream": "s1", + "orientation": "", + "latest": "", + "blockers": [f"blocker-{i}" for i in range(20)], + "open": [f"question-{i}" for i in range(20)], + "delta": "", + } + ledger.save("s1", data) + loaded = ledger.load("s1") + assert len(loaded["blockers"]) == ledger.MAX_BLOCKERS + assert len(loaded["open"]) == ledger.MAX_OPEN + + +class TestUpdate: + def test_update_fields(self): + result = ledger.update("s1", orientation="test orientation", delta="did stuff") + assert result["orientation"] == "test orientation" + assert result["delta"] == "did stuff" + assert result["stream"] == "s1" + + def test_update_ignores_unknown_keys(self): + result = ledger.update("s1", unknown_key="value") + assert "unknown_key" not in result + + def test_update_persists(self): + ledger.update("s1", orientation="phase 1") + loaded = ledger.load("s1") + assert loaded["orientation"] == "phase 1" + + +class TestSnapshot: + def test_snapshot_empty(self): + snap = ledger.snapshot("empty-stream") + assert "stream: empty-stream" in snap + assert "orientation:" in snap + + def test_snapshot_with_data(self): + ledger.update( + "s1", + orientation="context sovereignty", + blockers=["rate limits"], + open=["compression strategy?"], + ) + snap = ledger.snapshot("s1") + assert "context sovereignty" in snap + assert "rate limits" in snap + assert "compression strategy?" in snap diff --git a/ra2/tests/test_redact.py b/ra2/tests/test_redact.py new file mode 100644 index 00000000000..ae601acb376 --- /dev/null +++ b/ra2/tests/test_redact.py @@ -0,0 +1,114 @@ +"""Tests for ra2.redact""" + +import pytest +from ra2.redact import redact, redact_dict, redact_messages, REDACTED + + +class TestRedact: + def test_openai_key(self): + text = "my key is sk-abc123def456ghi789jklmnopqrs" + result = redact(text) + assert "sk-abc" not in result + assert REDACTED in result + + def test_anthropic_key(self): + text = "key: sk-ant-abc123def456ghi789jklmnopqrs" + result = redact(text) + assert "sk-ant-" not in result + assert REDACTED in result + + def test_discord_token(self): + # Build a fake Discord-shaped token dynamically to avoid push protection. + # Pattern: [MN][A-Za-z0-9]{23,}.[A-Za-z0-9_-]{6}.[A-Za-z0-9_-]{27,} + prefix = "M" + "T" * 23 # 24 chars, starts with M + mid = "G" + "a" * 5 # 6 chars + suffix = "x" * 27 # 27 chars + token = f"{prefix}.{mid}.{suffix}" + text = f"token is {token}" + result = redact(text) + assert token not in result + assert REDACTED in result + + def test_google_key(self): + text = "key=AIzaSyD-abcdefghijklmnopqrstuvwxyz12345" + result = redact(text) + assert "AIza" not in result + assert REDACTED in result + + def test_aws_key(self): + text = "aws key: AKIAIOSFODNN7EXAMPLE" + result = redact(text) + assert "AKIA" not in result + + def test_slack_token(self): + text = "token: xoxb-123456789012-abcdefghij" + result = redact(text) + assert "xoxb-" not in result + + def test_github_token(self): + text = "auth: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijkl" + result = redact(text) + assert "ghp_" not in result + + def test_telegram_token(self): + text = "bot: 123456789:ABCDefGHIJKlMNOpQRSTuvWXYz0123456789a" + result = redact(text) + assert "ABCDef" not in result + + def test_bearer_token(self): + text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.abc" + result = redact(text) + assert "eyJh" not in result + + def test_generic_secret_key_value(self): + text = 'api_key = "abcdefghijklmnopqrstuvwxyz1234567890ABCD"' + result = redact(text) + assert "abcdefghij" not in result + # The label should still be there + assert "api_key" in result + + def test_no_false_positive_normal_text(self): + text = "Hello, this is a normal message with no secrets." + assert redact(text) == text + + def test_multiple_secrets(self): + text = "keys: sk-abc123def456ghi789jklmnopqrs and sk-ant-xyz123abc456def789ghi" + result = redact(text) + assert "sk-abc" not in result + assert "sk-ant-" not in result + assert result.count(REDACTED) == 2 + + +class TestRedactDict: + def test_flat_dict(self): + d = {"key": "sk-abc123def456ghi789jklmnopqrs", "name": "test"} + result = redact_dict(d) + assert REDACTED in result["key"] + assert result["name"] == "test" + + def test_nested_dict(self): + d = {"outer": {"inner": "sk-abc123def456ghi789jklmnopqrs"}} + result = redact_dict(d) + assert REDACTED in result["outer"]["inner"] + + def test_list_values(self): + d = {"tokens": ["sk-abc123def456ghi789jklmnopqrs", "normal"]} + result = redact_dict(d) + assert REDACTED in result["tokens"][0] + assert result["tokens"][1] == "normal" + + +class TestRedactMessages: + def test_redacts_content(self): + msgs = [ + {"role": "user", "content": "my key is sk-abc123def456ghi789jklmnopqrs"}, + {"role": "assistant", "content": "I see a key"}, + ] + result = redact_messages(msgs) + assert REDACTED in result[0]["content"] + assert result[1]["content"] == "I see a key" + + def test_preserves_non_string_content(self): + msgs = [{"role": "user", "content": 42}] + result = redact_messages(msgs) + assert result[0]["content"] == 42 diff --git a/ra2/tests/test_sigil.py b/ra2/tests/test_sigil.py new file mode 100644 index 00000000000..1a8038327c4 --- /dev/null +++ b/ra2/tests/test_sigil.py @@ -0,0 +1,256 @@ +"""Tests for ra2.sigil (JSON layered format)""" + +import json +import pytest +from ra2 import sigil + + +@pytest.fixture(autouse=True) +def tmp_sigil_dir(monkeypatch, tmp_path): + """Redirect sigil storage to a temp directory for each test.""" + d = str(tmp_path / "sigils") + monkeypatch.setattr(sigil, "SIGIL_DIR", d) + return d + + +# ── Load / Save ───────────────────────────────────────────────────── + +class TestLoadSave: + def test_load_empty(self): + state = sigil.load("test-stream") + assert state["event"] == [] + assert state["state"]["arch"]["wrapper"] == "" + assert state["state"]["risk"]["token_pressure"] == "" + assert state["state"]["mode"]["debug"] is False + + def test_save_and_load_roundtrip(self): + state = sigil._empty_state() + state["event"].append({ + "operator": "fork", + "constraint": "architectural_scope", + "decision": "thin_wrapper", + "timestamp": "2026-02-19T04:00:00Z", + }) + state["state"]["arch"]["wrapper"] = "thin" + sigil.save("s1", state) + loaded = sigil.load("s1") + assert len(loaded["event"]) == 1 + assert loaded["event"][0]["operator"] == "fork" + assert loaded["state"]["arch"]["wrapper"] == "thin" + + def test_save_creates_json_file(self, tmp_sigil_dir): + state = sigil._empty_state() + sigil.save("s1", state) + import os + path = os.path.join(tmp_sigil_dir, "s1.json") + assert os.path.exists(path) + with open(path) as f: + data = json.load(f) + assert "event" in data + assert "state" in data + + def test_fifo_on_save(self): + state = sigil._empty_state() + for i in range(20): + state["event"].append({ + "operator": f"op_{i}", + "constraint": "c", + "decision": "d", + "timestamp": "2026-01-01T00:00:00Z", + }) + sigil.save("s1", state) + loaded = sigil.load("s1") + assert len(loaded["event"]) == sigil.MAX_EVENT_ENTRIES + # Should keep the last 15 + assert loaded["event"][0]["operator"] == "op_5" + assert loaded["event"][-1]["operator"] == "op_19" + + def test_corrupt_file_returns_empty(self, tmp_sigil_dir): + import os + os.makedirs(tmp_sigil_dir, exist_ok=True) + path = os.path.join(tmp_sigil_dir, "bad.json") + with open(path, "w") as f: + f.write("not valid json{{{") + state = sigil.load("bad") + assert state["event"] == [] + assert "arch" in state["state"] + + def test_missing_sections_filled(self, tmp_sigil_dir): + import os + os.makedirs(tmp_sigil_dir, exist_ok=True) + path = os.path.join(tmp_sigil_dir, "partial.json") + with open(path, "w") as f: + json.dump({"event": [], "state": {}}, f) + state = sigil.load("partial") + assert "arch" in state["state"] + assert "risk" in state["state"] + assert "mode" in state["state"] + + +# ── append_event ──────────────────────────────────────────────────── + +class TestAppendEvent: + def test_append_single(self): + state = sigil.append_event("s1", "fork", "arch_scope", "thin_wrapper") + assert len(state["event"]) == 1 + assert state["event"][0]["operator"] == "fork" + assert state["event"][0]["constraint"] == "arch_scope" + assert state["event"][0]["decision"] == "thin_wrapper" + assert "timestamp" in state["event"][0] + + def test_append_multiple(self): + sigil.append_event("s1", "fork", "scope", "wrapper") + state = sigil.append_event("s1", "token_burn", "overflow", "compress") + assert len(state["event"]) == 2 + + def test_deduplication(self): + sigil.append_event("s1", "fork", "scope", "wrapper") + state = sigil.append_event("s1", "fork", "scope", "wrapper") + assert len(state["event"]) == 1 + + def test_fifo_eviction(self): + for i in range(20): + state = sigil.append_event("s1", f"op_{i}", "c", "d") + assert len(state["event"]) == sigil.MAX_EVENT_ENTRIES + operators = [e["operator"] for e in state["event"]] + assert "op_0" not in operators + assert "op_19" in operators + + def test_rejects_empty_fields(self): + state = sigil.append_event("s1", "", "c", "d") + assert len(state["event"]) == 0 + + def test_truncates_long_fields(self): + long_val = "a" * 100 + state = sigil.append_event("s1", long_val, "c", "d") + assert len(state["event"]) == 1 + assert len(state["event"][0]["operator"]) <= sigil.MAX_FIELD_CHARS + + +# ── update_state ──────────────────────────────────────────────────── + +class TestUpdateState: + def test_update_arch(self): + state = sigil.update_state("s1", arch={ + "wrapper": "thin", + "compression": "rule_based_v1", + "agents": "disabled", + "router": "legacy", + }) + assert state["state"]["arch"]["wrapper"] == "thin" + assert state["state"]["arch"]["compression"] == "rule_based_v1" + + def test_update_risk(self): + state = sigil.update_state("s1", risk={ + "token_pressure": "controlled", + "cooldown": "monitored", + "scope_creep": "constrained", + }) + assert state["state"]["risk"]["token_pressure"] == "controlled" + + def test_update_mode(self): + state = sigil.update_state("s1", mode={ + "determinism": "prioritized", + "rewrite_mode": "disabled", + "debug": False, + }) + assert state["state"]["mode"]["determinism"] == "prioritized" + + def test_update_overwrites(self): + sigil.update_state("s1", arch={"wrapper": "thin"}) + state = sigil.update_state("s1", arch={"wrapper": "fat"}) + assert state["state"]["arch"]["wrapper"] == "fat" + + def test_update_preserves_events(self): + sigil.append_event("s1", "fork", "scope", "wrapper") + state = sigil.update_state("s1", arch={"wrapper": "thin"}) + assert len(state["event"]) == 1 + assert state["state"]["arch"]["wrapper"] == "thin" + + def test_partial_update(self): + sigil.update_state("s1", arch={"wrapper": "thin"}) + state = sigil.update_state("s1", risk={"token_pressure": "high"}) + # arch should still be there + assert state["state"]["arch"]["wrapper"] == "thin" + assert state["state"]["risk"]["token_pressure"] == "high" + + +# ── snapshot ──────────────────────────────────────────────────────── + +class TestSnapshot: + def test_snapshot_empty(self): + snap = sigil.snapshot("empty") + assert snap == "(no sigils)" + + def test_snapshot_with_events(self): + sigil.append_event("s1", "fork", "scope", "wrapper") + snap = sigil.snapshot("s1") + data = json.loads(snap) + assert len(data["event"]) == 1 + assert data["event"][0]["operator"] == "fork" + + def test_snapshot_with_state(self): + sigil.update_state("s1", arch={"wrapper": "thin"}) + snap = sigil.snapshot("s1") + data = json.loads(snap) + assert data["state"]["arch"]["wrapper"] == "thin" + + def test_snapshot_is_valid_json(self): + sigil.append_event("s1", "fork", "scope", "wrapper") + sigil.update_state("s1", arch={"wrapper": "thin"}) + snap = sigil.snapshot("s1") + data = json.loads(snap) # Should not raise + assert "event" in data + assert "state" in data + + +# ── generate_from_message ─────────────────────────────────────────── + +class TestGenerateFromMessage: + def test_fork_detection(self): + result = sigil.generate_from_message("We forked to context_sov branch") + assert result is not None + op, constraint, decision = result + assert op == "fork" + assert constraint == "architectural_scope" + assert "context_sov" in decision + + def test_token_burn_detection(self): + result = sigil.generate_from_message("Seeing token burn on this stream") + assert result == ("token_burn", "context_overflow", "compress_first") + + def test_rate_limit_detection(self): + result = sigil.generate_from_message("Hit a rate limit again") + assert result == ("rate_limit", "cooldown_active", "fallback_model") + + def test_thin_wrapper_detection(self): + result = sigil.generate_from_message("Use a thin wrapper approach") + assert result is not None + assert result[2] == "thin_wrapper" + + def test_no_match(self): + result = sigil.generate_from_message("Hello, how are you?") + assert result is None + + def test_returns_triple(self): + result = sigil.generate_from_message("compaction trigger needed") + assert result is not None + assert len(result) == 3 + op, constraint, decision = result + assert op == "compaction" + assert constraint == "history_overflow" + assert decision == "compact_now" + + +# ── File size cap ─────────────────────────────────────────────────── + +class TestFileSizeCap: + def test_file_respects_size_cap(self, monkeypatch, tmp_sigil_dir): + import os + # Set a small cap + monkeypatch.setattr(sigil, "MAX_FILE_BYTES", 512) + for i in range(20): + sigil.append_event("s1", f"operator_{i}", "constraint", "decision") + path = os.path.join(tmp_sigil_dir, "s1.json") + size = os.path.getsize(path) + assert size <= 512 diff --git a/ra2/tests/test_token_gate.py b/ra2/tests/test_token_gate.py new file mode 100644 index 00000000000..874c98a2d1e --- /dev/null +++ b/ra2/tests/test_token_gate.py @@ -0,0 +1,80 @@ +"""Tests for ra2.token_gate""" + +import pytest +from ra2.token_gate import ( + estimate_tokens, + check_budget, + shrink_window, + TokenBudgetExceeded, + LIVE_WINDOW_MIN, +) + + +class TestEstimateTokens: + def test_empty_string(self): + assert estimate_tokens("") == 0 + + def test_short_string(self): + assert estimate_tokens("ab") == 1 + + def test_known_length_ascii(self): + text = "a" * 400 + # 400 / 3.3 ≈ 121 + assert estimate_tokens(text) == int(400 / 3.3) + + def test_proportional(self): + short = estimate_tokens("hello world") + long = estimate_tokens("hello world " * 100) + assert long > short + + def test_non_ascii_increases_estimate(self): + ascii_text = "a" * 100 + # Mix in non-ASCII to trigger the penalty + non_ascii_text = "\u4e00" * 100 # CJK characters + assert estimate_tokens(non_ascii_text) > estimate_tokens(ascii_text) + + def test_code_heavy_reasonable(self): + code = 'def foo(x: int) -> bool:\n return x > 0\n' * 10 + tokens = estimate_tokens(code) + # Should be more conservative than len//4 + assert tokens > len(code) // 4 + + +class TestCheckBudget: + def test_within_budget(self): + assert check_budget(100, limit=200) is True + + def test_at_budget(self): + assert check_budget(200, limit=200) is True + + def test_over_budget(self): + assert check_budget(201, limit=200) is False + + +class TestShrinkWindow: + def test_halves(self): + assert shrink_window(16) == 8 + + def test_halves_again(self): + assert shrink_window(8) == 4 + + def test_at_minimum_raises(self): + with pytest.raises(TokenBudgetExceeded): + shrink_window(LIVE_WINDOW_MIN) + + def test_below_minimum_raises(self): + with pytest.raises(TokenBudgetExceeded): + shrink_window(2) + + def test_odd_number(self): + # 5 // 2 = 2, but clamped to LIVE_WINDOW_MIN (4) + assert shrink_window(5) == LIVE_WINDOW_MIN + + +class TestTokenBudgetExceeded: + def test_attributes(self): + exc = TokenBudgetExceeded(estimated=7000, limit=6000) + assert exc.estimated == 7000 + assert exc.limit == 6000 + assert "7000" in str(exc) + assert "6000" in str(exc) diff --git a/ra2/token_gate.py b/ra2/token_gate.py new file mode 100644 index 00000000000..364bb29552a --- /dev/null +++ b/ra2/token_gate.py @@ -0,0 +1,62 @@ +""" +ra2.token_gate — Token estimation and hard cap enforcement. + +Provides a fast, deterministic token estimator (no external tokenizer dependency) +and gate logic that prevents any prompt from exceeding MAX_TOKENS. +""" + +import os + +# Configurable via environment or direct override +MAX_TOKENS: int = int(os.environ.get("RA2_MAX_TOKENS", "6000")) +LIVE_WINDOW: int = int(os.environ.get("RA2_LIVE_WINDOW", "16")) +LIVE_WINDOW_MIN: int = 4 # Never shrink below this + + +class TokenBudgetExceeded(Exception): + """Raised when prompt exceeds MAX_TOKENS even after shrinking.""" + + def __init__(self, estimated: int, limit: int): + self.estimated = estimated + self.limit = limit + super().__init__( + f"Token budget exceeded: {estimated} > {limit} after all shrink attempts" + ) + + +def estimate_tokens(text: str) -> int: + """Fast deterministic token estimate. + + Base ratio: ~3.3 chars/token (conservative vs the common ~4 estimate). + Applies a penalty when non-ASCII density is high, since code symbols + and multilingual characters tend to produce shorter tokens. + No external dependency. + """ + if not text: + return 0 + length = len(text) + non_ascii = sum(1 for ch in text if ord(ch) > 127) + ratio = non_ascii / length if length else 0 + # Shift from 3.3 toward 2.5 chars/token as non-ASCII density rises + chars_per_token = 3.3 - (0.8 * ratio) + return max(1, int(length / chars_per_token)) + + +def check_budget(estimated: int, limit: int | None = None) -> bool: + """Return True if *estimated* is within budget, False otherwise.""" + limit = limit if limit is not None else MAX_TOKENS + return estimated <= limit + + +def shrink_window(current_window: int) -> int: + """Halve the live window, respecting the minimum. + + Returns the new window size, or raises TokenBudgetExceeded if + already at minimum. + """ + if current_window <= LIVE_WINDOW_MIN: + raise TokenBudgetExceeded( + estimated=0, # caller should fill real value + limit=MAX_TOKENS, + ) + return max(LIVE_WINDOW_MIN, current_window // 2)