9
9
import tensorrt as trt
10
10
import torch
11
11
from torch .export import ExportedProgram
12
+ from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
12
13
from torch_tensorrt ._enums import dtype
13
14
from torch_tensorrt ._features import needs_refit
14
15
from torch_tensorrt ._Input import Input
@@ -61,26 +62,13 @@ def construct_refit_mapping(
61
62
Returns:
62
63
Mapping from weight name in TensorRT to actual weight value in np.ndarray
63
64
"""
64
- MODULE_MAP = {
65
- "SCALE" : (trt .IScaleLayer , [("scale" , "SCALE" ), ("shift" , "SHIFT" )]),
66
- "CONVOLUTION" : (
67
- trt .IConvolutionLayer ,
68
- [("kernel" , "KERNEL" ), ("bias" , "BIAS" )],
69
- ),
70
- "DECONVOLUTION" : (
71
- trt .IDeconvolutionLayer ,
72
- [("kernel" , "KERNEL" ), ("bias" , "BIAS" )],
73
- ),
74
- "CONSTANT" : (trt .IConstantLayer , [("weights" , "CONSTANT" )]),
75
- }
76
65
77
66
output_dtypes = infer_module_output_dtypes (
78
67
module ,
79
68
truncate_double = settings .truncate_double ,
80
69
)
81
70
82
71
# Use Interpreter
83
- weight_map = {}
84
72
interpreter = TRTInterpreter (
85
73
module ,
86
74
inputs ,
@@ -89,24 +77,8 @@ def construct_refit_mapping(
89
77
compilation_settings = settings ,
90
78
)
91
79
interpreter ._construct_trt_network_def ()
92
- net = interpreter .ctx .net
93
- for i in range (net .num_layers ):
94
- layer = net [i ]
95
- layer_type : str = layer .type .name
96
- if layer_type in MODULE_MAP :
97
- # Cast the parent class to child class to access attributes
98
- # For example: ILayer does not have ILayer.kernel/ILayer.bias
99
- # So we cast it to IConvolutionLayer and access the attributes
100
- layer .__class__ = MODULE_MAP [layer_type ][0 ]
101
- for weight_type , weight_name in MODULE_MAP [layer_type ][1 ]:
102
- weight = layer .__getattribute__ (weight_type ).copy ()
103
- weight_dtype = dtype .try_from (weight .dtype ).to (trt .DataType )
104
- weight_map [f"{ layer .name } { weight_name } " ] = (
105
- weight ,
106
- weight_dtype ,
107
- )
108
80
109
- return weight_map
81
+ return interpreter . ctx . mapping
110
82
111
83
112
84
@needs_refit
@@ -117,13 +89,12 @@ def construct_refit_mapping_from_weight_name_map(
117
89
) -> dict [Any , Any ]:
118
90
engine_weight_map = {}
119
91
for engine_weight_name , (sd_weight_name , np_weight_type ) in weight_name_map .items ():
120
- trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
121
- torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
122
-
123
92
if sd_weight_name not in state_dict :
124
93
# If weights is not in sd, we can leave it unchanged
125
94
continue
126
95
else :
96
+ trt_dtype = dtype ._from (np_weight_type ).to (trt .DataType )
97
+ torch_dtype = dtype ._from (np_weight_type ).to (torch .dtype )
127
98
engine_weight_map [engine_weight_name ] = state_dict [sd_weight_name ].to (
128
99
to_torch_device (settings .device )
129
100
)
@@ -152,71 +123,73 @@ def _refit_single_trt_engine_with_gm(
152
123
Refit a TensorRT Engine in place
153
124
"""
154
125
155
- refitted = set ()
156
- torch_device = get_model_device (new_gm )
157
- refitter = trt .Refitter (old_engine , TRT_LOGGER )
158
- weight_list = refitter .get_all_weights ()
159
-
160
- if weight_name_map :
161
- # Get the refitting mapping
162
- trt_wt_location = (
163
- trt .TensorLocation .DEVICE
164
- if torch_device .type == "cuda"
165
- else trt .TensorLocation .HOST
166
- )
126
+ with unset_fake_temporarily ():
127
+ refitted = set ()
128
+ torch_device = get_model_device (new_gm )
129
+ refitter = trt .Refitter (old_engine , TRT_LOGGER )
130
+ weight_list = refitter .get_all_weights ()
131
+
132
+ if weight_name_map :
133
+ # Get the refitting mapping
134
+ trt_wt_location = (
135
+ trt .TensorLocation .DEVICE
136
+ if torch_device .type == "cuda"
137
+ else trt .TensorLocation .HOST
138
+ )
167
139
168
- constant_mapping : dict [str , Any ] = weight_name_map .pop (
169
- "constant_mapping" , {}
170
- ) # type: ignore
171
- mapping = construct_refit_mapping_from_weight_name_map (
172
- weight_name_map , new_gm .state_dict (), settings
173
- )
174
- constant_mapping_with_type = {}
175
-
176
- for constant_name , val in constant_mapping .items ():
177
- np_weight_type = val .dtype
178
- val_tensor = torch .from_numpy (val ).cuda ()
179
- trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
180
- torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
181
- constant_mapping_with_type [constant_name ] = (
182
- val_tensor .clone ().reshape (- 1 ).contiguous ().to (torch_dtype ),
183
- trt_dtype ,
140
+ constant_mapping : dict [str , Any ] = weight_name_map .pop (
141
+ "constant_mapping" , {}
142
+ ) # type: ignore
143
+ mapping = construct_refit_mapping_from_weight_name_map (
144
+ weight_name_map , new_gm .state_dict (), settings
184
145
)
146
+ constant_mapping_with_type = {}
147
+
148
+ for constant_name , val in constant_mapping .items ():
149
+ np_weight_type = val .dtype
150
+ val_tensor = torch .from_numpy (val ).cuda ()
151
+ trt_dtype = dtype ._from (np_weight_type ).to (trt .DataType )
152
+ torch_dtype = dtype ._from (np_weight_type ).to (torch .dtype )
153
+ constant_mapping_with_type [constant_name ] = (
154
+ val_tensor .clone ().reshape (- 1 ).contiguous ().to (torch_dtype ),
155
+ trt_dtype ,
156
+ )
185
157
186
- mapping .update (constant_mapping_with_type )
158
+ mapping .update (constant_mapping_with_type )
187
159
188
- for layer_name in weight_list :
189
- if layer_name not in mapping :
190
- logger .warning (f"{ layer_name } is not found in weight mapping." )
191
- continue
192
- # Use Numpy to create weights
193
- weight , weight_dtype = mapping [layer_name ]
194
- trt_wt_tensor = trt .Weights (
195
- weight_dtype , weight .data_ptr (), torch .numel (weight )
196
- )
197
- refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
198
- assert (
199
- len (refitter .get_missing_weights ()) == 0
200
- ), "Fast refitting failed due to incomplete mapping"
160
+ for layer_name in weight_list :
161
+ if layer_name not in mapping :
162
+ logger .warning (f"{ layer_name } is not found in weight mapping." )
163
+ continue
164
+ # Use Numpy to create weights
165
+ weight , weight_dtype = mapping [layer_name ]
166
+ trt_wt_tensor = trt .Weights (
167
+ weight_dtype , weight .data_ptr (), torch .numel (weight )
168
+ )
169
+ refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
170
+ assert (
171
+ len (refitter .get_missing_weights ()) == 0
172
+ ), "Fast refitting failed due to incomplete mapping"
201
173
202
- else :
203
- mapping = construct_refit_mapping (new_gm , input_list , settings )
204
- trt_wt_location = trt .TensorLocation .HOST
205
- for layer_name in weight_list :
206
- if layer_name not in mapping :
207
- raise AssertionError (f"{ layer_name } is not found in weight mapping" )
208
- # Use Numpy to create weights
209
- weight , datatype = mapping [layer_name ]
210
- trt_wt_tensor = trt .Weights (datatype , weight .ctypes .data , weight .size )
211
- refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
212
- refitted .add (layer_name )
213
-
214
- if len (refitted ) != len (weight_list ):
215
- logger .warning ("Not all weights have been refitted!!!" )
216
-
217
- if not refitter .refit_cuda_engine ():
218
- logger .error ("Error: failed to refit new weights." )
219
- raise AssertionError ("Refitting failed." )
174
+ else :
175
+ mapping = construct_refit_mapping (new_gm , input_list , settings )
176
+ trt_wt_location = trt .TensorLocation .HOST
177
+ for layer_name in weight_list :
178
+ if layer_name not in mapping :
179
+ raise AssertionError (f"{ layer_name } is not found in weight mapping" )
180
+ # Use Numpy to create weights
181
+ weight = mapping [layer_name ]
182
+ trt_dtype = dtype ._from (weight .dtype ).to (trt .DataType )
183
+ trt_wt_tensor = trt .Weights (trt_dtype , weight .ctypes .data , weight .size )
184
+ refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
185
+ refitted .add (layer_name )
186
+
187
+ if len (refitted ) != len (weight_list ):
188
+ logger .warning ("Not all weights have been refitted!!!" )
189
+
190
+ if not refitter .refit_cuda_engine ():
191
+ logger .error ("Error: failed to refit new weights." )
192
+ raise AssertionError ("Refitting failed." )
220
193
221
194
222
195
@needs_refit
0 commit comments