Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 17, 2025

Stacked PRs:


[mxfp8 moe training] add CUDA kernel for per-group conversion of scale factors to blocked layout

Benchmarks

kernel_version             scale_shape      time_us    mem_bw_gbps  fastest_version            speedup_vs_torch
-------------------------  -------------  ---------  -------------  -------------------------  ------------------
torch                      (8192, 512)       304.08         28.449  cuda_pipelined_64_chunks4
triton                     (8192, 512)        20.48        422.4    cuda_pipelined_64_chunks4  14.85x
cuda_pipelined_64_chunks4  (8192, 512)        15.36        563.2    cuda_pipelined_64_chunks4  19.80x
torch                      (8192, 1024)      306.37         55.617  cuda_pipelined_64_chunks4
triton                     (8192, 1024)       27.49        619.884  cuda_pipelined_64_chunks4  11.15x
cuda_pipelined_64_chunks4  (8192, 1024)       19.33        881.589  cuda_pipelined_64_chunks4  15.85x
torch                      (8192, 2048)      322.18        104.963  cuda_pipelined_64_chunks4
triton                     (8192, 2048)       60.7         557.073  cuda_pipelined_64_chunks4  5.31x
cuda_pipelined_64_chunks4  (8192, 2048)       23.74       1424.22   cuda_pipelined_64_chunks4  13.57x
torch                      (8192, 4096)      383.12        175.848  cuda_pipelined_64_chunks4
triton                     (8192, 4096)      125.73        535.847  cuda_pipelined_64_chunks4  3.05x
cuda_pipelined_64_chunks4  (8192, 4096)       33.79       1993.7    cuda_pipelined_64_chunks4  11.34x
torch                      (8192, 32768)    2000.42        268.511  cuda_pipelined_64_chunks4
triton                     (8192, 32768)    1047.55        512.751  cuda_pipelined_64_chunks4  1.91x
cuda_pipelined_64_chunks4  (8192, 32768)     177.22       3030.95   cuda_pipelined_64_chunks4  11.29x
torch                      (5120, 512)       301.2          17.951  cuda_pipelined_64_chunks4
triton                     (5120, 512)        20.42        264.828  cuda_pipelined_64_chunks4  14.75x
cuda_pipelined_64_chunks4  (5120, 512)        13.28        407.133  cuda_pipelined_64_chunks4  22.68x
torch                      (5120, 1024)      303.36         35.105  cuda_pipelined_64_chunks4
triton                     (5120, 1024)       27.65        385.185  cuda_pipelined_64_chunks4  10.97x
cuda_pipelined_64_chunks4  (5120, 1024)       15.39        691.892  cuda_pipelined_64_chunks4  19.71x
torch                      (5120, 2048)      307.17         68.807  cuda_pipelined_64_chunks4
triton                     (5120, 2048)       38.88        543.605  cuda_pipelined_64_chunks4  7.90x
cuda_pipelined_64_chunks4  (5120, 2048)       19.49       1084.53   cuda_pipelined_64_chunks4  15.76x
torch                      (5120, 4096)      316.48        133.048  cuda_pipelined_64_chunks4
triton                     (5120, 4096)       62.46        674.098  cuda_pipelined_64_chunks4  5.07x
cuda_pipelined_64_chunks4  (5120, 4096)       25.63       1642.75   cuda_pipelined_64_chunks4  12.35x
torch                      (5120, 32768)    1258.21        266.815  cuda_pipelined_64_chunks4
triton                     (5120, 32768)     759.01        442.299  cuda_pipelined_64_chunks4  1.66x
cuda_pipelined_64_chunks4  (5120, 32768)     117.89       2847.69   cuda_pipelined_64_chunks4  10.67x
torch                      (7168, 512)       301.54         25.103  cuda_pipelined_64_chunks4
triton                     (7168, 512)        19.68        384.624  cuda_pipelined_64_chunks4  15.32x
cuda_pipelined_64_chunks4  (7168, 512)        15.36        492.8    cuda_pipelined_64_chunks4  19.63x
torch                      (7168, 1024)      305.25         48.844  cuda_pipelined_64_chunks4
triton                     (7168, 1024)       33.82        440.795  cuda_pipelined_64_chunks4  9.02x
cuda_pipelined_64_chunks4  (7168, 1024)       17.44        854.899  cuda_pipelined_64_chunks4  17.50x
torch                      (7168, 2048)      306.43         96.561  cuda_pipelined_64_chunks4
triton                     (7168, 2048)       62.5         473.462  cuda_pipelined_64_chunks4  4.90x
cuda_pipelined_64_chunks4  (7168, 2048)       21.63       1367.86   cuda_pipelined_64_chunks4  14.17x
torch                      (7168, 4096)      366.5         160.847  cuda_pipelined_64_chunks4
triton                     (7168, 4096)       72.74        810.46   cuda_pipelined_64_chunks4  5.04x
cuda_pipelined_64_chunks4  (7168, 4096)       31.74       1857.03   cuda_pipelined_64_chunks4  11.55x
torch                      (7168, 32768)    1770.08        265.52   cuda_pipelined_64_chunks4
triton                     (7168, 32768)     867.3         541.904  cuda_pipelined_64_chunks4  2.04x
cuda_pipelined_64_chunks4  (7168, 32768)     158.75       2960.54   cuda_pipelined_64_chunks4  11.15x
torch                      (2048, 512)       300.61          7.194  cuda_pipelined_64_chunks4
triton                     (2048, 512)        19.65        110.072  cuda_pipelined_64_chunks4  15.30x
cuda_pipelined_64_chunks4  (2048, 512)        11.33        190.915  cuda_pipelined_64_chunks4  26.54x
torch                      (2048, 1024)      302.69         14.073  cuda_pipelined_64_chunks4
triton                     (2048, 1024)       19.87        214.364  cuda_pipelined_64_chunks4  15.23x
cuda_pipelined_64_chunks4  (2048, 1024)       14.4         295.822  cuda_pipelined_64_chunks4  21.02x
torch                      (2048, 2048)      304.26         27.786  cuda_pipelined_64_chunks4
triton                     (2048, 2048)       25.6         330.24   cuda_pipelined_64_chunks4  11.88x
cuda_pipelined_64_chunks4  (2048, 2048)       15.36        550.4    cuda_pipelined_64_chunks4  19.81x
torch                      (2048, 4096)      308.32         54.628  cuda_pipelined_64_chunks4
triton                     (2048, 4096)       44.06        382.234  cuda_pipelined_64_chunks4  7.00x
cuda_pipelined_64_chunks4  (2048, 4096)       17.44        965.754  cuda_pipelined_64_chunks4  17.68x
torch                      (2048, 32768)     596.8         225.005  cuda_pipelined_64_chunks4
triton                     (2048, 32768)     318.5         421.617  cuda_pipelined_64_chunks4  1.87x
cuda_pipelined_64_chunks4  (2048, 32768)      54.27       2474.26   cuda_pipelined_64_chunks4  11.00x

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3504

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 0dfef18 with merge base 1f9bfd7 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 17, 2025
…e factors to blocked layout

stack-info: PR: #3504, branch: danielvegamyhre/stack/86
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant