35
35
@pytest .mark .unit
36
36
def test_mapping ():
37
37
38
- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
39
- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
38
+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
39
+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
40
40
inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
41
41
trt_input = [
42
42
torchtrt .Input (i .shape , dtype = torch .float , format = torch .contiguous_format )
@@ -58,6 +58,7 @@ def test_mapping():
58
58
debug = debug ,
59
59
min_block_size = min_block_size ,
60
60
make_refitable = True ,
61
+ reuse_cached_engines = False ,
61
62
)
62
63
settings = trt_gm ._run_on_acc_0 .settings
63
64
runtime = trt .Runtime (TRT_LOGGER )
@@ -110,6 +111,7 @@ def test_refit_one_engine_with_weightmap():
110
111
debug = debug ,
111
112
min_block_size = min_block_size ,
112
113
make_refitable = True ,
114
+ reuse_cached_engines = False ,
113
115
)
114
116
115
117
new_trt_gm = refit_module_weights (
@@ -141,8 +143,8 @@ def test_refit_one_engine_with_weightmap():
141
143
@pytest .mark .unit
142
144
def test_refit_one_engine_no_map_with_weightmap ():
143
145
144
- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
145
- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
146
+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
147
+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
146
148
inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
147
149
enabled_precisions = {torch .float }
148
150
debug = False
@@ -160,6 +162,7 @@ def test_refit_one_engine_no_map_with_weightmap():
160
162
debug = debug ,
161
163
min_block_size = min_block_size ,
162
164
make_refitable = True ,
165
+ reuse_cached_engines = False ,
163
166
)
164
167
165
168
trt_gm ._run_on_acc_0 .weight_name_map = None
@@ -192,8 +195,8 @@ def test_refit_one_engine_no_map_with_weightmap():
192
195
@pytest .mark .unit
193
196
def test_refit_one_engine_with_wrong_weightmap ():
194
197
195
- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
196
- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
198
+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
199
+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
197
200
inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
198
201
enabled_precisions = {torch .float }
199
202
debug = False
@@ -211,6 +214,7 @@ def test_refit_one_engine_with_wrong_weightmap():
211
214
debug = debug ,
212
215
min_block_size = min_block_size ,
213
216
make_refitable = True ,
217
+ reuse_cached_engines = False ,
214
218
)
215
219
# Manually Deleted all batch norm layer. This suppose to fail the fast refit
216
220
trt_gm ._run_on_acc_0 .weight_name_map = {
@@ -268,6 +272,7 @@ def test_refit_one_engine_bert_with_weightmap():
268
272
debug = debug ,
269
273
min_block_size = min_block_size ,
270
274
make_refitable = True ,
275
+ reuse_cached_engines = False ,
271
276
)
272
277
273
278
new_trt_gm = refit_module_weights (
@@ -302,8 +307,8 @@ def test_refit_one_engine_bert_with_weightmap():
302
307
@pytest .mark .unit
303
308
def test_refit_one_engine_inline_runtime__with_weightmap ():
304
309
trt_ep_path = os .path .join (tempfile .gettempdir (), "compiled.ep" )
305
- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
306
- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
310
+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
311
+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
307
312
inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
308
313
enabled_precisions = {torch .float }
309
314
debug = False
@@ -321,6 +326,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
321
326
debug = debug ,
322
327
min_block_size = min_block_size ,
323
328
make_refitable = True ,
329
+ reuse_cached_engines = False ,
324
330
)
325
331
torchtrt .save (trt_gm , trt_ep_path , inputs = inputs )
326
332
trt_gm = torch .export .load (trt_ep_path )
@@ -348,8 +354,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
348
354
@pytest .mark .unit
349
355
def test_refit_one_engine_python_runtime_with_weightmap ():
350
356
351
- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
352
- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
357
+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
358
+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
353
359
inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
354
360
enabled_precisions = {torch .float }
355
361
debug = False
@@ -367,6 +373,7 @@ def test_refit_one_engine_python_runtime_with_weightmap():
367
373
debug = debug ,
368
374
min_block_size = min_block_size ,
369
375
make_refitable = True ,
376
+ reuse_cached_engines = False ,
370
377
)
371
378
372
379
new_trt_gm = refit_module_weights (
@@ -438,6 +445,7 @@ def forward(self, x):
438
445
min_block_size = min_block_size ,
439
446
make_refitable = True ,
440
447
torch_executed_ops = torch_executed_ops ,
448
+ reuse_cached_engines = False ,
441
449
)
442
450
443
451
new_trt_gm = refit_module_weights (
@@ -487,10 +495,11 @@ def test_refit_one_engine_without_weightmap():
487
495
debug = debug ,
488
496
min_block_size = min_block_size ,
489
497
make_refitable = True ,
498
+ reuse_cached_engines = False ,
490
499
)
491
500
492
501
new_trt_gm = refit_module_weights (
493
- compiled_module = trt_gm ,
502
+ compiled_module = trt_gm ,
494
503
new_weight_module = exp_program2 ,
495
504
arg_inputs = inputs ,
496
505
use_weight_map_cache = False ,
@@ -538,6 +547,7 @@ def test_refit_one_engine_bert_without_weightmap():
538
547
debug = debug ,
539
548
min_block_size = min_block_size ,
540
549
make_refitable = True ,
550
+ reuse_cached_engines = False ,
541
551
)
542
552
543
553
new_trt_gm = refit_module_weights (
@@ -591,6 +601,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
591
601
debug = debug ,
592
602
min_block_size = min_block_size ,
593
603
make_refitable = True ,
604
+ reuse_cached_engines = False ,
594
605
)
595
606
torchtrt .save (trt_gm , trt_ep_path , inputs = inputs )
596
607
trt_gm = torch .export .load (trt_ep_path )
@@ -637,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap():
637
648
debug = debug ,
638
649
min_block_size = min_block_size ,
639
650
make_refitable = True ,
651
+ reuse_cached_engines = False ,
640
652
)
641
653
642
654
new_trt_gm = refit_module_weights (
@@ -708,6 +720,7 @@ def forward(self, x):
708
720
min_block_size = min_block_size ,
709
721
make_refitable = True ,
710
722
torch_executed_ops = torch_executed_ops ,
723
+ reuse_cached_engines = False ,
711
724
)
712
725
713
726
new_trt_gm = refit_module_weights (
0 commit comments