Skip to content

vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize

FlashInferCutlassMoEPrepareAndFinalize

Bases: FusedMoEPrepareAndFinalize

Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):

    def __init__(
        self,
        quant_dtype: Optional[torch.dtype] = None,
        per_channel_quant: bool = False,
        block_shape: Optional[list[int]] = None,
        num_dispatchers: int = 1,
    ):
        super().__init__()
        self.per_channel_quant = per_channel_quant
        self.block_shape = block_shape
        self.quant_dtype = quant_dtype
        self.num_dispatchers_ = num_dispatchers

    @property
    def activation_format(self) -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.Standard

    def max_num_tokens_per_rank(self) -> Optional[int]:
        return None

    def topk_indices_dtype(self) -> Optional[torch.dtype]:
        return None

    def num_dispatchers(self) -> int:
        return self.num_dispatchers_

    def prepare(
        self,
        a1: torch.Tensor,
        a1_scale: Optional[torch.Tensor],  # Not used
        a2_scale: Optional[torch.Tensor],  # Not used
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        num_experts: int,
        expert_map: Optional[torch.Tensor],
        apply_router_weight_on_input: bool,
        quant_config: FusedMoEQuantConfig,
        extra_prepare_args: Optional[dict[str, Any]]
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
               Optional[torch.Tensor], Optional[torch.Tensor]]:

        assert not apply_router_weight_on_input

        (a1_gscale, use_dp, local_tokens) = extract_required_args(
            extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])

        a1q, a1q_scale = moe_kernel_quantize_input(
            a1,
            a1_gscale,
            quant_config.quant_dtype,
            self.per_channel_quant,
            self.block_shape,
            is_fp4_scale_swizzled=not use_dp,  # Swizzling after communication
        )
        if use_dp:
            topk_weights, topk_ids, a1q, a1q_scale = \
                get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
                                           dim=0,
                                           sizes=get_local_sizes())
            a1_m, a1_n = a1q.shape
            a1q_scale = nvfp4_block_scale_interleave(a1q_scale)

        return a1q, a1q_scale, None, topk_ids, topk_weights

    def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
                 topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 apply_router_weight_on_input: bool,
                 weight_and_reduce_impl: mk.TopKWeightAndReduce,
                 extra_finalize_args: Optional[dict[str, Any]]) -> None:

        (use_dp,
         local_tokens) = extract_required_args(extra_finalize_args,
                                               ['use_dp', 'local_tokens'])
        if use_dp:
            fused_expert_output = get_dp_group().reduce_scatterv(
                fused_expert_output, dim=0, sizes=get_local_sizes())
        output.copy_(fused_expert_output)

activation_format property

activation_format: FusedMoEActivationFormat

block_shape instance-attribute

block_shape = block_shape

num_dispatchers_ instance-attribute

num_dispatchers_ = num_dispatchers

per_channel_quant instance-attribute

per_channel_quant = per_channel_quant

quant_dtype instance-attribute

quant_dtype = quant_dtype

__init__

__init__(
    quant_dtype: Optional[dtype] = None,
    per_channel_quant: bool = False,
    block_shape: Optional[list[int]] = None,
    num_dispatchers: int = 1,
)
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
def __init__(
    self,
    quant_dtype: Optional[torch.dtype] = None,
    per_channel_quant: bool = False,
    block_shape: Optional[list[int]] = None,
    num_dispatchers: int = 1,
):
    super().__init__()
    self.per_channel_quant = per_channel_quant
    self.block_shape = block_shape
    self.quant_dtype = quant_dtype
    self.num_dispatchers_ = num_dispatchers

finalize

finalize(
    output: Tensor,
    fused_expert_output: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: TopKWeightAndReduce,
    extra_finalize_args: Optional[dict[str, Any]],
) -> None
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
             topk_weights: torch.Tensor, topk_ids: torch.Tensor,
             apply_router_weight_on_input: bool,
             weight_and_reduce_impl: mk.TopKWeightAndReduce,
             extra_finalize_args: Optional[dict[str, Any]]) -> None:

    (use_dp,
     local_tokens) = extract_required_args(extra_finalize_args,
                                           ['use_dp', 'local_tokens'])
    if use_dp:
        fused_expert_output = get_dp_group().reduce_scatterv(
            fused_expert_output, dim=0, sizes=get_local_sizes())
    output.copy_(fused_expert_output)

max_num_tokens_per_rank

max_num_tokens_per_rank() -> Optional[int]
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
def max_num_tokens_per_rank(self) -> Optional[int]:
    return None

num_dispatchers

num_dispatchers() -> int
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
def num_dispatchers(self) -> int:
    return self.num_dispatchers_

prepare

prepare(
    a1: Tensor,
    a1_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    topk_weights: Tensor,
    topk_ids: Tensor,
    num_experts: int,
    expert_map: Optional[Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
    extra_prepare_args: Optional[dict[str, Any]],
) -> tuple[
    Tensor,
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
]
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
def prepare(
    self,
    a1: torch.Tensor,
    a1_scale: Optional[torch.Tensor],  # Not used
    a2_scale: Optional[torch.Tensor],  # Not used
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
    expert_map: Optional[torch.Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
    extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
           Optional[torch.Tensor], Optional[torch.Tensor]]:

    assert not apply_router_weight_on_input

    (a1_gscale, use_dp, local_tokens) = extract_required_args(
        extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])

    a1q, a1q_scale = moe_kernel_quantize_input(
        a1,
        a1_gscale,
        quant_config.quant_dtype,
        self.per_channel_quant,
        self.block_shape,
        is_fp4_scale_swizzled=not use_dp,  # Swizzling after communication
    )
    if use_dp:
        topk_weights, topk_ids, a1q, a1q_scale = \
            get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
                                       dim=0,
                                       sizes=get_local_sizes())
        a1_m, a1_n = a1q.shape
        a1q_scale = nvfp4_block_scale_interleave(a1q_scale)

    return a1q, a1q_scale, None, topk_ids, topk_weights

topk_indices_dtype

topk_indices_dtype() -> Optional[dtype]
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
def topk_indices_dtype(self) -> Optional[torch.dtype]:
    return None

get_local_sizes

get_local_sizes()
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
def get_local_sizes():
    return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()