Tags: aris134/TransformerEngine
Tags
Release v2.2 cherrypicks and bugfixes megatron lm (ROCm#362) * Ensure weight transpose is valid for FP8 training (#1596) (ROCm#276) * Update usage of weightmat before saving for backward * Added keep_fp8_weight_transpose_cache checks while updating transpose in fwd pass (ROCm#298) * Added keep_fp8_weight_transpose_cache checks while updating transpose * Added unittest for the fix * Added comment for the unit test * Fixed comment * Reverted test for single iteration, added assert statements to check for transpose cache, Modified docstring * Fixed test_numerics spacing * Added HIP Guards * Addressed PR Comments, and moved assertion statements under fp8 check * Reverting assertion to fix the dev ticket * Removed spacing --------- Co-authored-by: Sudharshan Govindan <sugovind@amd.com> * Bug fix for get_fp8_metas * Added keep_fp8_transpose_cache fix for base.py * added _fp8_metas check for None * Added comment --------- Co-authored-by: Sudharshan Govindan <sugovind@amd.com>
[CI] deprecate praxis installation and tests - Removed praxis installation and related test setup from `ci/jax.sh` - Installed `flax>=0.7.1`, with typing_extensions>=4.12.2
[CI] deprecate praxis installation and tests - Removed praxis installation and related test setup from `ci/jax.sh` - Installed `flax>=0.7.1`, with typing_extensions>=4.12.2