Skip to content

⭐️ TTNN Compiler for PyTorch 2 ⭐️ Enables running PyTorch models on Tenstorrent hardware using eager or compile path

License

Notifications You must be signed in to change notification settings

pcapink/pytorch2.0_ttnn

 
 

Repository files navigation

Ask DeepWiki

PyTorch 2.0 TTNN Compiler

The PyTorch 2.0 TT-NN Compiler enables seamless execution of PyTorch models on Tenstorrent AI accelerators. By leveraging the TT-NN backend, you can achieve significant performance improvements while maintaining PyTorch's familiar API.

🚀 Quick Start

Installation

Install from the repo:

pip install git+https://bitbucket.org/tenstorrent/pytorch2.0_ttnn

or as an editable package from source:

git clone https://github.com/tenstorrent/pytorch2.0_ttnn.git
cd pytorch2.0_ttnn
pip install -e .

✨ Basic Usage

Option 1: Eager Mode: get your model running by switching to a TT device

import torch
import torch_ttnn

model = YourModel()

device = ttnn.open_device(device_id=0)
model.to(torch_ttnn.ttnn_device_as_torch_device(device))

output = model(input_data)

Option 2: Compilation Mode (Recommended): get more perf with a JIT compiler

import torch
import torch_ttnn

model = YourModel()

device = ttnn.open_mesh_device(ttnn.MeshShape(1, 2))  # 1x2 device grid
option = torch_ttnn.TorchTtnnOption(device=device, data_parallel=2)

model = torch.compile(model, backend=torch_ttnn.backend, options=option)
output = model(input_data)

📊 Model Support

We've extensively tested the compiler across a diverse range of model architectures. Here's a summary of our validation results:

