Skip to content

Conversation

@andrewor14
Copy link
Contributor

Summary: This commit adds support for bf16 activations + fp32 weights mixed precision for NVFP4 QAT, which previously threw a dtype assertion error:

File "ao/torchao/prototype/qat/nvfp4.py", line 159, in forward
  assert fq.dtype == x.dtype

Test Plan:

python test/quantization/test_qat.py -k test_nvfp4_fake_quanitzed_linear_mixed_precision

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3501

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 4a5913b with merge base f3342a0 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 17, 2025
@andrewor14 andrewor14 added the topic: bug fix Use this tag for PRs that fix bugs label Dec 17, 2025
@andrewor14 andrewor14 force-pushed the fix-nvfp4-mixed-precision branch from 38c977e to 7826b54 Compare December 17, 2025 16:38
x = x.view(-1, x.shape[-1])
else:
batch_size = None
fq = _NVFP4QuantizedForwardFakeQuantizedBackward.apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the cast be inside _NVFP4QuantizedForwardFakeQuantizedBackward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved it in

**Summary:** This commit adds support for bf16 activations +
fp32 weights mixed precision for NVFP4 QAT, which previously
threw a dtype assertion error:
```
File "ao/torchao/prototype/qat/nvfp4.py", line 159, in forward
  assert fq.dtype == x.dtype
```

**Test Plan:**
```
python test/quantization/test_qat.py -k test_nvfp4_fake_quanitzed_linear_mixed_precision
```
@andrewor14 andrewor14 force-pushed the fix-nvfp4-mixed-precision branch from 7826b54 to 4a5913b Compare December 17, 2025 18:27
@andrewor14 andrewor14 requested a review from vkuzo December 17, 2025 18:27

ctx.save_for_backward(_input, weight)

return _addmm_nvfp4_dispatch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this returning the wrong dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be caused by adding the bias:

result = result + bias

Before this line result was bf16, after this line it's fp32. Do you think we should cast the bias here instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants