Investigating Flaky `test_eagle_dp` — Batch Invariance Failure on L4 GPUs
These are my personal notes from my investigation and fix to one of the CI environment checks in vLLM. The goal is to document this work, so I can reference this later without having to re-derive everything from scratch.
A few things before we proceed:
- These notes were generated with the help of Claude, given the all the context and my understandings. If you spot any reasoning errors, please let me know in the comments.
Issue: vllm-project/vllm#31913
Related: pytorch/pytorch#170563, vllm PR #38566, vllm PR #31915
Fix PR: vllm-project/vllm#38938
Follow-up Issue: vllm-project/vllm#39096 — Batch invariance breaks with torch.compile and/or CUDA graphs on SM<90
1. Problem Statement
The test tests/v1/distributed/test_eagle_dp.py::test_run_eagle_dp[FLASH_ATTN] is flaky on CI. It runs two engines sequentially on the same prompt with temperature=0 (greedy sampling):
- Engine A (EAGLE): Uses speculative decoding with a draft model. The target model verifies draft tokens in batches of
1 + num_draft_tokens = 4. - Engine B (no EAGLE): Normal autoregressive decoding. The target model processes
1token per decode step.
The test asserts output_A == output_B. With VLLM_BATCH_INVARIANT=1 set, this equality should hold — speculative decoding must not change the final output. The test fails intermittently on CI (L4 GPUs) but passes on other hardware.
2. Initial Hypotheses
From the Slack discussion and issue tracker, three hypotheses existed:
- Async scheduling bug: Nicolò (vLLM maintainer) observed he “couldn’t reproduce when async correction is disabled.” This pointed at the
AsyncSchedulercode path. - Batch invariance failure: zou3519 (PyTorch) hypothesized “something is either wrong with batch invariance or something else wrong with how our spec decoding is implemented.”
- Zero-bubble spec decode interaction: Matt Bonanni noted that timing changes from zero-bubble spec decode could surface an existing issue.
We investigated all three through static code analysis and targeted instrumentation.
3. Static Code Analysis
3.1 Async Scheduling Path
We traced the full lifecycle of EAGLE + async scheduling across these files:
| File | Role |
|---|---|
vllm/v1/engine/core.py | Engine core loop: step() → post_step() |
vllm/v1/core/sched/scheduler.py | Base scheduler: schedule(), update_from_output() |
vllm/v1/core/sched/async_scheduler.py | Async scheduler: placeholder-based draft token tracking |
vllm/v1/worker/gpu/model_runner.py | GPU model runner: execute_model(), sample_tokens() |
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py | EAGLE draft model: propose() |
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py | Rejection sampling: accept/reject draft tokens |
vllm/v1/worker/gpu/input_batch.py | Input construction: combine_sampled_and_draft_tokens() |
vllm/v1/worker/gpu/async_utils.py | Async D2H copy for output |
vllm/v1/request.py | Request state: num_output_placeholders, spec_token_ids |
vllm/config/scheduler.py | Scheduler config: async_scheduling resolution |
Key finding: When async_scheduling=None (default), it resolves to True for EAGLE models. Both the EAGLE and non-EAGLE engines run with async scheduling. The AsyncScheduler uses placeholder draft tokens [-1, -1, -1] and relies on the worker to fill in real draft tokens from GPU state.
We traced the full state transitions for prefill → first decode → subsequent steps and found the logic traced correctly on paper. No obvious bugs in the async scheduler state tracking, D2H copy synchronization, or draft token handling.
3.2 Batch Invariance Implementation
We examined vllm/model_executor/layers/batch_invariant.py and all 50+ locations where VLLM_BATCH_INVARIANT is checked.
What batch invariance does:
When enabled, it replaces standard GPU operations with deterministic alternatives:
- Matrix multiply (
torch.mm,torch.bmm,torch.matmul,torch.linear): Replaced with custom Triton persistent kernels that use a fixed K-reduction order per output tile, independent of batch size. - FlashAttention: Forces
num_splits=1, ensuring each query’s attention is computed by exactly one worker in one fixed order. - Softmax, log_softmax, mean: Replaced with Triton kernels that process one row per thread block.
- RMS norm: Replaced with a Triton kernel with fixed reduction order.
- cuBLAS settings: Disables reduced precision, forces
cublasltbackend. - NCCL settings: Forces deterministic all-reduce.
How these overrides are registered (in enable_batch_invariant_mode()):
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
# etc.
torch.library.Library("aten", "IMPL") creates a library that overrides the implementation of existing PyTorch aten operators. When you call .impl("aten::mean.dim", mean_batch_invariant, "CUDA"), it means: whenever aten::mean.dim is dispatched on CUDA, use the custom mean_batch_invariant function instead of PyTorch’s default. This works in eager mode because every op goes through the dispatch table.
Hardware-specific code path in enable_batch_invariant_mode():
if (
current_platform.is_device_capability_family(100) # Blackwell
or current_platform.is_device_capability_family(80) # A100, L4/RTX 4090 (SM80 family)
):
# Full Triton persistent matmul replacements
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
else:
# H100 (SM90): only cuBLAS workspace workaround
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
Both A100 (SM80) and L4 (SM89) take the same code path — full Triton persistent matmul replacements. FlashAttention num_splits=1 is enforced on both.
3.3 FlashAttention num_splits Enforcement
Three locations in vllm/v1/attention/backends/flash_attn.py set num_splits:
| Line | Context | Setting |
|---|---|---|
| 1019 | Encoder attention (causal=False) | num_splits=1 if batch_invariant else 0 |
| 1165 | Cascade attention (prefix) | num_splits=1 if BATCH_INVARIANT else max_num_splits |
| 1190 | Cascade attention (suffix) | num_splits=1 if BATCH_INVARIANT else max_num_splits |
For this test (Llama decoder-only, VLLM_BATCH_INVARIANT=1 disables cascade attention), only the main decode path is used, which correctly enforces num_splits=1 via attn_metadata.max_num_splits set at line 426-427.
4. Instrumentation Approach
Since static analysis couldn’t pinpoint the root cause, we added targeted logging at 6 points across 3 source files, controlled by a single tag [spec_decode_debug] for easy filtering.
4.1 Scheduler Instrumentation (Files 2 & 3)
Added to async_scheduler.py and scheduler.py:
| Point | Location | What it captures |
|---|---|---|
async_sched.after_schedule | End of _update_after_schedule() | num_computed_tokens, num_output_placeholders, spec_token_ids, is_prefill_chunk |
async_sched.update_output | End of _update_request_with_output() | New token IDs, placeholder count, cache position |
sched.update_from_output.spec | After rejection handling in update_from_output() | Generated tokens, accepted/rejected counts |
sched.update_from_output.tokens | After _update_request_with_output() | Final new tokens, output token history |
4.2 Model Runner Instrumentation (File 1)
Added to vllm/v1/worker/gpu/model_runner.py:
| Point | Location | What it captures |
|---|---|---|
logits | After self.model.compute_logits() in sample() | Top-5 logit values and IDs, argmax |
prepare_inputs | After combine_sampled_and_draft_tokens() | Input IDs, positions, draft token count |
sample_result | After self.sample() in sample_tokens() | Sampled token IDs, num_sampled, num_rejected |
post_propose | After speculator.propose() | Last sampled tokens, new draft tokens |
Note: The model runner instrumentation ran in the worker subprocess, whose stdout/stderr is not captured in CI logs. Only the scheduler instrumentation (running in the EngineCore process) was captured.
4.3 Test Modification (File 4)
Modified test_eagle_dp.py to print detailed divergence info on failure: exact index, surrounding tokens from both engines, and full output sequences.
5. Experimental Results
5.1 A100 Runs (2× A100 80GB, vast.ai)
20 consecutive runs, all passed.
Key observations:
- 100% deterministic: Identical token sequences, identical acceptance patterns, identical draft token behavior across all 20 runs.
- Acceptance pattern: Exactly 53 spec decode steps, 29 full rejections, 24 with some acceptance — identical every run.
- Token output: EAGLE and non-EAGLE engines produced identical output every time.
- The EAGLE engine’s output (
[323, 10344, 13, 578, 1296, 374, ...]) matches the “wrong” output seen in CI failures, meaning the EAGLE engine is consistent across hardware — it’s the non-EAGLE engine that differs.
5.2 L4 CI Runs
Failure 1 — Divergence at token 81:
Context: ...4188, 271, 2
EAGLE (batch=4): argmax → 20400
No-EAGLE (batch=1): argmax → 4324
EAGLE: generated=[271, 2, 20400] accepted=2 rejected=1
No-EAGLE: final_new_tokens=[271], [2], [4324] (one at a time)
Failure 2 — Divergence at token 33:
Context: ...16572, 389, 220
EAGLE (batch=4): argmax → 1490
No-EAGLE (batch=1): argmax → 4728
EAGLE: generated=[1490] accepted=0 rejected=3
No-EAGLE: final_new_tokens=[4728]
Failure 3 (after lm_head fix) — Divergence at token 32:
Context: ...16572, 389, 220
EAGLE (batch=4): argmax → 4728
No-EAGLE (batch=1): argmax → 1490
Note: the tokens swapped compared to Failure 2 — the lm_head fix changed which path produced which token, but the upstream divergence in the transformer layers remained.
In all cases: identical model, identical weights, identical input context, different batch size → different greedy output. This directly violates batch invariance.
6. Root Cause Analysis
6.1 Phase 1: Finding the lm_head Gap
We traced the Llama forward pass operation by operation, cross-referencing each against the batch invariance overrides.
The Llama forward pass chain:
input_ids
→ embed_tokens (embedding lookup, not a matmul)
→ 32× LlamaDecoderLayer:
→ input_layernorm (RMSNorm)
→ self_attn:
→ qkv_proj (linear) → RoPE → FlashAttention → o_proj (linear)
→ post_attention_layernorm (RMSNorm)
→ mlp:
→ gate_up_proj (linear) → SiluAndMul → down_proj (linear)
→ norm (final RMSNorm)
→ compute_logits → lm_head.quant_method.apply → matmul → argmax
Key finding: Two different apply methods for unquantized layers:
UnquantizedLinearMethod.apply (all model linear layers — qkv, o_proj, gate_up, down):
def apply(self, layer, x, bias=None):
if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias) # ✅ COVERED
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
UnquantizedEmbeddingMethod.apply (lm_head — the final logits projection):
def apply(self, layer, x, bias=None):
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) # ❌ NOT COVERED
The lm_head’s apply method had no VLLM_BATCH_INVARIANT check. It always used dispatch_unquantized_gemm() (which calls F.linear → cuBLAS), never the batch-invariant Triton persistent kernel. This is the operation that computes hidden_states @ lm_head_weight.T → logits, whose output is directly argmax’d to pick the next token.
Fix 1: Added the batch invariant check to UnquantizedEmbeddingMethod.apply in vllm/model_executor/layers/vocab_parallel_embedding.py.
6.2 Phase 2: Finding the RMSNorm Gap
After Fix 1, CI still failed (Failure 3 above), proving there was at least one more uncovered operation upstream. The token swap between Failures 2 and 3 confirmed the lm_head fix changed behavior but didn’t resolve the root cause.
We traced the RMSNorm dispatch path and discovered the interaction with torch.compile:
The CustomOp dispatch system:
vLLM’s CustomOp base class has two forward paths:
-
forward_cuda: Used in eager mode. Contains explicit batch invariant checks (callsrms_norm_batch_invariant()directly). -
forward_native: Used under torch.compile. Contains pure PyTorch ops (x.pow(2).mean(dim=-1)) for Inductor to compile.
Which path is used is determined by CustomOp.dispatch_forward():
def dispatch_forward(self, compile_native):
enabled = self._enforce_enable or self.enabled()
if not enabled:
return self.maybe_compile(self.forward_native, enable=compile_native)
# ... platform checks ...
return self.forward_cuda
self.enabled() checks custom_ops in the compilation config:
@classmethod
def enabled(cls):
compilation_config = get_cached_compilation_config()
custom_ops = compilation_config.custom_ops
enabled = f"+{cls.name}" in custom_ops
disabled = f"-{cls.name}" in custom_ops
return (CustomOp.default_on() or enabled) and not disabled
default_on() returns True if "all" is in custom_ops, False if "none" is in custom_ops.
The critical chain for this test:
- Test sets
enforce_eager=False→ torch.compile is active - In
vllm/config/vllm.py, when Inductor is the backend with compilation enabled:if all(s not in self.compilation_config.custom_ops for s in ("all", "none")): if self.compilation_config.backend == "inductor" and self.compilation_config.mode != CompilationMode.NONE: self.compilation_config.custom_ops.append("none") # ← THIS else: self.compilation_config.custom_ops.append("all") -
custom_ops = ["none"]→default_on()returnsFalse -
RMSNorm.enabled()returnsFalse(no+rms_normincustom_ops) -
dispatch_forward()→forward_native(notforward_cuda) -
forward_native→ir.ops.rms_norm→x.pow(2).mean(dim=-1, keepdim=True) - Inductor lowers these ops to its own IR and generates its own Triton reduction kernels
- These Inductor-generated kernels bypass the
aten::mean.dimbatch invariant override — Inductor has its own lowerings that don’t go through the aten dispatch table - The Inductor-generated reduction kernel produces batch-size-dependent results on L4
- This propagates through 32 transformer layers (64 RMSNorm calls), accumulates, and flips the argmax
Why override_envs_for_invariance didn’t help: This function (called by init_batch_invariance) sets NCCL, cuBLAS, and AOT compile env vars, but never touches custom_ops. There is no code anywhere in vLLM that forces custom ops on when VLLM_BATCH_INVARIANT=1.
Why init_batch_invariance can’t fix this: It runs in the worker subprocess during gpu_worker.py, but CustomOp.dispatch_forward() is called during RMSNorm.__init__() when the model is constructed. By the time the worker calls init_batch_invariance, the RMSNorm instances have already bound _forward_method = forward_native.
Fix 2: Added compilation_config={"custom_ops": ["none", "+rms_norm"]} to the test’s AsyncEngineArgs, forcing RMSNorm to dispatch to forward_cuda.
6.3 Why the Existing Batch Invariant Tests Didn’t Catch This
The e2e batch invariant tests (tests/v1/determinism/test_batch_invariance.py) set:
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
On L4 (SM89), IS_DEVICE_CAPABILITY_BELOW_90 evaluates to True, so they run in eager mode → custom_ops=["all"] → forward_cuda is used → batch invariant RMSNorm works correctly. These tests never hit the Inductor forward_native path on L4.
The EAGLE DP test is the only test that uses enforce_eager=False unconditionally on L4, which is why it’s the only test that exposed this gap.
6.4 Why L4 But Not A100
Both GPUs take the same batch invariance code path (Triton persistent matmul, num_splits=1, etc.). The uncovered operations (RMSNorm via Inductor, and previously lm_head) produce slightly different results depending on batch size due to GPU thread scheduling, memory access patterns, or internal reduction order differences.
On A100, either:
- The uncovered operations happen to be more numerically stable, or
- The logit gaps between top candidates are large enough that the tiny differences don’t flip any argmax.
On L4, at specific positions in the sequence, the top-2 logit candidates are so close (within ~1e-6) that the uncovered operation noise flips the argmax.
6.5 Why It’s Flaky
The test doesn’t always fail because:
- Whether the sequence reaches a “tie-breaker” position (where top-2 logits are within the noise margin) depends on the specific token sequence.
- With 100 tokens, there are 100 chances to hit such a position.
- On some runs, the sequence never encounters a sufficiently close tie.
- Increasing the token count makes failures more frequent (more chances for a tie).
- The divergence position varies between runs (token 32, 33, 80 observed).
This matches zou3519’s observation: “For pytorch 2.9 I had to adjust to 600; for PyTorch 2.10 adjusting to 50 will cause a failure.”
7. Fixes Applied
Fix 1: UnquantizedEmbeddingMethod.apply
File: vllm/model_executor/layers/vocab_parallel_embedding.py
Added the missing batch invariant check to the lm_head’s apply method:
def apply(self, layer, x, bias=None):
if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
Fix 2: RMSNorm Custom Op Under torch.compile
File: tests/v1/distributed/test_eagle_dp.py
Added compilation_config to force RMSNorm’s forward_cuda path:
engine_args = AsyncEngineArgs(
model=target_model,
enforce_eager=False,
compilation_config={"custom_ops": ["none", "+rms_norm"]},
...
)
This is scoped to the test as a workaround. The broader issue of batch invariant overrides being bypassed under torch.compile is tracked in #39096.
Long-term Fix (Proposed)
When VLLM_BATCH_INVARIANT=1, either:
- Automatically add
+rms_norm(and any other custom ops with batch invariant logic inforward_cuda) tocustom_opsinvllm/config/vllm.py - Ensure the
aten::IMPLoverrides are respected by Inductor during lowering, soforward_nativealso produces batch-invariant results
8. Slack Thread Discussion Summary
Key exchanges from the vLLM Slack thread on this issue:
Q: “Are we running this test with VLLM_BATCH_INVARIANT=1? And you’re saying that torch.compile’d reductions are not guaranteed to be batch invariant?”
A: Yes, the test sets VLLM_BATCH_INVARIANT=1 (via monkeypatch.setenv). The issue is that vLLM’s batch invariance overrides PyTorch’s aten operators via torch.library.Library("aten", "IMPL"). When RMSNorm uses forward_native under torch.compile, Inductor lowers the ops to its own IR and generates its own Triton reduction kernels — these bypass the aten dispatch table entirely, so the aten::mean.dim override never gets called.
Q: “Can you check what our e2e batch invariant tests do with RMS norm? Do they use --enforce-eager?”
A: The e2e tests set enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, so on L4 (SM89) they run in eager mode → custom_ops=["all"] → forward_cuda is used. The EAGLE DP test uses enforce_eager=False unconditionally, which is the only test that exposes this gap on L4.
Decision: Scope the RMSNorm fix to the test via compilation_config, keep the UnquantizedEmbeddingMethod.apply fix as a code-level fix, and file a separate issue (#39096) for the broader torch.compile + batch invariance interaction.
9. Evidence Summary
| Observation | Implication |
|---|---|
| 20/20 passes on A100 with identical output | Batch invariance works on SM80 in eager mode |
| Divergence at varying token positions on L4 (32, 33, 80) | Not a systematic logic bug; position-dependent |
| Identical context through divergence point, different argmax | Same input → different output depending on batch size |
| EAGLE engine output is consistent across A100 and L4 | The EAGLE path itself is correct |
| Token swap after lm_head fix (Failure 2 vs 3) | lm_head was one gap, but upstream divergence remained |
custom_ops=["none"] under torch.compile | RMSNorm’s forward_cuda (batch invariant) is bypassed |
E2e batch invariant tests use enforce_eager=True on L4 | Tests never exercise the Inductor path on SM89 |
| CI passes consistently after both fixes | Both gaps were the root cause |
10. Files Modified
| File | Changes | Purpose |
|---|---|---|
vllm/model_executor/layers/vocab_parallel_embedding.py | Added VLLM_BATCH_INVARIANT check to UnquantizedEmbeddingMethod.apply | Fix 1: lm_head batch invariance |
tests/v1/distributed/test_eagle_dp.py | Added compilation_config={"custom_ops": ["none", "+rms_norm"]}, divergence diagnostics | Fix 2: force RMSNorm forward_cuda path |
vllm/v1/core/sched/async_scheduler.py | Added logger.info at 2 points | Debugging instrumentation (to be removed) |
vllm/v1/core/sched/scheduler.py | Added logger.info at 2 points | Debugging instrumentation (to be removed) |
vllm/v1/worker/gpu/model_runner.py | Added print at 4 points | Debugging instrumentation (to be removed) |
11. Appendix: Investigation Command Reference
This section documents the grep and sed patterns used throughout this investigation. These are general-purpose techniques for tracing execution paths in large codebases.
11.1 Finding Where Things Are Defined
Find class or function definitions:
# Find class definitions across the codebase
grep -rn "class LlamaForCausalLM\|class LlamaModel\|class LlamaMLP" vllm/model_executor/models/llama.py
# Find all files that define a specific class
grep -rn "class UnquantizedLinearMethod" vllm/model_executor/layers/ --include="*.py"
# Find function definitions
grep -rn "def compute_logits" vllm/model_executor/layers/logits_processor.py
Flags explained:
-
-r: Recursive search through directories -
-n: Show line numbers (critical for follow-up withsed) -
-l: Show only filenames (useful for broad searches) -
--include="*.py": Restrict to Python files only -
\|: OR operator in grep (escape the pipe)
11.2 Tracing a Feature Across the Codebase
Find every file that references a feature flag:
# Find all files referencing BATCH_INVARIANT
grep -rn "BATCH_INVARIANT" vllm/model_executor/layers/ --include="*.py" -l
# Find specific usage patterns
grep -rn "BATCH_INVARIANT" vllm/model_executor/layers/ --include="*.py"
Find callers of a specific function:
# Who calls dispatch_unquantized_gemm?
grep -rn "dispatch_unquantized_gemm" vllm/ --include="*.py"
# Who calls init_batch_invariance?
grep -rn "init_batch_invariance" vllm/ --include="*.py"
11.3 Reading Specific Code Sections
Use sed to view exact line ranges (after grep -n gives you line numbers):
# View lines 81-170 of a file
sed -n '81,170p' vllm/model_executor/models/llama.py
# View the UnquantizedLinearMethod class (after grep told us it starts at line 182)
sed -n '182,230p' vllm/model_executor/layers/linear.py
# View RMSNorm forward_cuda method
sed -n '262,323p' vllm/model_executor/layers/layernorm.py
The grep → sed workflow:
# Step 1: Find where something is defined
grep -n "def forward_cuda" vllm/model_executor/layers/layernorm.py
# Output: 262: def forward_cuda(
# Step 2: Read the full method
sed -n '262,323p' vllm/model_executor/layers/layernorm.py
11.4 Filtering Search Results
Combine grep with grep to filter:
# Find custom_ops mentions that relate to defaults
grep -rn '"none"\|"all"' vllm/config/compilation.py | grep -i "custom_ops\|append\|default"
# Find enforce_eager OR BATCH_INVARIANT across test files
grep -rn "enforce_eager\|BATCH_INVARIANT" tests/ --include="*.py" | grep -i "batch_invariant\|enforce_eager"
11.5 Finding Files in Directory Trees
Use find for complex searches:
# Find all Python files that define a specific class
find vllm/ -name "*.py" -exec grep -ln "class CompilationConfig" {} \;
# Find all files with "forward" defined in a directory tree
find vllm/model_executor/layers/rotary_embedding/ -name "*.py" -exec grep -ln "def forward" {} \;
11.6 Practical Patterns Used in This Investigation
Pattern 1: Trace a forward pass chain
# Start from the top-level model
grep -rn "class LlamaForCausalLM" vllm/model_executor/models/llama.py
# Find compute_logits
grep -rn "def compute_logits" vllm/model_executor/models/llama.py
# Read the method
sed -n '580,610p' vllm/model_executor/models/llama.py
# It calls self.logits_processor → find that
grep -rn "class LogitsProcessor" vllm/model_executor/layers/logits_processor.py
# Read its forward
sed -n '54,100p' vllm/model_executor/layers/logits_processor.py
# It calls lm_head.quant_method.apply → find the apply method
grep -rn "class UnquantizedEmbeddingMethod" vllm/model_executor/layers/vocab_parallel_embedding.py
sed -n '31,75p' vllm/model_executor/layers/vocab_parallel_embedding.py
Pattern 2: Compare two similar code paths for differences
# Side-by-side comparison: does LinearMethod have a check that EmbeddingMethod lacks?
grep -A5 "def apply" vllm/model_executor/layers/linear.py | head -10
grep -A5 "def apply" vllm/model_executor/layers/vocab_parallel_embedding.py | head -10
Pattern 3: Trace a config value from CLI to runtime
# Where does custom_ops get populated?
grep -rn 'custom_ops.*append' vllm/config/ --include="*.py"
# Where does "none" get added?
grep -rn 'custom_ops.*none' vllm/config/vllm.py
# Where is it consumed?
grep -rn "custom_ops" vllm/model_executor/custom_op.py
Pattern 4: Verify a fix hasn’t been applied elsewhere
# Are there other callers that might have the same bug?
grep -rn "dispatch_unquantized_gemm" vllm/ --include="*.py"
# Check each caller for the missing BATCH_INVARIANT check
11.7 Tips
- Always start with
grep -rnto get line numbers, then usesed -nto read context. This is faster than opening files in an editor. - Use
-l(files only) first for broad searches, then narrow down with specific patterns. - Pipe grep into grep to filter: first grep finds all matches, second grep filters to relevant ones.
- Use
| head -Nto avoid being overwhelmed by large result sets. Start withhead -10, increase if needed. -
--include="*.py"is essential when searching in repos with compiled objects, node_modules, or binary files. - The
\|OR operator in grep is powerful for finding related concepts together (e.g.,"class Foo\|def bar\|SOME_FLAG"). - When tracing a bug: start from the symptoms (the divergent output), trace backward to the operation that produces it, then compare that operation’s code path between the two cases (batch_size=1 vs batch_size=4).
12. Phase 3: Empirical Verification (Luka’s Request)
After Phase 2 concluded with the RMSNorm hypothesis, Luka Govedič (RH) pushed back on Slack to repro where we show that Inductor generates a kernel that is not batch invariant.
He also pointed out that PR #27660 (“Batch invariant torch.compile”) had already enabled and tested batch invariance under torch.compile on DeepSeek, so RMSNorm under torch.compile should work. This raised the possibility that our hypothesis was wrong.
12.1 Discovery of the IR Op Priority System
While searching for proof, we discovered a second dispatch system we had missed: IrOpPriorityConfig.
The engine config in CI logs (line 116229) revealed:
'custom_ops': ['none']
ir_op_priority=IrOpPriorityConfig(rms_norm=['native'])
This is separate from the CustomOp.dispatch_forward() system we analyzed in Phase 2. It controls which IR implementation (native, vllm_c, oink, aiter) is used inside vllm.ir.ops.rms_norm.
Where it’s set — vllm/platforms/cuda.py:get_default_ir_op_priority():
default = ["native"] if using_inductor else ["vllm_c", "native"]
When Inductor is the backend, the IR priority is ["native"] — meaning use the pure PyTorch reference implementation (x.pow(2).mean(dim=-1)) so Inductor can compile it. When not using Inductor, it falls back to the hand-written vllm_c CUDA kernel.
Both systems happen to push toward the same forward_native path under torch.compile, but they’re independent mechanisms. The fix in PR #38938 (+rms_norm) works by overriding CustomOp dispatch to bind forward_cuda, which has an explicit rms_norm_batch_invariant() call that bypasses both the IR op priority system and the aten dispatch table.
12.2 The Empirical Test in CI
To prove the hypothesis (or disprove it), we added an inline check at the top of test_run_eagle_dp that runs on the same L4 hardware before any vLLM state is set up:
def rms_norm_native(x, weight, eps=1e-5):
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x.to(orig_dtype)
return x * weight
compiled = torch.compile(rms_norm_native, dynamic=False)
# Same row of data, presented as batch=1 and as the first row of batch=4
shared_row = torch.randn(hidden_size, dtype=torch.bfloat16, device="cuda")
x_b1 = shared_row.unsqueeze(0).clone()
x_b4 = torch.randn(4, hidden_size, dtype=torch.bfloat16, device="cuda")
x_b4[0] = shared_row
out_b1 = compiled(x_b1, weight)
out_b4 = compiled(x_b4, weight)
print(f"bitwise_equal: {torch.equal(out_b1[0], out_b4[0])}")
print(f"max_abs_diff: {(out_b1[0].to(torch.float32) - out_b4[0].to(torch.float32)).abs().max().item()}")
12.3 Surprising Result
CI output on L4 (NVIDIA L4, capability 8.9):
[inductor_bi_check] device: NVIDIA L4
[inductor_bi_check] capability: (8, 9)
[inductor_bi_check] bitwise_equal: True
[inductor_bi_check] max_abs_diff: 0.0
[inductor_bi_check] n_differing_elements: 0 / 4096
A second variant (v2) explicitly called enable_batch_invariant_mode() first to register the aten overrides. Same result:
[inductor_bi_check_v2] WITH enable_batch_invariant_mode()
[inductor_bi_check_v2] bitwise_equal: True
[inductor_bi_check_v2] max_abs_diff: 0.0
[inductor_bi_check_v2] n_differing_elements: 0 / 4096
Inductor’s compiled RMSNorm IS batch invariant on L4 — even for the simple isolated case, with or without enable_batch_invariant_mode() registered. This contradicts our Phase 2 hypothesis.
12.4 What This Means
Our Phase 2 explanation was wrong, or at least incomplete. The simple case shows that:
- Inductor’s
mean.dimcodegen is batch invariant on L4 (at least for this shape and dtype) - The aten override registration doesn’t affect this — Inductor produces the same invariant kernel either way
- So the divergence in the actual EAGLE DP test is not caused by Inductor’s basic RMSNorm codegen being non-deterministic
12.5 Open Questions
The empirical result raises several possibilities for what’s actually causing the EAGLE DP test failures:
-
The PR #38938 fix may have worked for a different reason than we documented. The token-level divergence we observed in CI may have been resolved by the fix, but our explanation of why may be incorrect.
-
The issue may be in the fused-add RMSNorm path (
fused_add_rms_norm), which is what’s actually used in Llama for all layers except the input layernorm. This is a different code path than the standalone RMSNorm we tested. -
The issue may be in op fusion — Inductor fuses RMSNorm with surrounding ops (residual add, next linear’s input prep). The fused kernel may behave differently than standalone RMSNorm.
-
The issue may be in a different op entirely — RoPE, attention, or activation. RMSNorm being fine doesn’t rule out something else being broken.
-
There may be an interaction with batch invariance machinery that only manifests in the full graph context (with
enable_batch_invariant_mode()having modified cuBLAS, NCCL, and reduction precision settings).
12.6 v3 (Failed): vLLM RMSNorm Module Under Compile
We tried to test the actual vllm.model_executor.layers.layernorm.RMSNorm module under torch.compile:
from vllm.model_executor.layers.layernorm import RMSNorm
real_norm = RMSNorm(hidden_size, eps=1e-5).cuda().to(torch.bfloat16)
compiled_norm = torch.compile(real_norm, dynamic=False)
This crashed with:
AssertionError: Current vLLM config is not set. This typically means
get_current_vllm_config() was called outside of a set_current_vllm_config()
context, or a CustomOp was instantiated at module import time or model
forward time when config is not set.
RMSNorm.__init__() (a CustomOp subclass) requires an active vLLM config context to instantiate. The fix is to use the default_vllm_config pytest fixture, but this means the test failure result we got was from this assertion, not from a token divergence — so we don’t yet know whether the rest of the test would have passed or failed at the actual greedy comparison.
12.7 Next Steps
-
Test the fused-add RMSNorm variant — write a standalone
fused_add_rms_norm_native(x, residual, weight)function and run the same batch=1 vs batch=4 invariance check. -
Test in graph context — instead of compiling RMSNorm in isolation, compile a small chunk that includes RMSNorm + a residual add + a linear layer, to let Inductor fuse them.
-
Use the
default_vllm_configfixture — to enable testing the actualvllm.model_executor.layers.layernorm.RMSNormmodule under compile. -
Bisect by removing fixes — temporarily revert the lm_head fix and the RMSNorm fix one at a time to identify which one is actually doing the work, and compare CI outputs to see which divergences appear with which fix removed.
-
Test other ops in the forward path — RoPE, SiluAndMul, attention output projection — using the same shape-invariance pattern.
12.8 Updated Status
The fix in PR #38938 still passes CI consistently, so it’s still the right short-term fix even if our explanation needs refinement. The follow-up issue #39096 should be updated to reflect that the simple RMSNorm-under-Inductor case is actually batch invariant on L4, and the real source of the divergence is still under investigation.
13. Resolution
13.1 Pivot: Move the Test to H100
After confirming that enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90 made the test pass consistently on L4, Benjamin Chislett pushed back on Slack that he would really not like to lose coverage of cuda graphs with this feature and that we could we migrate this test to 2xH100 to get over this issue.
The concern: the enforce_eager workaround disables torch.compile AND CUDA graphs on L4. If we kept the test only on L4, we’d lose coverage of the compiled + graph-captured path entirely for this test. Since the SM<90 limitation is well-documented and unlikely to be fixed soon, the better approach is to run the test on hardware where batch invariance + torch.compile + CUDA graphs actually works — H100 (SM90+).
Luka agreed and asked to move the test.
13.2 Additional Experiment: Isolating torch.compile vs CUDA Graphs
Before moving, Luka also asked whether we could isolate whether the SM<90 issue is specifically torch.compile or specifically CUDA graphs — since enforce_eager=True disables both. The goal was to determine if we could use --cudagraph_mode=NONE alone instead (keeping torch.compile active).
We tested enforce_eager=False, compilation_config={"cudagraph_mode": "NONE"} on L4. The test still failed with divergence at token 80 (the same position as the original Phase 1 failure, 20400 != 4324).
Result: torch.compile alone is sufficient to break batch invariance on SM89, even without CUDA graphs. The comment in tests/v1/determinism/utils.py (“For devices with SM < 90, batch invariance does not support CUDA Graphs”) is incomplete — both torch.compile and CUDA graphs contribute to the issue on SM<90, and disabling just one isn’t enough. Also, disabling torch.compile while keeping CUDA graphs active isn’t a valid configuration in vLLM (graph capture happens during the compiled forward pass), so we can’t test the reverse direction.
13.3 Final Fix
Two code changes:
-
vllm/model_executor/layers/vocab_parallel_embedding.py— AddedVLLM_BATCH_INVARIANTcheck toUnquantizedEmbeddingMethod.apply. This is a genuine bug in the lm_head projection path and applies to all GPUs, independent of the SM<90 issue. This was Phase 1’s finding and it’s a real fix. -
.buildkite/test_areas/distributed.yaml— Movedtest_eagle_dpout of the L4 jobs (Distributed DP Tests (2 GPUs)andDistributed DP Tests (4 GPUs)) and added two new dedicated H100 jobs (Distributed EAGLE DP Tests (2 GPUs)(H100)andDistributed EAGLE DP Tests (4 GPUs)(H100)). The H100 jobs have focusedsource_file_dependenciescovering the Llama model, batch invariance machinery, spec decode, and the test file itself.
Kept as defensive guard:
The enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90 change in tests/v1/distributed/test_eagle_dp.py stays. Even though the test now runs on H100 in CI, the workaround is kept so the test still works correctly if anyone runs it locally on an L4 machine, or if it ends up in another L4-based CI job (e.g., .buildkite/test_areas/model_runner_v2.yaml which also exercises this test on L4).
13.4 PR Bookkeeping
The PR had accumulated ~35 commits during the investigation (many were scratch debugging commits + merges from main). Before merging:
- Squashed all commits into a single clean commit using
git reset --soft origin/mainfollowed by onegit commit - Force-pushed to overwrite the noisy history
- CI was retriggered with
git commit --allow-empty -m "Retrigger CI"once due to a transient safetensors download race on a fresh H100 node (not related to the fix)
13.5 Summary of the Full Journey
| Phase | Hypothesis | Outcome |
|---|---|---|
| 1 | Async scheduling bug | Ruled out — scheduler logic traced correctly, A100 runs were 100% deterministic |
| 1 | lm_head missing batch invariance | Confirmed — UnquantizedEmbeddingMethod.apply was missing the check. Fixed. |
| 2 | RMSNorm forward_native bypasses aten override under Inductor | Disproven empirically — inline CI repro showed torch.compile(rms_norm) IS batch invariant on L4 |
| 3 | CUDA graphs on SM<90 are not batch invariant (per utils.py:14) | Confirmed as contributing factor, but incomplete — torch.compile alone is also implicated |
| 3 | Move test to H100 where batch invariance + compile + graphs works | Accepted by maintainers and implemented |
13.6 What Generalizes
A few lessons that generalize beyond this specific bug:
-
Empirical verification beats plausible explanation. The Phase 2 RMSNorm hypothesis was internally consistent and looked correct on paper. It took a direct measurement (the inline CI repro) to disprove it. Always test hypotheses with the cheapest possible experiment before committing to an explanation.
-
Git history encodes institutional knowledge. The single most valuable step in Phase 3 was finding PR #30018 via
git blameonIS_DEVICE_CAPABILITY_BELOW_90. The comment inutils.py:14had the answer we were looking for — we just hadn’t looked there. When investigating a bug that feels like it should already be known, check whether someone else has encoded a workaround in similar tests. -
Don’t conflate similar-sounding dispatch systems.
CustomOp.dispatch_forward()andIrOpPriorityConfiglook similar (both decide which implementation to use) but are separate mechanisms. Phase 2’s investigation got tangled up in both because we found one and assumed it was the whole picture. Always map out all dispatch layers before drawing conclusions. -
CI logs are the primary evidence, not hypotheses. Every conclusion in this investigation ultimately came from a log line or a direct measurement. When the maintainer (Luka) pushed back and asked for empirical proof, that pressure forced a real repro, which disproved the hypothesis we’d committed to. Maintainers asking for proof is a feature, not an obstacle.
-
Sometimes the right fix is moving the test, not fixing the code. The SM<90 + batch invariance + torch.compile interaction is a known limitation that isn’t worth fixing given the available alternatives (run on newer hardware). Accepting that workaround over a deeper fix was the right tradeoff.
13.7 Related Links
- PR #38938 — The fix (merged)
- Issue #39096 — Open follow-up: batch invariance + torch.compile on SM<90
- Issue #31913 — Original flaky test issue (closed by #38938)
- PR #30018 — Where
IS_DEVICE_CAPABILITY_BELOW_90was introduced - PR #27660 — Earlier batch invariant torch.compile work (DeepSeek)