Skip to content

vllm.model_executor.layers.quantization.modelopt

KV_CACHE_QUANT_ALGOS module-attribute

KV_CACHE_QUANT_ALGOS = ['FP8']

QUANT_ALGOS module-attribute

QUANT_ALGOS = ['FP8', 'NVFP4']

logger module-attribute

logger = init_logger(__name__)

FlashinferMoeBackend

Bases: Enum

Source code in vllm/model_executor/layers/quantization/modelopt.py
class FlashinferMoeBackend(Enum):
    TENSORRT_LLM = "TensorRT-LLM"
    CUTLASS = "CUTLASS"

CUTLASS class-attribute instance-attribute

CUTLASS = 'CUTLASS'

TENSORRT_LLM class-attribute instance-attribute

TENSORRT_LLM = 'TensorRT-LLM'

ModelOptFp8Config

Bases: QuantizationConfig

Config class for ModelOpt FP8.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
        kv_cache_quant_method: Optional[str] = None,
        exclude_modules: Optional[list[str]] = None,
    ) -> None:
        super().__init__()
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        self.kv_cache_quant_method = kv_cache_quant_method
        self.exclude_modules = exclude_modules
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
                           " the format is experimental and could change.")

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "modelopt"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError(
                    "Expected 'quantization' to be a dictionary in config")
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            exclude_modules = config.get("exclude_modules")

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration.")
        is_checkpoint_fp8_serialized = ("FP8" in quant_method)

        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
                   exclude_modules)

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

        # Check if any excluded module matches the prefix
        for module in self.exclude_modules:
            if (module in prefix
                    or (prefix.startswith("language_model.")
                        and module in prefix.removeprefix("language_model."))):
                return True
        return False

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import
        if isinstance(layer, LinearBase):
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
        elif isinstance(layer, FusedMoE):
            return ModelOptFp8MoEMethod(self)
        return None

exclude_modules instance-attribute

exclude_modules = exclude_modules

is_checkpoint_fp8_serialized instance-attribute

is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

kv_cache_quant_method instance-attribute

kv_cache_quant_method = kv_cache_quant_method

__init__

__init__(
    is_checkpoint_fp8_serialized: bool = False,
    kv_cache_quant_method: Optional[str] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    is_checkpoint_fp8_serialized: bool = False,
    kv_cache_quant_method: Optional[str] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None:
    super().__init__()
    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
    self.kv_cache_quant_method = kv_cache_quant_method
    self.exclude_modules = exclude_modules
    if is_checkpoint_fp8_serialized:
        logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
                       " the format is experimental and could change.")

from_config classmethod

from_config(config: dict[str, Any]) -> ModelOptFp8Config
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
    # Handle both ModelOpt format and compressed-tensors style format
    if "quantization" in config:
        # ModelOpt format: {"quantization": {"quant_algo": "..."}}
        quant_config = cls.get_from_keys(config, ["quantization"])
        if not isinstance(quant_config, dict):
            raise ValueError(
                "Expected 'quantization' to be a dictionary in config")
        quant_method = quant_config.get("quant_algo", "")
        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")
        kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
        exclude_modules = quant_config.get("exclude_modules")
    else:
        # Compressed-tensors style format:
        # {"quant_algo": "...", "quant_method": "modelopt"}
        quant_method = config.get("quant_algo", "")
        kv_cache_quant_method = config.get("kv_cache_quant_algo")
        exclude_modules = config.get("exclude_modules")

    if quant_method not in QUANT_ALGOS:
        raise ValueError(
            f"ModelOpt currently only supports: {QUANT_ALGOS} "
            "quantizations in vLLM. Please check the "
            "`hf_quant_config.json` file for your model's "
            "quant configuration.")
    is_checkpoint_fp8_serialized = ("FP8" in quant_method)

    return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
               exclude_modules)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["hf_quant_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_min_capability(cls) -> int:
    return 89

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "modelopt"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import
    if isinstance(layer, LinearBase):
        if self.is_layer_excluded(prefix):
            return UnquantizedLinearMethod()
        return ModelOptFp8LinearMethod(self)
    elif isinstance(layer, Attention):
        return ModelOptFp8KVCacheMethod(self)
    elif isinstance(layer, FusedMoE):
        return ModelOptFp8MoEMethod(self)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half]

is_layer_excluded

is_layer_excluded(prefix: str) -> bool

Check if a layer should be excluded from quantization.

This method handles both regular models and multimodal models that use the language_model prefix. For multimodal models, it checks if the module name (without the language_model prefix) is in the exclude list.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def is_layer_excluded(self, prefix: str) -> bool:
    """
    Check if a layer should be excluded from quantization.

    This method handles both regular models and multimodal models that use
    the language_model prefix. For multimodal models, it checks if the
    module name (without the language_model prefix) is in the exclude list.
    """
    if self.exclude_modules is None:
        return False

    # Check if any excluded module matches the prefix
    for module in self.exclude_modules:
        if (module in prefix
                or (prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model."))):
            return True
    return False

override_quantization_method classmethod

override_quantization_method(
    hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]

Detect if this ModelOpt config should be used based on quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def override_quantization_method(
        cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
    """Detect if this ModelOpt config should be used based on
    quantization config."""

    if hf_quant_cfg is None:
        return None

    # Use the community standard 'quant_method'
    quant_method = hf_quant_cfg.get("quant_method", "").lower()

    # Only proceed if the method is explicitly "modelopt"
    if quant_method != "modelopt":
        return None

    # Look for ModelOpt-specific config structure
    if "quantization" in hf_quant_cfg:
        quant_config = hf_quant_cfg["quantization"]
        if isinstance(quant_config, dict):
            quant_algo = quant_config.get("quant_algo", "")
            if "FP8" in quant_algo:
                return "modelopt"
    else:
        # Check for compressed-tensors style config with specific quant_algo
        quant_algo = hf_quant_cfg.get("quant_algo", "")
        if isinstance(quant_algo, str) and "FP8" in quant_algo:
            return "modelopt"

    return None

ModelOptFp8KVCacheMethod

Bases: BaseKVCacheMethod

Supports loading kv-cache scaling factors from FP8 checkpoints.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Union[ModelOptFp8Config,
                                           ModelOptNvFp4Config]):
        super().__init__(quant_config)

__init__

__init__(
    quant_config: Union[
        ModelOptFp8Config, ModelOptNvFp4Config
    ],
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: Union[ModelOptFp8Config,
                                       ModelOptNvFp4Config]):
    super().__init__(quant_config)

ModelOptFp8LinearMethod

Bases: LinearMethodBase

Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and activation scale. Future support might be added for dynamic scales.

Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale. Future support might be added for dynamic
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn datatype
        Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            weight_scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                                   weight_loader=weight_loader)
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
            scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                            weight_loader=weight_loader)

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
                layer.weight, layer.weight_scale, layer.logical_widths)
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
        layer.input_scale = Parameter(layer.input_scale.max(),
                                      requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
                                     input_scale=layer.input_scale,
                                     bias=bias)

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp(
    act_quant_static=True, act_quant_group_shape=PER_TENSOR
)

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: ModelOptFp8Config) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptFp8Config) -> None:
    self.quant_config = quant_config
    self.fp8_linear = Fp8LinearOp(
        act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return self.fp8_linear.apply(input=x,
                                 weight=layer.weight,
                                 weight_scale=layer.weight_scale,
                                 input_scale=layer.input_scale,
                                 bias=bias)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_fp8_serialized else
                    params_dtype)
    weight = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition,
        dtype=weight_dtype),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALE
        weight_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                               weight_loader=weight_loader)
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)
        # INPUT SCALE
        scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                        weight_loader=weight_loader)

        scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("input_scale", scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: Module) -> None:
    weight = layer.weight
    max_w_scale = layer.weight_scale.max()
    if not (layer.weight_scale == layer.weight_scale[0]).all():
        max_w_scale, weight = requantize_with_max_scale(
            layer.weight, layer.weight_scale, layer.logical_widths)
    layer.weight = Parameter(weight.t(), requires_grad=False)
    layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
    layer.input_scale = Parameter(layer.input_scale.max(),
                                  requires_grad=False)

