File tree Expand file tree Collapse file tree 4 files changed +16
-10
lines changed Expand file tree Collapse file tree 4 files changed +16
-10
lines changed Original file line number Diff line number Diff line change @@ -26,7 +26,6 @@ Functions
26
26
27
27
.. autofunction :: refit_module_weights
28
28
29
-
30
29
Classes
31
30
--------
32
31
Original file line number Diff line number Diff line change 91
91
print ("Refit successfully!" )
92
92
93
93
# %%
94
- # Alternative Workflow using Python Runtime
94
+ # Alterative Workflow using Python Runtime
95
95
# -----------------------------
96
96
97
97
# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.
Original file line number Diff line number Diff line change @@ -176,6 +176,18 @@ def compile(
176
176
177
177
if kwarg_inputs is None :
178
178
kwarg_inputs = {}
179
+
180
+ if "refit" in kwargs .keys ():
181
+ warnings .warn (
182
+ "Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine." ,
183
+ DeprecationWarning ,
184
+ stacklevel = 2 ,
185
+ )
186
+ if make_refitable :
187
+ raise ValueError ("Use flag make_refitable only. Flag refit is deprecated." )
188
+ else :
189
+ make_refitable = kwargs ["refit" ]
190
+
179
191
engine_capability = EngineCapability ._from (engine_capability )
180
192
181
193
if torch_executed_modules is not None and torch_executed_modules :
Original file line number Diff line number Diff line change 6
6
from typing import Any , Sequence , Tuple
7
7
8
8
import numpy as np
9
+ import tensorrt as trt
9
10
import torch
10
11
from torch .export import ExportedProgram
11
12
from torch_tensorrt ._enums import dtype
42
43
)
43
44
from torch_tensorrt .logging import TRT_LOGGER
44
45
45
- import tensorrt as trt
46
-
47
46
logger = logging .getLogger (__name__ )
48
47
49
48
@@ -96,16 +95,12 @@ def construct_refit_mapping(
96
95
layer_type : str = layer .type .name
97
96
if layer_type in MODULE_MAP :
98
97
# Cast the parent class to child class to access attributes
99
- # For example: ILayer does not have ILayer.kernel /ILayer.bias
98
+ # For example: ILayer does not have ILayer.kernal /ILayer.bias
100
99
# So we cast it to IConvolutionLayer and access the attributes
101
100
layer .__class__ = MODULE_MAP [layer_type ][0 ]
102
101
for weight_type , weight_name in MODULE_MAP [layer_type ][1 ]:
103
102
weight = layer .__getattribute__ (weight_type ).copy ()
104
- weight_dtype_opt = dtype .try_from (weight .dtype )
105
- assert (
106
- weight_dtype_opt is not None
107
- ), f"Weights { weight_name } has unsupported dtype { weight .dtype } "
108
- weight_dtype = weight_dtype_opt .to (trt .DataType )
103
+ weight_dtype = dtype .try_from (weight .dtype ).to (trt .DataType )
109
104
weight_map [f"{ layer .name } { weight_name } " ] = (
110
105
weight ,
111
106
weight_dtype ,
You can’t perform that action at this time.
0 commit comments