def deepseek_v4_flashinfer_sparse_mla_warmup(worker: "Worker") -> None:
"""Warm the DSV4 FlashInfer sparse-index builder variants.
CUDA graph capture exercises mixed batches, but Triton can still see the
first real prefill wave as a separate specialization for the per-layer C4A
and C128A index shapes. Compile those tiny index-builder launches during
engine warmup so they do not appear as inference-time bubbles.
"""
from vllm.v1.attention.backends.mla.sparse_swa import (
_compute_prefill_metadata_kernel,
)
from vllm.v1.attention.ops.deepseek_v4_ops.cache_utils import (
build_flashinfer_mixed_sparse_indices,
)
hf_config = worker.vllm_config.model_config.hf_config
compress_ratios = {
int(ratio) for ratio in getattr(hf_config, "compress_ratios", ())
}
if not compress_ratios:
return
window_size = int(getattr(hf_config, "sliding_window", 0))
if window_size <= 0:
return
logger.info("Warming up DeepSeek V4 FlashInfer sparse MLA index kernels.")
device = worker.model_runner.device
index_topk = int(getattr(hf_config, "index_topk", 0))
max_model_len = worker.vllm_config.model_config.max_model_len
max_num_seqs = max(1, worker.scheduler_config.max_num_seqs)
def _prefill_batch_sizes() -> list[int]:
sizes: list[int] = []
size = 1
while size < max_num_seqs:
sizes.append(size)
size *= 2
sizes.append(max_num_seqs)
return sizes
max_prefill_reqs = max(_prefill_batch_sizes())
seq_lens = torch.ones((max_prefill_reqs,), device=device, dtype=torch.int32)
query_start_loc = torch.arange(
max_prefill_reqs + 1, device=device, dtype=torch.int32
)
prefill_query_start_loc = torch.empty(
max_prefill_reqs + 1, device=device, dtype=torch.int32
)
prefill_gather_lens = torch.empty(
max_prefill_reqs, device=device, dtype=torch.int32
)
for num_prefills in _prefill_batch_sizes():
_compute_prefill_metadata_kernel[(1,)](
prefill_query_start_loc[: num_prefills + 1],
prefill_gather_lens[:num_prefills],
seq_lens[:num_prefills],
query_start_loc[: num_prefills + 1],
num_prefills,
0,
window_size,
BLOCK_SIZE=1 << num_prefills.bit_length(),
)
for compress_ratio in sorted(compress_ratios):
if compress_ratio == 4:
topk = index_topk
decode_compressed_indices_are_local = True
has_decode_compressed_lens = False
elif compress_ratio == 128:
topk = (max_model_len + compress_ratio - 1) // compress_ratio
topk = ((topk + 127) // 128) * 128
decode_compressed_indices_are_local = False
has_decode_compressed_lens = True
else:
continue
if topk <= 0:
continue
decode_swa_indices = torch.zeros(
(1, window_size), device=device, dtype=torch.int32
)
decode_compressed_indices = torch.zeros(
(1, topk), device=device, dtype=torch.int32
)
prefill_topk_indices = torch.zeros((1, topk), device=device, dtype=torch.int32)
query_start_loc = torch.tensor([0, 1, 2], device=device, dtype=torch.int32)
seq_lens = torch.tensor([1, 2], device=device, dtype=torch.int32)
token_to_req_indices = torch.tensor([0, 1], device=device, dtype=torch.int32)
swa_block_table = torch.zeros((2, 1), device=device, dtype=torch.int32)
compressed_block_table = torch.zeros((2, 1), device=device, dtype=torch.int32)
decode_compressed_topk_lens = (
torch.ones((1,), device=device, dtype=torch.int32)
if has_decode_compressed_lens
else None
)
decode_is_valid_token = (
torch.ones((1,), device=device, dtype=torch.bool)
if decode_compressed_indices_are_local
else None
)
build_flashinfer_mixed_sparse_indices(
decode_swa_indices,
decode_compressed_indices,
decode_compressed_topk_lens,
prefill_topk_indices,
query_start_loc,
seq_lens,
token_to_req_indices,
swa_block_table,
256,
compressed_block_table,
max(1, 256 // compress_ratio),
window_size,
compress_ratio,
topk,
decode_compressed_indices_are_local=decode_compressed_indices_are_local,
decode_is_valid_token=decode_is_valid_token,
)
torch.accelerator.synchronize()