ModelOptFp8MoEMethod

Bases: FusedMoEMethodBase

MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            cutlass_fp8_supported)
        self.cutlass_fp8_supported = cutlass_fp8_supported()
        self.flashinfer_moe_enabled = False
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
            logger.info_once(
                "Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
            self.flashinfer_moe_enabled = True

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):

        # Use FP8 dtype if checkpoint is serialized
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
        weight_loader = extra_weight_attrs.get("weight_loader")

        w13_weight = ModelWeightParameter(
            data=torch.empty(num_experts,
                             2 * intermediate_size_per_partition,
                             hidden_size,
                             dtype=weight_dtype),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
            data=torch.empty(num_experts,
                             hidden_size,
                             intermediate_size_per_partition,
                             dtype=weight_dtype),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
                    (num_experts, 2),
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

        layer.w13_weight = Parameter(layer.w13_weight.data,
                                     requires_grad=False)
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            per_tensor_dequantize)

        # Handle scale parameters
        if hasattr(layer,
                   "w13_weight_scale") and layer.w13_weight_scale is not None:
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
            if layer.w13_weight_scale.dim() == 2:

                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
                            layer.w13_weight[expert_id][start:start +
                                                        intermediate_size, :],
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
                            layer.w13_weight[expert_id][start:start +
                                                        intermediate_size, :],
                            _,
                        ) = scaled_fp8_quant(dq_weight,
                                             max_w13_scales[expert_id])

                        start += intermediate_size

                # Update the scale parameter to be per-expert
                layer.w13_weight_scale = Parameter(max_w13_scales,
                                                   requires_grad=False)
            else:
                layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
                                                   requires_grad=False)

        if hasattr(layer,
                   "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
                                              requires_grad=False)
        # Input scales must be equal for each expert in fp8 MoE layers.
        if hasattr(layer,
                   "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
                                              requires_grad=False)
        if hasattr(layer,
                   "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
                                             requires_grad=False)

        if self.flashinfer_moe_enabled:
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
            rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
                                              layer.w2_weight)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptFp8MoEMethod` yet.")

        if self.flashinfer_moe_enabled:
            assert activation == 'silu'
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
                apply_router_weight_on_input=apply_router_weight_on_input)

        # Expert selection
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
        )
        from vllm.model_executor.layers.fused_moe.fused_moe import (
            fused_experts)
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            use_fp8_w8a8=True,
            per_channel_quant=False,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

cutlass_fp8_supported instance-attribute

cutlass_fp8_supported = cutlass_fp8_supported()

flashinfer_moe_enabled instance-attribute

flashinfer_moe_enabled = False

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: ModelOptFp8Config) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptFp8Config) -> None:
    self.quant_config = quant_config
    from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
        cutlass_fp8_supported)
    self.cutlass_fp8_supported = cutlass_fp8_supported()
    self.flashinfer_moe_enabled = False
    if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
        logger.info_once(
            "Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
        self.flashinfer_moe_enabled = True

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `ModelOptFp8MoEMethod` yet.")

    if self.flashinfer_moe_enabled:
        assert activation == 'silu'
        assert not renormalize
        return apply_flashinfer_per_tensor_scale_fp8(
            layer=layer,
            hidden_states=x,
            router_logits=router_logits,
            routing_bias=e_score_correction_bias,
            global_num_experts=global_num_experts,
            top_k=top_k,
            num_expert_group=num_expert_group,
            topk_group=topk_group,
            apply_router_weight_on_input=apply_router_weight_on_input)

    # Expert selection
    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        e_score_correction_bias=e_score_correction_bias,
    )
    from vllm.model_executor.layers.fused_moe.fused_moe import (
        fused_experts)
    return fused_experts(
        x,
        layer.w13_weight,
        layer.w2_weight,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=True,
        activation=activation,
        use_fp8_w8a8=True,
        per_channel_quant=False,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        a1_scale=layer.w13_input_scale,
        a2_scale=layer.w2_input_scale,
        apply_router_weight_on_input=apply_router_weight_on_input,
    )

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):

    # Use FP8 dtype if checkpoint is serialized
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_fp8_serialized else
                    params_dtype)
    weight_loader = extra_weight_attrs.get("weight_loader")

    w13_weight = ModelWeightParameter(
        data=torch.empty(num_experts,
                         2 * intermediate_size_per_partition,
                         hidden_size,
                         dtype=weight_dtype),
        input_dim=2,
        output_dim=1,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w13_weight", w13_weight)

    w2_weight = ModelWeightParameter(
        data=torch.empty(num_experts,
                         hidden_size,
                         intermediate_size_per_partition,
                         dtype=weight_dtype),
        input_dim=2,
        output_dim=1,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w2_weight", w2_weight)

    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
                (num_experts, 2),
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        # Set weight loader attributes for scales
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Process FP8 MoE weights after loading from serialized checkpoint. Only supports pre-quantized checkpoints with FP8 weights and scales.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """Process FP8 MoE weights after loading from serialized checkpoint.
    Only supports pre-quantized checkpoints with FP8 weights and scales.
    """

    layer.w13_weight = Parameter(layer.w13_weight.data,
                                 requires_grad=False)
    layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

    from vllm._custom_ops import scaled_fp8_quant
    from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
        per_tensor_dequantize)

    # Handle scale parameters
    if hasattr(layer,
               "w13_weight_scale") and layer.w13_weight_scale is not None:
        # Fp8 moe kernel needs single weight scale for w13 per expert.
        # We take the max of the w1 and w3 scales
        # then dequant and requant each expert.
        if layer.w13_weight_scale.dim() == 2:

            # Get the maximum scale across w1 and w3 for each expert
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values

            # Requantize each expert's weights using the combined scale
            # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
            # where the first intermediate_size rows are w1, the next are w3
            intermediate_size = layer.w13_weight.shape[1] // 2
            for expert_id in range(layer.w13_weight.shape[0]):
                start = 0
                for shard_id in range(2):  # w1 and w3
                    # Dequantize using the original scale for this shard
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    intermediate_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    # Requantize using the combined max scale

                    (
                        layer.w13_weight[expert_id][start:start +
                                                    intermediate_size, :],
                        _,
                    ) = scaled_fp8_quant(dq_weight,
                                         max_w13_scales[expert_id])

                    start += intermediate_size

            # Update the scale parameter to be per-expert
            layer.w13_weight_scale = Parameter(max_w13_scales,
                                               requires_grad=False)
        else:
            layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
                                               requires_grad=False)

    if hasattr(layer,
               "w2_weight_scale") and layer.w2_weight_scale is not None:
        layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
                                          requires_grad=False)
    # Input scales must be equal for each expert in fp8 MoE layers.
    if hasattr(layer,
               "w13_input_scale") and layer.w13_input_scale is not None:
        layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
                                          requires_grad=False)
    if hasattr(layer,
               "w2_input_scale") and layer.w2_input_scale is not None:
        layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
                                         requires_grad=False)

    if self.flashinfer_moe_enabled:
        layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
        rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
                                          layer.w2_weight)

ModelOptNvFp4Config

Bases: QuantizationConfig

Config class for ModelOpt FP4.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
        kv_cache_quant_algo: Optional[str],
        exclude_modules: list[str],
        group_size: int = 16,
    ) -> None:
        super().__init__()
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
                " the format is experimental and could change in future.")

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "modelopt_fp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError(
                    "Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
                raise ValueError(f"kv_cache_quant_algo must be a string, got "
                                 f"{type(kv_cache_quant_algo_raw)}")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
                    raise ValueError(f"group_size must be an integer, got "
                                     f"{type(group_size_raw)}") from None

            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
                raise ValueError(f"exclude_modules must be a list, got "
                                 f"{type(exclude_modules)}")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
                raise ValueError(f"kv_cache_quant_algo must be a string, got "
                                 f"{type(kv_cache_quant_algo_raw)}")

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
                    raise ValueError(f"group_size must be an integer, got "
                                     f"{type(group_size_raw)}") from None

            exclude_modules = config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
                raise ValueError(f"exclude_modules must be a list, got "
                                 f"{type(exclude_modules)}")

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration.")
        is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
            required_fields = [
                "group_size", "kv_cache_quant_algo", "exclude_modules"
            ]
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
                    f"hf_quant_config.json: {missing_fields}")

        return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
                   exclude_modules, group_size)

    def is_layer_excluded(self, prefix: str,
                          exclude_modules: list[str]) -> bool:
        import regex as re
        for pattern in exclude_modules:
            regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
            if re.fullmatch(regex_str, prefix):
                return True
        return False

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import
        if isinstance(layer, LinearBase):
            if (is_layer_skipped(prefix, self.exclude_modules)
                    or self.is_layer_excluded(prefix, self.exclude_modules)):
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
        elif isinstance(layer, FusedMoE):
            return ModelOptNvFp4FusedMoE(self)
        return None

exclude_modules instance-attribute

exclude_modules = exclude_modules

group_size instance-attribute

group_size = group_size

is_checkpoint_nvfp4_serialized instance-attribute

is_checkpoint_nvfp4_serialized = (
    is_checkpoint_nvfp4_serialized
)

kv_cache_quant_algo instance-attribute

kv_cache_quant_algo = kv_cache_quant_algo

__init__

__init__(
    is_checkpoint_nvfp4_serialized: bool,
    kv_cache_quant_algo: Optional[str],
    exclude_modules: list[str],
    group_size: int = 16,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    is_checkpoint_nvfp4_serialized: bool,
    kv_cache_quant_algo: Optional[str],
    exclude_modules: list[str],
    group_size: int = 16,
) -> None:
    super().__init__()
    self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
    if is_checkpoint_nvfp4_serialized:
        logger.warning(
            "Detected ModelOpt NVFP4 checkpoint. Please note that"
            " the format is experimental and could change in future.")

        self.group_size = group_size
        self.kv_cache_quant_algo = kv_cache_quant_algo
        self.exclude_modules = exclude_modules

from_config classmethod

from_config(config: dict[str, Any]) -> ModelOptNvFp4Config
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
    # Handle both traditional ModelOpt format and compressed-tensors
    # style format
    if "quantization" in config:
        # Traditional ModelOpt format:
        # {"quantization": {"quant_algo": "..."}}
        quant_config = cls.get_from_keys(config, ["quantization"])
        if not isinstance(quant_config, dict):
            raise ValueError(
                "Expected 'quantization' to be a dictionary in config")

        quant_method = quant_config.get("quant_algo", "")
        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

        # Handle kv_cache_quant_algo with proper type validation
        kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
        if kv_cache_quant_algo_raw is None:
            # No KV cache quantization by default
            kv_cache_quant_algo = None
        elif isinstance(kv_cache_quant_algo_raw, str):
            kv_cache_quant_algo = kv_cache_quant_algo_raw
        else:
            raise ValueError(f"kv_cache_quant_algo must be a string, got "
                             f"{type(kv_cache_quant_algo_raw)}")

        # Handle group_size with proper type validation
        group_size_raw = quant_config.get("group_size")
        if group_size_raw is None:
            group_size = 16  # Default value
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(f"group_size must be an integer, got "
                                 f"{type(group_size_raw)}") from None

        exclude_modules = quant_config.get("exclude_modules", [])
        if not isinstance(exclude_modules, list):
            raise ValueError(f"exclude_modules must be a list, got "
                             f"{type(exclude_modules)}")
    else:
        # Compressed-tensors style format:
        # {"quant_algo": "...", "quant_method": "modelopt"}
        quant_method = config.get("quant_algo", "")

        # Handle kv_cache_quant_algo with proper type validation
        kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
        if kv_cache_quant_algo_raw is None:
            # No KV cache quantization by default
            kv_cache_quant_algo = None
        elif isinstance(kv_cache_quant_algo_raw, str):
            kv_cache_quant_algo = kv_cache_quant_algo_raw
        else:
            raise ValueError(f"kv_cache_quant_algo must be a string, got "
                             f"{type(kv_cache_quant_algo_raw)}")

        # Handle group_size with proper type validation
        group_size_raw = config.get("group_size")
        if group_size_raw is None:
            group_size = 16  # Default value
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(f"group_size must be an integer, got "
                                 f"{type(group_size_raw)}") from None

        exclude_modules = config.get("exclude_modules", [])
        if not isinstance(exclude_modules, list):
            raise ValueError(f"exclude_modules must be a list, got "
                             f"{type(exclude_modules)}")

    if quant_method not in QUANT_ALGOS:
        raise ValueError(
            f"ModelOpt currently only supports: {QUANT_ALGOS} "
            "quantizations in vLLM. Please check the "
            "`hf_quant_config.json` file for your model's "
            "quant configuration.")
    is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)

    # For FP4, these fields are required
    if is_checkpoint_nvfp4_serialized and "quantization" in config:
        # Check if required fields are present in the quantization config
        quant_config = config["quantization"]
        required_fields = [
            "group_size", "kv_cache_quant_algo", "exclude_modules"
        ]
        missing_fields = [
            field for field in required_fields if field not in quant_config
        ]
        if missing_fields:
            raise ValueError(
                f"NVFP4 quantization requires the following fields in "
                f"hf_quant_config.json: {missing_fields}")

    return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
               exclude_modules, group_size)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["hf_quant_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "modelopt_fp4"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import
    if isinstance(layer, LinearBase):
        if (is_layer_skipped(prefix, self.exclude_modules)
                or self.is_layer_excluded(prefix, self.exclude_modules)):
            return UnquantizedLinearMethod()
        return ModelOptNvFp4LinearMethod(self)
    elif isinstance(layer, Attention):
        return ModelOptFp8KVCacheMethod(self)
    elif isinstance(layer, FusedMoE):
        return ModelOptNvFp4FusedMoE(self)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

is_layer_excluded

is_layer_excluded(
    prefix: str, exclude_modules: list[str]
) -> bool
Source code in vllm/model_executor/layers/quantization/modelopt.py
def is_layer_excluded(self, prefix: str,
                      exclude_modules: list[str]) -> bool:
    import regex as re
    for pattern in exclude_modules:
        regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
        if re.fullmatch(regex_str, prefix):
            return True
    return False

override_quantization_method classmethod

override_quantization_method(
    hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]

Detect if this ModelOpt FP4 config should be used based on quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def override_quantization_method(
        cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
    """Detect if this ModelOpt FP4 config should be used based on
    quantization config."""
    if hf_quant_cfg is None:
        return None

    # Use the community standard 'quant_method'
    quant_method = hf_quant_cfg.get("quant_method", "").lower()

    # Only proceed if the method is explicitly "modelopt"
    if quant_method != "modelopt":
        return None

    # Look for ModelOpt-specific config structure
    if "quantization" in hf_quant_cfg:
        quant_config = hf_quant_cfg["quantization"]
        if isinstance(quant_config, dict):
            quant_algo = quant_config.get("quant_algo", "")
            if "NVFP4" in quant_algo:
                return "modelopt_fp4"
    else:
        # Check for compressed-tensors style config with specific
        # quant_algo field
        quant_algo = hf_quant_cfg.get("quant_algo", "")
        if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
            return "modelopt_fp4"

    return None

ModelOptNvFp4FusedMoE

Bases: FusedMoEMethodBase

MoE Method for FP4 Quantization. Args: quant_config: NVFP4 Quant Config

Source code in vllm/model_executor/layers/quantization/modelopt.py
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
    Args:
        quant_config: NVFP4 Quant Config
    """

    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
            detect_nvfp4_moe_support)
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
        self.allow_flashinfer = _nvfp4.allow_flashinfer
        self.use_marlin = _nvfp4.use_marlin
        self.flashinfer_moe_backend = None

        if self.allow_flashinfer:
            flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
            if flashinfer_moe_backend == "throughput":
                self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
                logger.info_once("Using FlashInfer CUTLASS kernels for "
                                 "ModelOptNvFp4FusedMoE.")
            elif flashinfer_moe_backend == "latency":
                self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
                logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
                                 "ModelOptNvFp4FusedMoE.")
            else:
                allowed_backends = ["throughput", "latency"]
                raise ValueError(
                    f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
                    f" expected one of {allowed_backends}")

        self.fused_experts: Optional[
            mk.FusedMoEModularKernel] = None  # type: ignore[assignment]

    def maybe_swap_experts_impl(
        self,
        moe_parallel_config: FusedMoEParallelConfig,
    ):
        if not self.allow_flashinfer:
            return
        self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
            moe_parallel_config)

    # This method update self.fused_experts
    # only prepare_finalize is not None call select_gemm_impl
    # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
    # when it's not called(TP case), we still have 2 kernels to use.
    def select_gemm_impl(self, prepare_finalize,
                         moe) -> mk.FusedMoEPermuteExpertsUnpermute:

        assert moe is not None and prepare_finalize is not None
        from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (  # noqa: E501
            select_nvfp4_gemm_impl)

        return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)

    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")

        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
                dtype=weight_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
                dtype=weight_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
                dtype=weight_scale_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition //
                self.quant_config.group_size,
                dtype=weight_scale_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

        w13_input_scale = PerTensorScaleParameter(data=torch.empty(
            num_experts, 2, dtype=torch.float32),
                                                  weight_loader=weight_loader)
        layer.register_parameter("w13_input_scale", w13_input_scale)

        w2_input_scale = PerTensorScaleParameter(data=torch.empty(
            num_experts, dtype=torch.float32),
                                                 weight_loader=weight_loader)
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def prepare_static_weight_layouts_for_trtllm_moe(
        self,
        gemm1_weights: torch.Tensor,
        gemm2_weights: torch.Tensor,
        gemm1_scales_linear_fp4_bytes: torch.Tensor,
        gemm2_scales_linear_fp4_bytes: torch.Tensor,
        hidden_size: int,
        intermediate_size: int,
        num_experts: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Prepare quantized weights for kernel (done offline with weights)."""
        from flashinfer import (reorder_rows_for_gated_act_gemm,
                                shuffle_matrix_a, shuffle_matrix_sf_a)
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 2)  # packed fp4
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
            torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size,
                                         hidden_size //
                                         16)  # fp8 scaling factors

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
            num_experts, hidden_size, intermediate_size // 2)  # packed fp4
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
            torch.float8_e4m3fn).reshape(num_experts, hidden_size,
                                         intermediate_size //
                                         16)  # fp8 scaling factors

        # Reorder rows of W1 and scales for fused gated activation
        gemm1_weights_fp4_interleaved = []
        gemm1_scales_fp4_interleaved = []
        for i in range(num_experts):
            gemm1_weights_fp4_interleaved.append(
                reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
            gemm1_scales_fp4_interleaved.append(
                reorder_rows_for_gated_act_gemm(
                    gemm1_scales_linear_fp4[i].clone()))

        # Stack weights and scales for all experts
        gemm1_weights_fp4_interleaved = torch.stack(
            gemm1_weights_fp4_interleaved).reshape(num_experts,
                                                   2 * intermediate_size,
                                                   hidden_size // 2)
        gemm1_scales_fp4_interleaved = torch.stack(
            gemm1_scales_fp4_interleaved).reshape(num_experts,
                                                  2 * intermediate_size,
                                                  hidden_size // 16)

        # Shuffle weights and scaling factors for transposed mma output
        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
            gemm1_weights_fp4_shuffled.append(
                shuffle_matrix_a(
                    gemm1_weights_fp4_interleaved[i].view(torch.uint8),
                    epilogue_tile_m))
            gemm1_scales_fp4_shuffled.append(
                shuffle_matrix_sf_a(
                    gemm1_scales_fp4_interleaved[i].view(torch.uint8),
                    epilogue_tile_m))

            gemm2_weights_fp4_shuffled.append(
                shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
                                 epilogue_tile_m))
            gemm2_scales_fp4_shuffled.append(
                shuffle_matrix_sf_a(
                    gemm2_scales_linear_fp4[i].view(torch.uint8),
                    epilogue_tile_m))

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
            torch.stack(gemm1_scales_fp4_shuffled).view(
                torch.float8_e4m3fn).reshape(num_experts,
                                             2 * intermediate_size,
                                             hidden_size // 16))

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
            torch.stack(gemm2_scales_fp4_shuffled).view(
                torch.float8_e4m3fn).reshape(num_experts, hidden_size,
                                             intermediate_size // 16))
        return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # GEMM 1 processing
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

        if self.allow_flashinfer:
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
                gemm1_weight, gemm1_weight_scale, dim=-2)

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
        layer.w13_weight_scale = Parameter(gemm1_weight_scale,
                                           requires_grad=False)

        # Common processing for w13_weight_scale_2
        if not torch.allclose(layer.w13_weight_scale_2[:, 0],
                              layer.w13_weight_scale_2[:, 1]):
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
                "Accuracy may be affected.")

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
                                             requires_grad=False)

        # Common processing for input scales and alphas
        w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
            torch.float32)
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
            requires_grad=False)

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
            (1 / w13_input_scale).to(torch.float32), requires_grad=False)

        # GEMM 2 processing
        layer.g2_alphas = Parameter(
            (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
            requires_grad=False)

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
            (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)

        # TensorRT-LLM specific processing
        if self.allow_flashinfer and \
            self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            # Prepare static weights for TRT-LLM kernel
            (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
             gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
             ) = self.prepare_static_weight_layouts_for_trtllm_moe(
                 layer.w13_weight,
                 layer.w2_weight,
                 layer.w13_weight_scale,
                 layer.w2_weight_scale,
                 layer.w2_weight.size(-2),  # hidden_size
                 layer.w13_weight.size(-2) // 2,  # intermediate_size
                 layer.w13_weight.size(0),  # num_experts
             )

            layer.gemm1_weights_fp4_shuffled = Parameter(
                gemm1_weights_fp4_shuffled, requires_grad=False)
            layer.gemm2_weights_fp4_shuffled = Parameter(
                gemm2_weights_fp4_shuffled, requires_grad=False)
            layer.gemm1_scales_fp4_shuffled = Parameter(
                gemm1_scales_fp4_shuffled, requires_grad=False)
            layer.gemm2_scales_fp4_shuffled = Parameter(
                gemm2_scales_fp4_shuffled, requires_grad=False)

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
                (layer.w2_input_scale_quant * layer.g1_alphas).to(
                    torch.float32),
                requires_grad=False,
            )

            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
            assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
                "Expected weight_scale.dim(1) to be divisible by 16")
            assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
                "Weight Blockscale must be represented as FP8-E4M3")
            w13_blockscale_swizzled = swizzle_blockscale(
                layer.w13_weight_scale)
            layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
                                                      requires_grad=False)

            assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
                "Expected weight_scale.dim(1) to be divisible by 16")
            assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
                "Weight Blockscale must be represented as FP8-E4M3")
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
            layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
                                                     requires_grad=False)
            layer.w2_weight = Parameter(layer.w2_weight.data,
                                        requires_grad=False)

        if self.use_marlin:
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
            del layer.w13_blockscale_swizzled
            del layer.w2_blockscale_swizzled

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ):
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
        assert activation == "silu", "Only SiLU activation is supported."

        if self.allow_flashinfer and \
            self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

            a1_gscale = layer.w13_input_scale_quant
            (hidden_states_fp4,
             hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
                 x,
                 a1_gscale,
                 is_sf_swizzled_layout=False,
             )
            use_llama4_routing = \
                custom_routing_function is Llama4MoE.custom_routing_function
            routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
            if use_llama4_routing:
                routing_method_type = flashinfer.RoutingMethodType.Llama4
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
                routing_logits=router_logits
                if use_llama4_routing else router_logits.to(torch.float32),
                routing_bias=e_score_correction_bias,
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
                    torch.float8_e4m3fn).flatten(),
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
                    torch.float8_e4m3fn),
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
                    torch.float8_e4m3fn),
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
                n_group=num_expert_group,
                topk_group=topk_group,
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                routed_scaling_factor=None,
                tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
                                                     layer.local_num_experts),
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias)

        if self.use_marlin:
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map)

        if self.fused_experts is None:
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
            from vllm.model_executor.layers.fused_moe.cutlass_moe import (
                cutlass_moe_fp4)
            out = cutlass_moe_fp4(
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                w1_blockscale=layer.w13_blockscale_swizzled,
                w2_blockscale=layer.w2_blockscale_swizzled,
                g1_alphas=layer.g1_alphas,
                g2_alphas=layer.g2_alphas,
                a1_gscale=layer.w13_input_scale_quant,
                a2_gscale=layer.w2_input_scale_quant,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
                device=x.device,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input)
        else:
            assert self.allow_flashinfer and \
               self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            out = flashinfer_fp4_cutlass_moe_forward(
                self.fused_experts,
                layer,
                x,
                topk_weights,
                topk_ids,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )

        return out

allow_flashinfer instance-attribute

allow_flashinfer = allow_flashinfer

cutlass_nvfp4_supported instance-attribute

cutlass_nvfp4_supported = cutlass_supported

flashinfer_moe_backend instance-attribute

flashinfer_moe_backend = None

fused_experts instance-attribute

fused_experts: Optional[FusedMoEModularKernel] = None

quant_config instance-attribute

quant_config = quant_config

use_marlin instance-attribute

use_marlin = use_marlin

__init__

__init__(quant_config: ModelOptNvFp4Config) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
    self.quant_config = quant_config
    from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
        detect_nvfp4_moe_support)
    _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
    self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
    self.allow_flashinfer = _nvfp4.allow_flashinfer
    self.use_marlin = _nvfp4.use_marlin
    self.flashinfer_moe_backend = None

    if self.allow_flashinfer:
        flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
        if flashinfer_moe_backend == "throughput":
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
            logger.info_once("Using FlashInfer CUTLASS kernels for "
                             "ModelOptNvFp4FusedMoE.")
        elif flashinfer_moe_backend == "latency":
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
            logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
                             "ModelOptNvFp4FusedMoE.")
        else:
            allowed_backends = ["throughput", "latency"]
            raise ValueError(
                f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
                f" expected one of {allowed_backends}")

    self.fused_experts: Optional[
        mk.FusedMoEModularKernel] = None  # type: ignore[assignment]

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
):
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
    assert activation == "silu", "Only SiLU activation is supported."

    if self.allow_flashinfer and \
        self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
        import flashinfer

        from vllm.model_executor.models.llama4 import Llama4MoE

        a1_gscale = layer.w13_input_scale_quant
        (hidden_states_fp4,
         hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
             x,
             a1_gscale,
             is_sf_swizzled_layout=False,
         )
        use_llama4_routing = \
            custom_routing_function is Llama4MoE.custom_routing_function
        routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
        if use_llama4_routing:
            routing_method_type = flashinfer.RoutingMethodType.Llama4
        out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
            routing_logits=router_logits
            if use_llama4_routing else router_logits.to(torch.float32),
            routing_bias=e_score_correction_bias,
            hidden_states=hidden_states_fp4,
            hidden_states_scale=hidden_states_scale_linear_fp4.view(
                torch.float8_e4m3fn).flatten(),
            gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
            gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
                torch.float8_e4m3fn),
            gemm1_bias=None,
            gemm1_alpha=None,
            gemm1_beta=None,
            gemm1_clamp_limit=None,
            gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
            gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
                torch.float8_e4m3fn),
            gemm2_bias=None,
            output1_scale_scalar=layer.g1_scale_c.data,
            output1_scale_gate_scalar=layer.g1_alphas.data,
            output2_scale_scalar=layer.g2_alphas.data,
            num_experts=global_num_experts,
            top_k=top_k,
            n_group=num_expert_group,
            topk_group=topk_group,
            intermediate_size=layer.intermediate_size_per_partition,
            local_expert_offset=layer.ep_rank * layer.local_num_experts,
            local_num_experts=layer.local_num_experts,
            routed_scaling_factor=None,
            tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
                                                 layer.local_num_experts),
            routing_method_type=routing_method_type,
            do_finalize=True,
        )[0]
        return out

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        e_score_correction_bias=e_score_correction_bias)

    if self.use_marlin:
        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight,
            layer.w2_weight,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            global_scale1=layer.w13_weight_scale_2,
            global_scale2=layer.w2_weight_scale_2,
            quant_type_id=scalar_types.float4_e2m1f.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map)

    if self.fused_experts is None:
        # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
        # only (no EP).
        from vllm.model_executor.layers.fused_moe.cutlass_moe import (
            cutlass_moe_fp4)
        out = cutlass_moe_fp4(
            a=x,
            w1_fp4=layer.w13_weight,
            w2_fp4=layer.w2_weight,
            w1_blockscale=layer.w13_blockscale_swizzled,
            w2_blockscale=layer.w2_blockscale_swizzled,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            m=x.shape[0],
            n=layer.w2_weight.shape[2] * 2,
            k=x.shape[1],
            e=layer.w13_weight.shape[0],
            device=x.device,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input)
    else:
        assert self.allow_flashinfer and \
           self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        out = flashinfer_fp4_cutlass_moe_forward(
            self.fused_experts,
            layer,
            x,
            topk_weights,
            topk_ids,
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

    return out

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):
    if not self.quant_config.is_checkpoint_nvfp4_serialized:
        raise ValueError("NVFP4 quantization was selected, "
                         " dynamic quantization is not supported.")

    layer.num_experts = num_experts
    layer.params_dtype = params_dtype
    layer.quant_config = self.quant_config
    weight_dtype = torch.uint8
    weight_scale_dtype = torch.float8_e4m3fn
    weight_loader = extra_weight_attrs.get("weight_loader")
    # GEMM 1
    w13_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // 2,
            dtype=weight_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w13_weight", w13_weight)

    # GEMM 2
    w2_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition // 2,
            dtype=weight_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w2_weight", w2_weight)

    w13_weight_scale = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // self.quant_config.group_size,
            dtype=weight_scale_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w13_weight_scale", w13_weight_scale)

    w2_weight_scale = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition //
            self.quant_config.group_size,
            dtype=weight_scale_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w2_weight_scale", w2_weight_scale)

    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})

    w13_weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(num_experts, 2, dtype=torch.float32),
        weight_loader=weight_loader)
    layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

    w2_weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(num_experts, dtype=torch.float32),
        weight_loader=weight_loader)
    layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

    w13_input_scale = PerTensorScaleParameter(data=torch.empty(
        num_experts, 2, dtype=torch.float32),
                                              weight_loader=weight_loader)
    layer.register_parameter("w13_input_scale", w13_input_scale)

    w2_input_scale = PerTensorScaleParameter(data=torch.empty(
        num_experts, dtype=torch.float32),
                                             weight_loader=weight_loader)
    layer.register_parameter("w2_input_scale", w2_input_scale)

maybe_swap_experts_impl

maybe_swap_experts_impl(
    moe_parallel_config: FusedMoEParallelConfig,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def maybe_swap_experts_impl(
    self,
    moe_parallel_config: FusedMoEParallelConfig,
):
    if not self.allow_flashinfer:
        return
    self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
        moe_parallel_config)

prepare_static_weight_layouts_for_trtllm_moe

prepare_static_weight_layouts_for_trtllm_moe(
    gemm1_weights: Tensor,
    gemm2_weights: Tensor,
    gemm1_scales_linear_fp4_bytes: Tensor,
    gemm2_scales_linear_fp4_bytes: Tensor,
    hidden_size: int,
    intermediate_size: int,
    num_experts: int,
) -> tuple[Tensor, Tensor, Tensor, Tensor]

Prepare quantized weights for kernel (done offline with weights).

Source code in vllm/model_executor/layers/quantization/modelopt.py
def prepare_static_weight_layouts_for_trtllm_moe(
    self,
    gemm1_weights: torch.Tensor,
    gemm2_weights: torch.Tensor,
    gemm1_scales_linear_fp4_bytes: torch.Tensor,
    gemm2_scales_linear_fp4_bytes: torch.Tensor,
    hidden_size: int,
    intermediate_size: int,
    num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Prepare quantized weights for kernel (done offline with weights)."""
    from flashinfer import (reorder_rows_for_gated_act_gemm,
                            shuffle_matrix_a, shuffle_matrix_sf_a)
    epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

    # Convert quantized weights to proper formats
    gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
        num_experts, 2 * intermediate_size, hidden_size // 2)  # packed fp4
    gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
        torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size,
                                     hidden_size //
                                     16)  # fp8 scaling factors

    gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
        num_experts, hidden_size, intermediate_size // 2)  # packed fp4
    gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
        torch.float8_e4m3fn).reshape(num_experts, hidden_size,
                                     intermediate_size //
                                     16)  # fp8 scaling factors

    # Reorder rows of W1 and scales for fused gated activation
    gemm1_weights_fp4_interleaved = []
    gemm1_scales_fp4_interleaved = []
    for i in range(num_experts):
        gemm1_weights_fp4_interleaved.append(
            reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
        gemm1_scales_fp4_interleaved.append(
            reorder_rows_for_gated_act_gemm(
                gemm1_scales_linear_fp4[i].clone()))

    # Stack weights and scales for all experts
    gemm1_weights_fp4_interleaved = torch.stack(
        gemm1_weights_fp4_interleaved).reshape(num_experts,
                                               2 * intermediate_size,
                                               hidden_size // 2)
    gemm1_scales_fp4_interleaved = torch.stack(
        gemm1_scales_fp4_interleaved).reshape(num_experts,
                                              2 * intermediate_size,
                                              hidden_size // 16)

    # Shuffle weights and scaling factors for transposed mma output
    gemm1_weights_fp4_shuffled = []
    gemm1_scales_fp4_shuffled = []
    gemm2_weights_fp4_shuffled = []
    gemm2_scales_fp4_shuffled = []
    for i in range(num_experts):
        gemm1_weights_fp4_shuffled.append(
            shuffle_matrix_a(
                gemm1_weights_fp4_interleaved[i].view(torch.uint8),
                epilogue_tile_m))
        gemm1_scales_fp4_shuffled.append(
            shuffle_matrix_sf_a(
                gemm1_scales_fp4_interleaved[i].view(torch.uint8),
                epilogue_tile_m))

        gemm2_weights_fp4_shuffled.append(
            shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
                             epilogue_tile_m))
        gemm2_scales_fp4_shuffled.append(
            shuffle_matrix_sf_a(
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m))

    # Stack weights for all experts
    gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
    gemm1_scales_fp4_shuffled = (
        torch.stack(gemm1_scales_fp4_shuffled).view(
            torch.float8_e4m3fn).reshape(num_experts,
                                         2 * intermediate_size,
                                         hidden_size // 16))

    gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
    gemm2_scales_fp4_shuffled = (
        torch.stack(gemm2_scales_fp4_shuffled).view(
            torch.float8_e4m3fn).reshape(num_experts, hidden_size,
                                         intermediate_size // 16))
    return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # GEMM 1 processing
    gemm1_weight = layer.w13_weight.data
    gemm1_weight_scale = layer.w13_weight_scale.data

    if self.allow_flashinfer:
        gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
            gemm1_weight, gemm1_weight_scale, dim=-2)

    layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
    layer.w13_weight_scale = Parameter(gemm1_weight_scale,
                                       requires_grad=False)

    # Common processing for w13_weight_scale_2
    if not torch.allclose(layer.w13_weight_scale_2[:, 0],
                          layer.w13_weight_scale_2[:, 1]):
        logger.warning_once(
            "w1_weight_scale_2 must match w3_weight_scale_2. "
            "Accuracy may be affected.")

    w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
    layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
                                         requires_grad=False)

    # Common processing for input scales and alphas
    w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
        torch.float32)
    layer.g1_alphas = Parameter(
        (w13_input_scale * w13_weight_scale_2).to(torch.float32),
        requires_grad=False)

    # This is for quantization, so we need to invert it.
    layer.w13_input_scale_quant = Parameter(
        (1 / w13_input_scale).to(torch.float32), requires_grad=False)

    # GEMM 2 processing
    layer.g2_alphas = Parameter(
        (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
        requires_grad=False)

    # This is for quantization, so we need to invert it.
    layer.w2_input_scale_quant = Parameter(
        (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)

    # TensorRT-LLM specific processing
    if self.allow_flashinfer and \
        self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
        # Prepare static weights for TRT-LLM kernel
        (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
         gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
         ) = self.prepare_static_weight_layouts_for_trtllm_moe(
             layer.w13_weight,
             layer.w2_weight,
             layer.w13_weight_scale,
             layer.w2_weight_scale,
             layer.w2_weight.size(-2),  # hidden_size
             layer.w13_weight.size(-2) // 2,  # intermediate_size
             layer.w13_weight.size(0),  # num_experts
         )

        layer.gemm1_weights_fp4_shuffled = Parameter(
            gemm1_weights_fp4_shuffled, requires_grad=False)
        layer.gemm2_weights_fp4_shuffled = Parameter(
            gemm2_weights_fp4_shuffled, requires_grad=False)
        layer.gemm1_scales_fp4_shuffled = Parameter(
            gemm1_scales_fp4_shuffled, requires_grad=False)
        layer.gemm2_scales_fp4_shuffled = Parameter(
            gemm2_scales_fp4_shuffled, requires_grad=False)

        # Additional parameter needed for TRT-LLM
        layer.g1_scale_c = Parameter(
            (layer.w2_input_scale_quant * layer.g1_alphas).to(
                torch.float32),
            requires_grad=False,
        )

        # Clean up weights that won't be used by TRT-LLM
        del layer.w2_weight
        del layer.w2_weight_scale
        del layer.w13_weight
        del layer.w13_weight_scale
    else:
        # Non-TRT-LLM processing (Cutlass or non-flashinfer)
        assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Blockscale must be represented as FP8-E4M3")
        w13_blockscale_swizzled = swizzle_blockscale(
            layer.w13_weight_scale)
        layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
                                                  requires_grad=False)

        assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Blockscale must be represented as FP8-E4M3")
        w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
        layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
                                                 requires_grad=False)
        layer.w2_weight = Parameter(layer.w2_weight.data,
                                    requires_grad=False)

    if self.use_marlin:
        prepare_moe_fp4_layer_for_marlin(layer)
        del layer.g1_alphas
        del layer.g2_alphas
        del layer.w13_input_scale_quant
        del layer.w2_input_scale_quant
        del layer.w13_blockscale_swizzled
        del layer.w2_blockscale_swizzled

select_gemm_impl

select_gemm_impl(
    prepare_finalize, moe
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/modelopt.py
def select_gemm_impl(self, prepare_finalize,
                     moe) -> mk.FusedMoEPermuteExpertsUnpermute:

    assert moe is not None and prepare_finalize is not None
    from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (  # noqa: E501
        select_nvfp4_gemm_impl)

    return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)

uses_weight_scale_2_pattern

uses_weight_scale_2_pattern() -> bool

FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def uses_weight_scale_2_pattern(self) -> bool:
    """
    FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
    """
    return True

ModelOptNvFp4LinearMethod

Bases: LinearMethodBase

Linear method for Model Optimizer NVFP4. Supports loading NVFP4 checkpoints with the following structure:

input_scale: torch.float32, scalar , weight: NVFP4(represented as byte) Shape: [1, X, y/2] weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, weight_scale_2: torch.float32, scalar, Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:

    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
        self.quant_config = quant_config
        self.cutlass_nvfp4_supported = cutlass_fp4_supported()
        self.use_marlin = False

        if not self.cutlass_nvfp4_supported:
            if is_fp4_marlin_supported():
                self.use_marlin = True
            else:
                raise ValueError("Current platform does not support NVFP4"
                                 " quantization. Please use Blackwell and"
                                 " above.")

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        if (input_size_per_partition % 16 != 0):
            raise ValueError("Unsupported model when in features size is "
                             "not multiple of 16")
        # The nvfp4 weight is still represented as
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_nvfp4_serialized
                        else params_dtype)
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
                dtype=torch.uint8),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        # Input Weight Scale
        input_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                              weight_loader=weight_loader)
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
        weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                                 weight_loader=weight_loader)
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
        weight_scale = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition // self.quant_config.group_size,
            dtype=weight_dtype,
        ),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:

        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

        layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                                requires_grad=False)

        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
        assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Block scale must be represented as FP8-E4M3")
        swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)

        layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
                                                requires_grad=False)
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        if self.use_marlin:
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
            del layer.weight_scale_swizzled

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if self.use_marlin:
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
                bias=bias)

        output_dtype = x.dtype
        output_shape = [x.shape[0], layer.weight.shape[0]]

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
        s_quant = 1 / layer.input_scale
        x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
        assert (x_fp4.dtype == torch.uint8)
        assert (layer.weight.dtype == torch.uint8)
        assert (x_blockscale.dtype == torch.float8_e4m3fn)
        assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
        assert (layer.alpha.dtype == torch.float32)

        out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
                                    layer.weight_scale_swizzled, layer.alpha,
                                    output_dtype)
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

cutlass_nvfp4_supported instance-attribute

cutlass_nvfp4_supported = cutlass_fp4_supported()

quant_config instance-attribute

quant_config = quant_config

use_marlin instance-attribute

use_marlin = False

__init__

__init__(quant_config: ModelOptNvFp4Config) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
    self.quant_config = quant_config
    self.cutlass_nvfp4_supported = cutlass_fp4_supported()
    self.use_marlin = False

    if not self.cutlass_nvfp4_supported:
        if is_fp4_marlin_supported():
            self.use_marlin = True
        else:
            raise ValueError("Current platform does not support NVFP4"
                             " quantization. Please use Blackwell and"
                             " above.")

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if self.use_marlin:
        return apply_fp4_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            weight_scale_2=layer.weight_scale_2,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias)

    output_dtype = x.dtype
    output_shape = [x.shape[0], layer.weight.shape[0]]

    # quantize BF16 or FP16 to (FP4 and interleaved block scale)
    s_quant = 1 / layer.input_scale
    x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)

    # validate dtypes of quantized input, input block scale,
    # weight and weight_blockscale
    assert (x_fp4.dtype == torch.uint8)
    assert (layer.weight.dtype == torch.uint8)
    assert (x_blockscale.dtype == torch.float8_e4m3fn)
    assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
    assert (layer.alpha.dtype == torch.float32)

    out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
                                layer.weight_scale_swizzled, layer.alpha,
                                output_dtype)
    if bias is not None:
        out = out + bias
    return out.view(*output_shape)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    if not self.quant_config.is_checkpoint_nvfp4_serialized:
        raise ValueError("NVFP4 quantization was selected, "
                         " dynamic quantization is not supported.")
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition

    if (input_size_per_partition % 16 != 0):
        raise ValueError("Unsupported model when in features size is "
                         "not multiple of 16")
    # The nvfp4 weight is still represented as
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_nvfp4_serialized
                    else params_dtype)
    # Weight
    weight = ModelWeightParameter(
        data=torch.empty(
            # 2 fp4 items are packed in the input dimension
            layer.output_size_per_partition,
            layer.input_size_per_partition // 2,
            dtype=torch.uint8),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    # Input Weight Scale
    input_scale = PerTensorScaleParameter(data=torch.empty(
        len(output_partition_sizes), dtype=torch.float32),
                                          weight_loader=weight_loader)
    layer.register_parameter("input_scale", input_scale)

    # Global Weight Scale
    weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
        len(output_partition_sizes), dtype=torch.float32),
                                             weight_loader=weight_loader)
    layer.register_parameter("weight_scale_2", weight_scale_2)

    # Per Block Weight Scale
    weight_scale = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition // self.quant_config.group_size,
        dtype=weight_dtype,
    ),
                                        input_dim=1,
                                        output_dim=0,
                                        weight_loader=weight_loader)

    layer.register_parameter("weight_scale", weight_scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: Module) -> None:

    # global scales:
    input_scale_2 = layer.input_scale.max().to(torch.float32)
    layer.input_scale = Parameter(input_scale_2, requires_grad=False)

    weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
    layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

    layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                            requires_grad=False)

    # Swizzle the weight blockscale.
    # contracting dimension is input dimension
    # block_size = 16;
    assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
        "Weight Block scale must be represented as FP8-E4M3")
    swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)

    layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
                                            requires_grad=False)
    layer.weight = Parameter(layer.weight.data, requires_grad=False)

    if self.use_marlin:
        prepare_fp4_layer_for_marlin(layer)
        del layer.alpha
        del layer.input_scale
        del layer.weight_scale_swizzled

_get_tile_tokens_dim

_get_tile_tokens_dim(
    num_tokens: int, top_k: int, num_experts: int
) -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
    # Guess tokens per expert assuming perfect expert distribution first.
    num_tokens_per_expert = (num_tokens * top_k) // num_experts
    # And pad the number to the next power of 2.
    tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
    # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
    tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
    return tile_tokens_dim