Skip to content

Commit e314ad6

Browse files
lanluo-nvidiacehongwangzewenli98
authored
cherry pick refit error 3170 from main to release/2.5 branch (#3236)
Co-authored-by: cehongwang <[email protected]> Co-authored-by: Evan Li <[email protected]>
1 parent cd12eb4 commit e314ad6

File tree

4 files changed

+54
-21
lines changed

4 files changed

+54
-21
lines changed

examples/dynamo/refit_engine_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
min_block_size=min_block_size,
7171
torch_executed_ops=torch_executed_ops,
7272
make_refittable=True,
73+
reuse_cached_engines=False,
7374
) # Output is a torch.fx.GraphModule
7475

7576
# Save the graph module as an exported program

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,12 +476,18 @@ def _save_weight_mapping(self) -> None:
476476
# Retrieve each weight name(s) in state_dict
477477
if layer_type == "CONSTANT":
478478
if "embedding" in suffix:
479-
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
479+
sd_weight_name = f"{sd_weight_name}.weight"
480480
elif "weight" in suffix or "mm_other" in suffix:
481481
# Linear layer weight
482-
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
482+
sd_weight_name = f"{sd_weight_name}.weight"
483+
elif "running_mean" in suffix:
484+
# Linear layer weight
485+
sd_weight_name = f"{sd_weight_name}.running_mean"
486+
elif "running_var" in suffix:
487+
# Linear layer weight
488+
sd_weight_name = f"{sd_weight_name}.running_var"
483489
else:
484-
sd_weight_name = f"{sd_weight_name}.{torch_attr[1]}"
490+
sd_weight_name = f"{sd_weight_name}.bias"
485491
elif layer_type == "SCALE":
486492
# Batch norm needs all weights to calculate scale and shift
487493
sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr]

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,27 @@ def batch_norm(
5050
# Save the original output shape for later use
5151
output_shape = input.shape
5252

53-
if weight is None:
54-
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
55-
if bias is None:
56-
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
57-
if running_mean is None:
58-
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
59-
if running_var is None:
60-
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")
53+
# We name the weight here according to the state_dict name
54+
weight = (
55+
get_trt_tensor(ctx, 1.0, f"{name}_weight")
56+
if weight is None
57+
else get_trt_tensor(ctx, weight, f"{name}_weight")
58+
)
59+
bias = (
60+
get_trt_tensor(ctx, 0.0, f"{name}_bias")
61+
if bias is None
62+
else get_trt_tensor(ctx, bias, f"{name}_bias")
63+
)
64+
running_mean = (
65+
get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
66+
if running_mean is None
67+
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
68+
)
69+
running_var = (
70+
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
71+
if running_var is None
72+
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
73+
)
6174

6275
# eps_tensor for numerical stability
6376
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")

tests/py/dynamo/models/test_model_refit.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
@pytest.mark.unit
3636
def test_mapping():
3737

38-
model = models.resnet18(pretrained=True).eval().to("cuda")
39-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
38+
model = models.resnet18(pretrained=False).eval().to("cuda")
39+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
4040
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
4141
trt_input = [
4242
torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format)
@@ -58,6 +58,7 @@ def test_mapping():
5858
debug=debug,
5959
min_block_size=min_block_size,
6060
make_refittable=True,
61+
reuse_cached_engines=False,
6162
)
6263
settings = trt_gm._run_on_acc_0.settings
6364
runtime = trt.Runtime(TRT_LOGGER)
@@ -110,6 +111,7 @@ def test_refit_one_engine_with_weightmap():
110111
debug=debug,
111112
min_block_size=min_block_size,
112113
make_refittable=True,
114+
reuse_cached_engines=False,
113115
)
114116

115117
new_trt_gm = refit_module_weights(
@@ -141,8 +143,8 @@ def test_refit_one_engine_with_weightmap():
141143
@pytest.mark.unit
142144
def test_refit_one_engine_no_map_with_weightmap():
143145

144-
model = models.resnet18(pretrained=True).eval().to("cuda")
145-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
146+
model = models.resnet18(pretrained=False).eval().to("cuda")
147+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
146148
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
147149
enabled_precisions = {torch.float}
148150
debug = False
@@ -160,6 +162,7 @@ def test_refit_one_engine_no_map_with_weightmap():
160162
debug=debug,
161163
min_block_size=min_block_size,
162164
make_refittable=True,
165+
reuse_cached_engines=False,
163166
)
164167

