Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 18, 2025

Summary

  • Add wgrad_with_hp option to optionally compute gradient of the weight in high precision.
  • Currently, the 2d2d mxfp8 grouped gemm in FBGEMM we dispatch to for wgrad computation provides only modest speedups (1.2x to 1.4x usually, sometimes as low as 0.9x. Higher for larger shapes usually). Adding in quantization overhead, this results in slowdown for smaller shapes and modest speedup for larger shapes.
  • We are working with NVIDIA to try to improve this CUTLASS kernel as we've tried several things ourselves already and it hasn't moved the needle much. In the meantime it makes sense to add an option to do wgrad with bf16.
  • This is also useful for small shapes in general where even with roofline performance, bf16 is strictly better than dynamic quant + mxfp8 grouped gemm.

Tests

  • pytest test/prototype/moe_training/test_scaled_grouped_mm.py -k dq

Benchmarks

Benchmarks show wgrad_with_hp impact:

  • helps for the small DeepSeek V3 16b model ((128000, 1536, 5120, ...))
  • neutral for the larger DSV3 671b model ((128000, 2048, 7168, ...))
  • hurts for Llama4 and ((128000, 8192, 5120, ...)).

Benchmarks are a bit jittery but this general trend holds over a few runs.

(look at column scaled_fwd_bwd_speedup in the middle)

wgrad_with_hp=False

ype/moe_training/benchmark_scaled_grouped_mm_dq.py --compile
M,N,K,G                  recipe                  bf16_fwd_bwd_us    scaled_fwd_bwd_us  scaled_fwd_bwd_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-----------------------  --------------------  -----------------  -------------------  ------------------------  -------------  ---------------  --------------------
(16384, 8192, 5120, 1)   MoEScalingType.MXFP8           3883.07              2831.36   1.371x                         1273.92           954.272  1.335x
(16384, 8192, 5120, 2)   MoEScalingType.MXFP8           4818.91              3183.71   1.514x                         1144.06          1036.29   1.104x
(16384, 8192, 5120, 4)   MoEScalingType.MXFP8           4140.19              3480.67   1.189x                         1163.46           998.768  1.165x
(16384, 8192, 5120, 8)   MoEScalingType.MXFP8           5234.19              3808.16   1.374x                         1068.16          1156.06   0.924x
(128000, 8192, 5120, 1)  MoEScalingType.MXFP8          43548.8              22567.4    1.93x                         14536.8           6346.72   2.29x
(128000, 8192, 5120, 2)  MoEScalingType.MXFP8          43921                23675.5    1.855x                        12642.4           6854.66   1.844x
(128000, 8192, 5120, 4)  MoEScalingType.MXFP8          43187.8              23669.7    1.825x                        13818.4           6421.5    2.152x
(128000, 8192, 5120, 8)  MoEScalingType.MXFP8          41643.6              23096.5    1.803x                        13787.6           6652.96   2.072x
(16384, 1536, 5120, 1)   MoEScalingType.MXFP8            966.72              1008.58   0.958x                          278.56           298.976  0.932x
(16384, 1536, 5120, 2)   MoEScalingType.MXFP8            911.808              951.296  0.958x                          302.112          289.632  1.043x
(16384, 1536, 5120, 4)   MoEScalingType.MXFP8            902.144              968.752  0.931x                          238.72           306.08   0.78x
(16384, 1536, 5120, 8)   MoEScalingType.MXFP8            943.68              1008.67   0.936x                          255.008          310.24   0.822x
(128000, 1536, 5120, 1)  MoEScalingType.MXFP8           8568.98              7410.78   1.156x                         2668.58          2213.06   1.206x
(128000, 1536, 5120, 2)  MoEScalingType.MXFP8           8409.31              6915.33   1.216x                         2691.66          2009.09   1.34x
(128000, 1536, 5120, 4)  MoEScalingType.MXFP8           7865.89              6862.88   1.146x                         2353.09          2037.78   1.155x
(128000, 1536, 5120, 8)  MoEScalingType.MXFP8           8096.9               6400.58   1.265x                         2228.22          1825.79   1.22x
(16384, 2048, 7168, 1)   MoEScalingType.MXFP8           1436.56              1544.26   0.93x                           545.76           456.784  1.195x
(16384, 2048, 7168, 2)   MoEScalingType.MXFP8           1678.46              1447.07   1.16x                           426.016          411.648  1.035x
(16384, 2048, 7168, 4)   MoEScalingType.MXFP8           1686.53              1525.76   1.105x                          488.576          438.192  1.115x
(16384, 2048, 7168, 8)   MoEScalingType.MXFP8           1884.06              1712.26   1.1x                            512.992          526.416  0.974x
(128000, 2048, 7168, 1)  MoEScalingType.MXFP8          15365.7              11003.8    1.396x                         3579.17          3377.6    1.06x
(128000, 2048, 7168, 2)  MoEScalingType.MXFP8          17184.2              10709.9    1.605x                         6145.95          3037.81   2.023x
(128000, 2048, 7168, 4)  MoEScalingType.MXFP8          15260.2              10167      1.501x                         5191.71          3020.51   1.719x
(128000, 2048, 7168, 8)  MoEScalingType.MXFP8          14557.2               9919.68   1.468x                         4398.27          2965.47   1.483x

