Recipe: flash_attention_2 graceful fallback#
Pattern: transformer scorers that prefer flash_attention_2 for
speed must degrade gracefully when the GPU class doesn’t support it —
otherwise the same code that runs on H100 / A100 breaks on RTX A6000
or smaller GPUs in the runpod-deploy GPU-failover pool.
Why this is a recipe, not a schema feature#
GPU-class detection and attention-implementation selection are consumer-domain concerns. They depend on the model architecture, the PyTorch/Transformers version pinned by your project, and what counts as an acceptable degraded mode for your evaluation. None of that is deployment metadata.
What runpod-deploy owns is the failover pool: it picks an available
GPU class from pod.gpu_order and provisions it. What attention
implementation your code uses on that GPU is yours to decide. Baking
the try/except into the orchestrator would force one fallback policy
on every consumer; consumers who genuinely need flash-attn-2 (e.g.,
for paper-grade timing comparisons) would have to opt out of the
fallback they didn’t ask for.
Pattern (Python)#
import torch
from transformers import AutoModel
try:
encoder = AutoModel.from_pretrained(
model_id,
revision=revision,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
except (ValueError, ImportError):
# flash-attention-2 not available on this GPU class; fall back to
# stock SDPA. Keep dtype + revision the same so determinism survives.
encoder = AutoModel.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
The try/except costs nothing at runtime when flash-attn-2 is supported (the import / construct succeeds on first try) and turns a hard ValueError into a logged degraded mode on smaller GPUs.
What lives where#
Concern |
Owner |
|---|---|
Selecting an available GPU class from |
|
Detecting GPU-class capabilities (FA2 support, SM compute capability) |
Your model-loading code |
Choosing the attention implementation |
Your model-loading code |
Logging which implementation was actually used (per-shard audit) |
Your training code (emit a |
Aggregating fallback frequency across shards |
Your post-run analysis ( |
Anti-pattern to avoid#
Do not let the model-load fail hard when flash-attn-2 isn’t
available. pod.gpu_order typically lists several GPU classes
(failover for stock-outs). A sweep may land on an H100 for one shard
and an A6000 for the next. Without the fallback, the second shard
fails at model load with a ValueError: flash_attention_2 is not supported, the orchestrator pulls a stack trace, and the operator
gets a billed failure for a portable-code bug.
Do not bake the fallback into runpod-deploy itself. The right
choice between flash-attn-2, SDPA, and eager attention depends on
your model + your tolerance for degraded performance. Consumers
running paper-grade timing comparisons may explicitly want the
hard-fail so they catch GPU-class drift early; consumers running
production evals want the graceful fallback. Both are legitimate.
See also#
multi-config-sweep.md— sweeps that span GPU classes (the typical case forpod.gpu_orderwith multiple entries) hit this exact failure mode without the fallback.reproducibility.md— log which attention implementation was used per shard for audit purposes; pair withevents.emit_event("attn_impl", ...)from your training code soevents-querycan later answer “which shards fell back?”.