def apply_flashinfer_per_tensor_scale_fp8(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
global_num_experts: int,
apply_router_weight_on_input: bool,
) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
"FusedMoE flashinfer kernels are only supported for Llama4"
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
input_scale=layer.w13_input_scale,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale,
activation_scale=layer.w2_input_scale,
num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=RoutingMethodType.Llama4,
)