wgrad with hp = True

M,N,K,G                  recipe                  bf16_fwd_bwd_us    scaled_fwd_bwd_us  scaled_fwd_bwd_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-----------------------  --------------------  -----------------  -------------------  ------------------------  -------------  ---------------  --------------------
(16384, 8192, 5120, 1)   MoEScalingType.MXFP8           4006.91              3372.54   1.188x                         1247.76           844.832  1.477x
(16384, 8192, 5120, 2)   MoEScalingType.MXFP8           3781.63              3090.4    1.224x                         1491.04           806.88   1.848x
(16384, 8192, 5120, 4)   MoEScalingType.MXFP8           3561.42              4049.89   0.879x                         1064.96           985.952  1.08x
(16384, 8192, 5120, 8)   MoEScalingType.MXFP8           4780.06              4003.46   1.194x                         1274.86          1153.44   1.105x
(128000, 8192, 5120, 1)  MoEScalingType.MXFP8          42967.1              25230.3    1.703x                        15023.2           6571.97   2.286x
(128000, 8192, 5120, 2)  MoEScalingType.MXFP8          43644.4              29649      1.472x                        14359.6           7152.7    2.008x
(128000, 8192, 5120, 4)  MoEScalingType.MXFP8          44191.3              28635.1    1.543x                        14257.3           6545.41   2.178x
(128000, 8192, 5120, 8)  MoEScalingType.MXFP8          42325.1              27963.2    1.514x                        13444.6           6637.5    2.026x
(16384, 1536, 5120, 1)   MoEScalingType.MXFP8            836.64               887.808  0.942x                          291.968          295.968  0.986x
(16384, 1536, 5120, 2)   MoEScalingType.MXFP8            849.408              879.68   0.966x                          246.832          279.424  0.883x
(16384, 1536, 5120, 4)   MoEScalingType.MXFP8            861.152              938.832  0.917x                          254.784          291.872  0.873x
(16384, 1536, 5120, 8)   MoEScalingType.MXFP8            902.112             1016.83   0.887x                          230.464          310.304  0.743x
(128000, 1536, 5120, 1)  MoEScalingType.MXFP8          10010.1               6709.92   1.492x                         2259.06          2226.14   1.015x
(128000, 1536, 5120, 2)  MoEScalingType.MXFP8           8102.96              6354.88   1.275x                         2340.88          2000.78   1.17x
(128000, 1536, 5120, 4)  MoEScalingType.MXFP8           9914.32              6446.08   1.538x                         1855.52          2029.6    0.914x
(128000, 1536, 5120, 8)  MoEScalingType.MXFP8           8258.4               5998.5    1.377x                         2272.26          1833.06   1.24x
(16384, 2048, 7168, 1)   MoEScalingType.MXFP8           1664.96              1499.82   1.11x                           417.024          449.632  0.927x
(16384, 2048, 7168, 2)   MoEScalingType.MXFP8           1617.92              1450.43   1.115x                          490.56           447.328  1.097x
(16384, 2048, 7168, 4)   MoEScalingType.MXFP8           1524.78              1492.82   1.021x                          503.744          440.224  1.144x
(16384, 2048, 7168, 8)   MoEScalingType.MXFP8           1499.94              1610.56   0.931x                          402.432          529.376  0.76x
(128000, 2048, 7168, 1)  MoEScalingType.MXFP8          17734.1              10969.1    1.617x                         4615.28          3345.44   1.38x
(128000, 2048, 7168, 2)  MoEScalingType.MXFP8          15514                10832.9    1.432x                         4462.58          3158.05   1.413x
(128000, 2048, 7168, 4)  MoEScalingType.MXFP8          12506.1              10927.7    1.144x                         4913.15          3071.5    1.6x
(128000, 2048, 7168, 8)  MoEScalingType.MXFP8          14571.5               9601.42   1.518x                         4506.74          2990.61   1.507x

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3508

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 53ea9fc with merge base 1f9bfd7 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 18, 2025
@danielvegamyhre danielvegamyhre added mx topic: new feature Use this tag if this PR adds a new feature moe and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Dec 18, 2025
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 18, 2025
out_dtype: Optional[torch.dtype] = torch.bfloat16,
emulated: bool = False,
use_triton_for_dim0_cast: bool = False,
wgrad_with_hp: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already use naming with_gw_hp, do you want to match it?

ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants