Skip to content

Tags: j316chuck/TransformerEngine

Tags

v0.11

Toggle v0.11's commit message
Replace deprecated sharding API in JAX test (NVIDIA#332)

Replace deprecated sharding API

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

v0.10

Toggle v0.10's commit message
Fix layer_norm ONNX export (NVIDIA#293)

* Fix ONNX export of layer_norm

ONNX has a spec bug: ConstantOfShape supports all dtypes except for BF16.
To WAR we use dtype FP32 and then cast to BF16.

Will also issue a PR to the ONNX sig committee to change the spec in opset 20.

Signed-off-by: Neta Zmora <nzmora@nvidia.com>

* fix lint

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Neta Zmora <nzmora@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

v0.9

Toggle v0.9's commit message
Jax bug fixes for the dot product attention (NVIDIA#236)

* Unfused scale+softmax if bias is present

Signed-off-by: Reese Wang <rewang@nvidia.com>

* WAR a causal masking + no_bias bug and add the unittests

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Fix the optional args (bias) sharding

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Disable fused attn in JAX by default, enable it with NVTE_USE_FUSED_ATTN

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add thread local for the plan cache

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Rename dbeta to dbias for the readability

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add scaled softmax with dropout test cases

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Updated NVTE_FUSED_ATTN variable name

Signed-off-by: Reese Wang <rewang@nvidia.com>

---------

Signed-off-by: Reese Wang <rewang@nvidia.com>

v0.8.0.1

Toggle v0.8.0.1's commit message

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
Shriya/tp overlap patch (NVIDIA#205)

userbuffer pushsend/recv fix with atomicAdd_system

Signed-off-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: Sangkug Lym <slym@nvidia.com>

v0.8

Toggle v0.8's commit message
Re-add support for PyTorch version 1.x (NVIDIA#180)

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

v0.8rc

Toggle v0.8rc's commit message
Re-add support for PyTorch version 1.x (NVIDIA#180)

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

v0.7

Toggle v0.7's commit message
Update installation instruction for JAX and add some dependencies. (N…

…VIDIA#117)

* Update installation instructio for JAX and add some depenencies.

Signed-off-by: Frederic Bastien <fbastien@nvidia.com>

* Bring back support for none pip installed pybind11.

Signed-off-by: Frederic Bastien <fbastien@nvidia.com>

* Apply suggestions from code review

Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Frédéric Bastien <frederic.bastien@gmail.com>

* Changes following review.

Signed-off-by: Frederic Bastien <fbastien@nvidia.com>

* Change order to make it more clear.

Signed-off-by: Frederic Bastien <fbastien@nvidia.com>

* Add other reviers suggestion.

Signed-off-by: Frederic Bastien <fbastien@nvidia.com>

* pybind11 is needed for all FW.

Signed-off-by: Frederic Bastien <fbastien@nvidia.com>

* Add flax as a dep

Signed-off-by: Frederic Bastien <fbastien@nvidia.com>

* Update README.rst

Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Frédéric Bastien <frederic.bastien@gmail.com>

---------

Signed-off-by: Frederic Bastien <fbastien@nvidia.com>
Signed-off-by: Frédéric Bastien <frederic.bastien@gmail.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

v0.6

Toggle v0.6's commit message
fix bug in non-FP8 nvfuser path (NVIDIA#81)

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

v0.5

Toggle v0.5's commit message
Address steady memory increase and bloated checkpoints (NVIDIA#63)

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

v0.4

Toggle v0.4's commit message
Change version to 0.4