Skip to content

Commit 0decda6

Browse files
authored
Use Torch ExportedModule to import initial MLIR module (#416)
Use a `torch.export.ExportedProgram` to generate the initial MLIR module. This requires us to create an `ExportedProgram` from the initial `GraphModule`. Benefits: - We can use the torch-mlir's official entrypoint - This handles in-place ops for us - We can run decompositions and **keep** location data - This location data will stick around throughout the compile process Issues: - `aten.clamp` is decomposed by torch-mlir to `maximum(minimum(input, max), min)`. `ttnn.maximum` requires that the operand which needs to be broadcasted is on the RHS. Currently, in tt-mlir the `PartiallyBroadcastable` op trait only enforces that the broadcasted operand is on the LHS - tt-torch issue: #431 - tt-mlir issue: tenstorrent/tt-mlir#2458 - Graph parameters are inlined as constants in the graph. To have the `FxImporter` treat them as graph inputs we need to edit the `ExportedModule`s `ExportedGraphSignature` and force all "parameter" types to "user inputs" - This is a hack as the `ExportedGraphSignature` is meant to be a private member of `ExportedProgram` - Ideally we can configure the `FxImporter` to _not_ inline the parameters if we pass a flag of some sort. Perhaps a future contribution to torch-mlir. Other Info: - We need to upgrade to PyTorch 2.6.0 as it contains crucial changes which allow us to use custom decompositions (necessary to support interpolation) - AdaptiveAvgPool2d is lowered AvgPool2d and eventually to `stablehlo.reduce_window **even in the case where the op is equivalent to a global average**. Since we do not have support for lowering a sum_pool in `StablehloToTTIRPatterns.cpp` (sum because the division is afterward), I've temporarily added a custom decomposition of `aten.avg_pool2d` which will convert to a mean over the spatial dimensions when the `avg_pool2d` is equivalent to it. - `aten.split` is no longer lowered to a series of `narrow` ops. Instead it is now lowered to a series of `as_strided` ops. - `narrow` is lowered to `slice`, which can be lowered to `stablehlo.slice`. `as_strided` cannot be lowered from Torch Backend IR to Stablehlo. I've temporarily added back the old decomposition from PyTorch 2.5.0 which uses narrow as a custom decomposition. - I've made a PR which adds a lowering of `AtenAsStridedOp` to `stablehlo::SliceOp` in our fork of torch-mlir: tenstorrent/llvm-torch-mlir#4 - The tracer which generates the `GraphModule` which is passed to `backend` does not account for control flow, I believe in PyTorch 2.5.0 a graph break would be triggered during `.generate` methods in `transformers` LLMs. It does not anymore and so `.generate` will run until the max length is reached. - **this means that the entire generation becomes one program** - Once the first EOS token is generated, the rest of the length is filled with padding. We cannot compare the golden output to the result from the `GraphModule` as the output shapes are different. - Since the output of `.generate` graphs are integers PCC/atol verification is not quite useful but does return `True` when the outputs are _identical_ - The tokenizer can decode the outputs and strip padding. - I've added a flag to `ModelTester` that informs the `ModelTester` it is testing a `.generate` call. It will decode the output tokens and we compare the resulting strings. - PyTorch has an experimental `torch.cond` which they seem to intend to use to trace data-dependent control-flow. There's a note in the `transformers` source that says they intend to use it when it is no longer experimental - When the graph is compiled, the user inputs are placed **at the end** of the arguments passed to the program rather than the front. That is graph constants first, then inputs. - I needed to implement an `FxImporter` hook for importing literals to the graph. By default it will make all non-scalars `DenseElementsResourceAttr`s, however, this causes the process to hang upon cleanup whether the test fails or not. So the hook just uses `DenseElementsAttr` for all literals. - Someone has mentioned this problem in an IREE issue as well: iree-org/iree#20102 - They've traced it down to this PR in llvm that adds a GIL acquire when destroying the `DenseElementsResourceAttr`: llvm/llvm-project#124832
1 parent c79248a commit 0decda6

30 files changed

+540
-316
lines changed

docs/src/controlling.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ You can use the following environment variables to override default behaviour:
99
| TT_TORCH_VERIFY_INTERMEDIATES | Sets whether to verify runtime intermediates during execution. | False |
1010
| TT_TORCH_CONSTEVAL | Enables evaluation of constant expressions (consteval) in the Torch FX graph prior to compilation. | False |
1111
| TT_TORCH_CONSTEVAL_PARAMETERS | Extends consteval to include parameters (e.g., model weights) as well as embedded constants. | False |
12-
| TT_TORCH_EMBEDDEDD_CONSTANTS | Remove embedded constants from the Torch FX graph and convert them to constant inputs | False |
12+
| TT_TORCH_INLINE_PARAMETERS | Inlines parameters in the MLIR module (and thus flatbuffer executable) rather than requiring them as inputs. NOTE: The maximum size of a flatbuffer is 2GB so this will cause compilation to fail for sufficiently large models | False |
1313
| TT_TORCH_IR_LOG_LEVEL | Enables printing MLIR from Torch to TTNN. It supports two modes; `INFO` and `DEBUG`. `INFO` prints MLIR for all conversions steps (Torch, StableHLO, TTIR and TTNN MLIR graphs). `DEBUG` prints intermediate MLIR for all passes (IR dump before and after each pass) additionally. Be warned, `DEBUG` IR printing forces single core compile, so it is much slower. | Disable |
1414

1515
### Controlling Compiler Behaviour Programatically

env/activate

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ else
3434
cd $TT_TORCH_HOME/third_party
3535
git clone https://github.com/pytorch/vision.git
3636
cd vision
37-
git checkout v0.20.0
37+
git checkout v0.21.0
3838
pip uninstall -y torchvision
3939
TORCHVISION_USE_VIDEO_CODEC=0 TORCHVISION_USE_FFMPEG=0 CC=clang CXX=clang++ _GLIBCXX_USE_CXX11_ABI=1 USE_CUDA=OFF python setup.py bdist_wheel
4040

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.5.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
1+
torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.6.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
22
black
33
mdutils
44
ninja

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def run(self):
6565
},
6666
zip_safe=False,
6767
install_requires=[
68-
"torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.5.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl",
68+
"torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.6.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl",
6969
"numpy",
7070
],
7171
)

tests/models/Qwen/test_qwen2_casual_lm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ def test_qwen2_casual_lm(record_property, model_name, mode, op_by_op):
5757
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
5858

5959
tester = ThisTester(
60-
model_name, mode, compiler_config=cc, record_property_handle=record_property
60+
model_name,
61+
mode,
62+
compiler_config=cc,
63+
record_property_handle=record_property,
64+
is_token_output=True,
6165
)
6266
results = tester.test_model()
6367

tests/models/RMBG/test_RMBG.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _load_inputs(self):
3939
"mode",
4040
["train", "eval"],
4141
)
42-
@pytest.mark.xfail(reason="Fails due pt2 compile issue, graph is traced")
42+
@pytest.mark.skip(reason="Python bus error at the end of torch op-by-op flow")
4343
@pytest.mark.parametrize(
4444
"op_by_op",
4545
[OpByOpBackend.STABLEHLO, OpByOpBackend.TORCH, None],

tests/models/beit/test_beit_image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_beit_image_classification(record_property, model_name, mode, op_by_op):
6060
if op_by_op == OpByOpBackend.STABLEHLO:
6161
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
6262

63-
required_atol = 0.032 if model_name == "microsoft/beit-base-patch16-224" else 0.05
63+
required_atol = 0.032 if model_name == "microsoft/beit-base-patch16-224" else 0.065
6464
tester = ThisTester(
6565
model_name,
6666
mode,

tests/models/codegen/test_codegen.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ def test_codegen(record_property, mode, op_by_op):
4646
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
4747

4848
tester = ThisTester(
49-
model_name, mode, compiler_config=cc, record_property_handle=record_property
49+
model_name,
50+
mode,
51+
compiler_config=cc,
52+
record_property_handle=record_property,
53+
is_transformers_generation=True,
5054
)
5155
results = tester.test_model()
5256

tests/models/deit/test_deit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_deit(record_property, model_name, mode, op_by_op):
6767
tester = ThisTester(
6868
model_name,
6969
mode,
70-
relative_atol=0.01,
70+
relative_atol=0.015,
7171
compiler_config=cc,
7272
record_property_handle=record_property,
7373
)

tests/models/falcon/test_falcon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_falcon(record_property, mode, op_by_op):
5050
tester = ThisTester(
5151
model_name,
5252
mode,
53-
relative_atol=0.013,
53+
relative_atol=0.015,
5454
compiler_config=cc,
5555
record_property_handle=record_property,
5656
)

tests/models/flan_t5/test_flan_t5.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_flan_t5(record_property, mode, op_by_op):
5555
record_property_handle=record_property,
5656
assert_pcc=False,
5757
assert_atol=False,
58+
is_token_output=True,
5859
)
5960
results = tester.test_model()
6061
if mode == "eval":

tests/models/gpt_neo/test_gpt_neo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_gpt_neo(record_property, mode, op_by_op):
6363
record_property_handle=record_property,
6464
assert_pcc=False,
6565
assert_atol=False,
66+
is_token_output=True,
6667
)
6768
results = tester.test_model()
6869
if mode == "eval":

tests/models/mamba/test_mamba.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ def test_mamba(record_property, model_name, mode, op_by_op):
6969
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
7070

7171
tester = ThisTester(
72-
model_name, mode, compiler_config=cc, record_property_handle=record_property
72+
model_name,
73+
mode,
74+
compiler_config=cc,
75+
record_property_handle=record_property,
76+
is_token_output=True,
7377
)
7478
results = tester.test_model()
7579

tests/models/mgp-str-base/test_mgp_str_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_mgp_str_base(record_property, mode, op_by_op):
5858
tester = ThisTester(
5959
model_name,
6060
mode,
61-
relative_atol=0.01,
61+
relative_atol=0.02,
6262
compiler_config=cc,
6363
record_property_handle=record_property,
6464
)

tests/models/musicgen_small/test_musicgen_small.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_musicgen_small(record_property, mode, op_by_op):
6363
assert_atol=False,
6464
assert_pcc=False,
6565
record_property_handle=record_property,
66+
is_token_output=True,
6667
)
6768
results = tester.test_model()
6869
tester.finalize()

tests/models/opt/test_opt.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ def test_opt(record_property, mode, op_by_op):
5252
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
5353

5454
tester = ThisTester(
55-
model_name, mode, compiler_config=cc, record_property_handle=record_property
55+
model_name,
56+
mode,
57+
compiler_config=cc,
58+
record_property_handle=record_property,
59+
is_token_output=True,
5660
)
5761
results = tester.test_model()
5862
if mode == "eval":

tests/models/speecht5_tts/test_speecht5_tts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ def test_speecht5_tts(record_property, mode, op_by_op):
6666
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
6767

6868
tester = ThisTester(
69-
model_name, mode, compiler_config=cc, record_property_handle=record_property
69+
model_name,
70+
mode,
71+
compiler_config=cc,
72+
record_property_handle=record_property,
73+
is_token_output=True,
7074
)
7175
tester.test_model()
7276
# if mode == "eval":

tests/models/t5/test_t5.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def test_t5(record_property, model_name, mode, op_by_op):
5252
record_property_handle=record_property,
5353
assert_pcc=False,
5454
assert_atol=False,
55+
is_token_output=True,
5556
)
5657
results = tester.test_model()
5758
if mode == "eval":

