Tags: j316chuck/TransformerEngine
Tags
Replace deprecated sharding API in JAX test (NVIDIA#332) Replace deprecated sharding API Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Fix layer_norm ONNX export (NVIDIA#293) * Fix ONNX export of layer_norm ONNX has a spec bug: ConstantOfShape supports all dtypes except for BF16. To WAR we use dtype FP32 and then cast to BF16. Will also issue a PR to the ONNX sig committee to change the spec in opset 20. Signed-off-by: Neta Zmora <nzmora@nvidia.com> * fix lint Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Neta Zmora <nzmora@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Jax bug fixes for the dot product attention (NVIDIA#236) * Unfused scale+softmax if bias is present Signed-off-by: Reese Wang <rewang@nvidia.com> * WAR a causal masking + no_bias bug and add the unittests Signed-off-by: Reese Wang <rewang@nvidia.com> * Fix the optional args (bias) sharding Signed-off-by: Reese Wang <rewang@nvidia.com> * Disable fused attn in JAX by default, enable it with NVTE_USE_FUSED_ATTN Signed-off-by: Reese Wang <rewang@nvidia.com> * Add thread local for the plan cache Signed-off-by: Reese Wang <rewang@nvidia.com> * Rename dbeta to dbias for the readability Signed-off-by: Reese Wang <rewang@nvidia.com> * Add scaled softmax with dropout test cases Signed-off-by: Reese Wang <rewang@nvidia.com> * Updated NVTE_FUSED_ATTN variable name Signed-off-by: Reese Wang <rewang@nvidia.com> --------- Signed-off-by: Reese Wang <rewang@nvidia.com>
Shriya/tp overlap patch (NVIDIA#205) userbuffer pushsend/recv fix with atomicAdd_system Signed-off-by: Sangkug Lym <slym@nvidia.com> Co-authored-by: Sangkug Lym <slym@nvidia.com>
Re-add support for PyTorch version 1.x (NVIDIA#180) Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Re-add support for PyTorch version 1.x (NVIDIA#180) Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Update installation instruction for JAX and add some dependencies. (N… …VIDIA#117) * Update installation instructio for JAX and add some depenencies. Signed-off-by: Frederic Bastien <fbastien@nvidia.com> * Bring back support for none pip installed pybind11. Signed-off-by: Frederic Bastien <fbastien@nvidia.com> * Apply suggestions from code review Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Frédéric Bastien <frederic.bastien@gmail.com> * Changes following review. Signed-off-by: Frederic Bastien <fbastien@nvidia.com> * Change order to make it more clear. Signed-off-by: Frederic Bastien <fbastien@nvidia.com> * Add other reviers suggestion. Signed-off-by: Frederic Bastien <fbastien@nvidia.com> * pybind11 is needed for all FW. Signed-off-by: Frederic Bastien <fbastien@nvidia.com> * Add flax as a dep Signed-off-by: Frederic Bastien <fbastien@nvidia.com> * Update README.rst Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Frédéric Bastien <frederic.bastien@gmail.com> --------- Signed-off-by: Frederic Bastien <fbastien@nvidia.com> Signed-off-by: Frédéric Bastien <frederic.bastien@gmail.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
PreviousNext