当前位置: 首页 > news >正文

【Cuda 编程思想】LinearQaunt-分块量化矩阵乘法计算过程

量化线性算子

  • 目标:用更小的存储(如 int8)保存激活和权重,通过少量“缩放因子”在计算时“还原”到接近浮点精度,然后做矩阵乘法。
  • 做法:把输入特征维度 K、输出通道维度 N 切成“块”,每块配一个缩放因子。计算时先对 A、B 做“分块缩放”,再做 matmul,最后加上 bias 和必要的 dtype 处理。
  • 好处:显著省内存和带宽,推理更快,同时保持接近原精度。

为什么需要量化

  • 浮点权重/激活(fp32/fp16)占内存大、带宽压力大。
  • 量化(如 int8)能把存储压到 1/4,但如果直接用 int8 乘法会损失精度。
  • 解决办法:为每一“块”配浮点缩放,计算时“还原”为浮点近似值,兼顾体积与精度。

这个算子到底做什么

  • 输入两个矩阵:A[M×K](激活)、B[N×K](权重)
  • 两个缩放表:As[M×T_k](A 的 K 维分块缩放),Bs[T_n×T_k](B 的 N×K 分块缩放)
  • 先把缩放表按块规则“铺开”到 K(和 N×K)上,得到逐元素缩放系数
  • 计算 (A ⊙ S_A^exp) @ (B ⊙ S_B^exp)^T + bias

计算过程可视化

在这里插入图片描述

  • 对照代码实现
class LinearQuantImpl:def native_impl(self, out: torch.Tensor, lhs: torch.Tensor, rhs: torch.Tensor, bias: torch.Tensor, lhs_scale: torch.Tensor, rhs_scale: torch.Tensor) -> None:"""CPU/GPU实现,使用torch.matmul进行矩阵乘法和缩放"""device = lhs.device# 根据设备类型选择数据类型:CPU使用fp64,GPU使用fp32if device.type == 'cpu':target_dtype = torch.float64else:target_dtype = torch.float32print(f"target_dtype: {target_dtype}")# 全部使用 target_dtypeA = lhs.to(target_dtype).contiguous()B = rhs.to(target_dtype).contiguous()As = lhs_scale.to(target_dtype).contiguous()Bs = rhs_scale.to(target_dtype).contiguous()# 获取维度信息A_shape = A.shapeM = A.numel() // A.size(-1)  # 总序列数(展平后)K = A.size(-1)               # 输入维度N = B.size(0)                # 输出维度# 重塑A为[M, K]用于矩阵乘法A = A.reshape(M, K)As = As.reshape(M, As.size(1))# 从scale张量维度推断块大小k_tiles = As.size(1)         # k维度块数n_tiles = Bs.size(0)         # n维度块数# 计算块大小block_k = (K + k_tiles - 1) // k_tiles  # 动态k块大小block_n = (N + n_tiles - 1) // n_tiles  # 动态n块大小# 验证scale维度一致性assert k_tiles == Bs.size(1), f"k_tiles mismatch: As has {k_tiles} k_tiles but Bs has {Bs.size(1)}"# 完全向量化实现:使用高级索引和广播# 创建索引张量用于向量化scale计算k_indices = torch.arange(K, dtype=torch.long)n_indices = torch.arange(N, dtype=torch.long)m_indices = torch.arange(M, dtype=torch.long)# 计算每个位置的tile索引k_tile_indices = k_indices // block_k  # [K]n_tile_indices = n_indices // block_n  # [N]# 使用广播计算scale索引# As: [M, k_tiles] -> 索引到 [M, K]# Bs: [n_tiles, k_tiles] -> 索引到 [N, K]k_tile_expanded = k_tile_indices.unsqueeze(0)  # [1, K]n_tile_expanded = n_tile_indices.unsqueeze(1)  # [N, 1]# 获取每个位置的scale值lhs_scale_expanded = As[:, k_tile_expanded]  # [M, K]rhs_scale_expanded = Bs[n_tile_expanded, k_tile_expanded]  # [N, K]# 应用scale到输入张量A_scaled = A * lhs_scale_expanded  # [M, K]B_scaled = B * rhs_scale_expanded  # [N, K]# 直接计算矩阵乘法: [M, K] @ [K, N] = [M, N]C = torch.matmul(A_scaled, B_scaled.t())# 添加bias(如果提供)if bias is not None and bias.numel() > 0:bias_tensor = bias.to(target_dtype).to(device)C.add_(bias_tensor.unsqueeze(0))  # 广播bias: [1, N] + [M, N]# 重塑回原始输出形状output_shape = list(A_shape)output_shape[-1] = NC = C.reshape(output_shape)# 处理数据类型转换和clampingif out.dtype == torch.int8:# 对于int8输出,clamp到有效范围C = torch.clamp(C, -127.0, 127.0)elif out.dtype == torch.float16:# 对于fp16输出,确保正确转换C = C.to(torch.float16)elif out.dtype == torch.bfloat16:# 对于bf16输出,确保正确转换C = C.to(torch.bfloat16)# 复制结果到输出张量,进行适当的dtype转换out.copy_(C.to(out.dtype))
  • cuda triton 实现
