mstar.utils.flashinfer_utils#
FlashInfer utility wrappers for batched paged attention.
Provides: - run_rms_norm / run_attention: simple single-request helpers - FlashInferPrefillWrapper: batched prefill with paged KV cache, optional CUDA graph mode - FlashInferDecodeWrapper: batched decode with paged KV cache, optional CUDA graph mode
CUDA graph mode requires: - Static buffer pointers passed at construction (qo_indptr_buf, paged_kv_indptr_buf, etc.) - plan() updates values via .copy_() without reallocating - The same wrapper object must be used during both capture and replay
- Adapted from VoxServe’s flashinfer_utils.py for our KV cache layout:
[num_layers, max_pages, 2, page_size, num_kv_heads, head_dim]
(VoxServe uses [n_pages, 2, page_size, n_heads, head_dim] without layer dim.)
Functions
|
|
|
Classes
|
Batched decode attention with paged KV cache. |
|
Batched prefill attention with paged KV cache. |
- class mstar.utils.flashinfer_utils.FlashInferDecodeWrapper(workspace_buffer, num_qo_heads, num_kv_heads, head_dim, page_size, batch_size=None, max_num_pages=None, device=torch.device('cuda'), use_cuda_graph=False, enable_nvtx=False)[source]#
Bases:
objectBatched decode attention with paged KV cache.
Optimized for the common decode case where each request appends exactly 1 new token. Uses BatchDecodeWithPagedKVCacheWrapper.
- Parameters:
workspace_buffer (Tensor) – FlashInfer workspace
num_qo_heads (int) – number of query/output heads
num_kv_heads (int) – number of key/value heads
head_dim (int) – dimension per head
page_size (int) – KV cache page size
batch_size (int | None) – required for CUDA graph mode (max requests in batch)
max_num_pages (int | None) – required for CUDA graph mode (max pages across all requests)
device (device) – torch device
use_cuda_graph (bool) – if True, pre-allocate static buffers for graph capture
enable_nvtx (bool)
- plan(paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len, kv_cache_locations=None, dtype=torch.bfloat16)[source]#
Plan decode attention and compute KV write locations.
For decode, each request appends exactly 1 token. The write location is the last page at position = last_page_len (before the append; after append it becomes last_page_len).
Inputs may be on CPU; see prefill wrapper’s plan docstring.
- class mstar.utils.flashinfer_utils.FlashInferPrefillWrapper(workspace_buffer, num_qo_heads, num_kv_heads, head_dim, page_size, batch_size=None, max_total_tokens=None, max_num_pages=None, device=torch.device('cuda'), use_cuda_graph=False, enable_nvtx=False)[source]#
Bases:
objectBatched prefill attention with paged KV cache.
Wraps flashinfer.BatchPrefillWithPagedKVCacheWrapper with: - Pre-computed token_to_page / token_to_cache for vectorized KV writes - Optional CUDA graph mode with static buffers
- Parameters:
workspace_buffer (Tensor) – FlashInfer workspace (256MB+ recommended)
num_qo_heads (int) – number of query/output heads
num_kv_heads (int) – number of key/value heads
head_dim (int) – dimension per head
page_size (int) – KV cache page size
batch_size (int | None) – required for CUDA graph mode (max requests in batch)
max_total_tokens (int | None) – required for CUDA graph mode (max total new tokens across batch)
max_num_pages (int | None) – required for CUDA graph mode (max pages across all requests)
device (device) – torch device
use_cuda_graph (bool) – if True, pre-allocate static buffers for graph capture
enable_nvtx (bool)
- plan(qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len, causal=True, dtype=torch.bfloat16)[source]#
Plan attention and compute KV write indices.
In CUDA graph mode, updates static buffers via .copy_() so that the same GPU addresses are used during graph replay.
Inputs may be on CPU — that’s preferred because FlashInfer’s
BatchPrefillWithPagedKVCacheWrapper.plandoesindptr.to("cpu")/last_page_len.to("cpu")internally; passing GPU tensors there triggers a synchronous default-stream sync that drains the speculatively-queued next decode step. We let the inner plan consume them as CPU and async-H2D copy to the device for our own per-token bookkeeping below.