@@ -83,12 +83,15 @@ def setUp(self):
83
83
self .ort_version = onnxruntime .__version__
84
84
85
85
def test_simple_function (self ):
86
- def func (x ):
87
- # TODO(justinchuby): Replicate torch's type casting policy
88
- # in the exporter for type promotion support
89
- y = x + 1.0
90
- z = y .relu ()
91
- return (y , z )
86
+ class Foo (torch .nn .Module ):
87
+ def forward (self , x ):
88
+ # TODO(justinchuby): Replicate torch's type casting policy
89
+ # in the exporter for type promotion support
90
+ y = x + 1.0
91
+ z = y .relu ()
92
+ return (y , z )
93
+
94
+ func = Foo ()
92
95
93
96
tensor_x = torch .randn (1 , 1 , 2 , dtype = torch .float32 )
94
97
@@ -118,10 +121,13 @@ def test_func_with_args_and_tensor_kwargs(self):
118
121
# practice to set mutable default values.
119
122
# `DynamoOptimizeExporter` applies a workaround by binding args and kwargs to
120
123
# model signature and fill in the default values of unprovided optional arguments.
121
- def func (x , b = torch .tensor (1.0 )):
122
- y = x + b
123
- z = y .relu ()
124
- return (y , z )
124
+ class Foo (torch .nn .Module ):
125
+ def forward (self , x , b = torch .tensor (1.0 )):
126
+ y = x + b
127
+ z = y .relu ()
128
+ return (y , z )
129
+
130
+ func = Foo ()
125
131
126
132
tensor_x = torch .randn (1 , 2 , 3 , dtype = torch .float32 )
127
133
@@ -140,21 +146,24 @@ def func(x, b=torch.tensor(1.0)):
140
146
"sympy operation tests don't need dynamic shape"
141
147
)
142
148
def test_sympy_operatons_return_numeric (self ):
143
- def func (x , y ):
144
- # TODO: add boolean tests when SymBool is supported
145
- # to infer types
146
- return (
147
- torch .tensor ([operator .add (x .item (), y .item ())]),
148
- torch .tensor ([operator .sub (x .item (), y .item ())]),
149
- torch .tensor ([operator .mul (x .item (), y .item ())]),
150
- torch .tensor ([operator .truediv (x .item (), y .item ())]),
151
- torch .tensor ([operator .floordiv (x .item (), y .item ())]),
152
- torch .tensor ([operator .pow (x .item (), y .item ())]),
153
- torch .tensor ([operator .abs (x .item ())]),
154
- torch .tensor ([operator .neg (x .item ())]),
155
- torch .tensor ([math .ceil (x .item ())]),
156
- torch .tensor ([math .floor (x .item ())]),
157
- )
149
+ class Foo (torch .nn .Module ):
150
+ def forward (self , x , y ):
151
+ # TODO: add boolean tests when SymBool is supported
152
+ # to infer types
153
+ return (
154
+ torch .tensor ([operator .add (x .item (), y .item ())]),
155
+ torch .tensor ([operator .sub (x .item (), y .item ())]),
156
+ torch .tensor ([operator .mul (x .item (), y .item ())]),
157
+ torch .tensor ([operator .truediv (x .item (), y .item ())]),
158
+ torch .tensor ([operator .floordiv (x .item (), y .item ())]),
159
+ torch .tensor ([operator .pow (x .item (), y .item ())]),
160
+ torch .tensor ([operator .abs (x .item ())]),
161
+ torch .tensor ([operator .neg (x .item ())]),
162
+ torch .tensor ([math .ceil (x .item ())]),
163
+ torch .tensor ([math .floor (x .item ())]),
164
+ )
165
+
166
+ func = Foo ()
158
167
159
168
x = torch .randn (1 , dtype = torch .float32 )
160
169
y = torch .randn (1 , dtype = torch .float32 )
@@ -171,10 +180,13 @@ def func(x, y):
171
180
reason = "https://github.com/pytorch/pytorch/issues/99534" ,
172
181
)
173
182
def test_xfail_func_with_non_tensor_args (self ):
174
- def func (x , b = 1.0 ):
175
- y = x + b
176
- z = y .relu ()
177
- return (y , z )
183
+ class Foo (torch .nn .Module ):
184
+ def forward (self , x , b = 1.0 ):
185
+ y = x + b
186
+ z = y .relu ()
187
+ return (y , z )
188
+
189
+ func = Foo ()
178
190
179
191
tensor_x = torch .randn (1 , 1 , 2 , dtype = torch .float32 )
180
192
@@ -202,25 +214,29 @@ def func(x, b=1.0):
202
214
torch .testing .assert_close (ref_output , torch .tensor (ort_output ))
203
215
204
216
def test_func_with_nested_input_structure (self ):
205
- def func (
206
- x_dict : Dict [str , torch .Tensor ],
207
- y_tuple : Tuple [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]],
208
- z_list : List [List [torch .Tensor ]],
209
- ):
210
- if "a" in x_dict :
211
- x = x_dict ["a" ]
212
- elif "b" in x_dict :
213
- x = x_dict ["b" ]
214
- else :
215
- x = torch .randn (3 )
217
+ class Foo (torch .nn .Module ):
218
+ def forward (
219
+ self ,
220
+ x_dict : Dict [str , torch .Tensor ],
221
+ y_tuple : Tuple [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]],
222
+ z_list : List [List [torch .Tensor ]],
223
+ ):
224
+ if "a" in x_dict :
225
+ x = x_dict ["a" ]
226
+ elif "b" in x_dict :
227
+ x = x_dict ["b" ]
228
+ else :
229
+ x = torch .randn (3 )
216
230
217
- y1 , (y2 , y3 ) = y_tuple
231
+ y1 , (y2 , y3 ) = y_tuple
218
232
219
- z = x + y1 + y2 + y3
220
- for z_sub_list in z_list :
221
- z = z + torch .stack (z_sub_list ).sum ()
233
+ z = x + y1 + y2 + y3
234
+ for z_sub_list in z_list :
235
+ z = z + torch .stack (z_sub_list ).sum ()
222
236
223
- return z
237
+ return z
238
+
239
+ func = Foo ()
224
240
225
241
x_dict = {"a" : torch .randn (3 ), "c" : torch .randn (3 )}
226
242
y_tuple = (torch .randn (3 ), (torch .randn (3 ), torch .randn (3 )))
@@ -233,14 +249,17 @@ def func(
233
249
)
234
250
235
251
def test_func_with_nested_output_structure (self ):
236
- def func (x , y , z ):
237
- x = x + y
238
- y = y + z
239
- z = x + y
240
- out1 = (x , (y , z ))
241
- out2 = [[x , y ], [y , z ]]
242
- out3 = {"z" : z , "x" : x }
243
- return out1 , out2 , out3
252
+ class Foo (torch .nn .Module ):
253
+ def forward (self , x , y , z ):
254
+ x = x + y
255
+ y = y + z
256
+ z = x + y
257
+ out1 = (x , (y , z ))
258
+ out2 = [[x , y ], [y , z ]]
259
+ out3 = {"z" : z , "x" : x }
260
+ return out1 , out2 , out3
261
+
262
+ func = Foo ()
244
263
245
264
x = torch .randn (3 )
246
265
y = torch .randn (3 )
@@ -535,19 +554,22 @@ def forward(self, x):
535
554
536
555
@pytorch_test_common .skipIfNoCuda
537
556
def test__scaled_dot_product_flash_attention (self ):
538
- def func (x ):
539
- (
540
- output ,
541
- _ ,
542
- _ ,
543
- _ ,
544
- _ ,
545
- _ ,
546
- _ ,
547
- _ ,
548
- _ ,
549
- ) = torch .ops .aten ._scaled_dot_product_flash_attention (x , x , x )
550
- return output
557
+ class Foo (torch .nn .Module ):
558
+ def forward (self , x ):
559
+ (
560
+ output ,
561
+ _ ,
562
+ _ ,
563
+ _ ,
564
+ _ ,
565
+ _ ,
566
+ _ ,
567
+ _ ,
568
+ _ ,
569
+ ) = torch .ops .aten ._scaled_dot_product_flash_attention (x , x , x )
570
+ return output
571
+
572
+ func = Foo ()
551
573
552
574
x = torch .randn (1 , 1 , 1 , 32 , device = torch .device ("cuda" ))
553
575
self .run_test_with_fx_to_onnx_exporter_and_onnx_runtime (func , (x ,))
@@ -597,9 +619,12 @@ def forward(
597
619
)
598
620
599
621
def test_operator_with_data_dependent_output (self ):
600
- def func (x ):
601
- # Repro from llama. Emits `torch.ops.aten._local_scalar_dense`.
602
- return x + torch .full (x .shape , torch .tensor (torch .finfo (x .dtype ).min ))
622
+ class Foo (torch .nn .Module ):
623
+ def forward (self , x ):
624
+ # Repro from llama. Emits `torch.ops.aten._local_scalar_dense`.
625
+ return x + torch .full (x .shape , torch .tensor (torch .finfo (x .dtype ).min ))
626
+
627
+ func = Foo ()
603
628
604
629
self .run_test_with_fx_to_onnx_exporter_and_onnx_runtime (
605
630
func , (torch .randn (3 , 4 ),)
@@ -610,8 +635,11 @@ def func(x):
610
635
reason = "https://github.com/pytorch/pytorch/issues/112622" ,
611
636
)
612
637
def test_operator_with_scalar_output (self ):
613
- def func (x , y ):
614
- return x .item () + y
638
+ class Foo (torch .nn .Module ):
639
+ def forward (self , x , y ):
640
+ return x .item () + y
641
+
642
+ func = Foo ()
615
643
616
644
self .run_test_with_fx_to_onnx_exporter_and_onnx_runtime (
617
645
func , (torch .tensor ([1 ]), torch .randn (3 , 4 ))
@@ -622,8 +650,11 @@ def func(x, y):
622
650
reason = "https://github.com/pytorch/pytorch/issues/112622" ,
623
651
)
624
652
def test_operator_with_dynamic_output_shape (self ):
625
- def func (x ):
626
- return x .nonzero ()
653
+ class Foo (torch .nn .Module ):
654
+ def forward (self , x ):
655
+ return x .nonzero ()
656
+
657
+ func = Foo ()
627
658
628
659
self .run_test_with_fx_to_onnx_exporter_and_onnx_runtime (
629
660
func , (torch .randn (3 , 4 ),)
0 commit comments