Skip to content

Commit d3b2c04

Browse files
committed
Changed test cases
1 parent 414d972 commit d3b2c04

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

examples/dynamo/refit_engine_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
min_block_size=min_block_size,
7171
torch_executed_ops=torch_executed_ops,
7272
make_refitable=True,
73+
reuse_cached_engines=False,
7374
) # Output is a torch.fx.GraphModule
7475

7576
# Save the graph module as an exported program

tests/py/dynamo/models/test_model_refit.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
@pytest.mark.unit
3636
def test_mapping():
3737

38-
model = models.resnet18(pretrained=True).eval().to("cuda")
39-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
38+
model = models.resnet18(pretrained=False).eval().to("cuda")
39+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
4040
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
4141
trt_input = [
4242
torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format)
@@ -58,6 +58,7 @@ def test_mapping():
5858
debug=debug,
5959
min_block_size=min_block_size,
6060
make_refitable=True,
61+
reuse_cached_engines=False,
6162
)
6263
settings = trt_gm._run_on_acc_0.settings
6364
runtime = trt.Runtime(TRT_LOGGER)
@@ -110,6 +111,7 @@ def test_refit_one_engine_with_weightmap():
110111
debug=debug,
111112
min_block_size=min_block_size,
112113
make_refitable=True,
114+
reuse_cached_engines=False,
113115
)
114116

115117
new_trt_gm = refit_module_weights(
@@ -141,8 +143,8 @@ def test_refit_one_engine_with_weightmap():
141143
@pytest.mark.unit
142144
def test_refit_one_engine_no_map_with_weightmap():
143145

144-
model = models.resnet18(pretrained=True).eval().to("cuda")
145-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
146+
model = models.resnet18(pretrained=False).eval().to("cuda")
147+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
146148
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
147149
enabled_precisions = {torch.float}
148150
debug = False
@@ -160,6 +162,7 @@ def test_refit_one_engine_no_map_with_weightmap():
160162
debug=debug,
161163
min_block_size=min_block_size,
162164
make_refitable=True,
165+
reuse_cached_engines=False,
163166
)
164167

165168
trt_gm._run_on_acc_0.weight_name_map = None
@@ -192,8 +195,8 @@ def test_refit_one_engine_no_map_with_weightmap():
192195
@pytest.mark.unit
193196
def test_refit_one_engine_with_wrong_weightmap():
194197

195-
model = models.resnet18(pretrained=True).eval().to("cuda")
196-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
198+
model = models.resnet18(pretrained=False).eval().to("cuda")
199+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
197200
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
198201
enabled_precisions = {torch.float}
199202
debug = False
@@ -211,6 +214,7 @@ def test_refit_one_engine_with_wrong_weightmap():
211214
debug=debug,
212215
min_block_size=min_block_size,
213216
make_refitable=True,
217+
reuse_cached_engines=False,
214218
)
215219
# Manually Deleted all batch norm layer. This suppose to fail the fast refit
216220
trt_gm._run_on_acc_0.weight_name_map = {
@@ -268,6 +272,7 @@ def test_refit_one_engine_bert_with_weightmap():
268272
debug=debug,
269273
min_block_size=min_block_size,
270274
make_refitable=True,
275+
reuse_cached_engines=False,
271276
)
272277

273278
new_trt_gm = refit_module_weights(
@@ -302,8 +307,8 @@ def test_refit_one_engine_bert_with_weightmap():
302307
@pytest.mark.unit
303308
def test_refit_one_engine_inline_runtime__with_weightmap():
304309
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
305-
model = models.resnet18(pretrained=True).eval().to("cuda")
306-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
310+
model = models.resnet18(pretrained=False).eval().to("cuda")
311+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
307312
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
308313
enabled_precisions = {torch.float}
309314
debug = False
@@ -321,6 +326,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
321326
debug=debug,
322327
min_block_size=min_block_size,
323328
make_refitable=True,
329+
reuse_cached_engines=False,
324330
)
325331
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
326332
trt_gm = torch.export.load(trt_ep_path)
@@ -348,8 +354,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
348354
@pytest.mark.unit
349355
def test_refit_one_engine_python_runtime_with_weightmap():
350356

351-
model = models.resnet18(pretrained=True).eval().to("cuda")
352-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
357+
model = models.resnet18(pretrained=False).eval().to("cuda")
358+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
353359
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
354360
enabled_precisions = {torch.float}
355361
debug = False
@@ -367,6 +373,7 @@ def test_refit_one_engine_python_runtime_with_weightmap():
367373
debug=debug,
368374
min_block_size=min_block_size,
369375
make_refitable=True,
376+
reuse_cached_engines=False,
370377
)
371378

372379
new_trt_gm = refit_module_weights(
@@ -438,6 +445,7 @@ def forward(self, x):
438445
min_block_size=min_block_size,
439446
make_refitable=True,
440447
torch_executed_ops=torch_executed_ops,
448+
reuse_cached_engines=False,
441449
)
442450

443451
new_trt_gm = refit_module_weights(
@@ -487,10 +495,11 @@ def test_refit_one_engine_without_weightmap():
487495
debug=debug,
488496
min_block_size=min_block_size,
489497
make_refitable=True,
498+
reuse_cached_engines=False,
490499
)
491500

492501
new_trt_gm = refit_module_weights(
493-
compiled_module=trt_gm,
502+
compiled_module=trt_gm,
494503
new_weight_module=exp_program2,
495504
arg_inputs=inputs,
496505
use_weight_map_cache=False,
@@ -538,6 +547,7 @@ def test_refit_one_engine_bert_without_weightmap():
538547
debug=debug,
539548
min_block_size=min_block_size,
540549
make_refitable=True,
550+
reuse_cached_engines=False,
541551
)
542552

543553
new_trt_gm = refit_module_weights(
@@ -591,6 +601,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
591601
debug=debug,
592602
min_block_size=min_block_size,
593603
make_refitable=True,
604+
reuse_cached_engines=False,
594605
)
595606
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
596607
trt_gm = torch.export.load(trt_ep_path)
@@ -637,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap():
637648
debug=debug,
638649
min_block_size=min_block_size,
639650
make_refitable=True,
651+
reuse_cached_engines=False,
640652
)
641653

642654
new_trt_gm = refit_module_weights(
@@ -708,6 +720,7 @@ def forward(self, x):
708720
min_block_size=min_block_size,
709721
make_refitable=True,
710722
torch_executed_ops=torch_executed_ops,
723+
reuse_cached_engines=False,
711724
)
712725

713726
new_trt_gm = refit_module_weights(

0 commit comments

Comments
 (0)