Skip to content

vllm.v1.attention.backends.mamba1_attn

Mamba1AttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mamba1_attn.py
class Mamba1AttentionBackend(AttentionBackend):

    @staticmethod
    def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
        return Mamba1AttentionMetadataBuilder

get_builder_cls staticmethod

get_builder_cls() -> type[Mamba1AttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/mamba1_attn.py
@staticmethod
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
    return Mamba1AttentionMetadataBuilder

Mamba1AttentionMetadata dataclass

Source code in vllm/v1/attention/backends/mamba1_attn.py
@dataclass
class Mamba1AttentionMetadata:
    query_start_loc: torch.Tensor
    context_lens_tensor: torch.Tensor
    state_indices_tensor: torch.Tensor
    has_initial_states: torch.Tensor

context_lens_tensor instance-attribute

context_lens_tensor: Tensor

has_initial_states instance-attribute

has_initial_states: Tensor

query_start_loc instance-attribute

query_start_loc: Tensor

state_indices_tensor instance-attribute

state_indices_tensor: Tensor

__init__

__init__(
    query_start_loc: Tensor,
    context_lens_tensor: Tensor,
    state_indices_tensor: Tensor,
    has_initial_states: Tensor,
) -> None

Mamba1AttentionMetadataBuilder

Bases: AttentionMetadataBuilder[Mamba1AttentionMetadata]

Source code in vllm/v1/attention/backends/mamba1_attn.py
class Mamba1AttentionMetadataBuilder(
        AttentionMetadataBuilder[Mamba1AttentionMetadata]):

    reorder_batch_threshold: ClassVar[int] = 1

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        vllm_config: VllmConfig,
        device: torch.device,
        layer_names: list[str],
    ):
        assert isinstance(kv_cache_spec, MambaSpec)
        self.kv_cache_spec = kv_cache_spec
        self.device = device
        self.vllm_config = vllm_config
        self.layer_names = layer_names

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> Mamba1AttentionMetadata:
        query_start_loc = common_attn_metadata.query_start_loc

        state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
        context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
            query_start_loc.device)
        has_initial_states = (context_lens_tensor > 0)

        return Mamba1AttentionMetadata(
            query_start_loc=query_start_loc,
            context_lens_tensor=context_lens_tensor,
            has_initial_states=has_initial_states,
            state_indices_tensor=state_indices_tensor,
        )

device instance-attribute

device = device

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

layer_names instance-attribute

layer_names = layer_names

reorder_batch_threshold class-attribute

reorder_batch_threshold: int = 1

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    vllm_config: VllmConfig,
    device: device,
    layer_names: list[str],
)
Source code in vllm/v1/attention/backends/mamba1_attn.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    vllm_config: VllmConfig,
    device: torch.device,
    layer_names: list[str],
):
    assert isinstance(kv_cache_spec, MambaSpec)
    self.kv_cache_spec = kv_cache_spec
    self.device = device
    self.vllm_config = vllm_config
    self.layer_names = layer_names

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> Mamba1AttentionMetadata
Source code in vllm/v1/attention/backends/mamba1_attn.py
def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> Mamba1AttentionMetadata:
    query_start_loc = common_attn_metadata.query_start_loc

    state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
    context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
        query_start_loc.device)
    has_initial_states = (context_lens_tensor > 0)

    return Mamba1AttentionMetadata(
        query_start_loc=query_start_loc,
        context_lens_tensor=context_lens_tensor,
        has_initial_states=has_initial_states,
        state_indices_tensor=state_indices_tensor,
    )