tests/models/whisper/test_whisper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ def test_whisper(record_property, mode, op_by_op):
6969
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
7070

7171
tester = ThisTester(
72-
model_name, mode, compiler_config=cc, record_property_handle=record_property
72+
model_name,
73+
mode,
74+
compiler_config=cc,
75+
record_property_handle=record_property,
76+
is_token_output=True,
7377
)
7478
tester.test_model()
7579
tester.finalize()

tests/torch/test_basic.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,21 +198,6 @@ def forward(self, x):
198198
verify_module(Basic(), input_shapes=[(32, 32)])
199199

200200

201-
def test_linear_with_bias_no_embedded_constants():
202-
class Basic(nn.Module):
203-
def __init__(self):
204-
super().__init__()
205-
self.linear_a = nn.Linear(32, 32)
206-
207-
def forward(self, x):
208-
x = self.linear_a(x)
209-
return x
210-
211-
cc = CompilerConfig()
212-
cc.remove_embedded_constants = True
213-
verify_module(Basic(), input_shapes=[(32, 32)], compiler_config=cc)
214-
215-
216201
@pytest.mark.parametrize(
217202
("input_type"),
218203
[

tests/torch/test_constant_fold.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def forward(self, x):
5050
inH = 5
5151
inW = 5
5252
inC = 1
53-
scale_factor = 3
5453

5554
input_shape = (1, inC, inH, inW)
5655
small = (

tests/torch/test_interpolation.py

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,41 @@
1111
import torch.nn.functional as F
1212

1313

14-
@pytest.mark.parametrize("inH", [50, 128, 224, 960])
1514
@pytest.mark.parametrize("inW", [50, 128, 224, 540])
1615
@pytest.mark.parametrize("scale_factor", [0.5, 2])
1716
@pytest.mark.parametrize("align_corners", [False, True])
17+
def test_linear_upsample(inW, scale_factor, align_corners):
18+
pytest.skip() # https://github.com/tenstorrent/tt-torch/issues/405
19+
20+
class Interpolate(nn.Module):
21+
def __init__(self):
22+
super().__init__()
23+
24+
def forward(self, x):
25+
return F.interpolate(
26+
x,
27+
scale_factor=scale_factor,
28+
mode="linear",
29+
align_corners=align_corners,
30+
)
31+
32+
input_shape = (1, 1, inW)
33+
small = torch.randn(input_shape, dtype=torch.bfloat16)
34+
35+
cc = CompilerConfig()
36+
cc.enable_consteval = True
37+
verify_module(
38+
Interpolate(),
39+
inputs=[small],
40+
compiler_config=cc,
41+
required_atol=0.07,
42+
)
43+
44+
45+
@pytest.mark.parametrize("inH", [128, 224, 960])
46+
@pytest.mark.parametrize("inW", [128, 224, 540])
47+
@pytest.mark.parametrize("scale_factor", [0.5, 2])
48+
@pytest.mark.parametrize("align_corners", [False, True])
1849
def test_bilinear_upsample(inH, inW, scale_factor, align_corners):
1950
pytest.skip() # https://github.com/tenstorrent/tt-torch/issues/405
2051

@@ -43,10 +74,72 @@ def forward(self, x):
4374
)
4475

4576

46-
@pytest.mark.parametrize("inH", [50, 128, 224, 960])
77+
@pytest.mark.parametrize("inZ", [4, 8])
78+
@pytest.mark.parametrize("inH", [224, 960])
79+
@pytest.mark.parametrize("inW", [224, 540])
80+
@pytest.mark.parametrize("scale_factor", [0.5, 2])
81+
@pytest.mark.parametrize("align_corners", [False, True])
82+
def test_trilinear_upsample(inZ, inH, inW, scale_factor, align_corners):
83+
pytest.skip() # https://github.com/tenstorrent/tt-torch/issues/405
84+
85+
class Interpolate(nn.Module):
86+
def __init__(self):
87+
super().__init__()
88+
89+
def forward(self, x):
90+
return F.interpolate(
91+
x,
92+
scale_factor=scale_factor,
93+
mode="trilinear",
94+
align_corners=align_corners,
95+
)
96+
97+
input_shape = (1, 1, inZ, inH, inW)
98+
small = torch.randn(input_shape, dtype=torch.bfloat16)
99+
100+
cc = CompilerConfig()
101+
cc.enable_consteval = True
102+
verify_module(
103+
Interpolate(),
104+
inputs=[small],
105+
compiler_config=cc,
106+
required_atol=0.08,
107+
)
108+
109+
47110
@pytest.mark.parametrize("inW", [50, 128, 224, 540])
48111
@pytest.mark.parametrize("scale_factor", [0.5, 2])
49-
def test_nearest_upsample(inH, inW, scale_factor):
112+
def test_nearest_upsample1d(inW, scale_factor):
113+
pytest.skip() # https://github.com/tenstorrent/tt-torch/issues/405
114+
115+
class Interpolate(nn.Module):
116+
def __init__(self):
117+
super().__init__()
118+
119+
def forward(self, x):
120+
return F.interpolate(
121+
x,
122+
scale_factor=scale_factor,
123+
mode="nearest",
124+
)
125+
126+
input_shape = (1, 1, inW)
127+
small = torch.randn(input_shape, dtype=torch.bfloat16)
128+
129+
cc = CompilerConfig()
130+
cc.enable_consteval = True
131+
verify_module(
132+
Interpolate(),
133+
inputs=[small],
134+
compiler_config=cc,
135+
required_atol=0.07,
136+
)
137+
138+
139+
@pytest.mark.parametrize("inH", [128, 224, 960])
140+
@pytest.mark.parametrize("inW", [128, 224, 540])
141+
@pytest.mark.parametrize("scale_factor", [0.5, 2])
142+
def test_nearest_upsample2d(inH, inW, scale_factor):
50143
pytest.skip() # https://github.com/tenstorrent/tt-torch/issues/405
51144

52145
class Interpolate(nn.Module):
@@ -57,11 +150,48 @@ def forward(self, x):
57150
return F.interpolate(
58151
x,
59152
scale_factor=scale_factor,
153+
mode="nearest",
60154
)
61155

62156
input_shape = (1, 1, inH, inW)
63157
small = torch.randn(input_shape, dtype=torch.bfloat16)
64158

65159
cc = CompilerConfig()
66160
cc.enable_consteval = True
67-
verify_module(Interpolate(), inputs=[small], compiler_config=cc, required_atol=0.02)
161+
verify_module(
162+
Interpolate(),
163+
inputs=[small],
164+
compiler_config=cc,
165+
required_atol=0.07,
166+
)
167+
168+
169+
@pytest.mark.parametrize("inZ", [4, 8])
170+
@pytest.mark.parametrize("inH", [224, 960])
171+
@pytest.mark.parametrize("inW", [224, 540])
172+
@pytest.mark.parametrize("scale_factor", [0.5, 2])
173+
def test_nearest_upsample3d(inZ, inH, inW, scale_factor):
174+
pytest.skip() # https://github.com/tenstorrent/tt-torch/issues/405
175+
176+
class Interpolate(nn.Module):
177+
def __init__(self):
178+
super().__init__()
179+
180+
def forward(self, x):
181+
return F.interpolate(
182+
x,
183+
scale_factor=scale_factor,
184+
mode="nearest",
185+
)
186+
187+
input_shape = (1, 1, inZ, inH, inW)
188+
small = torch.randn(input_shape, dtype=torch.bfloat16)
189+
190+
cc = CompilerConfig()
191+
cc.enable_consteval = True
192+
verify_module(
193+
Interpolate(),
194+
inputs=[small],
195+
compiler_config=cc,
196+
required_atol=0.08,
197+
)

0 commit comments

Comments
 (0)