Model Status Batch Compiled First Run (ms) Original Throughput (Inferences Per Second) Compiled Throughput (Inferences Per Second) Accuracy (%) Torch Ops Before (Unique Ops) Torch Ops Remain (Unique Ops) To/From Device Ops
Autoencoder (linear) 1 383.79 0.318679 537.6344086021505 100.0 22 (3) 0 (0) 0
BERT 8 48717.3 0.0107726 40.17476020690002 99.69 1465 (22) 0 (0) 0
DPR 1 18863.61 0.297168 69.06077348066299 99.38 720 (22) 0 (0) 1
HardNet 1 48348.55 0.197276 19.681165124975397 98.45 245 (10) 0 (0) 124
Mnist 1 6792.72 27.8552 348.4320557491289 99.42 14 (8) 0 (0) 1
MobileNetV2 1 109708.43 1.0701 32.414910858995135 99.09 154 (9) 0 (0) 0
OpenPose V2 1 19973.86 0.332744 35.58718861209964 91.49 155 (7) 0 (0) 6
Perceiver IO 1 56393.67 0.0200014 19.149751053236308 99.95 1531 (20) 0 (0) 1
ResNet18 1 18278.14 0.501771 77.33952049497293 99.15 70 (9) 0 (0) 1
ResNet50 4 81666.36 0.766981 46.99800258489014 98.61 176 (9) 0 (0) 1
RoBERTa 1 31621.3 0.0723718 21.96836555360281 28.56 719 (21) 0 (0) 3
U-Net 1 29665.08 0.0155244 57.17552887364209 100.0 68 (6) 0 (0) 12
Unet-brain 1 28840.79 0.0167 64.93506493506493 N/A 68 (6) 0 (0) 12
Unet-carvana 1 82362.64 0.0117354 30.902348578491967 99.69 67 (5) 0 (0) 12
albert/albert-base-v2 1 31540.21 0.525754 39.55696202531645 98.82 791 (21) 0 (0) 3
albert/albert-base-v2-classification 1 10191.74 0.716353 42.881646655231556 99.97 779 (21) 0 (0) 2
albert/albert-large-v2 1 24157.78 0.301436 19.778481012658226 98.95 1547 (21) 0 (0) 3
albert/albert-xlarge-v2 1 44847.71 0.105011 13.017443374121324 97.36 1547 (21) 0 (0) 3
densenet121 1 189079.76 0.313425 13.083867591259976 99.74 432 (10) 0 (0) 597
densenet161 1 198853.88 0.118535 9.37207122774133 99.49 572 (10) 0 (0) 1147
densenet169 1 68503.18 0.259137 9.194556822361161 99.58 600 (10) 0 (0) 1241
densenet201 1 293066.47 0.212029 7.531254707034192 99.39 712 (10) 0 (0) 1905
distilbert-base-uncased 1 27666.99 0.647647 85.1063829787234 72.37 361 (16) 0 (0) 1
dla34.in1k 1 99978.6 0.260171 40.84967320261438 99.48 135 (9) 0 (0) 23
ese_vovnet19b_dw.ra_in1k 1 73426.48 0.537458 40.95004095004095 99.44 111 (12) 0 (0) 19
ghostnet_100.in1k 1 175051.08 1.35422 16.857720836142953 99.6 515 (14) 0 (0) 64
mobilenet_v2 1 83486.0 1.12163 31.908104658583277 99.09 154 (9) 0 (0) 0
mobilenet_v3_large 1 156471.68 1.29134 28.465698832906345 99.15 188 (11) 0 (0) 0
mobilenet_v3_small 1 104853.87 2.04532 33.692722371967655 99.09 158 (11) 0 (0) 0
mobilenetv1_100.ra4_e3600_r224_in1k 1 80375.98 0.655566 57.50431282346176 96.04 85 (7) 0 (0) 0
regnet_x_16gf 1 66710.83 0.0558609 16.168148746968473 99.56 235 (8) 0 (0) 0
regnet_x_1_6gf 1 58812.93 0.527844 31.625553447185325 99.47 195 (8) 0 (0) 0
regnet_x_32gf 1 81282.16 0.0286101 7.94533608771651 99.27 245 (8) 0 (0) 0
regnet_x_3_2gf 1 45123.3 0.254426 22.482014388489212 99.5 265 (8) 0 (0) 0
regnet_x_400mf 1 57789.48 1.03357 25.04382669671926 99.66 235 (8) 0 (0) 0
regnet_x_800mf 1 80269.68 0.847364 34.1646737273659 99.44 175 (8) 0 (0) 0
regnet_x_8gf 1 92651.05 0.107167 18.21825469120058 98.99 245 (8) 0 (0) 0
regnet_y_16gf 1 72168.97 0.0639503 12.087513598452798 99.71 303 (10) 0 (0) 0
regnet_y_1_6gf 1 60978.04 0.505211 15.708451146716936 99.65 447 (10) 0 (0) 0
regnet_y_32gf 1 167374.18 0.0254866 7.950389569088886 99.72 335 (10) 0 (0) 0
regnet_y_3_2gf 1 130392.55 0.284977 19.5160031225605 99.82 351 (10) 0 (0) 0
regnet_y_400mf 1 116188.94 1.14254 26.56748140276302 99.64 271 (10) 0 (0) 0
regnet_y_800mf 1 73770.25 0.815727 29.824038174768862 99.59 239 (10) 0 (0) 0
regnet_y_8gf 1 89672.19 0.116776 18.52537977028529 99.82 287 (10) 0 (0) 0
resnet101 1 17604.23 0.123665 16.16553507921112 99.28 346 (9) 0 (0) 1
resnet152 1 26172.89 0.087908 11.715089034676664 99.14 516 (9) 0 (0) 1
resnet18 1 29809.15 0.426687 72.09805335255949 99.63 70 (9) 0 (0) 1
resnet34 1 6056.91 0.24183 44.34589800443459 98.9 126 (9) 0 (0) 1
resnet50 1 86577.72 0.216245 30.33980582524272 98.61 176 (9) 0 (0) 1
resnext101_32x8d 1 52154.16 0.0640595 8.844078889183692 99.57 346 (9) 0 (0) 1
resnext101_64x4d 1 33489.02 0.0675123 8.88888888888889 99.65 346 (9) 0 (0) 1
resnext50_32x4d 1 83109.06 0.222903 33.090668431502316 99.44 176 (9) 0 (0) 1
textattack/albert-base-v2-imdb 1 36573.89 0.634333 42.14075010535188 100.0 782 (22) 0 (0) 2
tf_efficientnet_lite0.in1k 1 138936.43 0.759653 25.687130747495505 99.3 149 (9) 0 (0) 5
tf_efficientnet_lite1.in1k 1 78078.43 0.758121 19.462826002335536 99.56 194 (9) 0 (0) 5
tf_efficientnet_lite2.in1k 1 121500.84 0.597083 13.053126223730585 99.21 194 (9) 0 (0) 5
twmkn9/albert-base-v2-squad2 1 22967.23 0.560758 44.78280340349306 99.86 783 (23) 0 (0) 2
vgg11 1 56549.18 0.0754884 95.41984732824427 99.65 33 (8) 0 (0) 5
vgg11_bn 1 6534.04 0.074688 87.10801393728222 98.93 41 (9) 0 (0) 5
vgg13 1 6427.48 0.0491899 81.16883116883116 99.35 37 (8) 0 (0) 5
vgg13_bn 1 66303.52 0.0496638 71.94244604316546 97.31 47 (9) 0 (0) 5
vgg16 1 2346.11 0.0368749 70.0770847932726 99.44 43 (8) 0 (0) 5
vgg16_bn 1 3806.76 0.0403229 62.853551225644246 98.37 56 (9) 0 (0) 5
vgg19 1 60381.36 0.0299954 61.61429451632779 99.24 49 (8) 0 (0) 5
vgg19_bn 1 7889.57 0.0294751 56.40157924421884 96.97 65 (9) 0 (0) 5
wide_resnet101_2 1 19504.24 0.044722 17.37619461337967 99.2 346 (9) 0 (0) 1
wide_resnet50_2 1 83259.66 0.0788305 32.1646831778707 98.8 176 (9) 0 (0) 1
xception71.tf_in1k 1 131613.9 0.0553388 3.919416790781532 99.21 393 (9) 0 (0) 0
Autoencoder (conv) 🚧 1 4425.42 0.971931 294.11764705882354 100.0 9 (3) 1 (1) 1
Autoencoder (conv)-train 🚧 1 16578.7 0.427418 149.47683109118086 100.0 24 (7) 11 (4) 0
Autoencoder (linear)-train 🚧 1 14231.39 0.572908 58.89281507656065 100.0 104 (8) 14 (2) 0
Bloom 🚧 1 45113.92 0.0331907 1.337613697164259 98.86 1405 (27) 2 (2) 0
CLIP 🚧 1 59147.29 0.165331 5.279552293965471 99.56 1397 (30) 7 (6) 2
CLIP-train 🚧 1 85043.23 0.0446743 0.6651589729945456 100.0 3944 (44) 265 (16) 5
DETR 🚧 1 154854.73 0.0101889 0.19921231450843363 94.02 1663 (42) 9 (6) 3
DINOv2 🚧 1 32461.38 0.0523196 15.211439002129602 98.99 928 (25) 16 (1) 2
GLPN-KITTI 🚧 1 282224.02 0.00821779 0.016979024113610047 99.77 2959 (26) 22 (2) 6
GPT-2 🚧 1 34057.19 0.375322 31.03662321539417 99.98 745 (29) 2 (2) 2
GaussianSplatting 🚧 1 45029.49 0.133893 0.0605810938522306 49.65 2193 (34) 179 (4) 20
GaussianSplatting-train 🚧 1 73816.52 0.0630212 0.01538980530049518 43.79 7443 (54) 1012 (15) 36
HardNet-train 🚧 1 169358.37 0.0745995 0.10767890581003071 100.0 867 (21) 412 (9) 120
Mnist-train 🚧 1 22987.76 0.3579 29.682398337785695 100.0 46 (15) 10 (6) 0
MobileNetSSD 🚧 1 237858.91 1.61264 0.4763787592238837 43.63 522 (31) 7 (4) 32
OpenPose V2-train 🚧 1 77986.54 0.0958387 0.1270118679889449 100.0 523 (14) 246 (7) 6
ResNet18-train 🚧 1 51117.14 0.169233 0.2388350581205114 100.0 241 (19) 121 (9) 0
ResNet50-train 🚧 1 91707.43 0.0614435 0.08364533041578422 100.0 616 (19) 318 (9) 0
SegFormer 🚧 1 28179.23 0.0252929 4.014774369680424 99.86 676 (22) 16 (1) 4
SegFormer-train 🚧 1 181331.16 0.0123532 0.03109938180648845 100.0 1794 (36) 156 (12) 4
U-Net-train 🚧 1 107817.25 0.00911335 0.021714343040151338 100.0 236 (15) 122 (8) 8
Unet-brain-train 🚧 1 113110.1 0.0094704 0.018023785629419434 100.0 236 (15) 122 (8) 8
Unet-carvana-train 🚧 1 137481.34 0.00584688 0.01143127443506356 100.0 232 (13) 121 (7) 8
YOLOS 🚧 1 44348.45 0.0678037 3.605422555523507 98.46 952 (27) 17 (2) 2
YOLOv3 🚧 1 78405.22 0.0049119 17.692852087756545 98.63 250 (7) 2 (1) 4
albert/albert-xxlarge-v2 🚧 1 19200.87 0.0519827 7.186489399928135 98.54 791 (21) 24 (1) 3
dla34.in1k-train 🚧 1 81457.49 0.100185 0.14373722535409664 100.0 469 (18) 230 (8) 17
ese_vovnet19b_dw.ra_in1k-train 🚧 1 79155.83 0.211436 0.30116037090911285 100.0 383 (25) 176 (10) 16
facebook/deit-base-patch16-224 🚧 1 26811.29 0.0660756 8.956560680698612 98.34 685 (17) 1 (1) 1
facebook/deit-base-patch16-224-train 🚧 1 35467.71 0.0135302 0.898319244693179 100.0 1854 (27) 127 (8) 2
ghostnet_100.in1k-train 🚧 1 209430.97 0.639758 0.5120458793107863 100.0 1469 (33) 562 (12) 64
ghostnetv2_100.in1k 🚧 1 253978.88 0.858752 8.872327211427558 99.65 683 (18) 24 (2) 68
ghostnetv2_100.in1k-train 🚧 1 88191.54 0.518111 0.24007298218658474 100.0 2001 (39) 852 (17) 68
googlenet 🚧 1 96560.6 0.516062 22.482014388489212 99.67 214 (15) 1 (1) 51
hrnet_w18.ms_aug_in1k 🚧 1 148429.34 0.183819 4.358437935843794 99.65 1209 (11) 31 (1) 0
hrnet_w18.ms_aug_in1k-train 🚧 1 188376.68 0.0712728 0.09518651321331584 100.0 3998 (21) 1973 (9) 0
inception_v4.tf_in1k 🚧 1 136626.27 0.0844907 6.105006105006105 99.09 495 (11) 14 (1) 84
inception_v4.tf_in1k-train 🚧 1 169217.47 0.0263513 0.03717647636152339 100.0 1851 (24) 932 (11) 80
mixer_b16_224.goog_in21k 🚧 1 17380.49 0.0860721 9.69838037047813 3.65 356 (11) 1 (1) 0
mixer_b16_224.goog_in21k-train 🚧 1 35730.01 0.0179507 0.8610075510362225 100.0 959 (18) 101 (6) 0
mobilenetv1_100.ra4_e3600_r224_in1k-train 🚧 1 70339.28 0.379328 0.36029933668892117 100.0 258 (16) 164 (7) 0
regnet_y_128gf 🚧 1 351571.29 0.00170038 0.016221957912454635 98.91 447 (10) 3 (1) 0
ssd300_vgg16 🚧 1 179326.71 0.255575 0.6767685654536718 N/A 332 (30) 8 (5) 37
ssdlite320_mobilenet_v3_large 🚧 1 195197.35 1.08724 0.47629480743400937 41.24 522 (31) 7 (4) 32
swin_b 🚧 1 96798.11 0.0690809 3.677822728944465 99.54 2492 (32) 110 (2) 479
swin_s 🚧 1 28686.82 0.119883 3.9404208369453855 99.68 2492 (32) 110 (2) 479
swin_t 🚧 1 134139.18 0.220782 7.700007700007699 99.76 1238 (32) 50 (2) 227
swin_v2_b 🚧 1 103813.77 0.0432882 2.7898672023211697 28.4 3140 (40) 158 (3) 473
swin_v2_s 🚧 1 35019.83 0.074537 3.1170126550713797 40.7 3140 (40) 158 (3) 473
swin_v2_t 🚧 1 137983.51 0.138262 5.311238580837052 51.81 1562 (40) 74 (3) 221
tf_efficientnet_lite0.in1k-train 🚧 1 129277.03 0.391011 0.13884299356601568 100.0 452 (17) 285 (8) 5
tf_efficientnet_lite1.in1k-train 🚧 1 87596.84 0.299663 0.09924435349250806 100.0 587 (17) 370 (8) 5
tf_efficientnet_lite2.in1k-train 🚧 1 150155.19 0.199099 0.07742670787832238 100.0 587 (17) 370 (8) 5
tf_efficientnet_lite3.in1k 🚧 1 142790.91 0.369576 3.7056251389609427 99.15 221 (9) 5 (1) 5
tf_efficientnet_lite3.in1k-train 🚧 1 117077.71 0.123944 0.05340014759800796 100.0 668 (17) 426 (9) 5
tf_efficientnet_lite4.in1k 🚧 1 152262.84 0.197188 2.2350364310938264 99.21 275 (9) 6 (1) 5
tf_efficientnet_lite4.in1k-train 🚧 1 167983.96 0.0553924 0.018335587082578902 100.0 830 (17) 529 (9) 5
vit_b_16 🚧 1 34037.78 0.0648821 6.6212010858769785 99.52 552 (17) 1 (1) 1
vit_b_32 🚧 1 27295.18 0.152805 6.953135864274788 98.73 552 (17) 1 (1) 1
vit_h_14 🚧 1 699214.93 0.00131198 0.3794432808183833 98.14 1452 (17) 1 (1) 1
vit_l_16 🚧 1 69256.68 0.0150174 3.2023569347039422 99.73 1092 (17) 1 (1) 1
vit_l_32 🚧 1 30916.9 0.0462639 4.388274530454626 99.06 1092 (17) 1 (1) 1
xception71.tf_in1k-train 🚧 1 177748.06 0.017083 0.017548428397849546 100.0 1378 (18) 806 (7) 0
FLAN-T5 N/A N/A 0.294142 N/A N/A 20020 (38) N/A N/A
Falcon-7B N/A N/A 0.0116651 N/A N/A 2600 (27) N/A N/A
GPTNeo N/A N/A 0.0837811 N/A N/A 2733 (35) N/A N/A
Llama N/A N/A 0.00549438 N/A N/A 3690 (35) N/A N/A
OPT N/A N/A 0.0391356 N/A N/A 4003 (32) N/A N/A
Stable Diffusion V2 N/A N/A 0.000518239 N/A N/A 1870 (29) N/A N/A
ViLT N/A N/A 0.0632051 N/A N/A 766 (29) N/A N/A
Whisper N/A N/A 0.00418221 N/A N/A 4310 (21) N/A N/A
YOLOv5 N/A N/A 0.0537337 N/A N/A 236 (13) N/A N/A
codegen N/A N/A 0.139108 N/A N/A 9183 (37) N/A N/A
speecht5-tts N/A N/A 0.01951 N/A N/A 6942 (40) N/A N/A
t5-base N/A N/A 0.208124 N/A N/A 14681 (38) N/A N/A
t5-large N/A N/A 0.0863555 N/A N/A 22696 (38) N/A N/A
t5-small N/A N/A 0.399949 N/A N/A 6118 (38) N/A N/A

