""" recovery_test.py — Argus vs Blind =================================== Proves Argus saves training. Same model. Same seed. Same data. Same failure injection. Run A — WITH Argus: detects failure, restores checkpoint, continues Run B — WITHOUT Argus: blind, no intervention, eats the failure Checkpoint saved at inject_step - 2. Fair comparison. Start server first: PYTHONPATH=. python api/server.py Then: PYTHONPATH=. python recovery_test.py """ import json, time, math, copy, urllib.request, urllib.error import torch, torch.nn as nn, torch.optim as optim import torchvision, torchvision.transforms as transforms from torch.utils.data import DataLoader, TensorDataset BASE = "" SECRET = "" HEADERS = {"Content-Type": "application/json", "x-internal-secret": SECRET} PASS, FAIL = "✅", "❌" results = [] # ─── HTTP ───────────────────────────────────────────────────────────────────── def req(method, path, body=None): url = BASE + path data = json.dumps(body).encode() if body else None r = urllib.request.Request(url, data=data, headers=HEADERS, method=method) try: with urllib.request.urlopen(r, timeout=30) as resp: return resp.status, json.loads(resp.read()) except urllib.error.HTTPError as e: return e.code, json.loads(e.read()) def send_step(run_id, step, loss, loss_delta, grad_norm, grad_sim): def _s(v): return 1e6 if (math.isnan(v) or math.isinf(v)) else min(abs(v), 1e6) _, b = req("POST", "/v2/detect", { "run_id": run_id, "step": step, "epoch": 0, "training": { "loss": _s(loss), "loss_delta": _s(loss_delta), "grad_norm": _s(grad_norm), "gradient_similarity": float(max(0.0, min(1.0, grad_sim))), }, "histogram": {"bins": 4, "counts": [4, 8, 8, 4]}, "control": {"mode": "AUTO"}, }) return b.get("harm_pressure", 0), b.get("intervention", "NONE"), b.get("anchor_point"), b.get("situation", "UNKNOWN"), b.get("pattern", "none") def check(label, ok, note=""): results.append((label, ok)) print(f" {PASS if ok else FAIL} {label} {note}") # ─── MODEL + DATA ───────────────────────────────────────────────────────────── def get_loader(seed=42): torch.manual_seed(seed) tf = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) ]) ds = torchvision.datasets.CIFAR10(root="data", train=True, download=False, transform=tf) loader = DataLoader(ds, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(seed)) xs, ys = [], [] for x, y in loader: xs.append(x); ys.append(y) if sum(len(b) for b in xs) >= 4096: break X = torch.cat(xs)[:4096] Y = torch.cat(ys)[:4096] return DataLoader(TensorDataset(X, Y), batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(seed)) def fresh_model(seed=42): torch.manual_seed(seed) return nn.Sequential( nn.Conv2d(3,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(64*8*8,256), nn.ReLU(), nn.Linear(256,10), ) def grad_norm_of(model): total = 0.0 for p in model.parameters(): if p.grad is not None: total += p.grad.data.norm(2).item() ** 2 return math.sqrt(total) def grad_sim_real(model, prev): cur = [p.grad.data.view(-1) for p in model.parameters() if p.grad is not None] if not cur or prev is None: return 1.0 c = torch.cat(cur); p = torch.cat(prev) return max(0.0, min(1.0, (c*p).sum().item() / (c.norm().item()*p.norm().item()+1e-8))) def get_grads(model): return [p.grad.data.clone().view(-1) for p in model.parameters() if p.grad is not None] # ─── CORE LOOP ──────────────────────────────────────────────────────────────── def run_paired(scenario_name, loader, seed, total_steps, inject_step, inject_fn, post_inject_fn=None): """ Runs scenario twice — WITH and WITHOUT Argus. Same model init (seed). Same data. Same injection. Checkpoint saved at inject_step - 2. """ SAVE_AT = inject_step - 2 criterion = nn.CrossEntropyLoss() # ── WITH ARGUS ────────────────────────────────────────────────── print(f" [WITH Argus]") torch.manual_seed(seed) model_w = fresh_model(seed) opt_w = optim.Adam(model_w.parameters(), lr=1e-3) model_w.train() rid_w = f"w_{scenario_name}_{int(time.time())%100000}" prev_loss_w = None prev_grads_w = None data_iter_w = iter(loader) saved_ckpt = None restored = False detected_step = None losses_w = [] for step in range(total_steps): try: x, y = next(data_iter_w) except StopIteration: data_iter_w = iter(loader); x, y = next(data_iter_w) if step == SAVE_AT: saved_ckpt = copy.deepcopy(model_w.state_dict()) if step == inject_step: opt_w = inject_fn(model_w, opt_w) opt_w.zero_grad() loss = criterion(model_w(x), y) loss.backward() if post_inject_fn and step > inject_step: post_inject_fn(model_w, step, inject_step) lv = loss.item() gn = grad_norm_of(model_w) gsv = grad_sim_real(model_w, prev_grads_w) ld = (lv - prev_loss_w) if prev_loss_w is not None else 0.0 prev_grads_w = get_grads(model_w) prev_loss_w = lv torch.nn.utils.clip_grad_norm_(model_w.parameters(), 500.0) opt_w.step() hp, intv, anchor, situation, pattern = send_step(rid_w, step, lv, ld, gn, gsv) if hp > 0 or step % 10 == 0: print(f" step={step} hp={hp} situation={situation} pattern={pattern} loss={lv:.3f} gn={gn:.2f}") if hp >= 2 and detected_step is None: detected_step = step print(f" ⚡ detected step={step} hp={hp} situation={situation} pattern={pattern}") should_restore = ( hp >= 2 and not restored and saved_ckpt is not None ) if should_restore: model_w.load_state_dict(saved_ckpt) opt_w = optim.Adam(model_w.parameters(), lr=1e-3) restored = True print(f" ⚡ restored — situation={situation} pattern={pattern}") losses_w.append(min(lv, 50.0)) loss_with = sum(losses_w[-10:]) / 10 print(f" final loss (avg last 10 steps): {loss_with:.3f}") # ── WITHOUT ARGUS ─────────────────────────────────────────────── print(f" [WITHOUT Argus]") torch.manual_seed(seed) model_n = fresh_model(seed) opt_n = optim.Adam(model_n.parameters(), lr=1e-3) model_n.train() prev_loss_n = None prev_grads_n = None data_iter_n = iter(loader) losses_n = [] for step in range(total_steps): try: x, y = next(data_iter_n) except StopIteration: data_iter_n = iter(loader); x, y = next(data_iter_n) if step == inject_step: opt_n = inject_fn(model_n, opt_n) opt_n.zero_grad() loss = criterion(model_n(x), y) loss.backward() if post_inject_fn and step > inject_step: post_inject_fn(model_n, step, inject_step) lv = loss.item() gn = grad_norm_of(model_n) gsv = grad_sim_real(model_n, prev_grads_n) ld = (lv - prev_loss_n) if prev_loss_n is not None else 0.0 prev_grads_n = get_grads(model_n) prev_loss_n = lv torch.nn.utils.clip_grad_norm_(model_n.parameters(), 500.0) opt_n.step() losses_n.append(min(lv, 50.0)) loss_without = sum(losses_n[-10:]) / 10 print(f" final loss (avg last 10 steps): {loss_without:.3f}") return loss_with, loss_without, detected_step, restored # ─── SCENARIOS ──────────────────────────────────────────────────────────────── def scenario_explosion(loader): print("\n── 1. EXPLOSION — SGD lr=8 at step 40 ──────────────────────────") def inject(model, opt): return optim.SGD(model.parameters(), lr=8.0) loss_w, loss_n, detected, restored = run_paired( scenario_name = "expl", loader = loader, seed = 42, total_steps = 300, inject_step = 40, inject_fn = inject, ) saved = round(loss_n - loss_w, 3) print(f"\n With Argus: {loss_w:.3f} | Without: {loss_n:.3f} | Saved: {saved:.3f}") check("Explosion detected", detected is not None, f"step={detected}") check("Checkpoint restored", restored) check("Argus wins (lower final loss)", loss_w < loss_n, f"delta={saved}") def scenario_slow_diverge(loader): print("\n── 2. WEIGHT POISON — noise ×10 weight magnitude at step 40 ────") def inject(model, opt): with torch.no_grad(): for p in model.parameters(): noise = torch.randn_like(p) * p.abs().mean() * 10.0 p.add_(noise) return opt loss_w, loss_n, detected, restored = run_paired( scenario_name = "poison", loader = loader, seed = 42, total_steps = 120, inject_step = 40, inject_fn = inject, ) saved = round(loss_n - loss_w, 3) print(f"\n With Argus: {loss_w:.3f} | Without: {loss_n:.3f} | Saved: {saved:.3f}") check("Weight poison detected", detected is not None, f"step={detected}") check("Checkpoint restored", restored) check("Argus wins (lower final loss)", loss_w < loss_n, f"delta={saved}") def scenario_grad_corruption(loader): print("\n── 3. WEIGHT DRIFT — noise ×3 all layers at step 40 ────────────") def inject(model, opt): with torch.no_grad(): for p in model.parameters(): noise = torch.randn_like(p) * p.abs().mean() * 5.0 p.add_(noise) return opt loss_w, loss_n, detected, restored = run_paired( scenario_name = "drift", loader = loader, seed = 42, total_steps = 120, inject_step = 40, inject_fn = inject, ) saved = round(loss_n - loss_w, 3) print(f"\n With Argus: {loss_w:.3f} | Without: {loss_n:.3f} | Saved: {saved:.3f}") check("Weight drift detected", detected is not None, f"step={detected}") check("Checkpoint restored", restored) check("Argus wins (lower final loss)", loss_w < loss_n, f"delta={saved}") def scenario_false_panic(loader): print("\n── 4. FALSE PANIC — soft noise that recovers naturally ──────────") print(" Argus should NOT restore. If it does, that is a false positive.\n") def inject(model, opt): with torch.no_grad(): for p in model.parameters(): noise = torch.randn_like(p) * p.abs().mean() * 0.5 p.add_(noise) return opt loss_w, loss_n, detected, restored = run_paired( scenario_name = "falsepanic", loader = loader, seed = 42, total_steps = 120, inject_step = 40, inject_fn = inject, ) saved = round(loss_n - loss_w, 3) print(f"\n With Argus: {loss_w:.3f} | Without: {loss_n:.3f} | Saved: {saved:.3f}") check("False panic — Argus did NOT restore", not restored, "FAIL=over-eager" if restored else "correct") check("False panic — no final loss penalty", loss_w <= loss_n + 0.1, f"delta={saved}") # ─── MAIN ───────────────────────────────────────────────────────────────────── def main(): print("\n" + "="*66) print(" Argus vs Blind — Paired Recovery Test") print(f" {time.strftime('%Y-%m-%d %H:%M:%S')}") print(" Same seed. Same model. Same data. Checkpoint at inject-2.") print("="*66) s, b = req("GET", "/v2/health") if s != 200: print(f" {FAIL} Server not reachable at {BASE}") print(" Run: PYTHONPATH=. python api/server.py") return print(f" {PASS} Server healthy\n") loader = get_loader(seed=42) scenario_explosion(loader) scenario_slow_diverge(loader) scenario_grad_corruption(loader) scenario_false_panic(loader) print("\n" + "="*66) print(" SUMMARY") print("="*66) for name, ok in results: print(f" {PASS if ok else FAIL} {name}") wins = sum(ok for _, ok in results) total = len(results) print(f"\n {wins}/{total} passing") if wins == total: print(" 🛡 Argus demonstrably saves training in all scenarios.") print("="*66 + "\n") if __name__ == "__main__": main()