vllm.v1.attention.ops.deepseek_v4_ops ¶
Modules:
| Name | Description |
|---|---|
cache_utils | Triton kernels for DeepseekV4 paged K-cache management and sparse-attention index |
fused_compress_quant_cache | Fused compressor + FP8/MXFP4 UE8M0 quantization + KV cache insert kernels. |
fused_indexer_q | |
fused_indexer_q_cutedsl | |
fused_inv_rope_fp8_quant | Fused inverse RoPE + block-scaled FP8 quantization kernel for DeepseekV4 attention. |
build_flashinfer_mixed_sparse_indices ¶
build_flashinfer_mixed_sparse_indices(
decode_swa_indices: Tensor,
decode_compressed_indices: Tensor | None,
decode_compressed_topk_lens: Tensor | None,
prefill_topk_indices: Tensor,
query_start_loc: Tensor,
seq_lens: Tensor,
token_to_req_indices: Tensor,
swa_block_table: Tensor,
swa_block_size: int,
compressed_block_table: Tensor | None,
compressed_block_size: int,
window_size: int,
compress_ratio: int,
topk: int,
decode_compressed_indices_are_local: bool = False,
decode_is_valid_token: Tensor | None = None,
) -> tuple[Tensor, Tensor]
Build FlashInfer DSV4 sparse indices for decode-first mixed batches.
Source code in vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 | |
compute_global_topk_indices_and_lens ¶
compute_global_topk_indices_and_lens(
topk_indices: Tensor,
token_to_req_indices: Tensor,
block_table: Tensor,
block_size: int,
is_valid_token: Tensor,
) -> tuple[Tensor, Tensor]
Map local topk indices to global KV cache slots and count valid entries.
Fuses three operations into a single kernel: 1. Block-table lookup (local index → global slot id) 2. Valid-entry counting (topk_lens per token) 3. Masking padding tokens to length 0
Source code in vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py
fused_indexer_q_rope_quant ¶
fused_indexer_q_rope_quant(
positions: Tensor,
index_q: Tensor,
index_q_cos_sin_cache: Tensor,
index_weights: Tensor,
index_weights_softmax_scale: float,
index_weights_head_scale: float,
use_fp4: bool = False,
) -> tuple[Tensor | tuple[Tensor, Tensor], Tensor]
Fused RoPE + quantize Q for the sparse indexer.
Weight-fold semantics (important — the two paths differ):
FP8 path (use_fp4=False, default): q_fp8 : (T, H, HEAD_DIM) float8_e4m3fn, per-token-per-head scalar scale (NOT stored — folded into weights below) weights_out = weights * q_scale * softmax_scale * head_scale Rationale: a single per-token q_scale is a scalar the downstream FP8 logits kernel would otherwise multiply in. Folding it into weights avoids emitting a separate tensor and is free for the logits kernel.
MXFP4 path (use_fp4=True): q_packed : (T, H, HEAD_DIM // 2) uint8 (2 E2M1 nibbles per byte) q_scale : (T, H, HEAD_DIM // MXFP4_BLOCK_SIZE) uint8 ue8m0 bytes weights_out = weights * softmax_scale * head_scale Rationale: MXFP4 has PER-BLOCK (32-element) scales that live with the Q values — they cannot be folded into a per-token weight scalar, so weights carries only the softmax and head scales.
Returns (q_quant, weights_out) where q_quant is either a Tensor (FP8) or a (values, scales) tuple (MXFP4). This matches the union type accepted by SparseAttnIndexer.forward_*.
Source code in vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 | |
quantize_and_insert_k_cache ¶
quantize_and_insert_k_cache(
k: Tensor,
k_cache: Tensor,
slot_mapping: Tensor,
block_size: int = 64,
is_ue8m0: bool = True,
)
Quantize K tensor and insert into paged K cache.
K Cache block layout (block_size=64 tokens): - First 64 * 576 = 36864 bytes: Token data - Each token: 448 bytes (fp8) + 128 bytes (bf16) - Next 64 * 8 = 512 bytes: Scales - Each token: 8 bytes (uint8 scales, 7 real + 1 padding) - Padded to multiple of 576