from vllm.triton_utils import tl, triton
from vllm.platforms import current_platform
from vllm.logger import init_logger
import torch
import os
import functools
from typing import Any, Callable, Optional, Unionlogger = init_logger(__name__)@functools.lru_cache
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,block_k: int) -> Optional[dict[int, Any]]:"""Return optimized configurations for the w8a8 block fp8 kernel.The return value will be a dictionary that maps an irregular grid ofbatch sizes to configurations of the w8a8 block fp8 kernel. To evaluate thekernel on a given batch size bs, the closest batch size in the grid shouldbe picked and the associated configuration chosen to invoke the kernel."""# First look up if an optimized configuration is available in the configs# directorydevice_name = current_platform.get_device_name().replace(" ", "_")json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json"  # noqa: E501config_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)if os.path.exists(config_file_path):with open(config_file_path) as f:logger.info("Using configuration from %s for W8A8 Block FP8 kernel.",config_file_path,)# If a configuration has been found, return itreturn {int(key): val for key, val in json.load(f).items()}# If no optimized configuration is available, we will use the default# configurationlogger.warning("Using default W8A8 Block FP8 kernel config. Performance might ""be sub-optimal! Config file not found at %s",config_file_path,)return Nonedef w8a8_block_fp8_matmul(A: torch.Tensor,B: torch.Tensor,As: torch.Tensor,Bs: torch.Tensor,dot_dtype = None,block_size: list[int] = [128, 128],output_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:"""This function performs matrix multiplication with block-wisequantization.It takes two input tensors `A` and `B` with scales `As` and `Bs`.The output is returned in the specified `output_dtype`.Args:A: The input tensor, e.g., activation.B: The input tensor, e.g., weight.As: The per-token-group quantization scale for `A`.Bs: The per-block quantization scale for `B`.block_size: The block size for per-block quantization. It shouldbe 2-dim, e.g., [128, 128].output_dytpe: The dtype of the returned tensor.Returns:torch.Tensor: The result of matmul."""if isinstance(dot_dtype, int) and dot_dtype == 1:dot_dtype = tl.bfloat16assert len(block_size) == 2block_n, block_k = block_size[0], block_size[1]assert A.shape[-1] == B.shape[-1]assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]M = A.numel() // A.shape[-1]assert B.ndim == 2 and Bs.ndim == 2N, K = B.shapeassert triton.cdiv(N, block_n) == Bs.shape[0]assert triton.cdiv(K, block_k) == Bs.shape[1]C_shape = A.shape[:-1] + (N, )C = A.new_empty(C_shape, dtype=output_dtype)configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])if configs:# Get the optimal config if there is oneconfig = configs[min(configs.keys(), key=lambda x: abs(x - M))]else:# Default config# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]# BLOCK_SIZE_K must be divisible by block_size[1]config = {"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": block_size[0],"BLOCK_SIZE_K": block_size[1],"GROUP_SIZE_M": 32,"num_warps": 4,"num_stages": 2,}def grid(META):return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *triton.cdiv(N, META["BLOCK_SIZE_N"]), )_w8a8_block_fp8_matmul[grid](A,B,C,As,Bs,M,N,K,block_n,block_k,# dot_dtype,A.stride(-2),A.stride(-1),B.stride(1),B.stride(0),C.stride(-2),C.stride(-1),As.stride(-2),As.stride(-1),Bs.stride(1),Bs.stride(0),**config,)return Cdef get_default_config(M: int,E: int,N: int,K: int,topk: int,dtype: Optional[str],is_marlin: bool,block_shape: Optional[list[int]] = None,
) -> dict[str, int]:if dtype == "fp8_w8a8" and block_shape is not None:# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]# BLOCK_SIZE_K must be divisible by block_shape[1]# num_stages=3 can cause triton.runtime.errors.OutOfResources# on ROCm, set it to 2 instead.config = {"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": block_shape[0],"BLOCK_SIZE_K": block_shape[1],"GROUP_SIZE_M": 32,"num_warps": 4,# "num_stages": 3 if not current_platform.is_rocm() else 2,"num_stages": 2}elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:# moe wna16 kernels# only set BLOCK_SIZE_M# BLOCK_SIZE_N and BLOCK_SIZE_K would be set laterbit = 4 if dtype == "int4_w4a16" else 8use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk,block_shape[1], E, bit)if use_moe_wna16_cuda:config = {"BLOCK_SIZE_M": min(16, M)}elif M <= 20:config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}elif M <= 40:config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}else:config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}elif is_marlin:for block_size_m in [8, 16, 32, 48, 64]:if M * topk / E / block_size_m < 0.9:breakreturn {"BLOCK_SIZE_M": block_size_m}elif M <= E:config = {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,}else:config = {"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 32,"GROUP_SIZE_M": 8,}return configdef try_get_optimal_moe_config(w1_shape: tuple[int, ...],w2_shape: tuple[int, ...],top_k: int,dtype: Optional[str],M: int,is_marlin: bool = False,block_shape: Optional[list[int]] = None,
) -> dict[str, int]:from vllm.model_executor.layers.fused_moe import get_configoverride_config = get_config()if override_config:config = override_configelse:# First try to load optimal config from the fileE, _, N = w2_shapeif dtype == "int4_w4a16":N = N * 2block_n = block_shape[0] if block_shape else 0block_k = block_shape[1] if block_shape else 0# Else use the default configconfig = get_default_config(M, E, N, w1_shape[2], top_k, dtype,is_marlin, block_shape)return config@triton.jit
def _w8a8_block_fp8_matmul(# Pointers to inputs and outputA,B,C,As,Bs,# Shape for matmulM,N,K,# Block size for block-wise quantizationgroup_n,group_k,# dot_dtype,# Stride for inputs and outputstride_am,stride_ak,stride_bk,stride_bn,stride_cm,stride_cn,stride_As_m,stride_As_k,stride_Bs_k,stride_Bs_n,# Meta-parametersBLOCK_SIZE_M: tl.constexpr,BLOCK_SIZE_N: tl.constexpr,BLOCK_SIZE_K: tl.constexpr,GROUP_SIZE_M: tl.constexpr,
):"""Triton-accelerated function used to perform linear operations (dotproduct) on input tensors `A` and `B` with block-wise quantization, andstore the result in output tensor `C`."""# dot_dtype = tl.bfloat16dot_dtype = Nonepid = tl.program_id(axis=0)num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)num_pid_in_group = GROUP_SIZE_M * num_pid_ngroup_id = pid // num_pid_in_groupfirst_pid_m = group_id * GROUP_SIZE_Mgroup_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)pid_m = first_pid_m + (pid % group_size_m)pid_n = (pid % num_pid_in_group) // group_size_moffs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % Moffs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % Noffs_k = tl.arange(0, BLOCK_SIZE_K)a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)As_ptrs = As + offs_am * stride_As_moffs_bsn = offs_bn // group_nBs_ptrs = Bs + offs_bsn * stride_Bs_naccumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):a = tl.load(a_ptrs,mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,other=0.0)b = tl.load(b_ptrs,mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,other=0.0)k_start = k * BLOCK_SIZE_Koffs_ks = k_start // group_ka_s = tl.load(As_ptrs + offs_ks * stride_As_k)b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)if dot_dtype is not None:a = a.to(dot_dtype)b = b.to(dot_dtype)accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]a_ptrs += BLOCK_SIZE_K * stride_akb_ptrs += BLOCK_SIZE_K * stride_bkif C.dtype.element_ty == tl.bfloat16:c = accumulator.to(tl.bfloat16)elif C.dtype.element_ty == tl.float16:c = accumulator.to(tl.float16)else:c = accumulator.to(tl.float32)offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)tl.store(c_ptrs, c, mask=c_mask)def get_config_dtype_str(dtype: torch.dtype,use_int4_w4a16: Optional[bool] = False,use_int8_w8a16: Optional[bool] = False,use_fp8_w8a8: Optional[bool] = False,use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]:if use_fp8_w8a8:return "fp8_w8a8"elif use_int8_w8a16:return "int8_w8a16"elif use_int4_w4a16:return "int4_w4a16"elif use_mxfp4_w4a4:return "mxfp4_w4a4"elif dtype == torch.float:# avoiding cases where kernel fails when float32 MoE# use fp16/bfloat16 configsreturn "float32"return Nonedef invoke_fused_moe_kernel(A: torch.Tensor,B: torch.Tensor,C: torch.Tensor,A_scale: Optional[torch.Tensor],B_scale: Optional[torch.Tensor],B_zp: Optional[torch.Tensor],topk_weights: Optional[torch.Tensor],sorted_token_ids: torch.Tensor,expert_ids: torch.Tensor,num_tokens_post_padded: torch.Tensor,mul_routed_weight: bool,top_k: int,config: dict[str, Any] = None,compute_type: tl.dtype = tl.bfloat16,use_fp8_w8a8: bool = True,use_int8_w8a8: bool = False,use_int8_w8a16: bool = False,use_int4_w4a16: bool = False,per_channel_quant: bool = False,block_shape: Optional[list[int]] = [128, 128],dot_dtype = None) -> None:if isinstance(dot_dtype, int) and dot_dtype == 1:dot_dtype = tl.bfloat16assert topk_weights is not None or not mul_routed_weightassert topk_weights is None or topk_weights.stride(1) == 1assert sorted_token_ids.stride(0) == 1if config is None:M = A.size(0)config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,use_int8_w8a16=use_int8_w8a16,use_int4_w4a16=use_int4_w4a16,use_mxfp4_w4a4=False,dtype=A.dtype)get_config_func = functools.partial(try_get_optimal_moe_config,B.size(),B.size(),top_k,config_dtype,block_shape=block_shape,)config = get_config_func(M)# config = {#     'BLOCK_SIZE_K': 128,#     'BLOCK_SIZE_M': 64,#     'BLOCK_SIZE_N': 128,#     'GROUP_SIZE_M': 32,#     'num_warps': 4,#     'num_stages': 2# }if use_fp8_w8a8 or use_int8_w8a8:assert B_scale is not Noneassert (block_shape is Noneor triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2))assert (block_shape is Noneor triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1))elif use_int8_w8a16 or use_int4_w4a16:assert B_scale is not Noneassert block_shape is None or block_shape[0] == 0else:assert A_scale is Noneassert B_scale is NoneM = A.size(0)num_tokens = M * top_kEM = sorted_token_ids.size(0)if A.size(0) < config["BLOCK_SIZE_M"]:# optimize for small batch_size.# We assume that top_ids of each token is unique, so# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,# and we can skip some invalid blocks.EM = min(sorted_token_ids.size(0),A.size(0) * top_k * config['BLOCK_SIZE_M'])grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(B.size(1), META['BLOCK_SIZE_N']), )config = config.copy()BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")if block_shape is not None:BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0],block_shape[1]))fused_moe_kernel[grid](A,B,C,A_scale,B_scale,topk_weights,sorted_token_ids,expert_ids,num_tokens_post_padded,B.size(1),B.size(2),EM,num_tokens,A.stride(0),A.stride(1),B.stride(0),B.stride(2),B.stride(1),C.stride(1),C.stride(2),A_scale.stride(0)if A_scale is not None and A_scale.ndim == 2 else 0,A_scale.stride(1)if A_scale is not None and A_scale.ndim == 2 else 0,B_scale.stride(0)if B_scale is not None and B_scale.ndim >= 2 else 0,B_scale.stride(2)if B_scale is not None and B_scale.ndim == 3 else 0,B_scale.stride(1)if B_scale is not None and B_scale.ndim >= 2 else 0,0 if block_shape is None else block_shape[0],0 if block_shape is None else block_shape[1],MUL_ROUTED_WEIGHT=mul_routed_weight,top_k=top_k,compute_type=compute_type,use_fp8_w8a8=use_fp8_w8a8,use_int8_w8a8=use_int8_w8a8,use_int8_w8a16=use_int8_w8a16,per_channel_quant=per_channel_quant,BLOCK_SIZE_K=BLOCK_SIZE_K,**config,)@triton.jit
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N,compute_type):accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]c_mask = token_mask[:, None] & (offs_cn[None, :] < N)tl.store(c_ptrs, accumulator, mask=c_mask)@triton.jit
def fused_moe_kernel(# Pointers to matricesa_ptr,b_ptr,c_ptr,a_scale_ptr,b_scale_ptr,topk_weights_ptr,sorted_token_ids_ptr,expert_ids_ptr,num_tokens_post_padded_ptr,# Matrix dimensionsN,K,EM,num_valid_tokens,# The stride variables represent how much to increase the ptr by when# moving by 1 element in a particular dimension. E.g. `stride_am` is# how much to increase `a_ptr` by to get the element one row down# (A has M rows).stride_am,stride_ak,stride_be,stride_bk,stride_bn,stride_cm,stride_cn,stride_asm,stride_ask,stride_bse,stride_bsk,stride_bsn,# Block size for block-wise quantizationgroup_n: tl.constexpr,group_k: tl.constexpr,# Meta-parametersBLOCK_SIZE_M: tl.constexpr,BLOCK_SIZE_N: tl.constexpr,BLOCK_SIZE_K: tl.constexpr,GROUP_SIZE_M: tl.constexpr,MUL_ROUTED_WEIGHT: tl.constexpr,top_k: tl.constexpr,compute_type: tl.constexpr,use_fp8_w8a8: tl.constexpr,use_int8_w8a8: tl.constexpr,use_int8_w8a16: tl.constexpr,per_channel_quant: tl.constexpr,
):"""Implements the fused computation for a Mixture of Experts (MOE) usingtoken and expert matrices.Key Parameters:- A: The input tensor representing tokens with shape (*, K), where '*' canbe any shape representing batches and K is the feature dimension ofeach token.- B: The stacked MOE weight tensor with shape (E, N, K), where E isthe number of experts, K is the input feature dimension, and N isthe output feature dimension.- C: The output cache tensor with shape (M, topk, N), where M is thetotal number of tokens post padding, topk is the number of timeseach token is repeated, and N is the output feature dimension.- sorted_token_ids: A tensor containing the sorted indices of tokens,repeated topk times and arranged by the expert index they areassigned to.- expert_ids: A tensor containing the indices of the expert for eachblock. It determines which expert matrix from B should be used foreach block in A.This kernel performs the multiplication of a token by its correspondingexpert matrix as determined by `expert_ids`. The sorting of`sorted_token_ids` by expert index and padding ensures divisibility byBLOCK_SIZE_M, which is necessary to maintain consistency in block matrixmultiplication across different blocks processed by the same expert."""dot_dtype = tl.bfloat16# dot_dtype = None# -----------------------------------------------------------# Map program ids `pid` to the block of C it should compute.# This is done in a grouped ordering to promote L2 data reuse.pid = tl.program_id(axis=0)num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)num_pid_in_group = GROUP_SIZE_M * num_pid_ngroup_id = pid // num_pid_in_groupfirst_pid_m = group_id * GROUP_SIZE_Mgroup_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)pid_n = (pid % num_pid_in_group) // group_size_m# ----------------------------------------------------------# Create pointers for the first blocks of A and B.# We will advance this pointer as we move in the K direction# and accumulate# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointersnum_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:returnoffs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)token_mask = offs_token < num_valid_tokensoff_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)if off_experts == -1:# -----------------------------------------------------------# Write back zeros to the output when the expert is not# in the current expert parallel rank.write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,offs_token, token_mask, BLOCK_SIZE_M,BLOCK_SIZE_N, compute_type)returnoffs_bn = (pid_n * BLOCK_SIZE_N +tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % Noffs_k = tl.arange(0, BLOCK_SIZE_K)a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +offs_k[None, :] * stride_ak)b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +offs_bn[None, :] * stride_bn)if use_int8_w8a16:b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsnb_scale = tl.load(b_scale_ptrs)if use_fp8_w8a8 or use_int8_w8a8:# block-wiseif group_k > 0 and group_n > 0:a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asmoffs_bsn = offs_bn // group_nb_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +offs_bsn * stride_bsn)# channel-wiseelif per_channel_quant:b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsnb_scale = tl.load(b_scale_ptrs)# Load per-token scale for activationsa_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asma_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:,None]# tensor-wiseelse:a_scale = tl.load(a_scale_ptr)b_scale = tl.load(b_scale_ptr + off_experts)# -----------------------------------------------------------# Iterate to compute a block of the C matrix.# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block# of fp32 values for higher accuracy.# `accumulator` will be converted back to fp16 after the loop.accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):# Load the next block of A and B, generate a mask by checking the# K dimension.a = tl.load(a_ptrs,mask=token_mask[:, None] &(offs_k[None, :] < K - k * BLOCK_SIZE_K),other=0.0)b = tl.load(b_ptrs,mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,other=0.0)if dot_dtype is not None:a = a.to(dot_dtype)b = b.to(dot_dtype)# We accumulate along the K dimension.if use_int8_w8a16:accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)elif use_fp8_w8a8 or use_int8_w8a8:if group_k > 0 and group_n > 0:k_start = k * BLOCK_SIZE_Koffs_ks = k_start // group_ka_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,mask=token_mask,other=0.0)b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)accumulator += tl.dot(a, b) * a_scale[:,None] * b_scale[None, :]else:if use_fp8_w8a8:# acc used to enable fp8_fast_accumaccumulator = tl.dot(a, b, acc=accumulator)else:accumulator += tl.dot(a, b)else:accumulator += tl.dot(a, b)# Advance the ptrs to the next K block.a_ptrs += BLOCK_SIZE_K * stride_akb_ptrs += BLOCK_SIZE_K * stride_bkif MUL_ROUTED_WEIGHT:moe_weight = tl.load(topk_weights_ptr + offs_token,mask=token_mask,other=0)accumulator = accumulator * moe_weight[:, None]if use_int8_w8a16:accumulator = (accumulator * b_scale).to(compute_type)elif use_fp8_w8a8 or use_int8_w8a8:if group_k > 0 and group_n > 0:accumulator = accumulator.to(compute_type)else:accumulator = (accumulator * a_scale * b_scale).to(compute_type)else:accumulator = accumulator.to(compute_type)# -----------------------------------------------------------# Write back the block of the outputoffs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]c_mask = token_mask[:, None] & (offs_cn[None, :] < N)tl.store(c_ptrs, accumulator, mask=c_mask)

