Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def run(
"dim0_mxfp4_floor",
"dim0_mxfp8_rceil",
"dim0_mxfp8_triton_floor",
"dim0_mxfp8_triton_rceil",
"dim0_nvfp4",
"dim0_nvfp4_triton_swizzle",
"dim1_mxfp8_floor",
Expand Down Expand Up @@ -243,9 +244,33 @@ def run(
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)

for _ in range(2):
__ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
__ = triton_to_mxfp8_dim0(
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
)
time_us = benchmark_cuda_function_in_microseconds(
lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE),
lambda x, b: triton_to_mxfp8_dim0(
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
),
x,
BLOCK_SIZE,
)
assert y_d0.dtype == torch.float8_e4m3fn
assert s_d0.dtype == torch.float8_e8m0fnu
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim0_mxfp8_triton_rceil":
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)

for _ in range(2):
__ = triton_to_mxfp8_dim0(
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
)
time_us = benchmark_cuda_function_in_microseconds(
lambda x, b: triton_to_mxfp8_dim0(
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
),
x,
BLOCK_SIZE,
)
Expand Down
43 changes: 32 additions & 11 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
is_cuda_version_at_least,
is_sm_at_least_89,
is_sm_at_least_100,
torch_version_at_least,
)
Expand Down Expand Up @@ -423,15 +422,19 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):


