diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index e553946413..0957bf0fb9 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -41,7 +41,7 @@ triton_to_mxfp8_dim1_reference, unpack_uint4, ) -from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx +from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, @@ -326,28 +326,6 @@ def test_fp4_pack_unpack(): assert torch.all(orig_vals_dq == orig_vals) -# TODO(future PR): fix or delete this test -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -@pytest.mark.skipif(is_sm_at_least_89(), reason="broken on CUDA capability 8.9+") -def test_fp4_triton_scaled_cast(): - size = (256,) - orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100 - mxtensor_ref = MXTensor.to_mx( - orig_vals, block_size=32, elem_dtype=torch.float4_e2m1fn_x2 - ) - mxtensor_triton = MXTensor.to_mx( - orig_vals, - block_size=32, - elem_dtype=torch.float4_e2m1fn_x2, - use_fp4_custom_triton_dequant_kernel=True, - ) - - f32_ref = mxtensor_ref.to_dtype(torch.float) - f32_triton = mxtensor_triton.to_dtype(torch.float) - assert torch.all(torch.eq(f32_ref, f32_triton)) - - @pytest.mark.parametrize("dtype_name", (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2)) def test_fp6_values(dtype_name): """ diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index f4af52bafa..ea1b7c6459 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -380,14 +380,12 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): else: raise AssertionError("unsupported") block_size = 4 - use_fp4_custom_triton_dequant_kernel = False tensor_mx = MXTensor( data_bits, scale_e8m0, elem_dtype, block_size, torch.float, - use_fp4_custom_triton_dequant_kernel, MXGemmKernelChoice.EMULATED, pack_fp6, None, @@ -427,14 +425,10 @@ def test_block_sizes(elem_dtype, B): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -@pytest.mark.parametrize("fp4_triton", [False, True]) -def test_transpose(elem_dtype, fp4_triton): +def test_transpose(elem_dtype): """ Verify that transposing an MX tensor works """ - if elem_dtype != torch.float4_e2m1fn_x2 and fp4_triton: - pytest.skip("unsupported configuration") - M, K = 128, 256 block_size = 32 tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) @@ -442,7 +436,6 @@ def test_transpose(elem_dtype, fp4_triton): tensor_hp, elem_dtype, block_size, - use_fp4_custom_triton_dequant_kernel=fp4_triton, ) tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t() @@ -510,7 +503,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): to_dtype_c = torch.compile(to_dtype, fullgraph=True) - use_fp4_custom_triton_dequant_kernel = False pack_fp6 = False x_mx_dq = to_dtype( x_mx.qdata, @@ -518,7 +510,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx._elem_dtype, x_mx._block_size, hp_dtype, # noqa: E501 - use_fp4_custom_triton_dequant_kernel, pack_fp6, ) x_mx_c_dq = to_dtype_c( @@ -527,7 +518,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx_c._elem_dtype, x_mx_c._block_size, hp_dtype, - use_fp4_custom_triton_dequant_kernel, pack_fp6, ) torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) diff --git a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py deleted file mode 100644 index ca0b926ce5..0000000000 --- a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Benchmarking mx quantize/dequantize -""" - -from typing import Optional - -import fire -import tabulate -import torch -from torch.profiler import ProfilerActivity, profile - -from torchao.prototype.mx_formats import config -from torchao.prototype.mx_formats.constants import ( # noqa: E501 - SUPPORTED_ELEM_DTYPES, -) -from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.utils import benchmark_torch_function_in_microseconds - - -def run(profile_folder: Optional[str] = None): - headers = [ - "elem_dtype", - "use_fp4_custom_triton_dequant_kernel", - "q_time_us", - "q_mem_bw_tb_s", - "dq_time_us", - "dq_mem_bw_tb_s", - ] - results = [] - - data_hp = torch.randn(1, 4096, 11008, dtype=torch.bfloat16, device="cuda") - - for elem_dtype in SUPPORTED_ELEM_DTYPES: - for use_fp4_custom_triton_dequant_kernel in (False, True): - config.use_fp4_custom_triton_dequant_kernel = ( - use_fp4_custom_triton_dequant_kernel - ) - - if ( - elem_dtype != torch.float4_e2m1fn_x2 - and use_fp4_custom_triton_dequant_kernel # noqa: E501 - ): - # custom_triton_kernels only works for fp4 - continue - - print( - "elem_dtype", - elem_dtype, - "use_fp4_custom_triton_dequant_kernel", - use_fp4_custom_triton_dequant_kernel, - ) - - data_lp = MXTensor.to_mx(data_hp, elem_dtype, block_size=32) - - if not use_fp4_custom_triton_dequant_kernel: - quant = torch.compile(MXTensor.to_mx, fullgraph=True) - dequant = torch.compile(data_lp.to_dtype, fullgraph=True) - else: - # As of 2024-04, torch.compile didn't work with the - # handwritten triton kernel, - # crashed on tl.interleave: - # https://github.com/pytorch/pytorch/issues/123967 - # As of 2024-05-24, now there is message asking to convert to - # an opaque custom op: - # https://gist.github.com/vkuzo/0b0b90dca03bdb8e0446e4135644238a # noqa: E501 - # TODO(future): make this better - quant = MXTensor.to_mx - dequant = data_lp.to_dtype - - # warm up - quant(data_hp, elem_dtype, block_size=32) - res = dequant(torch.bfloat16) - - if profile_folder is not None: - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, - ) as prof: - for _ in range(5): - quant(data_hp, elem_dtype, block_size=32) - dequant(torch.bfloat16) - prof.export_chrome_trace( - profile_folder - + f"/mx_qdq_{elem_dtype}_{use_fp4_custom_triton_dequant_kernel}.json" # noqa: E501 - ) - - q_execution_time_us = benchmark_torch_function_in_microseconds( - quant, data_hp, elem_dtype, block_size=32 - ) - dq_execution_time_us = benchmark_torch_function_in_microseconds( - dequant, torch.bfloat16 - ) - print(f"q time: {q_execution_time_us} us") - print(f"dq time: {dq_execution_time_us} us") - - # memory reads per element: - byte_per_stored_element = 1.0 # fp8 or 2xfp4 - byte_per_stored_exp_element = 1.0 # e8m0 - byte_per_dequantized_element = 2.0 # bfloat16 - mem_reads_writes_bytes = ( - # read raw data - (data_lp._data.numel() * byte_per_stored_element) - + - # read exponent - (data_lp._scale_e8m0.numel() * byte_per_stored_exp_element) - + - # write dequant - (res.numel() * byte_per_dequantized_element) - ) - # note: the above also works for quant, with reads/writes in - # reverse - - q_mem_bw_tb_s = (mem_reads_writes_bytes / 1e12) / ( - q_execution_time_us / 1e6 - ) - dq_mem_bw_tb_s = (mem_reads_writes_bytes / 1e12) / ( - dq_execution_time_us / 1e6 - ) - print(f"q mem bw: {q_mem_bw_tb_s} TB/s") - print(f"dq mem bw: {dq_mem_bw_tb_s} TB/s") - - results.append( - ( - elem_dtype, - use_fp4_custom_triton_dequant_kernel, - q_execution_time_us, - q_mem_bw_tb_s, - dq_execution_time_us, - dq_mem_bw_tb_s, - ) - ) - config.use_fp4_custom_triton_dequant_kernel = False - - torch._dynamo.reset() - - print(tabulate.tabulate(results, headers=headers, floatfmt=".2f")) - - -if __name__ == "__main__": - fire.Fire(run) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 7de90daa1c..388af07874 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -146,9 +146,6 @@ class MXLinearConfig(AOBaseConfig): MXFP8Dim1CastKernelChoice.TORCH ) - # If True, uses a custom triton kernel for fp4 dequantize - use_fp4_custom_triton_dequant_kernel: bool = False - scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR def __post_init__(self): @@ -217,8 +214,6 @@ def short_str(self) -> str: s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}" s += f", kernel={self.gemm_kernel_choice.value}" s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}" - if self.use_fp4_custom_triton_dequant_kernel: - s += ", use_fp4_custom_triton_dequant_kernel=True" if self.scale_calculation_mode != ScaleCalculationMode.FLOOR: s += f", scale_calculation_mode={self.scale_calculation_mode}" return s diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 732af4df2a..cd605917af 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -30,7 +30,6 @@ from torchao.prototype.mx_formats.constants import ( E8M0_EXPONENT_BIAS, E8M0_EXPONENT_NAN_VAL, - F4_E2M1_EXP_BIAS, F6_E2M3_EXP_BIAS, F6_E3M2_EXP_BIAS, F32_EXP_BIAS, @@ -196,89 +195,6 @@ def _fp4_packed_to_bf16( output = output.to(tl.bfloat16) return output - @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_IN": 128}), - triton.Config({"BLOCK_SIZE_IN": 256}), - triton.Config({"BLOCK_SIZE_IN": 512}), - triton.Config({"BLOCK_SIZE_IN": 1024}), - triton.Config({"BLOCK_SIZE_IN": 2048}), - ], - key=["n_elements_in"], - ) - @triton.jit - def triton_f4_to_scaled_bf16_kernel( - x_ptr, - s_ptr, - output_ptr, - n_elements_in, - mx_block_size: tl.constexpr, - sign_mask_f4: tl.constexpr, - mantissa_mask_f4: tl.constexpr, - mbits_f4_e2m1: tl.constexpr, - ebits_f4_e2m1: tl.constexpr, - f4_e2m1_exp_bias: tl.constexpr, - mbits_f32: tl.constexpr, - ebits_f32: tl.constexpr, - f32_exp_bias: tl.constexpr, - zero_bits_f32: tl.constexpr, - zero_point_five_bits_f32: tl.constexpr, - e8m0_exponent_bias: tl.constexpr, - e8m0_exponent_nan_val: tl.constexpr, - BLOCK_SIZE_IN: tl.constexpr, - ): - pid = tl.program_id(axis=0) - n_elements_out = n_elements_in * 2 - n_elements_s = n_elements_out // 32 - - BLOCK_SIZE_S: tl.constexpr = BLOCK_SIZE_IN // 16 - BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2 - - block_start_in = pid * BLOCK_SIZE_IN - offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN) - mask_in = offsets_in < n_elements_in - # packed uint8 - x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) - output = _fp4_packed_to_bf16( - x_packed, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - ) - - # load scale - block_start_s = pid * BLOCK_SIZE_S - offsets_s = block_start_s + tl.arange(0, BLOCK_SIZE_S) - mask_s = offsets_s < n_elements_s - s = tl.load(s_ptr + offsets_s, mask=mask_s) - - # create the scale in bf16 - s_offset = s.to(tl.int16) - e8m0_exponent_bias - s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16) - s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan")) - - # multiply output by scale - # TODO(later): see if manipulating the exponent instead of fp - # multiplication is going to give a significant speedup - output = tl.reshape(output, (BLOCK_SIZE_OUT // mx_block_size, mx_block_size)) # noqa: E501 - s_fp = tl.reshape(s_fp, (BLOCK_SIZE_S // 1, 1)) - output = output * s_fp - output = tl.reshape(output, (BLOCK_SIZE_OUT,)) - - # set up output offsets - block_start_out = pid * BLOCK_SIZE_OUT - offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT) - mask_out = offsets_out < n_elements_out - - tl.store(output_ptr + offsets_out, output, mask=mask_out) - @triton.jit def _fp6_packed_to_bf16( packed_4bits_a, @@ -575,28 +491,6 @@ def triton_pack_uint6_kernel( else: - def triton_f4_to_scaled_bf16_kernel( - x_ptr, - s_ptr, - output_ptr, - n_elements_in, - mx_block_size, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - e8m0_exponent_bias, - e8m0_exponent_nan_val, - BLOCK_SIZE_IN, - ): - raise AssertionError("unsupported without triton") - def triton_f6_to_bf16_kernel( x_ptr, output_ptr, @@ -638,47 +532,6 @@ def triton_pack_uint6_kernel( raise AssertionError("unsupported without triton") -def triton_f4_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, -): - """ - Input: a tensor of packed fp4 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) - new_shape = (*x.shape[:-1], x.shape[-1] * 2) - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - n_elements_in = x.numel() - grid = lambda meta: ( # noqa: E731 - triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), - ) - triton_f4_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_elements_in, - mx_block_size, - sign_mask_f4=SIGN_MASK_F4, - mantissa_mask_f4=MANTISSA_MASK_F4, - mbits_f4_e2m1=MBITS_F4_E2M1, - ebits_f4_e2m1=EBITS_F4_E2M1, - f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS, - mbits_f32=MBITS_F32, - ebits_f32=EBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - zero_bits_f32=ZERO_BITS_F32, - zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output - - def triton_f6_e2m3_to_bf16(x: torch.Tensor) -> torch.Tensor: """ Input: a tensor of packed fp6 values diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 1a033a1096..161fcd6064 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -65,7 +65,6 @@ def _to_mxfp8_dim1_kernel_wrapper( elem_dtype, block_size, hp_dtype, - False, gemm_kernel_choice, False, None, @@ -85,7 +84,6 @@ def _to_mxfp8_dim1_kernel_wrapper( elem_dtype, block_size, hp_dtype, - False, gemm_kernel_choice, False, None, diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index bd4efd379b..07e47eed66 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -95,7 +95,6 @@ def _addmm_mx_dispatch( k.elem_dtype, k.block_size, k.scaling_mode, - k.use_fp4_custom_triton_dequant_kernel, k.gemm_kernel_choice, k.pack_fp6, ) @@ -186,7 +185,6 @@ def mx_t(func, types, args, kwargs): old._elem_dtype, old._block_size, old._orig_dtype, - old._use_fp4_custom_triton_dequant_kernel, old._gemm_kernel_choice, old._pack_fp6, old.act_quant_kwargs, @@ -231,7 +229,6 @@ def mx_view_op(func, types, args, kwargs): args[0]._elem_dtype, args[0]._block_size, args[0]._orig_dtype, - args[0]._use_fp4_custom_triton_dequant_kernel, args[0]._gemm_kernel_choice, args[0]._pack_fp6, args[0].act_quant_kwargs, @@ -293,7 +290,6 @@ def mx_slice(func, types, args, kwargs): x._elem_dtype, x._block_size, x._orig_dtype, - x._use_fp4_custom_triton_dequant_kernel, x._gemm_kernel_choice, x._pack_fp6, x.act_quant_kwargs, @@ -348,7 +344,6 @@ def autocast_to_copy(func, types, args, kwargs): tensor._elem_dtype, tensor._block_size, dtype, - tensor._use_fp4_custom_triton_dequant_kernel, tensor._gemm_kernel_choice, tensor._pack_fp6, tensor.act_quant_kwargs, diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 533e186acd..273f1b2b56 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -53,7 +53,6 @@ f32_to_f6_e3m2_unpacked, pack_uint4, pack_uint6, - triton_f4_to_scaled_bf16, triton_f6_e2m3_to_scaled_bf16, triton_f6_e3m2_to_scaled_bf16, unpack_uint4, @@ -77,7 +76,6 @@ class QuantizeTensorToMXKwargs(QuantizeTensorKwargs): elem_dtype: Union[torch.dtype, str] = torch.float8_e4m3fn block_size: int = 32 scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR - use_fp4_custom_triton_dequant_kernel: bool = False gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED pack_fp6: bool = False @@ -349,7 +347,6 @@ def to_dtype( elem_dtype, block_size, target_dtype, - use_fp4_custom_triton_dequant_kernel, pack_fp6, ): orig_shape = data_lp.shape @@ -392,25 +389,15 @@ def to_dtype( data_hp = f6_e3m2_unpacked_to_f32(data_lp) data_hp = data_hp.to(target_dtype).reshape(orig_shape) elif elem_dtype == torch.float4_e2m1fn_x2: - if use_fp4_custom_triton_dequant_kernel: - data_hp_rescaled = triton_f4_to_scaled_bf16( - data_lp, - scale_e8m0, - block_size, - ) - if is_transposed: - data_hp_rescaled = data_hp_rescaled.t() - return data_hp_rescaled.to(target_dtype) - else: - # fp4 - f4_unpacked = unpack_uint4(data_lp) - # for now we only have a cast to f32 - # TODO(future PR): add cast directly to bf16 - f32 = f4_unpacked_to_f32(f4_unpacked) - data_hp = f32.to(target_dtype) - # manually adjust shape to account for the unpacking - # TODO(future PR): clean up the shape code and remove the hack - # below + # fp4 + f4_unpacked = unpack_uint4(data_lp) + # for now we only have a cast to f32 + # TODO(future PR): add cast directly to bf16 + f32 = f4_unpacked_to_f32(f4_unpacked) + data_hp = f32.to(target_dtype) + # manually adjust shape to account for the unpacking + # TODO(future PR): clean up the shape code and remove the hack + # below orig_shape = (*orig_shape[:-1], orig_shape[-1] * 2) else: raise AssertionError("unsupported") @@ -469,7 +456,6 @@ class MXTensor(TorchAOBaseTensor): "_elem_dtype", "_block_size", "_orig_dtype", - "_use_fp4_custom_triton_dequant_kernel", "_gemm_kernel_choice", "_pack_fp6", "act_quant_kwargs", @@ -482,7 +468,6 @@ def __new__( elem_dtype, block_size, orig_dtype, - use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6, act_quant_kwargs, @@ -551,9 +536,6 @@ def __new__( self._elem_dtype = elem_dtype self._block_size = block_size self._orig_dtype = orig_dtype - self._use_fp4_custom_triton_dequant_kernel = ( - use_fp4_custom_triton_dequant_kernel - ) self._gemm_kernel_choice = gemm_kernel_choice self._pack_fp6 = pack_fp6 self.act_quant_kwargs = act_quant_kwargs @@ -587,7 +569,6 @@ def to_dtype(self, target_dtype): self._elem_dtype, self._block_size, target_dtype, - self._use_fp4_custom_triton_dequant_kernel, self._pack_fp6, ) @@ -598,7 +579,6 @@ def to_mx( elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, - use_fp4_custom_triton_dequant_kernel: bool = False, # TODO(future PR): switch default gemm to cublas gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, pack_fp6: bool = False, @@ -617,7 +597,6 @@ def to_mx( elem_dtype, block_size, data_hp.dtype, - use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6, act_quant_kwargs, @@ -636,7 +615,6 @@ def to_mx( elem_dtype, block_size, data_hp.dtype, - use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6, act_quant_kwargs,