Add routing benchmark scripts; gitignore dataset and results
- run_benchmark.py: sends queries to /message, extracts tier= from docker logs, reports per-tier accuracy, saves results_latest.json - run_voice_benchmark.py: voice path benchmark - .gitignore: ignore benchmark.json (dataset) and results_latest.json (runtime output); benchmark scripts are tracked, data files are not Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,2 +1,6 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
logs/*.jsonl
|
||||
adolf_tuning_data/voice_audio/
|
||||
benchmark.json
|
||||
results_latest.json
|
||||
|
||||
318
run_benchmark.py
Normal file
318
run_benchmark.py
Normal file
@@ -0,0 +1,318 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Adolf routing benchmark.
|
||||
|
||||
Sends each query to Adolf's /message endpoint, waits briefly for the routing
|
||||
decision to appear in docker logs, then records the actual tier.
|
||||
|
||||
Usage:
|
||||
python3 run_benchmark.py [options]
|
||||
python3 run_benchmark.py --tier light|medium|complex
|
||||
python3 run_benchmark.py --category <name>
|
||||
python3 run_benchmark.py --ids 1,2,3
|
||||
python3 run_benchmark.py --list-categories
|
||||
python3 run_benchmark.py --dry-run # complex queries use medium model (no API cost)
|
||||
|
||||
IMPORTANT: Always check GPU is free before running. This script does it automatically.
|
||||
|
||||
Adolf must be running at http://localhost:8000.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
ADOLF_URL = "http://localhost:8000"
|
||||
OLLAMA_URL = "http://localhost:11436" # GPU Ollama
|
||||
DATASET = Path(__file__).parent / "benchmark.json"
|
||||
RESULTS = Path(__file__).parent / "results_latest.json"
|
||||
|
||||
# Max time to wait for each query to fully complete via SSE stream
|
||||
QUERY_TIMEOUT = 300 # seconds — generous to handle GPU semaphore waits
|
||||
|
||||
# Memory thresholds
|
||||
MIN_FREE_RAM_MB = 1500 # abort if less than this is free
|
||||
MIN_FREE_VRAM_MB = 500 # warn if less than this is free on GPU
|
||||
|
||||
|
||||
# ── Pre-flight checks ──────────────────────────────────────────────────────────
|
||||
|
||||
def check_ram() -> tuple[bool, str]:
|
||||
"""Check available system RAM. Returns (ok, message)."""
|
||||
try:
|
||||
with open("/proc/meminfo") as f:
|
||||
info = {}
|
||||
for line in f:
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
info[parts[0].rstrip(":")] = int(parts[1])
|
||||
free_mb = (info.get("MemAvailable", 0)) // 1024
|
||||
total_mb = info.get("MemTotal", 0) // 1024
|
||||
msg = f"RAM: {free_mb} MB free / {total_mb} MB total"
|
||||
if free_mb < MIN_FREE_RAM_MB:
|
||||
return False, f"CRITICAL: {msg} — need at least {MIN_FREE_RAM_MB} MB free"
|
||||
return True, msg
|
||||
except Exception as e:
|
||||
return True, f"RAM check failed (non-fatal): {e}"
|
||||
|
||||
|
||||
def check_gpu() -> tuple[bool, str]:
|
||||
"""Check GPU VRAM via Ollama /api/ps. Returns (ok, message)."""
|
||||
try:
|
||||
r = httpx.get(f"{OLLAMA_URL}/api/ps", timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = data.get("models", [])
|
||||
if models:
|
||||
names = [m.get("name", "?") for m in models]
|
||||
sizes_mb = [m.get("size_vram", 0) // (1024 * 1024) for m in models]
|
||||
loaded = ", ".join(f"{n} ({s}MB)" for n, s in zip(names, sizes_mb))
|
||||
total_vram = sum(sizes_mb)
|
||||
if total_vram > 7000:
|
||||
return False, f"GPU BUSY: models loaded = {loaded} — total VRAM used {total_vram}MB. Wait for models to unload."
|
||||
return True, f"GPU: models loaded = {loaded} (total {total_vram}MB VRAM)"
|
||||
return True, "GPU: idle (no models loaded)"
|
||||
except httpx.ConnectError:
|
||||
return True, "GPU check skipped (Ollama not reachable at localhost:11436)"
|
||||
except Exception as e:
|
||||
return True, f"GPU check failed (non-fatal): {e}"
|
||||
|
||||
|
||||
def preflight_checks(skip_gpu_check: bool = False) -> bool:
|
||||
"""Run all pre-flight checks. Returns True if safe to proceed."""
|
||||
print("\n── Pre-flight checks ──────────────────────────────────────────")
|
||||
|
||||
ram_ok, ram_msg = check_ram()
|
||||
print(f" {'✓' if ram_ok else '✗'} {ram_msg}")
|
||||
if not ram_ok:
|
||||
print("\nABORTING: not enough RAM. Free up memory before running benchmark.")
|
||||
return False
|
||||
|
||||
if not skip_gpu_check:
|
||||
gpu_ok, gpu_msg = check_gpu()
|
||||
print(f" {'✓' if gpu_ok else '✗'} {gpu_msg}")
|
||||
if not gpu_ok:
|
||||
print("\nABORTING: GPU is busy. Wait for current inference to finish, then retry.")
|
||||
return False
|
||||
|
||||
print(" All checks passed.\n")
|
||||
return True
|
||||
|
||||
|
||||
# ── Log helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
def get_log_tail(n: int = 50) -> str:
|
||||
result = subprocess.run(
|
||||
["docker", "logs", "deepagents", "--tail", str(n)],
|
||||
capture_output=True, text=True,
|
||||
)
|
||||
return result.stdout + result.stderr
|
||||
|
||||
|
||||
def extract_tier_from_logs(logs_before: str, logs_after: str) -> str | None:
|
||||
"""Find new tier= lines that appeared after we sent the query."""
|
||||
before_lines = set(logs_before.splitlines())
|
||||
new_lines = [l for l in logs_after.splitlines() if l not in before_lines]
|
||||
for line in reversed(new_lines):
|
||||
m = re.search(r"tier=(\w+(?:\s*\(dry-run\))?)", line)
|
||||
if m:
|
||||
tier_raw = m.group(1)
|
||||
# Normalise: "complex (dry-run)" → "complex"
|
||||
return tier_raw.split()[0]
|
||||
return None
|
||||
|
||||
|
||||
# ── Request helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
async def post_message(
|
||||
client: httpx.AsyncClient,
|
||||
query_id: int,
|
||||
query: str,
|
||||
dry_run: bool = False,
|
||||
) -> bool:
|
||||
payload = {
|
||||
"text": query,
|
||||
"session_id": f"benchmark-{query_id}",
|
||||
"channel": "cli",
|
||||
"user_id": "benchmark",
|
||||
"metadata": {"dry_run": dry_run, "benchmark": True},
|
||||
}
|
||||
try:
|
||||
r = await client.post(f"{ADOLF_URL}/message", json=payload, timeout=10)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f" POST_ERROR: {e}", end="")
|
||||
return False
|
||||
|
||||
|
||||
# ── Dataset ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def load_dataset() -> list[dict]:
|
||||
with open(DATASET) as f:
|
||||
return json.load(f)["queries"]
|
||||
|
||||
|
||||
def filter_queries(queries, tier, category, ids):
|
||||
if tier:
|
||||
queries = [q for q in queries if q["tier"] == tier]
|
||||
if category:
|
||||
queries = [q for q in queries if q["category"] == category]
|
||||
if ids:
|
||||
queries = [q for q in queries if q["id"] in ids]
|
||||
return queries
|
||||
|
||||
|
||||
# ── Main run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
async def run(queries: list[dict], dry_run: bool = False) -> list[dict]:
|
||||
results = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
r = await client.get(f"{ADOLF_URL}/health", timeout=5)
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"ERROR: Adolf not reachable: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
total = len(queries)
|
||||
correct = 0
|
||||
|
||||
dry_label = " [DRY-RUN: complex→medium]" if dry_run else ""
|
||||
print(f"\nRunning {total} queries{dry_label}\n")
|
||||
print(f"{'ID':>3} {'EXPECTED':8} {'ACTUAL':8} {'OK':3} {'TIME':6} {'CATEGORY':22} QUERY")
|
||||
print("─" * 110)
|
||||
|
||||
for q in queries:
|
||||
qid = q["id"]
|
||||
expected = q["tier"]
|
||||
category = q["category"]
|
||||
query_text = q["query"]
|
||||
|
||||
# In dry-run, complex queries still use complex classification (logged), but medium infers
|
||||
send_dry = dry_run and expected == "complex"
|
||||
session_id = f"benchmark-{qid}"
|
||||
|
||||
print(f"{qid:>3} {expected:8} ", end="", flush=True)
|
||||
|
||||
logs_before = get_log_tail(80)
|
||||
t0 = time.monotonic()
|
||||
|
||||
ok_post = await post_message(client, qid, query_text, dry_run=send_dry)
|
||||
if not ok_post:
|
||||
print(f"{'?':8} {'ERR':3} {'?':6} {category:22} {query_text[:40]}")
|
||||
results.append({"id": qid, "expected": expected, "actual": None, "ok": False})
|
||||
continue
|
||||
|
||||
# Wait for query to complete via SSE stream (handles GPU semaphore waits)
|
||||
try:
|
||||
async with client.stream(
|
||||
"GET", f"{ADOLF_URL}/stream/{session_id}", timeout=QUERY_TIMEOUT
|
||||
) as sse:
|
||||
async for line in sse.aiter_lines():
|
||||
if "data: [DONE]" in line:
|
||||
break
|
||||
except Exception:
|
||||
pass # timeout or connection issue — check logs anyway
|
||||
|
||||
# Now the query is done — check logs for tier
|
||||
await asyncio.sleep(0.3)
|
||||
logs_after = get_log_tail(80)
|
||||
actual = extract_tier_from_logs(logs_before, logs_after)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
match = actual == expected or (actual == "fast" and expected == "medium")
|
||||
if match:
|
||||
correct += 1
|
||||
|
||||
mark = "✓" if match else "✗"
|
||||
actual_str = actual or "?"
|
||||
print(f"{actual_str:8} {mark:3} {elapsed:5.1f}s {category:22} {query_text[:40]}")
|
||||
|
||||
results.append({
|
||||
"id": qid,
|
||||
"expected": expected,
|
||||
"actual": actual_str,
|
||||
"ok": match,
|
||||
"elapsed": round(elapsed, 1),
|
||||
"category": category,
|
||||
"query": query_text,
|
||||
"dry_run": send_dry,
|
||||
})
|
||||
|
||||
print("─" * 110)
|
||||
accuracy = correct / total * 100 if total else 0
|
||||
print(f"\nAccuracy: {correct}/{total} ({accuracy:.0f}%)")
|
||||
|
||||
for tier_name in ["light", "medium", "complex"]:
|
||||
tier_qs = [r for r in results if r["expected"] == tier_name]
|
||||
if tier_qs:
|
||||
tier_ok = sum(1 for r in tier_qs if r["ok"])
|
||||
print(f" {tier_name:8}: {tier_ok}/{len(tier_qs)}")
|
||||
|
||||
wrong = [r for r in results if not r["ok"]]
|
||||
if wrong:
|
||||
print(f"\nMisclassified ({len(wrong)}):")
|
||||
for r in wrong:
|
||||
print(f" id={r['id']:3} expected={r['expected']:8} actual={r['actual']:8} {r['query'][:60]}")
|
||||
|
||||
with open(RESULTS, "w") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"\nResults saved to {RESULTS}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Adolf routing benchmark",
|
||||
epilog="IMPORTANT: Always check GPU is free before running. This is done automatically."
|
||||
)
|
||||
parser.add_argument("--tier", choices=["light", "medium", "complex"])
|
||||
parser.add_argument("--category")
|
||||
parser.add_argument("--ids", help="Comma-separated IDs")
|
||||
parser.add_argument("--list-categories", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="For complex queries: route classification is tested but medium model is used for inference (no API cost)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-gpu-check",
|
||||
action="store_true",
|
||||
help="Skip GPU availability check (use only if you know GPU is free)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
queries = load_dataset()
|
||||
|
||||
if args.list_categories:
|
||||
cats = sorted(set(q["category"] for q in queries))
|
||||
tiers = {t: sum(1 for q in queries if q["tier"] == t) for t in ["light", "medium", "complex"]}
|
||||
print(f"Total: {len(queries)} | Tiers: {tiers}")
|
||||
print(f"Categories: {cats}")
|
||||
return
|
||||
|
||||
# ALWAYS check GPU and RAM before running
|
||||
if not preflight_checks(skip_gpu_check=args.skip_gpu_check):
|
||||
sys.exit(1)
|
||||
|
||||
ids = [int(i) for i in args.ids.split(",")] if args.ids else None
|
||||
queries = filter_queries(queries, args.tier, args.category, ids)
|
||||
if not queries:
|
||||
print("No queries match filters.")
|
||||
sys.exit(1)
|
||||
|
||||
asyncio.run(run(queries, dry_run=args.dry_run))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
426
run_voice_benchmark.py
Normal file
426
run_voice_benchmark.py
Normal file
@@ -0,0 +1,426 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Adolf voice benchmark.
|
||||
|
||||
Pipeline for each query:
|
||||
1. Synthesize query text → WAV via Silero TTS (localhost:8881)
|
||||
2. Transcribe WAV → text via faster-whisper STT (localhost:8880)
|
||||
3. Send transcription to Adolf → check routing tier
|
||||
4. Report: WER per query, routing accuracy vs text baseline
|
||||
|
||||
Usage:
|
||||
python3 run_voice_benchmark.py [options]
|
||||
python3 run_voice_benchmark.py --tier light|medium|complex
|
||||
python3 run_voice_benchmark.py --ids 1,2,3
|
||||
python3 run_voice_benchmark.py --dry-run # complex queries use medium model
|
||||
|
||||
IMPORTANT: Always check GPU is free before running. Done automatically.
|
||||
|
||||
Services required:
|
||||
- Adolf: http://localhost:8000
|
||||
- Silero TTS: http://localhost:8881 (openai/silero-tts container)
|
||||
- faster-whisper: http://localhost:8880 (faster-whisper container)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unicodedata
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
ADOLF_URL = "http://localhost:8000"
|
||||
OLLAMA_URL = "http://localhost:11436"
|
||||
TTS_URL = "http://localhost:8881" # Silero TTS — OpenAI-compatible /v1/audio/speech
|
||||
STT_URL = "http://localhost:8880" # faster-whisper — OpenAI-compatible /v1/audio/transcriptions
|
||||
|
||||
DATASET = Path(__file__).parent / "benchmark.json"
|
||||
RESULTS_DIR = Path(__file__).parent
|
||||
|
||||
TIER_WAIT = 15 # seconds to wait for tier= in docker logs
|
||||
MIN_FREE_RAM_MB = 1500
|
||||
MIN_FREE_VRAM_MB = 500
|
||||
|
||||
|
||||
# ── Pre-flight ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def check_ram() -> tuple[bool, str]:
|
||||
try:
|
||||
with open("/proc/meminfo") as f:
|
||||
info = {}
|
||||
for line in f:
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
info[parts[0].rstrip(":")] = int(parts[1])
|
||||
free_mb = info.get("MemAvailable", 0) // 1024
|
||||
total_mb = info.get("MemTotal", 0) // 1024
|
||||
msg = f"RAM: {free_mb} MB free / {total_mb} MB total"
|
||||
if free_mb < MIN_FREE_RAM_MB:
|
||||
return False, f"CRITICAL: {msg} — need at least {MIN_FREE_RAM_MB} MB free"
|
||||
return True, msg
|
||||
except Exception as e:
|
||||
return True, f"RAM check failed (non-fatal): {e}"
|
||||
|
||||
|
||||
def check_gpu() -> tuple[bool, str]:
|
||||
try:
|
||||
r = httpx.get(f"{OLLAMA_URL}/api/ps", timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = data.get("models", [])
|
||||
if models:
|
||||
names = [m.get("name", "?") for m in models]
|
||||
sizes_mb = [m.get("size_vram", 0) // (1024 * 1024) for m in models]
|
||||
loaded = ", ".join(f"{n} ({s}MB)" for n, s in zip(names, sizes_mb))
|
||||
total_vram = sum(sizes_mb)
|
||||
if total_vram > 7000:
|
||||
return False, f"GPU BUSY: {loaded} — {total_vram}MB VRAM used. Wait for models to unload."
|
||||
return True, f"GPU: {loaded} ({total_vram}MB VRAM)"
|
||||
return True, "GPU: idle"
|
||||
except httpx.ConnectError:
|
||||
return True, "GPU check skipped (Ollama not reachable)"
|
||||
except Exception as e:
|
||||
return True, f"GPU check failed (non-fatal): {e}"
|
||||
|
||||
|
||||
def check_services() -> tuple[bool, str]:
|
||||
"""Check TTS and STT are reachable."""
|
||||
msgs = []
|
||||
ok = True
|
||||
for name, url, path in [("TTS", TTS_URL, "/"), ("STT", STT_URL, "/")]:
|
||||
try:
|
||||
r = httpx.get(url + path, timeout=5)
|
||||
msgs.append(f"{name}: reachable (HTTP {r.status_code})")
|
||||
except Exception as e:
|
||||
msgs.append(f"{name}: NOT REACHABLE — {e}")
|
||||
ok = False
|
||||
return ok, " | ".join(msgs)
|
||||
|
||||
|
||||
def preflight_checks(skip_gpu_check: bool = False) -> bool:
|
||||
print("\n── Pre-flight checks ──────────────────────────────────────────")
|
||||
ram_ok, ram_msg = check_ram()
|
||||
print(f" {'✓' if ram_ok else '✗'} {ram_msg}")
|
||||
if not ram_ok:
|
||||
print("\nABORTING: not enough RAM.")
|
||||
return False
|
||||
|
||||
if not skip_gpu_check:
|
||||
gpu_ok, gpu_msg = check_gpu()
|
||||
print(f" {'✓' if gpu_ok else '✗'} {gpu_msg}")
|
||||
if not gpu_ok:
|
||||
print("\nABORTING: GPU is busy.")
|
||||
return False
|
||||
|
||||
svc_ok, svc_msg = check_services()
|
||||
print(f" {'✓' if svc_ok else '✗'} {svc_msg}")
|
||||
if not svc_ok:
|
||||
print("\nABORTING: required voice services not running.")
|
||||
print("Start them with: cd /home/alvis/agap_git/openai && docker compose up -d faster-whisper silero-tts")
|
||||
return False
|
||||
|
||||
print(" All checks passed.\n")
|
||||
return True
|
||||
|
||||
|
||||
# ── TTS ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def synthesize(client: httpx.AsyncClient, text: str) -> bytes | None:
|
||||
"""Synthesize text to WAV via Silero TTS (OpenAI-compatible /v1/audio/speech)."""
|
||||
try:
|
||||
r = await client.post(
|
||||
f"{TTS_URL}/v1/audio/speech",
|
||||
json={"model": "tts-1", "input": text, "voice": "alloy", "response_format": "wav"},
|
||||
timeout=30,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.content
|
||||
except Exception as e:
|
||||
print(f"\n [TTS error: {e}]", end="")
|
||||
return None
|
||||
|
||||
|
||||
# ── STT ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def transcribe(client: httpx.AsyncClient, wav_bytes: bytes) -> str | None:
|
||||
"""Transcribe WAV to text via faster-whisper (OpenAI-compatible /v1/audio/transcriptions)."""
|
||||
try:
|
||||
files = {"file": ("audio.wav", wav_bytes, "audio/wav")}
|
||||
data = {"model": "whisper-1", "language": "ru", "response_format": "json"}
|
||||
r = await client.post(
|
||||
f"{STT_URL}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=60,
|
||||
)
|
||||
r.raise_for_status()
|
||||
result = r.json()
|
||||
return result.get("text", "").strip()
|
||||
except Exception as e:
|
||||
print(f"\n [STT error: {e}]", end="")
|
||||
return None
|
||||
|
||||
|
||||
# ── WER ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def normalize(text: str) -> str:
|
||||
"""Lowercase, strip punctuation, normalize unicode for WER calculation."""
|
||||
text = unicodedata.normalize("NFC", text.lower())
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
|
||||
def word_error_rate(reference: str, hypothesis: str) -> float:
|
||||
"""Compute WER between reference and hypothesis."""
|
||||
ref = normalize(reference).split()
|
||||
hyp = normalize(hypothesis).split()
|
||||
if not ref:
|
||||
return 0.0 if not hyp else 1.0
|
||||
# Dynamic programming edit distance
|
||||
d = [[0] * (len(hyp) + 1) for _ in range(len(ref) + 1)]
|
||||
for i in range(len(ref) + 1):
|
||||
d[i][0] = i
|
||||
for j in range(len(hyp) + 1):
|
||||
d[0][j] = j
|
||||
for i in range(1, len(ref) + 1):
|
||||
for j in range(1, len(hyp) + 1):
|
||||
if ref[i - 1] == hyp[j - 1]:
|
||||
d[i][j] = d[i - 1][j - 1]
|
||||
else:
|
||||
d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1])
|
||||
return d[len(ref)][len(hyp)] / len(ref)
|
||||
|
||||
|
||||
# ── Adolf interaction ──────────────────────────────────────────────────────────
|
||||
|
||||
def get_log_tail(n: int = 60) -> str:
|
||||
result = subprocess.run(
|
||||
["docker", "logs", "deepagents", "--tail", str(n)],
|
||||
capture_output=True, text=True,
|
||||
)
|
||||
return result.stdout + result.stderr
|
||||
|
||||
|
||||
def extract_tier_from_logs(logs_before: str, logs_after: str) -> str | None:
|
||||
before_lines = set(logs_before.splitlines())
|
||||
new_lines = [l for l in logs_after.splitlines() if l not in before_lines]
|
||||
for line in reversed(new_lines):
|
||||
m = re.search(r"tier=(\w+(?:\s*\(dry-run\))?)", line)
|
||||
if m:
|
||||
return m.group(1).split()[0]
|
||||
return None
|
||||
|
||||
|
||||
async def post_to_adolf(
|
||||
client: httpx.AsyncClient,
|
||||
query_id: int,
|
||||
text: str,
|
||||
dry_run: bool = False,
|
||||
) -> bool:
|
||||
payload = {
|
||||
"text": text,
|
||||
"session_id": f"voice-bench-{query_id}",
|
||||
"channel": "cli",
|
||||
"user_id": "benchmark",
|
||||
"metadata": {"dry_run": dry_run, "benchmark": True, "voice": True},
|
||||
}
|
||||
try:
|
||||
r = await client.post(f"{ADOLF_URL}/message", json=payload, timeout=10)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"\n [Adolf error: {e}]", end="")
|
||||
return False
|
||||
|
||||
|
||||
# ── Dataset ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def load_dataset() -> list[dict]:
|
||||
with open(DATASET) as f:
|
||||
return json.load(f)["queries"]
|
||||
|
||||
|
||||
def filter_queries(queries, tier, category, ids):
|
||||
if tier:
|
||||
queries = [q for q in queries if q["tier"] == tier]
|
||||
if category:
|
||||
queries = [q for q in queries if q["category"] == category]
|
||||
if ids:
|
||||
queries = [q for q in queries if q["id"] in ids]
|
||||
return queries
|
||||
|
||||
|
||||
# ── Main run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
async def run(queries: list[dict], dry_run: bool = False, save_audio: bool = False) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Check Adolf
|
||||
try:
|
||||
r = await client.get(f"{ADOLF_URL}/health", timeout=5)
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"ERROR: Adolf not reachable: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
total = len(queries)
|
||||
results = []
|
||||
|
||||
dry_label = " [DRY-RUN]" if dry_run else ""
|
||||
print(f"Voice benchmark: {total} queries{dry_label}\n")
|
||||
print(f"{'ID':>3} {'EXP':8} {'ACT':8} {'OK':3} {'WER':5} {'TRANSCRIPT'}")
|
||||
print("─" * 100)
|
||||
|
||||
total_wer = 0.0
|
||||
wer_count = 0
|
||||
correct = 0
|
||||
|
||||
for q in queries:
|
||||
qid = q["id"]
|
||||
expected = q["tier"]
|
||||
original = q["query"]
|
||||
print(f"{qid:>3} {expected:8} ", end="", flush=True)
|
||||
|
||||
# Step 1: TTS
|
||||
wav = await synthesize(client, original)
|
||||
if wav is None:
|
||||
print(f"{'?':8} {'ERR':3} {'?':5} [TTS failed]")
|
||||
results.append({"id": qid, "expected": expected, "actual": None, "ok": False, "wer": None, "error": "tts"})
|
||||
continue
|
||||
|
||||
if save_audio:
|
||||
audio_path = RESULTS_DIR / f"voice_audio" / f"{qid}.wav"
|
||||
audio_path.parent.mkdir(exist_ok=True)
|
||||
audio_path.write_bytes(wav)
|
||||
|
||||
# Step 2: STT
|
||||
transcript = await transcribe(client, wav)
|
||||
if transcript is None:
|
||||
print(f"{'?':8} {'ERR':3} {'?':5} [STT failed]")
|
||||
results.append({"id": qid, "expected": expected, "actual": None, "ok": False, "wer": None, "error": "stt"})
|
||||
continue
|
||||
|
||||
# Calculate WER
|
||||
wer = word_error_rate(original, transcript)
|
||||
total_wer += wer
|
||||
wer_count += 1
|
||||
|
||||
# Step 3: Send to Adolf
|
||||
send_dry = dry_run and expected == "complex"
|
||||
logs_before = get_log_tail(60)
|
||||
t0 = time.monotonic()
|
||||
|
||||
ok_post = await post_to_adolf(client, qid, transcript, dry_run=send_dry)
|
||||
if not ok_post:
|
||||
print(f"{'?':8} {'ERR':3} {wer:4.2f} {transcript[:50]}")
|
||||
results.append({"id": qid, "expected": expected, "actual": None, "ok": False, "wer": wer, "transcript": transcript})
|
||||
continue
|
||||
|
||||
# Step 4: Wait for routing decision
|
||||
actual = None
|
||||
for _ in range(TIER_WAIT * 2):
|
||||
await asyncio.sleep(0.5)
|
||||
logs_after = get_log_tail(60)
|
||||
actual = extract_tier_from_logs(logs_before, logs_after)
|
||||
if actual and actual in ("light", "medium", "complex", "fast"):
|
||||
break
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
match = actual == expected or (actual == "fast" and expected == "medium")
|
||||
if match:
|
||||
correct += 1
|
||||
|
||||
mark = "✓" if match else "✗"
|
||||
actual_str = actual or "?"
|
||||
print(f"{actual_str:8} {mark:3} {wer:4.2f} {transcript[:60]}")
|
||||
|
||||
results.append({
|
||||
"id": qid,
|
||||
"expected": expected,
|
||||
"actual": actual_str,
|
||||
"ok": match,
|
||||
"wer": round(wer, 3),
|
||||
"original": original,
|
||||
"transcript": transcript,
|
||||
"elapsed": round(elapsed, 1),
|
||||
"dry_run": send_dry,
|
||||
})
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
print("─" * 100)
|
||||
|
||||
# Summary
|
||||
accuracy = correct / total * 100 if total else 0
|
||||
avg_wer = total_wer / wer_count * 100 if wer_count else 0
|
||||
print(f"\nRouting accuracy: {correct}/{total} ({accuracy:.0f}%)")
|
||||
print(f"Average WER: {avg_wer:.1f}% (lower is better; 0% = perfect transcription)")
|
||||
|
||||
for tier_name in ["light", "medium", "complex"]:
|
||||
tier_qs = [r for r in results if r["expected"] == tier_name]
|
||||
if tier_qs:
|
||||
tier_ok = sum(1 for r in tier_qs if r["ok"])
|
||||
tier_wers = [r["wer"] for r in tier_qs if r.get("wer") is not None]
|
||||
avg = sum(tier_wers) / len(tier_wers) * 100 if tier_wers else 0
|
||||
print(f" {tier_name:8}: routing {tier_ok}/{len(tier_qs)} avg WER {avg:.1f}%")
|
||||
|
||||
wrong = [r for r in results if not r["ok"]]
|
||||
if wrong:
|
||||
print(f"\nMisclassified after voice ({len(wrong)}):")
|
||||
for r in wrong:
|
||||
print(f" id={r['id']:3} expected={r.get('expected','?'):8} actual={r.get('actual','?'):8} transcript={r.get('transcript','')[:50]}")
|
||||
|
||||
high_wer = [r for r in results if r.get("wer") and r["wer"] > 0.3]
|
||||
if high_wer:
|
||||
print(f"\nHigh WER queries (>30%) — transcription quality issues:")
|
||||
for r in high_wer:
|
||||
print(f" id={r['id']:3} WER={r['wer']*100:.0f}% original: {r.get('original','')[:50]}")
|
||||
print(f" transcript: {r.get('transcript','')[:50]}")
|
||||
|
||||
# Save results
|
||||
ts = int(time.time())
|
||||
out_path = RESULTS_DIR / f"voice_results_{ts}.json"
|
||||
latest_path = RESULTS_DIR / "voice_results_latest.json"
|
||||
with open(out_path, "w") as f:
|
||||
json.dump({"summary": {"accuracy": accuracy, "avg_wer": avg_wer, "total": total}, "results": results}, f, indent=2, ensure_ascii=False)
|
||||
with open(latest_path, "w") as f:
|
||||
json.dump({"summary": {"accuracy": accuracy, "avg_wer": avg_wer, "total": total}, "results": results}, f, indent=2, ensure_ascii=False)
|
||||
print(f"\nResults saved to {latest_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Adolf voice benchmark — TTS→STT→routing pipeline",
|
||||
epilog="Requires: Silero TTS (port 8881) and faster-whisper (port 8880) running."
|
||||
)
|
||||
parser.add_argument("--tier", choices=["light", "medium", "complex"])
|
||||
parser.add_argument("--category")
|
||||
parser.add_argument("--ids", help="Comma-separated IDs")
|
||||
parser.add_argument("--dry-run", action="store_true",
|
||||
help="Complex queries use medium model for inference (no API cost)")
|
||||
parser.add_argument("--save-audio", action="store_true",
|
||||
help="Save synthesized WAV files to voice_audio/ directory")
|
||||
parser.add_argument("--skip-gpu-check", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not preflight_checks(skip_gpu_check=args.skip_gpu_check):
|
||||
sys.exit(1)
|
||||
|
||||
queries = load_dataset()
|
||||
ids = [int(i) for i in args.ids.split(",")] if args.ids else None
|
||||
queries = filter_queries(queries, args.tier, args.category, ids)
|
||||
if not queries:
|
||||
print("No queries match filters.")
|
||||
sys.exit(1)
|
||||
|
||||
asyncio.run(run(queries, dry_run=args.dry_run, save_audio=args.save_audio))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user