165168
trt_gm._run_on_acc_0.weight_name_map = None
@@ -192,8 +195,8 @@ def test_refit_one_engine_no_map_with_weightmap():
192195
@pytest.mark.unit
193196
def test_refit_one_engine_with_wrong_weightmap():
194197

195-
model = models.resnet18(pretrained=True).eval().to("cuda")
196-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
198+
model = models.resnet18(pretrained=False).eval().to("cuda")
199+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
197200
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
198201
enabled_precisions = {torch.float}
199202
debug = False
@@ -211,6 +214,7 @@ def test_refit_one_engine_with_wrong_weightmap():
211214
debug=debug,
212215
min_block_size=min_block_size,
213216
make_refittable=True,
217+
reuse_cached_engines=False,
214218
)
215219
# Manually Deleted all batch norm layer. This suppose to fail the fast refit
216220
trt_gm._run_on_acc_0.weight_name_map = {
@@ -268,6 +272,7 @@ def test_refit_one_engine_bert_with_weightmap():
268272
debug=debug,
269273
min_block_size=min_block_size,
270274
make_refittable=True,
275+
reuse_cached_engines=False,
271276
)
272277

273278
new_trt_gm = refit_module_weights(
@@ -302,8 +307,8 @@ def test_refit_one_engine_bert_with_weightmap():
302307
@pytest.mark.unit
303308
def test_refit_one_engine_inline_runtime__with_weightmap():
304309
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
305-
model = models.resnet18(pretrained=True).eval().to("cuda")
306-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
310+
model = models.resnet18(pretrained=False).eval().to("cuda")
311+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
307312
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
308313
enabled_precisions = {torch.float}
309314
debug = False
@@ -321,6 +326,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
321326
debug=debug,
322327
min_block_size=min_block_size,
323328
make_refittable=True,
329+
reuse_cached_engines=False,
324330
)
325331
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
326332
trt_gm = torch.export.load(trt_ep_path)
@@ -348,8 +354,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
348354
@pytest.mark.unit
349355
def test_refit_one_engine_python_runtime_with_weightmap():
350356

351-
model = models.resnet18(pretrained=True).eval().to("cuda")
352-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
357+
model = models.resnet18(pretrained=False).eval().to("cuda")
358+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
353359
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
354360
enabled_precisions = {torch.float}
355361
debug = False
@@ -367,6 +373,7 @@ def test_refit_one_engine_python_runtime_with_weightmap():
367373
debug=debug,
368374
min_block_size=min_block_size,
369375
make_refittable=True,
376+
reuse_cached_engines=False,
370377
)
371378

372379
new_trt_gm = refit_module_weights(
@@ -438,6 +445,7 @@ def forward(self, x):
438445
min_block_size=min_block_size,
439446
make_refittable=True,
440447
torch_executed_ops=torch_executed_ops,
448+
reuse_cached_engines=False,
441449
)
442450

443451
new_trt_gm = refit_module_weights(
@@ -487,6 +495,7 @@ def test_refit_one_engine_without_weightmap():
487495
debug=debug,
488496
min_block_size=min_block_size,
489497
make_refittable=True,
498+
reuse_cached_engines=False,
490499
)
491500

492501
new_trt_gm = refit_module_weights(
@@ -538,6 +547,7 @@ def test_refit_one_engine_bert_without_weightmap():
538547
debug=debug,
539548
min_block_size=min_block_size,
540549
make_refittable=True,
550+
reuse_cached_engines=False,
541551
)
542552

543553
new_trt_gm = refit_module_weights(
@@ -591,6 +601,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
591601
debug=debug,
592602
min_block_size=min_block_size,
593603
make_refittable=True,
604+
reuse_cached_engines=False,
594605
)
595606
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
596607
trt_gm = torch.export.load(trt_ep_path)
@@ -637,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap():
637648
debug=debug,
638649
min_block_size=min_block_size,
639650
make_refittable=True,
651+
reuse_cached_engines=False,
640652
)
641653

642654
new_trt_gm = refit_module_weights(
@@ -708,6 +720,7 @@ def forward(self, x):
708720
min_block_size=min_block_size,
709721
make_refittable=True,
710722
torch_executed_ops=torch_executed_ops,
723+
reuse_cached_engines=False,
711724
)
712725

713726
new_trt_gm = refit_module_weights(

0 commit comments

Comments
 (0)