def triton_to_mxfp8_dim0_reference(
x_hp: torch.Tensor, block_size
x_hp: torch.Tensor,
block_size,
scaling_mode=ScaleCalculationMode.FLOOR,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
A reference version of `triton_to_mxfp8_dim0` for rowwise quantization.
"""
from torchao.prototype.mx_formats.mx_tensor import to_mx

# cast across dim0 (rowwise) - no transpose needed
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
)
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
return (
x_hp_d0_normalized,
Expand All @@ -441,8 +444,8 @@ def triton_to_mxfp8_dim0_reference(

@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
not is_sm_at_least_100(),
reason="mxfp8 in triton requires CUDA capability 10.0 or greater",
)
@pytest.mark.parametrize("M", (128, 256))
@pytest.mark.parametrize("K", (128, 256))
Expand All @@ -461,10 +464,19 @@ def test_triton_mxfp8_dim1_randn(M, K):
)
@pytest.mark.parametrize("M", (128, 256))
@pytest.mark.parametrize("K", (128, 256))
def test_triton_mxfp8_dim0_randn(M, K):
@pytest.mark.parametrize(
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
)
def test_triton_mxfp8_dim0_randn(M, K, scaling_mode):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
x, block_size=32, scaling_mode=scaling_mode
)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
x,
inner_block_size=32,
scaling_mode=scaling_mode.value.lower(),
)
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

Expand All @@ -474,10 +486,19 @@ def test_triton_mxfp8_dim0_randn(M, K):
not is_sm_at_least_100(),
reason="mxfp8 requires CUDA capability 10.0 or greater",
)
def test_triton_mxfp8_dim0_zeros():
@pytest.mark.parametrize(
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
)
def test_triton_mxfp8_dim0_zeros(scaling_mode):
x = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda")
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
x, block_size=32, scaling_mode=scaling_mode
)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
x,
inner_block_size=32,
scaling_mode=scaling_mode.value.lower(),
)
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
Expand Down
20 changes: 16 additions & 4 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,13 @@ def test_linear_eager_vs_hp(
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")

if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
if scale_calculation_mode != ScaleCalculationMode.FLOOR:
pytest.skip("unsupported configuration")
if scale_calculation_mode not in (
ScaleCalculationMode.FLOOR,
ScaleCalculationMode.RCEIL,
):
pytest.skip("triton mxfp8 quantization kernels only require sm100")
if not is_sm_at_least_100():
pytest.skip("triton mxfp8 quantization kernels require sm100")
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
if scale_calculation_mode not in (
ScaleCalculationMode.FLOOR,
Expand Down Expand Up @@ -316,8 +321,15 @@ def test_linear_compile(
pytest.skip("unsupported configuration")

if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
if scale_calculation_mode != ScaleCalculationMode.FLOOR:
pytest.skip("unsupported configuration")
if scale_calculation_mode not in (
ScaleCalculationMode.FLOOR,
ScaleCalculationMode.RCEIL,
):
pytest.skip(
"triton mxfp8 quantization kernels only support FLOOR and RCEIL scaling modes"
)
if is_sm_at_least_100():
pytest.skip("triton mxfp8 quantization kernels require sm100")
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
if scale_calculation_mode not in (
ScaleCalculationMode.FLOOR,
Expand Down
80 changes: 59 additions & 21 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor:
from torch.library import triton_op, wrap_triton

@triton.jit
def _triton_calculate_scale(x, axis):
def _triton_calculate_scale(x, axis, SCALING_MODE: tl.constexpr):
# There is no good support for accessing globals from a jit'ed triton
# function, so we redefine them here. Since this is prototype code which
# we plan to remove after torch.compile catches up, this is fine.
Expand All @@ -179,23 +179,48 @@ def _triton_calculate_scale(x, axis):
# Find the maximum absolute value for each row
max_abs = tl.max(x, axis=axis)

# Calculate the e8m0 scale by extracting the exponent (floor)
# TODO(future PR): support other exponent extraction types (ceil, RNE)
max_abs = max_abs.to(tl.bfloat16)
max_abs_int16 = max_abs.to(tl.int16, bitcast=True)
extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias
extracted_pow2 = extracted_pow2 - target_max_pow2
scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16)

# Clamp to exponents that can be represented in e8m0
# Add 1 to capture NaNs
scale_e8m0_unbiased = tl.clamp(
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias + 1
)
# Compute e8m0 biased scale using either RCEIL or FLOOR rounding.
if SCALING_MODE == "rceil":
# RCEIL scaling mode using PTX instruction supported on sm100.
# The input should be: amax / 448.0
# where 448.0 is the max representable value in FP8 E4M3 format.
F8E4M3_MAX_RCP: tl.constexpr = 1.0 / 448.0
scale_input = max_abs.to(tl.float32) * F8E4M3_MAX_RCP

# The PTX instruction outputs a packed uint16 where:
# - high byte = E8M0 of first input (0.0 in our case)
# - low byte = E8M0 of second input (scale_input)
# Casting uint16 to uint8 naturally truncates to the low byte.
scale_e8m0_biased = tl.inline_asm_elementwise(
asm="cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, $1;",
constraints="=h,r",
args=[scale_input.to(tl.float32, bitcast=False)],
dtype=tl.uint16,
is_pure=True,
pack=1,
).to(tl.uint8)
else:
tl.static_assert(SCALING_MODE == "floor")

# Original floor implementation
# Calculate the e8m0 scale by extracting the exponent (floor)
max_abs = max_abs.to(tl.bfloat16)
max_abs_int16 = max_abs.to(tl.int16, bitcast=True)
extracted_pow2 = (
(max_abs_int16 >> bf16_mbits) & 0b11111111
) - bf16_exp_bias
extracted_pow2 = extracted_pow2 - target_max_pow2
scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16)

# Clamp to exponents that can be represented in e8m0
# Add 1 to capture NaNs
scale_e8m0_unbiased = tl.clamp(
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias + 1
)

# Create the biased e8m0 representation and cast it to 8 bits
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
# Create the biased e8m0 representation and cast it to 8 bits
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)

# TODO(future PR): add NaN handling here,
# https://github.com/pytorch/pytorch/pull/100572 will likely be useful to
Expand Down Expand Up @@ -248,6 +273,7 @@ def to_mxfp8_dim1_kernel(
ROW_TILE_SIZE: tl.constexpr,
COL_TILE_SIZE: tl.constexpr,
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
SCALING_MODE: tl.constexpr,
):
"""
Example tiling for n_rows==8, n_cols=8, ROW_TILE_SIZE=4, COL_TILE_SIZE=4, INNER_BLOCK_SIZE=2,
Expand Down Expand Up @@ -334,7 +360,11 @@ def to_mxfp8_dim1_kernel(

# Find the maximum absolute value for each column
# shape: (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,)
col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(x_block_abs_t_r, axis=1)
col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(
x_block_abs_t_r,
axis=1,
SCALING_MODE=SCALING_MODE,
)

# Divide each column by scale
# Broadcasting col_scale to match x_block's shape
Expand Down Expand Up @@ -397,6 +427,7 @@ def to_mxfp8_dim0_kernel(
ROW_TILE_SIZE: tl.constexpr,
COL_TILE_SIZE: tl.constexpr,
SCALE_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
SCALING_MODE: tl.constexpr,
):
"""
Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
Expand Down Expand Up @@ -432,7 +463,9 @@ def to_mxfp8_dim0_kernel(

# Find the maximum absolute value for each row (across columns)
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
scale_fp32_r, scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)
scale_fp32_r, scale_e8m0_r = _triton_calculate_scale(
x_block_abs_r, axis=1, mode=SCALING_MODE
)

# Divide each row by scale
# Broadcasting scale to match x_block's shape
Expand Down Expand Up @@ -468,12 +501,15 @@ def to_mxfp8_dim0_kernel(

@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
def triton_to_mxfp8_dim0(
x: torch.Tensor, inner_block_size: int = 32
x: torch.Tensor,
inner_block_size: int = 32,
scaling_mode: str = "rceil",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Input:
* `x` - input tensor, in row major memory layout
* `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
* `scaling_mode` - floor or rceil

Output:
* `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
Expand Down Expand Up @@ -518,6 +554,7 @@ def triton_to_mxfp8_dim0(
n_rows=n_rows,
n_cols=n_cols,
SCALE_BLOCK_SIZE=inner_block_size,
SCALING_MODE=scaling_mode,
)

# Reshape output back to original shape
Expand All @@ -531,7 +568,7 @@ def triton_to_mxfp8_dim0(

@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
def triton_to_mxfp8_dim1(
x: torch.Tensor, inner_block_size: int = 32
x: torch.Tensor, inner_block_size: int = 32, scaling_mode: str = "rceil"
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Input:
Expand Down Expand Up @@ -583,6 +620,7 @@ def triton_to_mxfp8_dim1(
n_rows=n_rows,
n_cols=n_cols,
INNER_BLOCK_SIZE=inner_block_size,
SCALING_MODE=scaling_mode,
)

return (
Expand Down
Loading