feat: add run_routing_benchmark.py — routing-only benchmark #19
217
benchmarks/run_routing_benchmark.py
Normal file
217
benchmarks/run_routing_benchmark.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Adolf routing benchmark — tests routing decisions only, no LLM inference.
|
||||
|
||||
Sends each query with no_inference=True, waits for the routing decision to
|
||||
appear in docker logs, and records whether the correct tier was selected.
|
||||
|
||||
Usage:
|
||||
python3 run_routing_benchmark.py [options]
|
||||
python3 run_routing_benchmark.py --tier light|medium|complex
|
||||
python3 run_routing_benchmark.py --category <name>
|
||||
python3 run_routing_benchmark.py --ids 1,2,3
|
||||
python3 run_routing_benchmark.py --list-categories
|
||||
|
||||
No GPU check needed — inference is disabled for all queries.
|
||||
Adolf must be running at http://localhost:8000.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
ADOLF_URL = "http://localhost:8000"
|
||||
DATASET = Path(__file__).parent / "benchmark.json"
|
||||
RESULTS = Path(__file__).parent / "routing_results_latest.json"
|
||||
QUERY_TIMEOUT = 30 # seconds — routing is fast, no LLM wait
|
||||
|
||||
|
||||
# ── Log helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
def get_log_tail(n: int = 50) -> str:
|
||||
result = subprocess.run(
|
||||
["docker", "logs", "deepagents", "--tail", str(n)],
|
||||
capture_output=True, text=True,
|
||||
)
|
||||
return result.stdout + result.stderr
|
||||
|
||||
|
||||
def extract_tier_from_logs(logs_before: str, logs_after: str) -> str | None:
|
||||
"""Find new tier= lines that appeared after we sent the query."""
|
||||
before_lines = set(logs_before.splitlines())
|
||||
new_lines = [line for line in logs_after.splitlines() if line not in before_lines]
|
||||
for line in new_lines:
|
||||
m = re.search(r"tier=(\w+(?:\s*\(no-inference\))?)", line)
|
||||
if m:
|
||||
tier_raw = m.group(1)
|
||||
return tier_raw.split()[0]
|
||||
return None
|
||||
|
||||
|
||||
# ── Request helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
async def post_message(client: httpx.AsyncClient, query_id: int, query: str) -> bool:
|
||||
payload = {
|
||||
"text": query,
|
||||
"session_id": f"routing-bench-{query_id}",
|
||||
"channel": "cli",
|
||||
"user_id": "benchmark",
|
||||
"metadata": {"no_inference": True, "benchmark": True},
|
||||
}
|
||||
try:
|
||||
r = await client.post(f"{ADOLF_URL}/message", json=payload, timeout=10)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f" POST_ERROR: {e}", end="")
|
||||
return False
|
||||
|
||||
|
||||
# ── Dataset ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def load_dataset() -> list[dict]:
|
||||
with open(DATASET) as f:
|
||||
return json.load(f)["queries"]
|
||||
|
||||
|
||||
def filter_queries(queries, tier, category, ids):
|
||||
if tier:
|
||||
queries = [q for q in queries if q["tier"] == tier]
|
||||
if category:
|
||||
queries = [q for q in queries if q["category"] == category]
|
||||
if ids:
|
||||
queries = [q for q in queries if q["id"] in ids]
|
||||
return queries
|
||||
|
||||
|
||||
# ── Main run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
async def run(queries: list[dict]) -> list[dict]:
|
||||
results = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
r = await client.get(f"{ADOLF_URL}/health", timeout=5)
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"ERROR: Adolf not reachable: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
total = len(queries)
|
||||
correct = 0
|
||||
|
||||
print(f"\nRunning {total} queries [NO-INFERENCE: routing only]\n")
|
||||
print(f"{'ID':>3} {'EXPECTED':8} {'ACTUAL':8} {'OK':3} {'TIME':6} {'CATEGORY':22} QUERY")
|
||||
print("─" * 110)
|
||||
|
||||
for q in queries:
|
||||
qid = q["id"]
|
||||
expected = q["tier"]
|
||||
category = q["category"]
|
||||
query_text = q["query"]
|
||||
session_id = f"routing-bench-{qid}"
|
||||
|
||||
print(f"{qid:>3} {expected:8} ", end="", flush=True)
|
||||
|
||||
logs_before = get_log_tail(300)
|
||||
t0 = time.monotonic()
|
||||
|
||||
ok_post = await post_message(client, qid, query_text)
|
||||
if not ok_post:
|
||||
print(f"{'?':8} {'ERR':3} {'?':6} {category:22} {query_text[:40]}")
|
||||
results.append({"id": qid, "expected": expected, "actual": None, "ok": False})
|
||||
continue
|
||||
|
||||
try:
|
||||
async with client.stream(
|
||||
"GET", f"{ADOLF_URL}/stream/{session_id}", timeout=QUERY_TIMEOUT
|
||||
) as sse:
|
||||
async for line in sse.aiter_lines():
|
||||
if "data: [DONE]" in line:
|
||||
break
|
||||
except Exception:
|
||||
pass # timeout or connection issue — check logs anyway
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
logs_after = get_log_tail(300)
|
||||
actual = extract_tier_from_logs(logs_before, logs_after)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
match = actual == expected or (actual == "fast" and expected == "medium")
|
||||
if match:
|
||||
correct += 1
|
||||
|
||||
mark = "✓" if match else "✗"
|
||||
actual_str = actual or "?"
|
||||
print(f"{actual_str:8} {mark:3} {elapsed:5.1f}s {category:22} {query_text[:40]}")
|
||||
|
||||
results.append({
|
||||
"id": qid,
|
||||
"expected": expected,
|
||||
"actual": actual_str,
|
||||
"ok": match,
|
||||
"elapsed": round(elapsed, 1),
|
||||
"category": category,
|
||||
"query": query_text,
|
||||
})
|
||||
|
||||
print("─" * 110)
|
||||
accuracy = correct / total * 100 if total else 0
|
||||
print(f"\nAccuracy: {correct}/{total} ({accuracy:.0f}%)")
|
||||
|
||||
for tier_name in ["light", "medium", "complex"]:
|
||||
tier_qs = [r for r in results if r["expected"] == tier_name]
|
||||
if tier_qs:
|
||||
tier_ok = sum(1 for r in tier_qs if r["ok"])
|
||||
print(f" {tier_name:8}: {tier_ok}/{len(tier_qs)}")
|
||||
|
||||
wrong = [r for r in results if not r["ok"]]
|
||||
if wrong:
|
||||
print(f"\nMisclassified ({len(wrong)}):")
|
||||
for r in wrong:
|
||||
print(f" id={r['id']:3} expected={r['expected']:8} actual={r['actual']:8} {r['query'][:60]}")
|
||||
|
||||
with open(RESULTS, "w") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"\nResults saved to {RESULTS}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Adolf routing benchmark — routing decisions only, no LLM inference",
|
||||
)
|
||||
parser.add_argument("--tier", choices=["light", "medium", "complex"])
|
||||
parser.add_argument("--category")
|
||||
parser.add_argument("--ids", help="Comma-separated IDs")
|
||||
parser.add_argument("--list-categories", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
queries = load_dataset()
|
||||
|
||||
if args.list_categories:
|
||||
cats = sorted(set(q["category"] for q in queries))
|
||||
tiers = {t: sum(1 for q in queries if q["tier"] == t) for t in ["light", "medium", "complex"]}
|
||||
print(f"Total: {len(queries)} | Tiers: {tiers}")
|
||||
print(f"Categories: {cats}")
|
||||
return
|
||||
|
||||
ids = [int(i) for i in args.ids.split(",")] if args.ids else None
|
||||
queries = filter_queries(queries, args.tier, args.category, ids)
|
||||
if not queries:
|
||||
print("No queries match filters.")
|
||||
sys.exit(1)
|
||||
|
||||
asyncio.run(run(queries))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user