Explanation of Metrics

Model: Name of the model.
Status: Indicates whether the model is:

  • ✅ End-to-end on device: All PyTorch operations have been converted to TT-NN operations.
  • 🚧 Compiled: The converted model runs but some operations still fallback to PyTorch. This may be due to an unsupported operation or configuration.
  • ❌ Traced: The model does not run but its PyTorch operations are traced for future development. This may indicate a temporary incompatibility with a compiler pass.
    Batch: Batch size used for inference
    Compiled First Run (ms): Time until the first compiled run finishes (ms), including compilation time and warming caches.
    Original Throughput (Inferences Per Second): Execution throughput (in inferences per second) of the model before conversion.
    Compiled Throughput (Inferences Per Second): Execution throughput (in inferences per second) of the model after conversion, once caches are warm.
    Accuracy (%): Model accuracy on a predefined test dataset after conversion.
    Torch Ops Before (Unique Ops): The total number of operations used by the model in the original Torch implementation. The number in parentheses represents the total unique ops.
    Torch Ops Remain (Unique Ops): The total number of operations used after conversion to TT-NN. The number in parentheses represents the total unique ops.
    To/From Device Ops: The number of to/from_device operations (data transfer to/from the device).

Contributing

Whether you are new to Tenstorrent hardware or an experienced developer, there are many ways to contribute.

Getting Started

Start with our high level Contribution guide. You can find more information here:

We encourage contributions and offer 🤑 Bounties for some issues.

Development Environment

To get started with development, you'll need a Wormhole or Blackhole Tenstorrent accelerator card, which:

Install the development dependencies:

pip install -r requirements-dev.txt
pip install -e .

You can build the wheel file with

python -m build

Project Structure

  • torch_ttnn/: Main package directory containing the core implementation
  • tests/: Test files for the project including model suites. We use pytest as our testing framework.
  • tools/: Development and utility scripts
  • docs/: Project documentation and reports
  • demo/: Example code and usage demonstrations

Questions and Support

If you have questions or need help getting started, please:

  1. Review the existing documentation
  2. Ask PyTorch TT-NN DeepWiki or TT-Metal DeepWiki
  3. Ask on Discord
  4. Open an issue on GitHub

About

⭐️ TTNN Compiler for PyTorch 2 ⭐️ Enables running PyTorch models on Tenstorrent hardware using eager or compile path

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.8%
  • Other 0.2%