@@ -47,21 +47,21 @@ def forward(self, x):
47
47
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
48
48
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
49
49
torchtrt .save (trt_module , trt_ep_path , inputs = [input ])
50
- # TODO: Enable this serialization issues are fixed
51
- # deser_trt_module = torchtrt.load(trt_ep_path).module()
50
+
51
+ deser_trt_module = torchtrt .load (trt_ep_path ).module ()
52
52
# Check Pyt and TRT exported program outputs
53
53
cos_sim = cosine_similarity (model (input ), trt_module (input )[0 ])
54
54
assertions .assertTrue (
55
55
cos_sim > COSINE_THRESHOLD ,
56
56
msg = f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
57
57
)
58
- # TODO: Enable this serialization issues are fixed
59
- # # Check Pyt and deserialized TRT exported program outputs
60
- # cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0])
61
- # assertions.assertTrue(
62
- # cos_sim > COSINE_THRESHOLD,
63
- # msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
64
- # )
58
+
59
+ # Check Pyt and deserialized TRT exported program outputs
60
+ cos_sim = cosine_similarity (model (input ), deser_trt_module (input )[0 ])
61
+ assertions .assertTrue (
62
+ cos_sim > COSINE_THRESHOLD ,
63
+ msg = f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
64
+ )
65
65
66
66
67
67
@pytest .mark .unit
@@ -99,8 +99,8 @@ def forward(self, x):
99
99
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
100
100
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
101
101
torchtrt .save (trt_module , trt_ep_path , inputs = [input ])
102
- # TODO: Enable this serialization issues are fixed
103
- # deser_trt_module = torchtrt.load(trt_ep_path).module()
102
+
103
+ deser_trt_module = torchtrt .load (trt_ep_path ).module ()
104
104
# Check Pyt and TRT exported program outputs
105
105
outputs_pyt = model (input )
106
106
outputs_trt = trt_module (input )
@@ -111,15 +111,14 @@ def forward(self, x):
111
111
msg = f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
112
112
)
113
113
114
- # TODO: Enable this serialization issues are fixed
115
114
# # Check Pyt and deserialized TRT exported program outputs
116
- # outputs_trt_deser = deser_trt_module(input)
117
- # for idx in range(len(outputs_pyt)):
118
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
119
- # assertions.assertTrue(
120
- # cos_sim > COSINE_THRESHOLD,
121
- # msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
122
- # )
115
+ outputs_trt_deser = deser_trt_module (input )
116
+ for idx in range (len (outputs_pyt )):
117
+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
118
+ assertions .assertTrue (
119
+ cos_sim > COSINE_THRESHOLD ,
120
+ msg = f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
121
+ )
123
122
124
123
125
124
@pytest .mark .unit
@@ -156,8 +155,8 @@ def forward(self, x):
156
155
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
157
156
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
158
157
torchtrt .save (trt_module , trt_ep_path , inputs = [input ])
159
- # TODO: Enable this serialization issues are fixed
160
- # deser_trt_module = torchtrt.load(trt_ep_path).module()
158
+
159
+ deser_trt_module = torchtrt .load (trt_ep_path ).module ()
161
160
# Check Pyt and TRT exported program outputs
162
161
outputs_pyt = model (input )
163
162
outputs_trt = trt_module (input )
@@ -168,15 +167,14 @@ def forward(self, x):
168
167
msg = f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
169
168
)
170
169
171
- # TODO: Enable this serialization issues are fixed
172
170
# # Check Pyt and deserialized TRT exported program outputs
173
- # outputs_trt_deser = deser_trt_module(input)
174
- # for idx in range(len(outputs_pyt)):
175
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
176
- # assertions.assertTrue(
177
- # cos_sim > COSINE_THRESHOLD,
178
- # msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
179
- # )
171
+ outputs_trt_deser = deser_trt_module (input )
172
+ for idx in range (len (outputs_pyt )):
173
+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
174
+ assertions .assertTrue (
175
+ cos_sim > COSINE_THRESHOLD ,
176
+ msg = f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
177
+ )
180
178
181
179
182
180
@pytest .mark .unit
@@ -216,8 +214,8 @@ def forward(self, x):
216
214
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
217
215
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
218
216
torchtrt .save (trt_module , trt_ep_path , inputs = [input ])
219
- # TODO: Enable this serialization issues are fixed
220
- # deser_trt_module = torchtrt.load(trt_ep_path).module()
217
+
218
+ deser_trt_module = torchtrt .load (trt_ep_path ).module ()
221
219
outputs_pyt = model (input )
222
220
outputs_trt = trt_module (input )
223
221
for idx in range (len (outputs_pyt )):
@@ -227,14 +225,13 @@ def forward(self, x):
227
225
msg = f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
228
226
)
229
227
230
- # TODO: Enable this serialization issues are fixed
231
- # outputs_trt_deser = deser_trt_module(input)
232
- # for idx in range(len(outputs_pyt)):
233
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
234
- # assertions.assertTrue(
235
- # cos_sim > COSINE_THRESHOLD,
236
- # msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
237
- # )
228
+ outputs_trt_deser = deser_trt_module (input )
229
+ for idx in range (len (outputs_pyt )):
230
+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
231
+ assertions .assertTrue (
232
+ cos_sim > COSINE_THRESHOLD ,
233
+ msg = f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
234
+ )
238
235
239
236
240
237
@pytest .mark .unit
@@ -258,8 +255,8 @@ def test_resnet18(ir):
258
255
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
259
256
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
260
257
torchtrt .save (trt_module , trt_ep_path , inputs = [input ])
261
- # TODO: Enable this serialization issues are fixed
262
- # deser_trt_module = torchtrt.load(trt_ep_path).module()
258
+
259
+ deser_trt_module = torchtrt .load (trt_ep_path ).module ()
263
260
outputs_pyt = model (input )
264
261
outputs_trt = trt_module (input )
265
262
cos_sim = cosine_similarity (outputs_pyt , outputs_trt [0 ])
@@ -268,13 +265,12 @@ def test_resnet18(ir):
268
265
msg = f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
269
266
)
270
267
271
- # TODO: Enable this serialization issues are fixed
272
- # outputs_trt_deser = deser_trt_module(input)
273
- # cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
274
- # assertions.assertTrue(
275
- # cos_sim > COSINE_THRESHOLD,
276
- # msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
277
- # )
268
+ outputs_trt_deser = deser_trt_module (input )
269
+ cos_sim = cosine_similarity (outputs_pyt , outputs_trt_deser [0 ])
270
+ assertions .assertTrue (
271
+ cos_sim > COSINE_THRESHOLD ,
272
+ msg = f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
273
+ )
278
274
279
275
280
276
@pytest .mark .unit
@@ -314,8 +310,8 @@ def forward(self, x):
314
310
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
315
311
316
312
torchtrt .save (trt_module , trt_ep_path , inputs = [input ])
317
- # TODO: Enable this serialization issues are fixed
318
- # deser_trt_module = torchtrt.load(trt_ep_path).module()
313
+
314
+ deser_trt_module = torchtrt .load (trt_ep_path ).module ()
319
315
outputs_pyt = model (input )
320
316
outputs_trt = trt_module (input )
321
317
@@ -326,14 +322,13 @@ def forward(self, x):
326
322
msg = f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
327
323
)
328
324
329
- # TODO: Enable this serialization issues are fixed
330
- # outputs_trt_deser = deser_trt_module(input)
331
- # for idx in range(len(outputs_pyt)):
332
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
333
- # assertions.assertTrue(
334
- # cos_sim > COSINE_THRESHOLD,
335
- # msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
336
- # )
325
+ outputs_trt_deser = deser_trt_module (input )
326
+ for idx in range (len (outputs_pyt )):
327
+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
328
+ assertions .assertTrue (
329
+ cos_sim > COSINE_THRESHOLD ,
330
+ msg = f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
331
+ )
337
332
338
333
339
334
@pytest .mark .unit
0 commit comments