feat(skills): add local-db and rag-store community skills
Add two new community skills for OpenClaw: **local-db** — SQLite database management CLI - Create, query, migrate, and backup local SQLite databases - Destructive SQL guard (--allow-destructive flag required for DROP/DELETE/TRUNCATE) - Robust SQL parser that respects string literals, quoted identifiers (double-quote, bracket, backtick), and comments - Comment-obfuscation-proof destructive check via _strip_sql_for_pattern_check() - CREATE TRIGGER support: BEGIN...END blocks preserve inner semicolons - Transaction BEGIN (BEGIN/BEGIN TRANSACTION) correctly split as separate statements - Migration tracking with _migrations table - Database name validation (alphanumeric, _, -) **rag-store** — ChromaDB-based RAG vector store CLI - Ingest documents (txt, md, pdf, epub) into vector collections - Semantic search with configurable top-k results - Source tracking with source_path metadata for unambiguous file identification - Stale chunk cleanup on re-ingestion - Collection name validation - Guards against negative overlap and trailing duplicate chunks Tests: 48 total (36 local-db + 12 rag-store) > This PR was developed with AI assistance (GitHub Copilot).
This commit is contained in:
parent
4f620bebe5
commit
25ea63d901
143
skills/local-db/SKILL.md
Normal file
143
skills/local-db/SKILL.md
Normal file
@ -0,0 +1,143 @@
|
||||
---
|
||||
name: local-db
|
||||
description: "Manage local SQLite databases: create DBs, tables with relationships, execute queries, apply safe migrations (never loses data), and backup. Use when the user asks to store, query, or manage structured data locally. NOT for: unstructured document search (use rag-store), cloud databases, or key-value config storage."
|
||||
metadata: { "openclaw": { "emoji": "🗄️", "requires": { "bins": ["python3"] } } }
|
||||
---
|
||||
|
||||
# Local Database (SQLite)
|
||||
|
||||
Manage local SQLite databases via the bundled `localdb.py` script. Databases are stored in `~/.openclaw/databases/`.
|
||||
|
||||
## When to use
|
||||
|
||||
✅ **USE this skill when:**
|
||||
|
||||
- User asks to "create a database", "store data", "create a table"
|
||||
- User wants to query structured local data
|
||||
- User needs relationships between data (foreign keys)
|
||||
- User asks to track records, patients, inventory, contacts, etc.
|
||||
- User wants a persistent local data store
|
||||
|
||||
## When NOT to use
|
||||
|
||||
❌ **DON'T use this skill when:**
|
||||
|
||||
- User wants to search document content semantically → use rag-store
|
||||
- User needs cloud/remote database access → use appropriate cloud tools
|
||||
- User wants key-value or config storage → use files/JSON directly
|
||||
- Data is unstructured (free text, PDFs) → use rag-store
|
||||
|
||||
## Key safety rules
|
||||
|
||||
1. **NEVER drop tables** without explicit user confirmation
|
||||
2. **ALWAYS use migrations** for schema changes (not raw ALTER/DROP)
|
||||
3. **ALWAYS backup** before destructive operations
|
||||
4. Foreign keys are enforced (`PRAGMA foreign_keys = ON`)
|
||||
5. All schema changes are recorded in `_migrations` table
|
||||
6. Destructive SQL (DROP, DELETE, TRUNCATE) requires `--allow-destructive` flag
|
||||
7. **ALWAYS ask the user for confirmation** before using `--allow-destructive`
|
||||
|
||||
## Commands
|
||||
|
||||
### Create a database
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/localdb.py create mydata
|
||||
```
|
||||
|
||||
### List databases
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/localdb.py list
|
||||
```
|
||||
|
||||
### Apply a migration (safe schema change)
|
||||
|
||||
Always use migrations to create or alter tables:
|
||||
|
||||
```bash
|
||||
# Create tables
|
||||
python3 {baseDir}/scripts/localdb.py migrate mydata -d "Create patients table" -s "CREATE TABLE IF NOT EXISTS patients (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, phone TEXT, active INTEGER DEFAULT 1, created_at TEXT DEFAULT (datetime('now')))"
|
||||
|
||||
# Add a column (safe — IF NOT EXISTS not available for ALTER, so check first)
|
||||
python3 {baseDir}/scripts/localdb.py migrate mydata -d "Add email to patients" -s "ALTER TABLE patients ADD COLUMN email TEXT"
|
||||
|
||||
# Create related table with foreign key
|
||||
python3 {baseDir}/scripts/localdb.py migrate mydata -d "Create appointments table" -s "CREATE TABLE IF NOT EXISTS appointments (id INTEGER PRIMARY KEY AUTOINCREMENT, patient_id INTEGER NOT NULL REFERENCES patients(id), date TEXT NOT NULL, notes TEXT, created_at TEXT DEFAULT (datetime('now')))"
|
||||
```
|
||||
|
||||
### Execute queries
|
||||
|
||||
```bash
|
||||
# Insert data
|
||||
python3 {baseDir}/scripts/localdb.py exec mydata "INSERT INTO patients (name, phone) VALUES ('João Silva', '+351912345678')"
|
||||
|
||||
# Select data
|
||||
python3 {baseDir}/scripts/localdb.py exec mydata "SELECT * FROM patients WHERE active = 1"
|
||||
|
||||
# JSON output (great for piping)
|
||||
python3 {baseDir}/scripts/localdb.py exec mydata "SELECT * FROM patients" --json
|
||||
|
||||
# Join queries
|
||||
python3 {baseDir}/scripts/localdb.py exec mydata "SELECT p.name, a.date, a.notes FROM patients p JOIN appointments a ON p.id = a.patient_id ORDER BY a.date DESC"
|
||||
|
||||
# Multi-statement (semicolon-separated)
|
||||
python3 {baseDir}/scripts/localdb.py exec mydata "INSERT INTO patients (name) VALUES ('Alice'); INSERT INTO patients (name) VALUES ('Bob')"
|
||||
|
||||
# Destructive operations (requires explicit flag + user confirmation)
|
||||
python3 {baseDir}/scripts/localdb.py exec mydata "DELETE FROM patients WHERE active = 0" --allow-destructive
|
||||
```
|
||||
|
||||
### View schema
|
||||
|
||||
```bash
|
||||
# All tables
|
||||
python3 {baseDir}/scripts/localdb.py schema mydata
|
||||
|
||||
# Specific table
|
||||
python3 {baseDir}/scripts/localdb.py schema mydata -t patients
|
||||
```
|
||||
|
||||
### List tables
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/localdb.py tables mydata
|
||||
```
|
||||
|
||||
### View migration history
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/localdb.py migrations mydata
|
||||
```
|
||||
|
||||
### Backup a database
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/localdb.py backup mydata
|
||||
```
|
||||
|
||||
## Patterns
|
||||
|
||||
### Safe column addition
|
||||
|
||||
Before adding a column, check if it exists:
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/localdb.py exec mydata "PRAGMA table_info(patients)" --json
|
||||
```
|
||||
|
||||
Then migrate if the column doesn't exist.
|
||||
|
||||
### Bulk inserts
|
||||
|
||||
Use multiple INSERT statements separated by semicolons in a single exec call:
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/localdb.py exec mydata "INSERT INTO patients (name) VALUES ('Alice'); INSERT INTO patients (name) VALUES ('Bob')"
|
||||
```
|
||||
|
||||
### Data export
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/localdb.py exec mydata "SELECT * FROM patients" --json > patients_export.json
|
||||
```
|
||||
545
skills/local-db/scripts/localdb.py
Normal file
545
skills/local-db/scripts/localdb.py
Normal file
@ -0,0 +1,545 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
localdb — CLI for managing local SQLite databases.
|
||||
OpenClaw skill: lets the agent create databases, tables, relationships,
|
||||
query and safely migrate data without losing existing records.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
DB_DIR = os.path.expanduser("~/.openclaw/databases")
|
||||
|
||||
|
||||
def get_db_path(name):
|
||||
os.makedirs(DB_DIR, exist_ok=True)
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', name):
|
||||
print(f"Error: invalid database name '{name}'. Use only alphanumeric, _ and -.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return os.path.join(DB_DIR, f"{name}.db")
|
||||
|
||||
|
||||
def cmd_list_dbs(args):
|
||||
os.makedirs(DB_DIR, exist_ok=True)
|
||||
dbs = [f[:-3] for f in os.listdir(DB_DIR) if f.endswith(".db")]
|
||||
if not dbs:
|
||||
print("No databases found.")
|
||||
return
|
||||
for db in sorted(dbs):
|
||||
path = os.path.join(DB_DIR, f"{db}.db")
|
||||
size = os.path.getsize(path)
|
||||
print(f" {db} ({size} bytes)")
|
||||
|
||||
|
||||
def cmd_create_db(args):
|
||||
path = get_db_path(args.name)
|
||||
if os.path.exists(path) and not args.force:
|
||||
print(f"Database '{args.name}' already exists. Use --force to recreate.")
|
||||
return
|
||||
if os.path.exists(path) and args.force:
|
||||
os.remove(path)
|
||||
conn = sqlite3.connect(path)
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
description TEXT NOT NULL,
|
||||
sql_up TEXT NOT NULL,
|
||||
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(f"Database '{args.name}' created at {path}")
|
||||
|
||||
|
||||
def cmd_tables(args):
|
||||
path = get_db_path(args.db)
|
||||
if not os.path.exists(path):
|
||||
print(f"Database '{args.db}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
conn = sqlite3.connect(path)
|
||||
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT IN ('_migrations', 'sqlite_sequence') ORDER BY name")
|
||||
tables = [row[0] for row in cursor]
|
||||
conn.close()
|
||||
if not tables:
|
||||
print("No tables (excluding internal).")
|
||||
return
|
||||
for t in tables:
|
||||
print(f" {t}")
|
||||
|
||||
|
||||
def cmd_schema(args):
|
||||
path = get_db_path(args.db)
|
||||
if not os.path.exists(path):
|
||||
print(f"Database '{args.db}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
conn = sqlite3.connect(path)
|
||||
try:
|
||||
if args.table:
|
||||
cursor = conn.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (args.table,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
print(row[0])
|
||||
else:
|
||||
print(f"Table '{args.table}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
cursor = conn.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name NOT IN ('_migrations', 'sqlite_sequence') ORDER BY name")
|
||||
for row in cursor:
|
||||
if row[0]:
|
||||
print(row[0] + ";\n")
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# SQL patterns that require --allow-destructive flag.
|
||||
# ALTER TABLE uses .*?\bDROP\b (lazy dot-all) so it matches any identifier style
|
||||
# ("my table", `my table`, [my table], plain name) without depending on \S+.
|
||||
_DESTRUCTIVE_PATTERNS = re.compile(
|
||||
r'\b(DROP\s+TABLE|DROP\s+INDEX|DROP\s+VIEW|DROP\s+TRIGGER|'
|
||||
r'DELETE\s+FROM|TRUNCATE)\b'
|
||||
r'|ALTER\s+TABLE\b.*?\bDROP\b',
|
||||
re.IGNORECASE | re.DOTALL
|
||||
)
|
||||
|
||||
|
||||
def _strip_sql_for_pattern_check(sql):
|
||||
"""Strip string literals and comments, replacing with placeholder tokens.
|
||||
|
||||
This allows _DESTRUCTIVE_PATTERNS to:
|
||||
- Avoid false positives: SELECT 'TRUNCATE' -> SELECT _STR_ (no match)
|
||||
- Catch comment obfuscation: DROP/**/TABLE -> DROP TABLE (matches \\s+)
|
||||
- Handle spaced identifiers: ALTER TABLE "my table" DROP -> ... _ID_ DROP
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
length = len(sql)
|
||||
while i < length:
|
||||
ch = sql[i]
|
||||
if ch == "'":
|
||||
result.append('_STR_')
|
||||
i += 1
|
||||
while i < length:
|
||||
if sql[i] == "'":
|
||||
i += 1
|
||||
if i < length and sql[i] == "'": # escaped ''
|
||||
i += 1
|
||||
continue
|
||||
break
|
||||
i += 1
|
||||
elif ch == '"':
|
||||
result.append('_ID_')
|
||||
i += 1
|
||||
while i < length:
|
||||
if sql[i] == '"':
|
||||
i += 1
|
||||
if i < length and sql[i] == '"': # escaped ""
|
||||
i += 1
|
||||
continue
|
||||
break
|
||||
i += 1
|
||||
elif ch == '[':
|
||||
result.append('_ID_')
|
||||
i += 1
|
||||
while i < length and sql[i] != ']':
|
||||
i += 1
|
||||
if i < length:
|
||||
i += 1
|
||||
elif ch == '`':
|
||||
result.append('_ID_')
|
||||
i += 1
|
||||
while i < length and sql[i] != '`':
|
||||
i += 1
|
||||
if i < length:
|
||||
i += 1
|
||||
elif ch == '-' and i + 1 < length and sql[i + 1] == '-':
|
||||
result.append(' ')
|
||||
while i < length and sql[i] != '\n':
|
||||
i += 1
|
||||
elif ch == '/' and i + 1 < length and sql[i + 1] == '*':
|
||||
result.append(' ')
|
||||
i += 2
|
||||
while i < length:
|
||||
if sql[i] == '*' and i + 1 < length and sql[i + 1] == '/':
|
||||
i += 2
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
result.append(ch)
|
||||
i += 1
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
def _split_sql(sql):
|
||||
"""Split SQL on semicolons, respecting string literals, identifiers, comments,
|
||||
and BEGIN...END blocks (preserves semicolons inside CREATE TRIGGER bodies)."""
|
||||
statements = []
|
||||
current = []
|
||||
in_single = False
|
||||
in_double = False
|
||||
in_bracket = False
|
||||
in_backtick = False
|
||||
begin_depth = 0 # tracks nesting for CREATE TRIGGER ... BEGIN ... END
|
||||
case_depth = 0 # tracks CASE...END nesting inside trigger bodies
|
||||
i = 0
|
||||
length = len(sql)
|
||||
|
||||
def _word_char(c):
|
||||
return c.isalnum() or c == '_'
|
||||
|
||||
while i < length:
|
||||
ch = sql[i]
|
||||
|
||||
# Inside single-quoted string
|
||||
if in_single:
|
||||
current.append(ch)
|
||||
if ch == "'":
|
||||
if i + 1 < length and sql[i + 1] == "'":
|
||||
current.append(sql[i + 1]) # escaped ''
|
||||
i += 2
|
||||
continue
|
||||
in_single = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Inside double-quoted identifier
|
||||
if in_double:
|
||||
current.append(ch)
|
||||
if ch == '"':
|
||||
if i + 1 < length and sql[i + 1] == '"':
|
||||
current.append(sql[i + 1]) # escaped ""
|
||||
i += 2
|
||||
continue
|
||||
in_double = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Inside bracket-quoted identifier [...]
|
||||
if in_bracket:
|
||||
current.append(ch)
|
||||
if ch == ']':
|
||||
in_bracket = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Inside backtick-quoted identifier
|
||||
if in_backtick:
|
||||
current.append(ch)
|
||||
if ch == '`':
|
||||
in_backtick = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Line comment: -- until end of line
|
||||
if ch == '-' and i + 1 < length and sql[i + 1] == '-':
|
||||
while i < length and sql[i] != '\n':
|
||||
current.append(sql[i])
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Block comment: /* ... */
|
||||
if ch == '/' and i + 1 < length and sql[i + 1] == '*':
|
||||
current.append(ch)
|
||||
i += 1
|
||||
current.append(sql[i])
|
||||
i += 1
|
||||
while i < length:
|
||||
if sql[i] == '*' and i + 1 < length and sql[i + 1] == '/':
|
||||
current.append(sql[i])
|
||||
i += 1
|
||||
current.append(sql[i])
|
||||
i += 1
|
||||
break
|
||||
current.append(sql[i])
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Track BEGIN/END depth to preserve semicolons inside trigger bodies.
|
||||
# Only increment when BEGIN follows a partial CREATE [TEMP] TRIGGER
|
||||
# declaration — NOT for transaction control (BEGIN TRANSACTION) or
|
||||
# when `begin` is used as a plain identifier/table name.
|
||||
if ch in ('B', 'b') and sql[i:i + 5].upper() == 'BEGIN':
|
||||
prev_ok = i == 0 or not _word_char(sql[i - 1])
|
||||
next_ok = i + 5 >= length or not _word_char(sql[i + 5])
|
||||
if prev_ok and next_ok:
|
||||
current_text = ''.join(current).strip()
|
||||
# Strip comments so a comment-only prefix like '/* header */'
|
||||
# does not falsely count as trigger context.
|
||||
text_no_comments = re.sub(r'/\*.*?\*/', ' ', current_text, flags=re.DOTALL)
|
||||
text_no_comments = re.sub(r'--[^\n]*', ' ', text_no_comments).strip()
|
||||
if text_no_comments and re.search(
|
||||
r'\bCREATE\s+(?:TEMP\s+|TEMPORARY\s+)?TRIGGER\b',
|
||||
text_no_comments, re.IGNORECASE
|
||||
):
|
||||
begin_depth += 1
|
||||
elif ch in ('C', 'c') and begin_depth > 0 and sql[i:i + 4].upper() == 'CASE':
|
||||
prev_ok = i == 0 or not _word_char(sql[i - 1])
|
||||
next_ok = i + 4 >= length or not _word_char(sql[i + 4])
|
||||
if prev_ok and next_ok:
|
||||
case_depth += 1
|
||||
elif ch in ('E', 'e') and sql[i:i + 3].upper() == 'END' and (begin_depth > 0 or case_depth > 0):
|
||||
prev_ok = i == 0 or not _word_char(sql[i - 1])
|
||||
next_ok = i + 3 >= length or not _word_char(sql[i + 3])
|
||||
if prev_ok and next_ok:
|
||||
if case_depth > 0:
|
||||
case_depth -= 1 # CASE...END consumed
|
||||
elif begin_depth > 0:
|
||||
begin_depth -= 1 # trigger END consumed
|
||||
case_depth = 0 # reset CASE depth on trigger boundary
|
||||
|
||||
# Start of quoted context
|
||||
if ch == "'":
|
||||
in_single = True
|
||||
elif ch == '"':
|
||||
in_double = True
|
||||
elif ch == '[':
|
||||
in_bracket = True
|
||||
elif ch == '`':
|
||||
in_backtick = True
|
||||
elif ch == ';':
|
||||
if begin_depth == 0:
|
||||
stmt = ''.join(current).strip()
|
||||
if stmt:
|
||||
statements.append(stmt)
|
||||
current = []
|
||||
i += 1
|
||||
continue
|
||||
# inside BEGIN...END block: keep semicolon as part of statement
|
||||
|
||||
current.append(ch)
|
||||
i += 1
|
||||
|
||||
stmt = ''.join(current).strip()
|
||||
if stmt:
|
||||
statements.append(stmt)
|
||||
return statements
|
||||
|
||||
|
||||
def cmd_execute(args):
|
||||
path = get_db_path(args.db)
|
||||
if not os.path.exists(path):
|
||||
print(f"Database '{args.db}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
sql = args.sql
|
||||
|
||||
# Guard against destructive operations without explicit flag.
|
||||
# Run against stripped SQL to catch comment obfuscation (DROP/**/TABLE)
|
||||
# and avoid false positives from keywords inside string literals.
|
||||
if _DESTRUCTIVE_PATTERNS.search(_strip_sql_for_pattern_check(sql)) and not args.allow_destructive:
|
||||
print(
|
||||
"Error: destructive SQL detected (DROP/DELETE/TRUNCATE).\n"
|
||||
"Use --allow-destructive to confirm, or use migrations instead.",
|
||||
file=sys.stderr
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
conn = sqlite3.connect(path)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
|
||||
try:
|
||||
# Support multi-statement SQL (respects quoted semicolons)
|
||||
statements = _split_sql(sql)
|
||||
last_cursor = None
|
||||
total_affected = 0
|
||||
|
||||
for stmt in statements:
|
||||
cursor = conn.execute(stmt)
|
||||
last_cursor = cursor
|
||||
if cursor.rowcount > 0:
|
||||
total_affected += cursor.rowcount
|
||||
|
||||
# Fetch result rows BEFORE committing: SQLite keeps INSERT...RETURNING
|
||||
# statements open until all rows are consumed, so calling conn.commit()
|
||||
# before fetchall() raises OperationalError "cannot commit transaction -
|
||||
# SQL statements in progress" and rolls back the write.
|
||||
result_cols = None
|
||||
result_rows = None
|
||||
if last_cursor and last_cursor.description:
|
||||
result_cols = [d[0] for d in last_cursor.description]
|
||||
result_rows = last_cursor.fetchall()
|
||||
|
||||
# Commit after rows are fetched so writes from DML (including RETURNING)
|
||||
# are persisted regardless of what the last statement type was.
|
||||
conn.commit()
|
||||
|
||||
# Detect result-returning queries via cursor.description (handles
|
||||
# SELECT, PRAGMA, WITH ... SELECT, INSERT ... RETURNING, etc.)
|
||||
if result_rows is not None:
|
||||
cols = result_cols
|
||||
rows = result_rows
|
||||
if args.json_output:
|
||||
result = [dict(zip(cols, row)) for row in rows]
|
||||
print(json.dumps(result, indent=2, default=str))
|
||||
else:
|
||||
if cols:
|
||||
print(" | ".join(cols))
|
||||
print("-" * (sum(len(c) for c in cols) + 3 * (len(cols) - 1)))
|
||||
for row in rows:
|
||||
print(" | ".join(str(v) for v in row))
|
||||
print(f"\n({len(rows)} rows)")
|
||||
else:
|
||||
print(f"OK. Rows affected: {total_affected}")
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
print(f"SQL Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def cmd_migrate(args):
|
||||
"""Apply a migration: records the SQL and description, then executes."""
|
||||
path = get_db_path(args.db)
|
||||
if not os.path.exists(path):
|
||||
print(f"Database '{args.db}' not found. Create it first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if _DESTRUCTIVE_PATTERNS.search(_strip_sql_for_pattern_check(args.sql)):
|
||||
print(
|
||||
"Error: destructive SQL detected (DROP/DELETE/TRUNCATE) in migration.\n"
|
||||
"Migrations should only add or modify structure. "
|
||||
"Use exec --allow-destructive for intentional destructive ops.",
|
||||
file=sys.stderr
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
conn = sqlite3.connect(path)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
|
||||
# Ensure migrations table exists
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
description TEXT NOT NULL,
|
||||
sql_up TEXT NOT NULL,
|
||||
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
""")
|
||||
|
||||
try:
|
||||
# Execute the migration SQL
|
||||
for statement in _split_sql(args.sql):
|
||||
conn.execute(statement)
|
||||
|
||||
# Record the migration
|
||||
conn.execute(
|
||||
"INSERT INTO _migrations (description, sql_up) VALUES (?, ?)",
|
||||
(args.description, args.sql)
|
||||
)
|
||||
conn.commit()
|
||||
print(f"Migration applied: {args.description}")
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
print(f"Migration failed: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def cmd_migrations(args):
|
||||
"""List applied migrations."""
|
||||
path = get_db_path(args.db)
|
||||
if not os.path.exists(path):
|
||||
print(f"Database '{args.db}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
conn = sqlite3.connect(path)
|
||||
try:
|
||||
cursor = conn.execute("SELECT id, description, applied_at FROM _migrations ORDER BY id")
|
||||
rows = cursor.fetchall()
|
||||
if not rows:
|
||||
print("No migrations applied yet.")
|
||||
return
|
||||
for row in rows:
|
||||
print(f" #{row[0]} {row[2]} {row[1]}")
|
||||
except sqlite3.OperationalError:
|
||||
print("No migrations table found.")
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def cmd_backup(args):
|
||||
path = get_db_path(args.db)
|
||||
if not os.path.exists(path):
|
||||
print(f"Database '{args.db}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = os.path.join(DB_DIR, f"{args.db}_backup_{timestamp}.db")
|
||||
src = sqlite3.connect(path)
|
||||
dst = sqlite3.connect(backup_path)
|
||||
src.backup(dst)
|
||||
src.close()
|
||||
dst.close()
|
||||
print(f"Backup saved: {backup_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="localdb",
|
||||
description="Manage local SQLite databases for OpenClaw"
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command")
|
||||
|
||||
# list
|
||||
sub.add_parser("list", help="List all databases")
|
||||
|
||||
# create
|
||||
p = sub.add_parser("create", help="Create a new database")
|
||||
p.add_argument("name", help="Database name")
|
||||
p.add_argument("--force", action="store_true", help="Overwrite if exists")
|
||||
|
||||
# tables
|
||||
p = sub.add_parser("tables", help="List tables in a database")
|
||||
p.add_argument("db", help="Database name")
|
||||
|
||||
# schema
|
||||
p = sub.add_parser("schema", help="Show table schema(s)")
|
||||
p.add_argument("db", help="Database name")
|
||||
p.add_argument("--table", "-t", help="Specific table name")
|
||||
|
||||
# execute
|
||||
p = sub.add_parser("exec", help="Execute SQL query")
|
||||
p.add_argument("db", help="Database name")
|
||||
p.add_argument("sql", help="SQL statement(s) to execute (semicolon-separated)")
|
||||
p.add_argument("--json", dest="json_output", action="store_true", help="Output as JSON")
|
||||
p.add_argument("--allow-destructive", action="store_true",
|
||||
help="Allow DROP/DELETE/TRUNCATE statements (requires explicit opt-in)")
|
||||
|
||||
# migrate
|
||||
p = sub.add_parser("migrate", help="Apply a named migration")
|
||||
p.add_argument("db", help="Database name")
|
||||
p.add_argument("--description", "-d", required=True, help="Migration description")
|
||||
p.add_argument("--sql", "-s", required=True, help="SQL to execute")
|
||||
|
||||
# migrations
|
||||
p = sub.add_parser("migrations", help="List applied migrations")
|
||||
p.add_argument("db", help="Database name")
|
||||
|
||||
# backup
|
||||
p = sub.add_parser("backup", help="Create a backup of a database")
|
||||
p.add_argument("db", help="Database name")
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
cmds = {
|
||||
"list": cmd_list_dbs,
|
||||
"create": cmd_create_db,
|
||||
"tables": cmd_tables,
|
||||
"schema": cmd_schema,
|
||||
"exec": cmd_execute,
|
||||
"migrate": cmd_migrate,
|
||||
"migrations": cmd_migrations,
|
||||
"backup": cmd_backup,
|
||||
}
|
||||
cmds[args.command](args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
385
skills/local-db/scripts/test_localdb.py
Normal file
385
skills/local-db/scripts/test_localdb.py
Normal file
@ -0,0 +1,385 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for localdb helpers.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from unittest import TestCase, main
|
||||
from unittest.mock import patch
|
||||
|
||||
# Patch DB_DIR before importing
|
||||
_tmpdir = tempfile.mkdtemp()
|
||||
|
||||
with patch.dict(os.environ, {}):
|
||||
import localdb
|
||||
|
||||
localdb.DB_DIR = _tmpdir
|
||||
|
||||
|
||||
class TestGetDbPath(TestCase):
|
||||
def test_valid_names(self):
|
||||
for name in ["mydb", "test-db", "db_123", "A"]:
|
||||
path = localdb.get_db_path(name)
|
||||
self.assertTrue(path.endswith(f"{name}.db"))
|
||||
|
||||
def test_invalid_names_exit(self):
|
||||
for name in ["../etc", "my db", "db;drop", "a/b", ""]:
|
||||
with self.assertRaises(SystemExit):
|
||||
localdb.get_db_path(name)
|
||||
|
||||
|
||||
class TestDestructiveGuard(TestCase):
|
||||
def test_detects_drop_table(self):
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search("DROP TABLE users"))
|
||||
|
||||
def test_detects_delete_from(self):
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search("DELETE FROM users WHERE id=1"))
|
||||
|
||||
def test_detects_truncate(self):
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search("TRUNCATE users"))
|
||||
|
||||
def test_allows_select(self):
|
||||
self.assertIsNone(localdb._DESTRUCTIVE_PATTERNS.search("SELECT * FROM users"))
|
||||
|
||||
def test_allows_insert(self):
|
||||
self.assertIsNone(localdb._DESTRUCTIVE_PATTERNS.search("INSERT INTO users (name) VALUES ('a')"))
|
||||
|
||||
def test_allows_create_table(self):
|
||||
self.assertIsNone(localdb._DESTRUCTIVE_PATTERNS.search("CREATE TABLE users (id INTEGER)"))
|
||||
|
||||
def test_detects_quoted_alter_table_drop(self):
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search('ALTER TABLE "u-ser" DROP COLUMN c'))
|
||||
|
||||
def test_detects_raw_double_quoted_spaced_alter_drop(self):
|
||||
# .*?\bDROP\b: raw SQL (no strip) matches regardless of identifier style
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search('ALTER TABLE "my table" DROP COLUMN c'))
|
||||
|
||||
def test_detects_raw_backtick_spaced_alter_drop(self):
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search('ALTER TABLE `my table` DROP COLUMN c'))
|
||||
|
||||
def test_no_false_positive_alter_table_add(self):
|
||||
self.assertIsNone(localdb._DESTRUCTIVE_PATTERNS.search('ALTER TABLE users ADD COLUMN created_at TEXT'))
|
||||
|
||||
def test_no_false_positive_alter_table_column_named_dropped(self):
|
||||
self.assertIsNone(localdb._DESTRUCTIVE_PATTERNS.search('ALTER TABLE t ADD COLUMN dropped_at TEXT'))
|
||||
|
||||
|
||||
class TestStripSqlForPatternCheck(TestCase):
|
||||
"""Ensure destructive-pattern check catches obfuscation and avoids false positives."""
|
||||
|
||||
def test_rejects_comment_obfuscated_drop(self):
|
||||
stripped = localdb._strip_sql_for_pattern_check("DROP/**/TABLE users")
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search(stripped))
|
||||
|
||||
def test_rejects_line_comment_obfuscated_delete(self):
|
||||
stripped = localdb._strip_sql_for_pattern_check("DELETE-- comment\nFROM t")
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search(stripped))
|
||||
|
||||
def test_no_false_positive_keyword_in_string(self):
|
||||
stripped = localdb._strip_sql_for_pattern_check("INSERT INTO t VALUES ('TRUNCATE')")
|
||||
self.assertIsNone(localdb._DESTRUCTIVE_PATTERNS.search(stripped))
|
||||
|
||||
def test_no_false_positive_drop_in_string(self):
|
||||
stripped = localdb._strip_sql_for_pattern_check("INSERT INTO t VALUES ('DROP TABLE x')")
|
||||
self.assertIsNone(localdb._DESTRUCTIVE_PATTERNS.search(stripped))
|
||||
|
||||
def test_detects_spaced_identifier_alter_drop(self):
|
||||
stripped = localdb._strip_sql_for_pattern_check('ALTER TABLE "my table" DROP COLUMN c')
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search(stripped))
|
||||
|
||||
def test_detects_backtick_spaced_identifier_alter_drop(self):
|
||||
stripped = localdb._strip_sql_for_pattern_check('ALTER TABLE `my table` DROP COLUMN c')
|
||||
self.assertIsNotNone(localdb._DESTRUCTIVE_PATTERNS.search(stripped))
|
||||
|
||||
|
||||
class TestSplitSql(TestCase):
|
||||
def test_simple_split(self):
|
||||
stmts = localdb._split_sql("SELECT 1; SELECT 2")
|
||||
self.assertEqual(stmts, ["SELECT 1", "SELECT 2"])
|
||||
|
||||
def test_preserves_semicolon_in_string(self):
|
||||
sql = "INSERT INTO t VALUES('a;b'); SELECT * FROM t"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2)
|
||||
self.assertIn("a;b", stmts[0])
|
||||
|
||||
def test_preserves_semicolon_in_double_quotes(self):
|
||||
sql = 'SELECT "col;name" FROM t; SELECT 1'
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2)
|
||||
self.assertIn('"col;name"', stmts[0])
|
||||
|
||||
def test_single_statement_no_semicolon(self):
|
||||
stmts = localdb._split_sql("SELECT 1")
|
||||
self.assertEqual(stmts, ["SELECT 1"])
|
||||
|
||||
def test_escaped_single_quotes(self):
|
||||
sql = "INSERT INTO t VALUES('it''s ok'); SELECT 1"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2)
|
||||
self.assertIn("it''s ok", stmts[0])
|
||||
|
||||
def test_escaped_double_quotes(self):
|
||||
sql = 'SELECT "col""name" FROM t; SELECT 1'
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2)
|
||||
self.assertIn('"col""name"', stmts[0])
|
||||
|
||||
def test_line_comment_with_semicolon(self):
|
||||
sql = "SELECT 1 -- comment; not a split\n; SELECT 2"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2)
|
||||
self.assertIn("-- comment; not a split", stmts[0])
|
||||
|
||||
def test_block_comment_with_semicolon(self):
|
||||
sql = "SELECT /* ignore; this */ 1; SELECT 2"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2)
|
||||
self.assertIn("/* ignore; this */", stmts[0])
|
||||
|
||||
def test_bracket_quoted_identifier(self):
|
||||
sql = "SELECT [col;name] FROM t; SELECT 1"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2)
|
||||
self.assertIn("[col;name]", stmts[0])
|
||||
|
||||
def test_trigger_begin_end_preserves_inner_semicolons(self):
|
||||
sql = (
|
||||
"CREATE TRIGGER trg AFTER INSERT ON t BEGIN "
|
||||
"INSERT INTO log VALUES (1); "
|
||||
"INSERT INTO log VALUES (2); "
|
||||
"END"
|
||||
)
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 1)
|
||||
self.assertIn("BEGIN", stmts[0])
|
||||
self.assertIn("END", stmts[0])
|
||||
|
||||
def test_trigger_followed_by_another_statement(self):
|
||||
sql = (
|
||||
"CREATE TRIGGER trg AFTER INSERT ON t BEGIN "
|
||||
"INSERT INTO log VALUES (1); END; "
|
||||
"SELECT 1"
|
||||
)
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2)
|
||||
|
||||
def test_begin_transaction_splits_normally(self):
|
||||
sql = "BEGIN TRANSACTION; INSERT INTO t VALUES (1); COMMIT"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 3)
|
||||
self.assertEqual(stmts[0], "BEGIN TRANSACTION")
|
||||
self.assertEqual(stmts[2], "COMMIT")
|
||||
|
||||
def test_begin_immediate_splits_normally(self):
|
||||
sql = "BEGIN IMMEDIATE; INSERT INTO t VALUES (1); COMMIT"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 3)
|
||||
|
||||
def test_backtick_quoted_identifier(self):
|
||||
sql = "SELECT `col;name` FROM t; SELECT 1"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2)
|
||||
self.assertIn("`col;name`", stmts[0])
|
||||
|
||||
def test_comment_prefix_before_begin_transaction_splits_normally(self):
|
||||
# /* comment */ before BEGIN TRANSACTION must NOT be treated as a trigger body
|
||||
sql = "/* header */ BEGIN TRANSACTION; INSERT INTO t VALUES (1); COMMIT"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 3, f"Expected 3 stmts, got {stmts}")
|
||||
|
||||
def test_begin_as_identifier_splits_normally(self):
|
||||
# 'begin' used as a table name must NOT be treated as a trigger body
|
||||
sql = "CREATE TABLE begin (id INTEGER); INSERT INTO begin VALUES (1)"
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 2, f"Expected 2 stmts, got {stmts}")
|
||||
|
||||
def test_trigger_with_case_end_preserved(self):
|
||||
# CASE...END inside a trigger body must NOT prematurely close begin_depth
|
||||
sql = (
|
||||
"CREATE TRIGGER trg AFTER INSERT ON t BEGIN "
|
||||
"INSERT INTO x VALUES (CASE WHEN 1 THEN 1 ELSE 0 END); "
|
||||
"END"
|
||||
)
|
||||
stmts = localdb._split_sql(sql)
|
||||
self.assertEqual(len(stmts), 1, f"Expected 1 stmt (full trigger), got {stmts}")
|
||||
|
||||
|
||||
class TestCreateAndMigrate(TestCase):
|
||||
def setUp(self):
|
||||
self.db_name = f"test_{os.getpid()}"
|
||||
self.db_path = os.path.join(_tmpdir, f"{self.db_name}.db")
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
|
||||
def test_create_db(self):
|
||||
localdb.cmd_create_db(type("Args", (), {"name": self.db_name, "force": False})())
|
||||
self.assertTrue(os.path.exists(self.db_path))
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='_migrations'"
|
||||
)
|
||||
self.assertIsNotNone(cursor.fetchone())
|
||||
conn.close()
|
||||
|
||||
def test_create_force_recreates(self):
|
||||
# Create first
|
||||
localdb.cmd_create_db(type("Args", (), {"name": self.db_name, "force": False})())
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.execute("CREATE TABLE dummy (id INTEGER)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Force recreate
|
||||
localdb.cmd_create_db(type("Args", (), {"name": self.db_name, "force": True})())
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='dummy'"
|
||||
)
|
||||
self.assertIsNone(cursor.fetchone())
|
||||
conn.close()
|
||||
|
||||
|
||||
class TestExecCommitBeforeSelect(TestCase):
|
||||
"""Ensure INSERT + SELECT in same command persists data (bug fix)."""
|
||||
|
||||
def setUp(self):
|
||||
self.db_name = f"test_exec_{os.getpid()}"
|
||||
self.db_path = os.path.join(_tmpdir, f"{self.db_name}.db")
|
||||
localdb.cmd_create_db(type("Args", (), {"name": self.db_name, "force": False})())
|
||||
# Create a table first
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.execute("CREATE TABLE items (name TEXT)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
|
||||
def test_insert_then_select_persists(self):
|
||||
args = type("Args", (), {
|
||||
"db": self.db_name,
|
||||
"sql": "INSERT INTO items VALUES ('apple'); SELECT * FROM items",
|
||||
"allow_destructive": False,
|
||||
"json_output": False,
|
||||
})()
|
||||
localdb.cmd_execute(args)
|
||||
# Verify data was actually committed
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
rows = conn.execute("SELECT * FROM items").fetchall()
|
||||
conn.close()
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0][0], "apple")
|
||||
|
||||
def test_insert_returning_persists_and_returns_rows(self):
|
||||
"""INSERT...RETURNING must commit the insert AND return the new row."""
|
||||
args = type("Args", (), {
|
||||
"db": self.db_name,
|
||||
"sql": "INSERT INTO items VALUES ('pear') RETURNING name",
|
||||
"allow_destructive": False,
|
||||
"json_output": False,
|
||||
})()
|
||||
localdb.cmd_execute(args)
|
||||
# Row must be persisted after the RETURNING fetch
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
rows = conn.execute("SELECT * FROM items").fetchall()
|
||||
conn.close()
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0][0], "pear")
|
||||
|
||||
def test_ddl_does_not_report_negative_rows(self):
|
||||
"""DDL statements should not cause negative row counts."""
|
||||
args = type("Args", (), {
|
||||
"db": self.db_name,
|
||||
"sql": "CREATE TABLE extra (id INTEGER); INSERT INTO items VALUES ('x')",
|
||||
"allow_destructive": False,
|
||||
"json_output": False,
|
||||
})()
|
||||
localdb.cmd_execute(args)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
rows = conn.execute("SELECT * FROM items").fetchall()
|
||||
conn.close()
|
||||
self.assertEqual(len(rows), 1)
|
||||
|
||||
|
||||
class TestMigrateDestructiveGuard(TestCase):
|
||||
"""Ensure cmd_migrate blocks destructive SQL."""
|
||||
|
||||
def setUp(self):
|
||||
self.db_name = f"test_mig_{os.getpid()}"
|
||||
self.db_path = os.path.join(_tmpdir, f"{self.db_name}.db")
|
||||
localdb.cmd_create_db(type("Args", (), {"name": self.db_name, "force": False})())
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
|
||||
def test_migrate_blocks_drop_table(self):
|
||||
args = type("Args", (), {
|
||||
"db": self.db_name,
|
||||
"sql": "DROP TABLE _migrations",
|
||||
"description": "bad migration",
|
||||
})()
|
||||
with self.assertRaises(SystemExit):
|
||||
localdb.cmd_migrate(args)
|
||||
|
||||
def test_migrate_blocks_delete_from(self):
|
||||
args = type("Args", (), {
|
||||
"db": self.db_name,
|
||||
"sql": "DELETE FROM _migrations",
|
||||
"description": "bad migration",
|
||||
})()
|
||||
with self.assertRaises(SystemExit):
|
||||
localdb.cmd_migrate(args)
|
||||
|
||||
def test_migrate_allows_create_table(self):
|
||||
args = type("Args", (), {
|
||||
"db": self.db_name,
|
||||
"sql": "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
|
||||
"description": "add users table",
|
||||
})()
|
||||
localdb.cmd_migrate(args)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
||||
)
|
||||
self.assertIsNotNone(cursor.fetchone())
|
||||
conn.close()
|
||||
|
||||
|
||||
class TestCmdSchema(TestCase):
|
||||
def setUp(self):
|
||||
self.db_name = "schematest"
|
||||
self.db_path = localdb.get_db_path(self.db_name)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.execute("CREATE TABLE widgets (id INTEGER PRIMARY KEY, name TEXT)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
|
||||
def test_schema_missing_table_exits_nonzero(self):
|
||||
"""schema --table <missing> must exit with a non-zero status code."""
|
||||
args = type("Args", (), {"db": self.db_name, "table": "no_such_table"})()
|
||||
with self.assertRaises(SystemExit) as cm:
|
||||
localdb.cmd_schema(args)
|
||||
self.assertNotEqual(cm.exception.code, 0)
|
||||
|
||||
def test_schema_existing_table_succeeds(self):
|
||||
"""schema --table <existing> must print DDL without raising SystemExit."""
|
||||
args = type("Args", (), {"db": self.db_name, "table": "widgets"})()
|
||||
try:
|
||||
localdb.cmd_schema(args)
|
||||
except SystemExit:
|
||||
self.fail("cmd_schema raised SystemExit for an existing table")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
140
skills/rag-store/SKILL.md
Normal file
140
skills/rag-store/SKILL.md
Normal file
@ -0,0 +1,140 @@
|
||||
---
|
||||
name: rag-store
|
||||
description: "Local RAG (Retrieval-Augmented Generation) with ChromaDB vector store. Ingest books, documents, PDFs, and text files, then answer questions using semantic search. Use when user wants to add knowledge or ask questions about their documents. NOT for: structured data with tables (use local-db), real-time web search, or editing documents."
|
||||
metadata:
|
||||
{
|
||||
"openclaw":
|
||||
{
|
||||
"emoji": "📚",
|
||||
"requires": { "bins": ["python3"] },
|
||||
"install":
|
||||
[
|
||||
{
|
||||
"id": "pip-chromadb",
|
||||
"kind": "pip",
|
||||
"package": "chromadb",
|
||||
"label": "Install ChromaDB (pip)",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
---
|
||||
|
||||
# RAG Store (ChromaDB)
|
||||
|
||||
Manage a local RAG vector store via the bundled `ragstore.py` script. Uses ChromaDB with built-in embeddings (no API key needed). Data stored in `~/.openclaw/ragstore/`.
|
||||
|
||||
## When to use
|
||||
|
||||
✅ **USE this skill when:**
|
||||
|
||||
- User says "add this book/document to my knowledge base"
|
||||
- User asks "what does [book/document] say about X?"
|
||||
- User wants to search across documents semantically
|
||||
- User says "I want to ask questions about this PDF/file"
|
||||
- User mentions RAG, vector search, or knowledge base
|
||||
|
||||
## When NOT to use
|
||||
|
||||
❌ **DON'T use this skill when:**
|
||||
|
||||
- User wants structured data with tables/relationships → use local-db
|
||||
- User wants to store key-value pairs or config → use files/JSON
|
||||
- User needs real-time web search → use browser or search tools
|
||||
- User wants to edit or modify the original document → use file tools
|
||||
- Document is very small (< 100 words) → just read the file directly
|
||||
|
||||
## Commands
|
||||
|
||||
### Add a document to a collection
|
||||
|
||||
```bash
|
||||
# Add a text/markdown file
|
||||
python3 {baseDir}/scripts/ragstore.py add ~/Documents/my-book.txt -c books
|
||||
|
||||
# Add a PDF (requires poppler-utils: apt install poppler-utils)
|
||||
python3 {baseDir}/scripts/ragstore.py add ~/Documents/paper.pdf -c research
|
||||
|
||||
# Custom chunk settings
|
||||
python3 {baseDir}/scripts/ragstore.py add ~/Documents/large-book.txt -c books --chunk-size 300 --overlap 30
|
||||
```
|
||||
|
||||
### Query the knowledge base
|
||||
|
||||
```bash
|
||||
# Ask a question (returns top 5 relevant chunks)
|
||||
python3 {baseDir}/scripts/ragstore.py query "What are the main principles of stoicism?" -c books
|
||||
|
||||
# More results
|
||||
python3 {baseDir}/scripts/ragstore.py query "How to treat lower back pain?" -c medical -k 10
|
||||
|
||||
# JSON output (for programmatic use)
|
||||
python3 {baseDir}/scripts/ragstore.py query "What is the treatment protocol?" -c medical --json
|
||||
```
|
||||
|
||||
### List collections
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/ragstore.py collections
|
||||
```
|
||||
|
||||
### List ingested documents in a collection
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/ragstore.py sources -c books
|
||||
```
|
||||
|
||||
### Remove a document from a collection
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/ragstore.py remove-source "my-book.txt" -c books
|
||||
```
|
||||
|
||||
### Delete a collection entirely
|
||||
|
||||
```bash
|
||||
python3 {baseDir}/scripts/ragstore.py delete-collection old-collection
|
||||
```
|
||||
|
||||
## Supported file formats
|
||||
|
||||
| Format | Extension | Requirement |
|
||||
| ---------- | --------------------------- | ------------------------------------------- |
|
||||
| Plain text | .txt, .md, .rst, .csv, .log | Built-in |
|
||||
| PDF | .pdf | `poppler-utils` (apt install poppler-utils) |
|
||||
| EPUB | .epub | `calibre` (apt install calibre) |
|
||||
| JSON | .json | Built-in |
|
||||
|
||||
## How it works
|
||||
|
||||
1. **Ingest**: Document is split into overlapping chunks (default: 500 words, 50 overlap)
|
||||
2. **Embed**: ChromaDB generates embeddings using its default model (all-MiniLM-L6-v2, runs locally)
|
||||
3. **Store**: Chunks + metadata stored in persistent ChromaDB at `~/.openclaw/ragstore/`
|
||||
4. **Query**: Question is embedded and compared against stored chunks using cosine similarity
|
||||
5. **Return**: Top-K most relevant chunks returned with source info and distance score
|
||||
|
||||
## Workflow for answering user questions
|
||||
|
||||
When the user asks about document content:
|
||||
|
||||
1. Use `python3 {baseDir}/scripts/ragstore.py query "<question>" -c <collection> --json` to get relevant chunks
|
||||
2. Read the returned chunks
|
||||
3. Synthesize an answer using the chunk content as context
|
||||
4. Cite the source document and chunk index
|
||||
|
||||
## Collections pattern
|
||||
|
||||
Organize knowledge by topic:
|
||||
|
||||
- `books` — General reading
|
||||
- `medical` — Medical literature
|
||||
- `work` — Work-related documents
|
||||
- `personal` — Personal notes/documents
|
||||
- `research` — Academic papers
|
||||
|
||||
## Tips
|
||||
|
||||
- Smaller chunk sizes (200-300) work better for precise Q&A
|
||||
- Larger chunk sizes (500-800) work better for summaries
|
||||
- The first query in a new session may be slow (model loading)
|
||||
- ChromaDB uses ~200MB for the default embedding model (downloaded on first use)
|
||||
365
skills/rag-store/scripts/ragstore.py
Normal file
365
skills/rag-store/scripts/ragstore.py
Normal file
@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ragstore — CLI for managing a local RAG (Retrieval-Augmented Generation) store.
|
||||
OpenClaw skill: lets the agent ingest documents (books, PDFs, text files)
|
||||
into a ChromaDB vector store and query them semantically.
|
||||
"""
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
STORE_DIR = os.path.expanduser("~/.openclaw/ragstore")
|
||||
|
||||
|
||||
def get_client(collection_name="default"):
|
||||
try:
|
||||
import chromadb
|
||||
except ImportError:
|
||||
print("Error: chromadb not installed. Run: pip3 install chromadb", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
os.makedirs(STORE_DIR, exist_ok=True)
|
||||
client = chromadb.PersistentClient(path=STORE_DIR)
|
||||
return client
|
||||
|
||||
|
||||
def chunk_text(text, chunk_size=500, overlap=50):
|
||||
"""Split text into overlapping chunks by words."""
|
||||
if overlap < 0:
|
||||
print(f"Error: overlap ({overlap}) must be non-negative.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if overlap >= chunk_size:
|
||||
print(f"Error: overlap ({overlap}) must be less than chunk_size ({chunk_size}).", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
words = text.split()
|
||||
if not words:
|
||||
return []
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(words):
|
||||
end = start + chunk_size
|
||||
chunk = " ".join(words[start:end])
|
||||
chunks.append(chunk)
|
||||
if end >= len(words):
|
||||
break
|
||||
start = end - overlap
|
||||
return chunks
|
||||
|
||||
|
||||
def extract_text_from_file(filepath):
|
||||
"""Extract text from various file formats."""
|
||||
ext = os.path.splitext(filepath)[1].lower()
|
||||
|
||||
if ext == ".pdf":
|
||||
try:
|
||||
import subprocess # nosec B404
|
||||
result = subprocess.run( # nosec B603 B607
|
||||
["pdftotext", filepath, "-"],
|
||||
capture_output=True, text=True, timeout=60
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
print(f"Warning: pdftotext not available. Install poppler-utils.", file=sys.stderr)
|
||||
return None
|
||||
|
||||
elif ext in (".txt", ".md", ".rst", ".csv", ".log", ".json"):
|
||||
with open(filepath, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
|
||||
elif ext in (".epub",):
|
||||
try:
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
result = subprocess.run( # nosec B603 B607
|
||||
["ebook-convert", filepath, tmp_path],
|
||||
capture_output=True, text=True, timeout=120
|
||||
)
|
||||
if result.returncode == 0:
|
||||
with open(tmp_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
text = f.read()
|
||||
os.unlink(tmp_path)
|
||||
return text
|
||||
os.unlink(tmp_path)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
if 'tmp_path' in locals():
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
print(f"Warning: ebook-convert not available. Install calibre.", file=sys.stderr)
|
||||
return None
|
||||
|
||||
else:
|
||||
# Try as plain text
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
print(f"Cannot read file: {filepath}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def validate_collection_name(name):
|
||||
"""Validate collection name to prevent path traversal or injection."""
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', name):
|
||||
print(f"Error: invalid collection name '{name}'. Use only alphanumeric, _ and -.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return name
|
||||
|
||||
|
||||
def cmd_collections(args):
|
||||
client = get_client()
|
||||
cols = client.list_collections()
|
||||
if not cols:
|
||||
print("No collections. Create one with: ragstore add -c <name> <file>")
|
||||
return
|
||||
for c in cols:
|
||||
col = client.get_collection(c.name)
|
||||
count = col.count()
|
||||
print(f" {c.name} ({count} chunks)")
|
||||
|
||||
|
||||
def cmd_add(args):
|
||||
validate_collection_name(args.collection)
|
||||
client = get_client()
|
||||
collection = client.get_or_create_collection(
|
||||
name=args.collection,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
filepath = os.path.expanduser(args.file)
|
||||
if not os.path.exists(filepath):
|
||||
print(f"File not found: {filepath}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
text = extract_text_from_file(filepath)
|
||||
if not text or len(text.strip()) == 0:
|
||||
print(f"No text extracted from: {filepath}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
filename = os.path.basename(filepath)
|
||||
chunks = chunk_text(text, chunk_size=args.chunk_size, overlap=args.overlap)
|
||||
|
||||
print(f"Ingesting '{filename}' into collection '{args.collection}'...")
|
||||
print(f" Text length: {len(text)} chars")
|
||||
print(f" Chunks: {len(chunks)}")
|
||||
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
resolved_path = os.path.realpath(filepath)
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc_id = hashlib.md5(f"{resolved_path}:{i}".encode(), usedforsecurity=False).hexdigest() # nosec B324
|
||||
ids.append(doc_id)
|
||||
documents.append(chunk)
|
||||
metadatas.append({
|
||||
"source": filename,
|
||||
"source_path": resolved_path,
|
||||
"chunk_index": i,
|
||||
"total_chunks": len(chunks),
|
||||
})
|
||||
|
||||
# Add in batches of 100
|
||||
batch_size = 100
|
||||
for start in range(0, len(ids), batch_size):
|
||||
end = min(start + batch_size, len(ids))
|
||||
collection.upsert(
|
||||
ids=ids[start:end],
|
||||
documents=documents[start:end],
|
||||
metadatas=metadatas[start:end],
|
||||
)
|
||||
|
||||
# Remove stale chunks from previous ingestion of the same file
|
||||
# (e.g. file was shortened and now has fewer chunks)
|
||||
stale_ids = []
|
||||
idx = len(chunks)
|
||||
while True:
|
||||
old_id = hashlib.md5(f"{resolved_path}:{idx}".encode(), usedforsecurity=False).hexdigest() # nosec B324
|
||||
existing = collection.get(ids=[old_id])
|
||||
if not existing["ids"]:
|
||||
break
|
||||
stale_ids.append(old_id)
|
||||
idx += 1
|
||||
if stale_ids:
|
||||
collection.delete(ids=stale_ids)
|
||||
print(f" Removed {len(stale_ids)} stale chunks from previous ingestion.")
|
||||
|
||||
print(f"Done. Total chunks in collection: {collection.count()}")
|
||||
|
||||
|
||||
def cmd_query(args):
|
||||
validate_collection_name(args.collection)
|
||||
client = get_client()
|
||||
try:
|
||||
collection = client.get_collection(args.collection)
|
||||
except Exception:
|
||||
print(f"Collection '{args.collection}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
results = collection.query(
|
||||
query_texts=[args.question],
|
||||
n_results=args.top_k,
|
||||
)
|
||||
|
||||
if not results["documents"] or not results["documents"][0]:
|
||||
print("No results found.")
|
||||
return
|
||||
|
||||
if args.json_output:
|
||||
output = []
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
||||
dist = results["distances"][0][i] if results["distances"] else None
|
||||
output.append({
|
||||
"rank": i + 1,
|
||||
"source": meta.get("source", "unknown"),
|
||||
"chunk_index": meta.get("chunk_index"),
|
||||
"distance": dist,
|
||||
"text": doc,
|
||||
})
|
||||
print(json.dumps(output, indent=2))
|
||||
else:
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
||||
dist = results["distances"][0][i] if results["distances"] else None
|
||||
source = meta.get("source", "unknown")
|
||||
chunk_idx = meta.get("chunk_index", "?")
|
||||
score = f" (distance: {dist:.4f})" if dist is not None else ""
|
||||
print(f"\n--- Result {i+1} [{source} chunk {chunk_idx}]{score} ---")
|
||||
print(doc[:500])
|
||||
if len(doc) > 500:
|
||||
print(f" ... ({len(doc)} chars total)")
|
||||
|
||||
|
||||
def cmd_sources(args):
|
||||
validate_collection_name(args.collection)
|
||||
client = get_client()
|
||||
try:
|
||||
collection = client.get_collection(args.collection)
|
||||
except Exception:
|
||||
print(f"Collection '{args.collection}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
sources = {}
|
||||
for meta in all_data["metadatas"]:
|
||||
key = meta.get("source_path") or meta.get("source", "unknown")
|
||||
total = meta.get("total_chunks", 1)
|
||||
sources[key] = total
|
||||
|
||||
if not sources:
|
||||
print("No documents in this collection.")
|
||||
return
|
||||
|
||||
for src, chunks in sorted(sources.items()):
|
||||
print(f" {src} ({chunks} chunks)")
|
||||
|
||||
|
||||
def cmd_delete_collection(args):
|
||||
validate_collection_name(args.collection)
|
||||
client = get_client()
|
||||
try:
|
||||
client.delete_collection(args.collection)
|
||||
print(f"Collection '{args.collection}' deleted.")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_remove_source(args):
|
||||
validate_collection_name(args.collection)
|
||||
client = get_client()
|
||||
try:
|
||||
collection = client.get_collection(args.collection)
|
||||
except Exception:
|
||||
print(f"Collection '{args.collection}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
source_basename = os.path.basename(args.source)
|
||||
resolved_source = os.path.realpath(args.source)
|
||||
# Basename fallback only when user passed a bare filename (no directory part),
|
||||
# to avoid removing chunks from different files with the same basename.
|
||||
is_bare_name = os.sep not in args.source and '/' not in args.source
|
||||
ids_to_remove = []
|
||||
for i, meta in enumerate(all_data["metadatas"]):
|
||||
stored_path = meta.get("source_path", "")
|
||||
stored_name = meta.get("source", "")
|
||||
if (
|
||||
stored_path == resolved_source
|
||||
or stored_name == args.source
|
||||
or (is_bare_name and stored_name == source_basename)
|
||||
):
|
||||
ids_to_remove.append(all_data["ids"][i])
|
||||
|
||||
if not ids_to_remove:
|
||||
print(f"No chunks from source '{args.source}' found.")
|
||||
return
|
||||
|
||||
collection.delete(ids=ids_to_remove)
|
||||
print(f"Removed {len(ids_to_remove)} chunks from source '{args.source}'.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="ragstore",
|
||||
description="Local RAG vector store for OpenClaw (ChromaDB)"
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command")
|
||||
|
||||
# collections
|
||||
sub.add_parser("collections", help="List all collections")
|
||||
|
||||
# add
|
||||
p = sub.add_parser("add", help="Ingest a document into a collection")
|
||||
p.add_argument("file", help="Path to document (txt, md, pdf, epub, etc.)")
|
||||
p.add_argument("-c", "--collection", default="default", help="Collection name (default: 'default')")
|
||||
p.add_argument("--chunk-size", type=int, default=500, help="Words per chunk (default: 500)")
|
||||
p.add_argument("--overlap", type=int, default=50, help="Overlap words between chunks (default: 50)")
|
||||
|
||||
# query
|
||||
p = sub.add_parser("query", help="Semantic search in a collection")
|
||||
p.add_argument("question", help="Natural language question")
|
||||
p.add_argument("-c", "--collection", default="default", help="Collection name")
|
||||
p.add_argument("-k", "--top-k", type=int, default=5, help="Number of results (default: 5)")
|
||||
p.add_argument("--json", dest="json_output", action="store_true", help="Output as JSON")
|
||||
|
||||
# sources
|
||||
p = sub.add_parser("sources", help="List ingested documents in a collection")
|
||||
p.add_argument("-c", "--collection", default="default", help="Collection name")
|
||||
|
||||
# delete-collection
|
||||
p = sub.add_parser("delete-collection", help="Delete a collection")
|
||||
p.add_argument("collection", help="Collection name to delete")
|
||||
|
||||
# remove-source
|
||||
p = sub.add_parser("remove-source", help="Remove a document source from a collection")
|
||||
p.add_argument("source", help="Source filename to remove")
|
||||
p.add_argument("-c", "--collection", default="default", help="Collection name")
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
cmds = {
|
||||
"collections": cmd_collections,
|
||||
"add": cmd_add,
|
||||
"query": cmd_query,
|
||||
"sources": cmd_sources,
|
||||
"delete-collection": cmd_delete_collection,
|
||||
"remove-source": cmd_remove_source,
|
||||
}
|
||||
cmds[args.command](args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
83
skills/rag-store/scripts/test_ragstore.py
Normal file
83
skills/rag-store/scripts/test_ragstore.py
Normal file
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for ragstore helpers.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from unittest import TestCase, main
|
||||
|
||||
from ragstore import chunk_text, validate_collection_name
|
||||
|
||||
|
||||
class TestChunkText(TestCase):
|
||||
def test_basic_chunking(self):
|
||||
text = " ".join(f"word{i}" for i in range(20))
|
||||
chunks = chunk_text(text, chunk_size=10, overlap=2)
|
||||
self.assertEqual(len(chunks), 3)
|
||||
|
||||
def test_empty_text(self):
|
||||
self.assertEqual(chunk_text("", chunk_size=10, overlap=2), [])
|
||||
|
||||
def test_single_chunk(self):
|
||||
text = "hello world"
|
||||
chunks = chunk_text(text, chunk_size=100, overlap=10)
|
||||
self.assertEqual(len(chunks), 1)
|
||||
self.assertEqual(chunks[0], "hello world")
|
||||
|
||||
def test_overlap_too_large_exits(self):
|
||||
with self.assertRaises(SystemExit):
|
||||
chunk_text("some text here", chunk_size=5, overlap=5)
|
||||
|
||||
def test_overlap_larger_than_chunk_exits(self):
|
||||
with self.assertRaises(SystemExit):
|
||||
chunk_text("some text here", chunk_size=5, overlap=10)
|
||||
|
||||
def test_negative_overlap_exits(self):
|
||||
with self.assertRaises(SystemExit):
|
||||
chunk_text("some text here", chunk_size=5, overlap=-1)
|
||||
|
||||
def test_exact_chunk_size_no_trailing_duplicate(self):
|
||||
"""Text with exactly chunk_size words must produce exactly 1 chunk."""
|
||||
text = " ".join(f"w{i}" for i in range(10))
|
||||
chunks = chunk_text(text, chunk_size=10, overlap=2)
|
||||
self.assertEqual(len(chunks), 1)
|
||||
|
||||
def test_stride_boundary_no_trailing_duplicate(self):
|
||||
"""Text whose length falls exactly on a stride boundary must not emit a duplicate tail chunk."""
|
||||
# chunk_size=5, overlap=1 -> stride=4; 8 words -> 2 full chunks exactly
|
||||
text = " ".join(f"w{i}" for i in range(8))
|
||||
chunks = chunk_text(text, chunk_size=5, overlap=1)
|
||||
self.assertEqual(len(chunks), 2)
|
||||
|
||||
|
||||
class TestValidateCollectionName(TestCase):
|
||||
def test_valid_names(self):
|
||||
for name in ["books", "my-docs", "work_notes", "A123"]:
|
||||
self.assertEqual(validate_collection_name(name), name)
|
||||
|
||||
def test_invalid_names_exit(self):
|
||||
for name in ["../etc", "my docs", "col;drop", "a/b"]:
|
||||
with self.assertRaises(SystemExit):
|
||||
validate_collection_name(name)
|
||||
|
||||
|
||||
class TestChunkIdUniqueness(TestCase):
|
||||
"""Ensure files with same basename but different paths produce different IDs."""
|
||||
|
||||
def test_different_paths_produce_different_ids(self):
|
||||
path_a = os.path.realpath("/home/user/books/notes.txt")
|
||||
path_b = os.path.realpath("/home/user/work/notes.txt")
|
||||
id_a = hashlib.md5(f"{path_a}:0".encode()).hexdigest()
|
||||
id_b = hashlib.md5(f"{path_b}:0".encode()).hexdigest()
|
||||
self.assertNotEqual(id_a, id_b)
|
||||
|
||||
def test_same_path_produces_same_id(self):
|
||||
path = os.path.realpath("/home/user/docs/readme.md")
|
||||
id1 = hashlib.md5(f"{path}:0".encode()).hexdigest()
|
||||
id2 = hashlib.md5(f"{path}:0".encode()).hexdigest()
|
||||
self.assertEqual(id1, id2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user