@@ -64,11 +64,6 @@ def forward(self, x):
64
64
cos_sim > COSINE_THRESHOLD ,
65
65
msg = f"test_dyn_full_compile model TRT outputs don't match with the pytorch model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
66
66
)
67
- # Clean up model env
68
- torch ._dynamo .reset ()
69
-
70
- with torch .no_grad ():
71
- torch .cuda .empty_cache ()
72
67
73
68
74
69
@unittest .skip (
@@ -128,12 +123,6 @@ def forward(self, x):
128
123
msg = f"test_base_dynamic_fallback model TRT outputs don't match with the pytorch model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
129
124
)
130
125
131
- # Clean up model env
132
- torch ._dynamo .reset ()
133
-
134
- with torch .no_grad ():
135
- torch .cuda .empty_cache ()
136
-
137
126
138
127
@pytest .mark .unit
139
128
def test_view (ir ):
@@ -185,12 +174,6 @@ def forward(self, x):
185
174
msg = f"test_view model TRT outputs don't match with the pytorch model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
186
175
)
187
176
188
- # Clean up model env
189
- torch ._dynamo .reset ()
190
-
191
- with torch .no_grad ():
192
- torch .cuda .empty_cache ()
193
-
194
177
195
178
@pytest .mark .unit
196
179
def test_resnet_dynamic (ir ):
@@ -234,12 +217,6 @@ def test_resnet_dynamic(ir):
234
217
msg = f"test_resnet_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
235
218
)
236
219
237
- # Clean up model env
238
- torch ._dynamo .reset ()
239
-
240
- with torch .no_grad ():
241
- torch .cuda .empty_cache ()
242
-
243
220
244
221
@pytest .mark .unit
245
222
def test_view (ir ):
@@ -284,8 +261,52 @@ def forward(self, x):
284
261
msg = f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
285
262
)
286
263
287
- # Clean up model env
288
- torch ._dynamo .reset ()
289
264
290
- with torch .no_grad ():
291
- torch .cuda .empty_cache ()
265
+ @pytest .mark .unit
266
+ def test_linear (ir ):
267
+ """
268
+ Tests the model with linear op and operator.mul (added internally by PyTorch)
269
+ with dynamic shapes
270
+ """
271
+
272
+ class MyModule (torch .nn .Module ):
273
+ def __init__ (self ):
274
+ super ().__init__ ()
275
+ self .linear1 = torch .nn .Linear (10 , 10 )
276
+
277
+ def forward (self , x ):
278
+ return self .linear1 (x )
279
+
280
+ model = MyModule ().eval ().cuda ()
281
+
282
+ compile_spec = {
283
+ "device" : torchtrt .Device ("cuda:0" ),
284
+ "enabled_precisions" : {torch .float },
285
+ "ir" : ir ,
286
+ "min_block_size" : 1 ,
287
+ }
288
+ inputs_bs2 = torch .randn (2 , 2 , 10 ).to ("cuda" )
289
+ if ir == "torch_compile" :
290
+ torch ._dynamo .mark_dynamic (inputs_bs2 , 0 , min = 1 , max = 10 )
291
+ torch ._dynamo .mark_dynamic (inputs_bs2 , 1 , min = 1 , max = 10 )
292
+ # Compile the model
293
+ trt_model = torch .compile (model , backend = "tensorrt" , options = compile_spec )
294
+ trt_model (inputs_bs2 )
295
+ elif ir == "dynamo" :
296
+ dynamic_shapes = (
297
+ {
298
+ 0 : torch .export .Dim ("batch_size" , min = 1 , max = 10 ),
299
+ 1 : torch .export .Dim ("seq_len" , max = 10 ),
300
+ },
301
+ )
302
+ exp_program = torch .export .export (
303
+ model , (inputs_bs2 ,), dynamic_shapes = dynamic_shapes
304
+ )
305
+ trt_model = torchtrt .dynamo .compile (exp_program , [inputs_bs2 ], ** compile_spec )
306
+
307
+ input_bs6_s3 = torch .randn ((6 , 3 , 10 )).to ("cuda" )
308
+ cos_sim = cosine_similarity (model (input_bs6_s3 ), trt_model (input_bs6_s3 ))
309
+ assertions .assertTrue (
310
+ cos_sim > COSINE_THRESHOLD ,
311
+ msg = f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
312
+ )
0 commit comments