Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,31 @@ def test_qat_nvfp4_training(self, use_per_tensor_scale: bool):
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
self.assertFalse(torch.equal(new_weight, prev_weight))

@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_nvfp4_fake_quanitzed_linear_mixed_precision(self):
"""
Test `NVFP4FakeQuantizedLinear` with bf16 input activations and fp32 weights.
"""
from torchao.prototype.qat.nvfp4 import (
NVFP4FakeQuantizeConfig,
NVFP4FakeQuantizedLinear,
)

activation_dtype = torch.bfloat16
weight_dtype = torch.float32
linear = torch.nn.Linear(128, 512, dtype=weight_dtype).cuda()
activation_config = NVFP4FakeQuantizeConfig(use_per_tensor_scale=True)
weight_config = NVFP4FakeQuantizeConfig(use_per_tensor_scale=True)
linear = NVFP4FakeQuantizedLinear.from_linear(
linear, activation_config, weight_config
)
x = torch.randn(1, 128, dtype=activation_dtype).cuda()
out = linear(x)
self.assertEqual(linear.weight.dtype, weight_dtype)
self.assertEqual(x.dtype, activation_dtype)
self.assertEqual(out.dtype, activation_dtype)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(
not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/qat/nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def forward(
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = None
input_dtype = _input.dtype
_input = NVFP4Tensor.to_nvfp4(
_input,
per_tensor_scale=per_tensor_scale,
Expand Down Expand Up @@ -84,7 +85,7 @@ def forward(
weight.t(),
None, # aten_op, not used
bias,
)
).to(input_dtype)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -156,7 +157,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
fq = _NVFP4QuantizedForwardFakeQuantizedBackward.apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the cast be inside _NVFP4QuantizedForwardFakeQuantizedBackward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved it in

x, self.weight, self.bias, self.activation_config, self.weight_config
)
assert fq.dtype == x.dtype
if batch_size is not None:
return fq.view(batch_size, -1, fq.shape[-1])
else:
Expand Down
Loading