@@ -44,14 +44,19 @@ def forward(self, x):
44
44
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
45
45
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
46
46
torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
47
- # TODO: Enable this serialization issues are fixed
48
- # deser_trt_module = torchtrt.load("/tmp/trt.ep").module()
47
+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
49
48
# Check Pyt and TRT exported program outputs
50
49
cos_sim = cosine_similarity (model (input ), trt_module (input )[0 ])
51
50
assertions .assertTrue (
52
51
cos_sim > COSINE_THRESHOLD ,
53
52
msg = f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
54
53
)
54
+ # Check Pyt and deserialized TRT exported program outputs
55
+ cos_sim = cosine_similarity (model (input ), deser_trt_module (input )[0 ])
56
+ assertions .assertTrue (
57
+ cos_sim > COSINE_THRESHOLD ,
58
+ msg = f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
59
+ )
55
60
# TODO: Enable this serialization issues are fixed
56
61
# # Check Pyt and deserialized TRT exported program outputs
57
62
# cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0])
@@ -95,9 +100,8 @@ def forward(self, x):
95
100
96
101
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
97
102
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
98
- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
99
- # TODO: Enable this serialization issues are fixed
100
- # deser_trt_module = torchtrt.load("./trt.ep").module()
103
+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
104
+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
101
105
# Check Pyt and TRT exported program outputs
102
106
outputs_pyt = model (input )
103
107
outputs_trt = trt_module (input )
@@ -108,15 +112,14 @@ def forward(self, x):
108
112
msg = f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
109
113
)
110
114
111
- # TODO: Enable this serialization issues are fixed
112
- # # Check Pyt and deserialized TRT exported program outputs
113
- # outputs_trt_deser = deser_trt_module(input)
114
- # for idx in range(len(outputs_pyt)):
115
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
116
- # assertions.assertTrue(
117
- # cos_sim > COSINE_THRESHOLD,
118
- # msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
119
- # )
115
+ # Check Pyt and deserialized TRT exported program outputs
116
+ outputs_trt_deser = deser_trt_module (input )
117
+ for idx in range (len (outputs_pyt )):
118
+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
119
+ assertions .assertTrue (
120
+ cos_sim > COSINE_THRESHOLD ,
121
+ msg = f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
122
+ )
120
123
121
124
122
125
@pytest .mark .unit
@@ -152,9 +155,8 @@ def forward(self, x):
152
155
153
156
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
154
157
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
155
- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
156
- # TODO: Enable this serialization issues are fixed
157
- # deser_trt_module = torchtrt.load("./trt.ep").module()
158
+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
159
+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
158
160
# Check Pyt and TRT exported program outputs
159
161
outputs_pyt = model (input )
160
162
outputs_trt = trt_module (input )
@@ -165,15 +167,14 @@ def forward(self, x):
165
167
msg = f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
166
168
)
167
169
168
- # TODO: Enable this serialization issues are fixed
169
- # # Check Pyt and deserialized TRT exported program outputs
170
- # outputs_trt_deser = deser_trt_module(input)
171
- # for idx in range(len(outputs_pyt)):
172
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
173
- # assertions.assertTrue(
174
- # cos_sim > COSINE_THRESHOLD,
175
- # msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
176
- # )
170
+ # Check Pyt and deserialized TRT exported program outputs
171
+ outputs_trt_deser = deser_trt_module (input )
172
+ for idx in range (len (outputs_pyt )):
173
+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
174
+ assertions .assertTrue (
175
+ cos_sim > COSINE_THRESHOLD ,
176
+ msg = f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
177
+ )
177
178
178
179
179
180
@pytest .mark .unit
@@ -212,9 +213,8 @@ def forward(self, x):
212
213
213
214
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
214
215
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
215
- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
216
- # TODO: Enable this serialization issues are fixed
217
- # deser_trt_module = torchtrt.load("./trt.ep").module()
216
+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
217
+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
218
218
outputs_pyt = model (input )
219
219
outputs_trt = trt_module (input )
220
220
for idx in range (len (outputs_pyt )):
@@ -224,14 +224,13 @@ def forward(self, x):
224
224
msg = f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
225
225
)
226
226
227
- # TODO: Enable this serialization issues are fixed
228
- # outputs_trt_deser = deser_trt_module(input)
229
- # for idx in range(len(outputs_pyt)):
230
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
231
- # assertions.assertTrue(
232
- # cos_sim > COSINE_THRESHOLD,
233
- # msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
234
- # )
227
+ outputs_trt_deser = deser_trt_module (input )
228
+ for idx in range (len (outputs_pyt )):
229
+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
230
+ assertions .assertTrue (
231
+ cos_sim > COSINE_THRESHOLD ,
232
+ msg = f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
233
+ )
235
234
236
235
237
236
@pytest .mark .unit
@@ -254,9 +253,8 @@ def test_resnet18(ir):
254
253
255
254
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
256
255
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
257
- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
258
- # TODO: Enable this serialization issues are fixed
259
- # deser_trt_module = torchtrt.load("./trt.ep").module()
256
+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
257
+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
260
258
outputs_pyt = model (input )
261
259
outputs_trt = trt_module (input )
262
260
cos_sim = cosine_similarity (outputs_pyt , outputs_trt [0 ])
@@ -265,13 +263,13 @@ def test_resnet18(ir):
265
263
msg = f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
266
264
)
267
265
268
- # TODO: Enable this serialization issues are fixed
269
- # outputs_trt_deser = deser_trt_module(input)
270
- # cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
271
- # assertions.assertTrue(
272
- # cos_sim > COSINE_THRESHOLD,
273
- # msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
274
- # )
266
+ outputs_trt_deser = deser_trt_module ( input )
267
+
268
+ cos_sim = cosine_similarity (outputs_pyt , outputs_trt_deser [0 ])
269
+ assertions .assertTrue (
270
+ cos_sim > COSINE_THRESHOLD ,
271
+ msg = f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
272
+ )
275
273
276
274
277
275
@pytest .mark .unit
@@ -310,9 +308,8 @@ def forward(self, x):
310
308
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
311
309
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
312
310
313
- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
314
- # TODO: Enable this serialization issues are fixed
315
- # deser_trt_module = torchtrt.load("./trt.ep").module()
311
+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
312
+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
316
313
outputs_pyt = model (input )
317
314
outputs_trt = trt_module (input )
318
315
@@ -323,14 +320,13 @@ def forward(self, x):
323
320
msg = f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
324
321
)
325
322
326
- # TODO: Enable this serialization issues are fixed
327
- # outputs_trt_deser = deser_trt_module(input)
328
- # for idx in range(len(outputs_pyt)):
329
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
330
- # assertions.assertTrue(
331
- # cos_sim > COSINE_THRESHOLD,
332
- # msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
333
- # )
323
+ outputs_trt_deser = deser_trt_module (input )
324
+ for idx in range (len (outputs_pyt )):
325
+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
326
+ assertions .assertTrue (
327
+ cos_sim > COSINE_THRESHOLD ,
328
+ msg = f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
329
+ )
334
330
335
331
336
332
@pytest .mark .unit
@@ -361,9 +357,9 @@ def forward(self, x):
361
357
)
362
358
outputs_trt = trt_gm (input )
363
359
# Save it as torchscript representation
364
- torchtrt .save (trt_gm , ". /trt.ts" , output_format = "torchscript" , inputs = [input ])
360
+ torchtrt .save (trt_gm , "/tmp /trt.ts" , output_format = "torchscript" , inputs = [input ])
365
361
366
- trt_ts_module = torchtrt .load (". /trt.ts" )
362
+ trt_ts_module = torchtrt .load ("/tmp /trt.ts" )
367
363
outputs_trt_deser = trt_ts_module (input )
368
364
369
365
cos_sim = cosine_similarity (outputs_trt , outputs_trt_deser )
0 commit comments