@@ -107,6 +107,36 @@ def interpret_module_to_result(
107
107
compilation_settings = settings ,
108
108
)
109
109
interpreter_result = interpreter .run ()
110
+
111
+ if settings .make_refitable :
112
+ # Run fast refit even if it's the first compilation.
113
+ # This is to ensure that the weight name map is correct for future refits.
114
+ # If the fast refit fails, remove the weight name map.
115
+ from torch_tensorrt .dynamo ._refit import _refit_single_trt_engine_with_gm
116
+ from torch_tensorrt .logging import TRT_LOGGER
117
+
118
+ runtime = trt .Runtime (TRT_LOGGER )
119
+ refit_test_engine = runtime .deserialize_cuda_engine (
120
+ interpreter_result .serialized_engine
121
+ )
122
+ try :
123
+ _refit_single_trt_engine_with_gm (
124
+ new_gm = module ,
125
+ old_engine = refit_test_engine ,
126
+ input_list = inputs ,
127
+ settings = settings ,
128
+ weight_name_map = interpreter_result .weight_name_map ,
129
+ )
130
+ except AssertionError :
131
+ # TRTInterpreterResult is a tuple, so we need to create a new one
132
+ interpreter_result = TRTInterpreterResult (
133
+ interpreter_result .serialized_engine ,
134
+ interpreter_result .input_names ,
135
+ interpreter_result .output_names ,
136
+ None ,
137
+ )
138
+ logger .warning ("Fast refit test failed. Removing the weight map caching." )
139
+
110
140
return interpreter_result
111
141
112
142
@@ -126,28 +156,6 @@ def convert_module(
126
156
PythonTorchTensorRTModule or TorchTensorRTModule
127
157
"""
128
158
interpreter_result = interpret_module_to_result (module , inputs , settings )
129
- # Test fast refit:
130
- from torch_tensorrt .dynamo ._refit import _refit_single_trt_engine_with_gm
131
- from torch_tensorrt .logging import TRT_LOGGER
132
-
133
- runtime = trt .Runtime (TRT_LOGGER )
134
- refit_test_engine = runtime .deserialize_cuda_engine (
135
- interpreter_result .serialized_engine
136
- )
137
- weight_name_map : Any = None
138
- # Do the test refit with cached map if make_refitable is enabled
139
- if settings .make_refitable :
140
- weight_name_map = interpreter_result .weight_name_map
141
- try :
142
- _refit_single_trt_engine_with_gm (
143
- new_gm = module ,
144
- old_engine = refit_test_engine ,
145
- input_list = inputs ,
146
- settings = settings ,
147
- weight_name_map = interpreter_result .weight_name_map ,
148
- )
149
- except AssertionError :
150
- logger .warning ("Fast refit test failed. Removing the weight map caching." )
151
159
152
160
rt_cls = PythonTorchTensorRTModule
153
161
@@ -171,5 +179,5 @@ def convert_module(
171
179
output_binding_names = list (interpreter_result .output_names ),
172
180
name = name ,
173
181
settings = settings ,
174
- weight_name_map = weight_name_map ,
182
+ weight_name_map = interpreter_result . weight_name_map ,
175
183
)
0 commit comments