Skip to content

Commit 50ba223

Browse files
committed
Implemented basic pipeline for Refitting (#2886)
1 parent cf9a9bb commit 50ba223

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

docsrc/py_api/dynamo.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ Functions
2626

2727
.. autofunction:: refit_module_weights
2828

29-
3029
Classes
3130
--------
3231

examples/dynamo/refit_engine_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
print("Refit successfully!")
9292

9393
# %%
94-
# Alternative Workflow using Python Runtime
94+
# Alterative Workflow using Python Runtime
9595
# -----------------------------
9696

9797
# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,18 @@ def compile(
176176

177177
if kwarg_inputs is None:
178178
kwarg_inputs = {}
179+
180+
if "refit" in kwargs.keys():
181+
warnings.warn(
182+
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
183+
DeprecationWarning,
184+
stacklevel=2,
185+
)
186+
if make_refitable:
187+
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
188+
else:
189+
make_refitable = kwargs["refit"]
190+
179191
engine_capability = EngineCapability._from(engine_capability)
180192

181193
if torch_executed_modules is not None and torch_executed_modules:

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Sequence, Tuple
77

88
import numpy as np
9+
import tensorrt as trt
910
import torch
1011
from torch.export import ExportedProgram
1112
from torch_tensorrt._enums import dtype
@@ -42,8 +43,6 @@
4243
)
4344
from torch_tensorrt.logging import TRT_LOGGER
4445

45-
import tensorrt as trt
46-
4746
logger = logging.getLogger(__name__)
4847

4948

@@ -96,16 +95,12 @@ def construct_refit_mapping(
9695
layer_type: str = layer.type.name
9796
if layer_type in MODULE_MAP:
9897
# Cast the parent class to child class to access attributes
99-
# For example: ILayer does not have ILayer.kernel/ILayer.bias
98+
# For example: ILayer does not have ILayer.kernal/ILayer.bias
10099
# So we cast it to IConvolutionLayer and access the attributes
101100
layer.__class__ = MODULE_MAP[layer_type][0]
102101
for weight_type, weight_name in MODULE_MAP[layer_type][1]:
103102
weight = layer.__getattribute__(weight_type).copy()
104-
weight_dtype_opt = dtype.try_from(weight.dtype)
105-
assert (
106-
weight_dtype_opt is not None
107-
), f"Weights {weight_name} has unsupported dtype {weight.dtype}"
108-
weight_dtype = weight_dtype_opt.to(trt.DataType)
103+
weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType)
109104
weight_map[f"{layer.name} {weight_name}"] = (
110105
weight,
111106
weight_dtype,

0 commit comments

Comments
 (0)