Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 17, 2025

Stacked PRs:


[mxfp8 moe training] integrate new cuda kernel for blocked layout for groups along K

Important note:

  • Writing scales to blocked format is much more efficient when they are in row major format. Currently the CUDA kernel for mxfp8 dim1 quantization writes to column major to avoid uncoalesced writes to GMEM.
  • However, I did repeated microbenchmarking on various devices and determined there is no performance regression if we write the scales to row major (see benchmarks below).
  • In the quantization kernel, there are 64 threads per thread block, so we are only computing two scalar, one-byte E8M0 scale factors. So row major vs column major just means 2 transactions instead of 1 for writing the 2 bytes to GMEM. And this is negligible compared to the 64 bytes of quantized data that we'll be writing to global memory.

CUDA 2d tensor dim1 mxfp8 quantization kernel benchmarks for writing scales to col vs row major

FLOOR COL MAJOR
=== GPU 7 ===
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.11.0.dev20251216+cu128
triton version: 3.6.0
mode: dim1_mxfp8_cuda_floor
time_us 150.4959985613823
mem_bw_gbps 5406.75488902199

FLOOR ROW MAJOR
=== GPU 7 ===
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.11.0.dev20251216+cu128
triton version: 3.6.0
mode: dim1_mxfp8_cuda_floor
time_us 150.68800002336502
mem_bw_gbps 5399.8657880775645

RCEIL COL MAJOR
=== GPU 7 ===
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.11.0.dev20251216+cu128
triton version: 3.6.0
mode: dim1_mxfp8_cuda_rceil
time_us 156.6080003976822
mem_bw_gbps 5195.743346021566

RCEIL ROW MAJOR
=== GPU 7 ===
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.11.0.dev20251216+cu128
triton version: 3.6.0
mode: dim1_mxfp8_cuda_rceil
time_us 155.7759940624237
mem_bw_gbps 5223.4940364041595

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3505

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b57a58e with merge base 1f9bfd7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

… groups along K

stack-info: PR: #3505, branch: danielvegamyhre/stack/87
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 17, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/87 branch from 04d00bc to b57a58e Compare December 17, 2025 23:53
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from 26e079e to 0dfef18 Compare December 17, 2025 23:53
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes moe labels Dec 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant