diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index d133636ce8..b656aca30f 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -112,6 +112,7 @@ def run( "dim0_mxfp4_floor", "dim0_mxfp8_rceil", "dim0_mxfp8_triton_floor", + "dim0_mxfp8_triton_rceil", "dim0_nvfp4", "dim0_nvfp4_triton_swizzle", "dim1_mxfp8_floor", @@ -243,9 +244,33 @@ def run( y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE) for _ in range(2): - __ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE) + __ = triton_to_mxfp8_dim0( + x, inner_block_size=BLOCK_SIZE, scaling_mode="floor" + ) time_us = benchmark_cuda_function_in_microseconds( - lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE), + lambda x, b: triton_to_mxfp8_dim0( + x, inner_block_size=BLOCK_SIZE, scaling_mode="floor" + ), + x, + BLOCK_SIZE, + ) + assert y_d0.dtype == torch.float8_e4m3fn + assert s_d0.dtype == torch.float8_e8m0fnu + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + + elif mode == "dim0_mxfp8_triton_rceil": + y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE) + + for _ in range(2): + __ = triton_to_mxfp8_dim0( + x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil" + ) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: triton_to_mxfp8_dim0( + x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil" + ), x, BLOCK_SIZE, ) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index a4839f7c61..c7dc281fa4 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -45,7 +45,6 @@ from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( is_cuda_version_at_least, - is_sm_at_least_89, is_sm_at_least_100, torch_version_at_least, ) @@ -423,7 +422,9 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device): def triton_to_mxfp8_dim0_reference( - x_hp: torch.Tensor, block_size + x_hp: torch.Tensor, + block_size, + scaling_mode=ScaleCalculationMode.FLOOR, ) -> tuple[torch.Tensor, torch.Tensor]: """ A reference version of `triton_to_mxfp8_dim0` for rowwise quantization. @@ -431,7 +432,9 @@ def triton_to_mxfp8_dim0_reference( from torchao.prototype.mx_formats.mx_tensor import to_mx # cast across dim0 (rowwise) - no transpose needed - scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size) + scale_e8m0_dim0, x_hp_d0_normalized = to_mx( + x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode + ) scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu) return ( x_hp_d0_normalized, @@ -441,8 +444,8 @@ def triton_to_mxfp8_dim0_reference( @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( - not is_sm_at_least_89(), - reason="float8 in triton requires CUDA capability 8.9 or greater", + not is_sm_at_least_100(), + reason="mxfp8 in triton requires CUDA capability 10.0 or greater", ) @pytest.mark.parametrize("M", (128, 256)) @pytest.mark.parametrize("K", (128, 256)) @@ -461,10 +464,19 @@ def test_triton_mxfp8_dim1_randn(M, K): ) @pytest.mark.parametrize("M", (128, 256)) @pytest.mark.parametrize("K", (128, 256)) -def test_triton_mxfp8_dim0_randn(M, K): +@pytest.mark.parametrize( + "scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL) +) +def test_triton_mxfp8_dim0_randn(M, K, scaling_mode): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") - x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) - x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) + x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference( + x, block_size=32, scaling_mode=scaling_mode + ) + x_mx_t, x_s_t = triton_to_mxfp8_dim0( + x, + inner_block_size=32, + scaling_mode=scaling_mode.value.lower(), + ) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) @@ -474,10 +486,19 @@ def test_triton_mxfp8_dim0_randn(M, K): not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) -def test_triton_mxfp8_dim0_zeros(): +@pytest.mark.parametrize( + "scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL) +) +def test_triton_mxfp8_dim0_zeros(scaling_mode): x = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda") - x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) - x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) + x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference( + x, block_size=32, scaling_mode=scaling_mode + ) + x_mx_t, x_s_t = triton_to_mxfp8_dim0( + x, + inner_block_size=32, + scaling_mode=scaling_mode.value.lower(), + ) assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs" torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 211c9b16a9..52f0351617 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -121,8 +121,13 @@ def test_linear_eager_vs_hp( pytest.skip("CUDA capability >= 8.9 required for float8 in triton") if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON: - if scale_calculation_mode != ScaleCalculationMode.FLOOR: - pytest.skip("unsupported configuration") + if scale_calculation_mode not in ( + ScaleCalculationMode.FLOOR, + ScaleCalculationMode.RCEIL, + ): + pytest.skip("triton mxfp8 quantization kernels only require sm100") + if not is_sm_at_least_100(): + pytest.skip("triton mxfp8 quantization kernels require sm100") elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA: if scale_calculation_mode not in ( ScaleCalculationMode.FLOOR, @@ -316,8 +321,15 @@ def test_linear_compile( pytest.skip("unsupported configuration") if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON: - if scale_calculation_mode != ScaleCalculationMode.FLOOR: - pytest.skip("unsupported configuration") + if scale_calculation_mode not in ( + ScaleCalculationMode.FLOOR, + ScaleCalculationMode.RCEIL, + ): + pytest.skip( + "triton mxfp8 quantization kernels only support FLOOR and RCEIL scaling modes" + ) + if is_sm_at_least_100(): + pytest.skip("triton mxfp8 quantization kernels require sm100") elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA: if scale_calculation_mode not in ( ScaleCalculationMode.FLOOR, diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index ad7ab5d596..236aa3db53 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -166,7 +166,7 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: from torch.library import triton_op, wrap_triton @triton.jit - def _triton_calculate_scale(x, axis): + def _triton_calculate_scale(x, axis, SCALING_MODE: tl.constexpr): # There is no good support for accessing globals from a jit'ed triton # function, so we redefine them here. Since this is prototype code which # we plan to remove after torch.compile catches up, this is fine. @@ -179,23 +179,48 @@ def _triton_calculate_scale(x, axis): # Find the maximum absolute value for each row max_abs = tl.max(x, axis=axis) - # Calculate the e8m0 scale by extracting the exponent (floor) - # TODO(future PR): support other exponent extraction types (ceil, RNE) - max_abs = max_abs.to(tl.bfloat16) - max_abs_int16 = max_abs.to(tl.int16, bitcast=True) - extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias - extracted_pow2 = extracted_pow2 - target_max_pow2 - scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16) - - # Clamp to exponents that can be represented in e8m0 - # Add 1 to capture NaNs - scale_e8m0_unbiased = tl.clamp( - scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias + 1 - ) + # Compute e8m0 biased scale using either RCEIL or FLOOR rounding. + if SCALING_MODE == "rceil": + # RCEIL scaling mode using PTX instruction supported on sm100. + # The input should be: amax / 448.0 + # where 448.0 is the max representable value in FP8 E4M3 format. + F8E4M3_MAX_RCP: tl.constexpr = 1.0 / 448.0 + scale_input = max_abs.to(tl.float32) * F8E4M3_MAX_RCP + + # The PTX instruction outputs a packed uint16 where: + # - high byte = E8M0 of first input (0.0 in our case) + # - low byte = E8M0 of second input (scale_input) + # Casting uint16 to uint8 naturally truncates to the low byte. + scale_e8m0_biased = tl.inline_asm_elementwise( + asm="cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, $1;", + constraints="=h,r", + args=[scale_input.to(tl.float32, bitcast=False)], + dtype=tl.uint16, + is_pure=True, + pack=1, + ).to(tl.uint8) + else: + tl.static_assert(SCALING_MODE == "floor") + + # Original floor implementation + # Calculate the e8m0 scale by extracting the exponent (floor) + max_abs = max_abs.to(tl.bfloat16) + max_abs_int16 = max_abs.to(tl.int16, bitcast=True) + extracted_pow2 = ( + (max_abs_int16 >> bf16_mbits) & 0b11111111 + ) - bf16_exp_bias + extracted_pow2 = extracted_pow2 - target_max_pow2 + scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16) + + # Clamp to exponents that can be represented in e8m0 + # Add 1 to capture NaNs + scale_e8m0_unbiased = tl.clamp( + scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias + 1 + ) - # Create the biased e8m0 representation and cast it to 8 bits - scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias - scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8) + # Create the biased e8m0 representation and cast it to 8 bits + scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias + scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8) # TODO(future PR): add NaN handling here, # https://github.com/pytorch/pytorch/pull/100572 will likely be useful to @@ -248,6 +273,7 @@ def to_mxfp8_dim1_kernel( ROW_TILE_SIZE: tl.constexpr, COL_TILE_SIZE: tl.constexpr, INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX + SCALING_MODE: tl.constexpr, ): """ Example tiling for n_rows==8, n_cols=8, ROW_TILE_SIZE=4, COL_TILE_SIZE=4, INNER_BLOCK_SIZE=2, @@ -334,7 +360,11 @@ def to_mxfp8_dim1_kernel( # Find the maximum absolute value for each column # shape: (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,) - col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(x_block_abs_t_r, axis=1) + col_scale_r, col_scale_e8m0_r = _triton_calculate_scale( + x_block_abs_t_r, + axis=1, + SCALING_MODE=SCALING_MODE, + ) # Divide each column by scale # Broadcasting col_scale to match x_block's shape @@ -397,6 +427,7 @@ def to_mxfp8_dim0_kernel( ROW_TILE_SIZE: tl.constexpr, COL_TILE_SIZE: tl.constexpr, SCALE_BLOCK_SIZE: tl.constexpr, # should be 32 for MX + SCALING_MODE: tl.constexpr, ): """ Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity). @@ -432,7 +463,9 @@ def to_mxfp8_dim0_kernel( # Find the maximum absolute value for each row (across columns) # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) - scale_fp32_r, scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1) + scale_fp32_r, scale_e8m0_r = _triton_calculate_scale( + x_block_abs_r, axis=1, mode=SCALING_MODE + ) # Divide each row by scale # Broadcasting scale to match x_block's shape @@ -468,12 +501,15 @@ def to_mxfp8_dim0_kernel( @triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={}) def triton_to_mxfp8_dim0( - x: torch.Tensor, inner_block_size: int = 32 + x: torch.Tensor, + inner_block_size: int = 32, + scaling_mode: str = "rceil", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Input: * `x` - input tensor, in row major memory layout * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes + * `scaling_mode` - floor or rceil Output: * `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise) @@ -518,6 +554,7 @@ def triton_to_mxfp8_dim0( n_rows=n_rows, n_cols=n_cols, SCALE_BLOCK_SIZE=inner_block_size, + SCALING_MODE=scaling_mode, ) # Reshape output back to original shape @@ -531,7 +568,7 @@ def triton_to_mxfp8_dim0( @triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={}) def triton_to_mxfp8_dim1( - x: torch.Tensor, inner_block_size: int = 32 + x: torch.Tensor, inner_block_size: int = 32, scaling_mode: str = "rceil" ) -> Tuple[torch.Tensor, torch.Tensor]: """ Input: @@ -583,6 +620,7 @@ def triton_to_mxfp8_dim1( n_rows=n_rows, n_cols=n_cols, INNER_BLOCK_SIZE=inner_block_size, + SCALING_MODE=scaling_mode, ) return (