http://www.lryc.cn/news/621878.html

相关文章:

  • 25. 移动端-uni-app
  • 【URP】[光栅阶段][光栅插值]Unity透视校正插值
  • 2025年最新政策下,劳务报酬的增值税应该如何计算?
  • MqSQL中的《快照读》和《当前读》
  • Prometheus 监控 Kubernetes Cluster 最新极简教程
  • [论文笔记] WiscKey: Separating Keys from Values in SSD-Conscious Storage
  • DeepSeek-V2:一种强大、经济且高效的混合专家语言模型
  • 在 macOS 上顺利安装 lapsolver
  • 从根本上解决MAC权限问题(关闭sip)
  • vue3 wangeditor5 编辑器,使用方法
  • demo 通讯录 + 城市选择器 (字母索引左右联动 ListItemGroup+AlphabetIndexer)笔记
  • 分布式锁:从理论到实战的深度指南
  • 【机器人-基础知识】ROS常见功能架构
  • 微软自曝Win 11严重漏洞:可导致全盘数据丢失
  • Kafka生产者原理深度解析
  • 从ChatGPT到智能助手:Agent智能体如何颠覆AI应用
  • Python爬虫反爬检测失效问题的代理池轮换与请求头伪装实战方案
  • 【121页PPT】智慧方案智慧综合体智能化设计方案(附下载方式)
  • java + html 图片点击文字验证码
  • 结构体(Struct)、枚举(Enum)的使用
  • 电源测试系统ATECLOUD-Power,让您告别电源模块测试痛点!
  • MLOps已死,AgenticOps当立:构建新一代AI智能体的全新工程范式
  • 【Redis】Redis典型应用——分布式锁
  • 【部署K8S集群】 1、安装前环境准备配置
  • k8s1.28.2集群部署istioctl的1.20.0版本(X86架构)
  • Mac(一)常用的快捷键整理
  • Mac Mysql 卸载
  • 18- 网络编程
  • Java ArrayList的介绍及用法
  • 单片机闪烁灯实验