Worked example: RecallAtLowFPR loss training#

What this shows. Train a tiny logistic regressor using RecallAtLowFPR — the Meta Prompt Guard 2 (PG2) training recipe. Demonstrates that the loss descends under SGD on a well-separated toy problem (the headline correctness property from the acceptance criteria).

Runtime: <2 s. Requires pip install eval-toolkit[losses] (torch only — no transformers). Closes eval-toolkit#50.

Setup#

This page is rendered statically because the docs CI environment doesn’t include [losses]. Real usage looks like:

pip install eval-toolkit[losses]
import torch

from eval_toolkit.losses import RecallAtLowFPR

torch.manual_seed(42)

Construct the loss#

loss = RecallAtLowFPR(
    fpr_target=0.01,           # operating point: keep FPR ≤ 1%
    fpr_smoothing_beta=10.0,   # soft-indicator temperature
    pos_weight=1.0,             # unweighted (use >1 for imbalanced sets)
    reduction="mean",
)

print(f"loss.fpr_target = {loss.fpr_target}")
print(f"loss.fpr_smoothing_beta = {loss.fpr_smoothing_beta}")

Single training step#

# 32 negatives + 8 positives; positives initialized BELOW the threshold
logits = torch.cat([torch.randn(32) - 0.5, torch.randn(8) - 0.5])
logits = logits.detach().clone().requires_grad_(True)
labels = torch.cat([torch.zeros(32), torch.ones(8)]).int()

out = loss(logits, labels)
out.backward()

print(f"loss = {out.item():.4f}")
print(f"gradient norm = {logits.grad.norm().item():.4f}")

The gradient shifts positive logits upward (since increasing TP score reduces 1 - Recall@FPR) and tugs FP scores downward at the threshold.

Multi-step convergence#

optimizer = torch.optim.SGD([logits], lr=1.0)

losses_over_time = []
for step in range(50):
    optimizer.zero_grad()
    out = loss(logits, labels)
    out.backward()
    optimizer.step()
    losses_over_time.append(out.item())

print(f"Initial loss: {losses_over_time[0]:.4f}")
print(f"Final loss:   {losses_over_time[-1]:.4f}")
print(f"Reduction:    {(1 - losses_over_time[-1] / losses_over_time[0]) * 100:.1f}%")

Comparison vs. fpr_target#

# Stricter FPR constraint → higher loss (fewer positives clear the threshold)
for fpr in [0.5, 0.1, 0.05, 0.01]:
    l = RecallAtLowFPR(fpr_target=fpr)
    val = l(logits.detach(), labels).item()
    print(f"  fpr_target={fpr:.2f}  →  loss={val:.4f}")

As fpr_target shrinks, the bar for “TP above threshold” rises and the loss increases — exactly what you want for a detector that must operate at a strict low-FPR point.

Reduction modes#

logits_d = logits.detach().clone().requires_grad_(True)

for reduction in ("mean", "sum", "none"):
    l = RecallAtLowFPR(fpr_target=0.1, reduction=reduction)
    out = l(logits_d, labels)
    print(f"  reduction={reduction:6s}  shape={tuple(out.shape) if out.dim() else '()'}")

Production usage#

Drop into a standard PyTorch training loop:

model = MyDetector()
loss_fn = RecallAtLowFPR(fpr_target=0.01)
optimizer = torch.optim.AdamW(model.parameters())

for batch in dataloader:
    optimizer.zero_grad()
    logits = model(batch["input_ids"])
    loss = loss_fn(logits, batch["labels"])
    loss.backward()
    optimizer.step()

The standard training-loop shape, drop-in replacement for nn.BCEWithLogitsLoss when you need FPR-constrained training.