Skip to content

Commit a86260e

Browse files
committed
move refit into interpret_module_to_result
1 parent 132547f commit a86260e

File tree

3 files changed

+34
-27
lines changed

3 files changed

+34
-27
lines changed

py/torch_tensorrt/dynamo/_engine_caching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ def pack(
4848
serialized_engine: bytes,
4949
input_names: List[str],
5050
output_names: List[str],
51-
weight_name_map: Optional[Dict[str, Any]],
51+
weight_name_map: Optional[Dict[Any, Any]],
5252
) -> bytes:
5353
"""Pack serialized engine, input names, output names, and weight map into a single blob
5454
5555
Args:
5656
serialized_engine (bytes): serialized TRT engine
5757
input_names (List[str]): input names of TRT engine
5858
output_names (List[str]): output names of TRT engine
59-
weight_name_map (Optional[Dict[str, Any]]): weight name map for refitting
59+
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
6060
6161
Returns:
6262
bytes: packed blob
@@ -73,7 +73,7 @@ def pack(
7373
@staticmethod
7474
def unpack(
7575
packed_obj: bytes,
76-
) -> Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]:
76+
) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]:
7777
"""Unpack packed blob into serialized engine, input names, output names, and weight map
7878
7979
Args:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,6 @@ def run(
501501
_LOGGER.info(
502502
"Found the cached engine that corresponds to this graph. It is directly loaded."
503503
)
504-
# TODO: refit the engine here or outside (within convert_module)?
505504
return TRTInterpreterResult(
506505
serialized_engine,
507506
self._input_names,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,36 @@ def interpret_module_to_result(
107107
compilation_settings=settings,
108108
)
109109
interpreter_result = interpreter.run()
110+
111+
if settings.make_refitable:
112+
# Run fast refit even if it's the first compilation.
113+
# This is to ensure that the weight name map is correct for future refits.
114+
# If the fast refit fails, remove the weight name map.
115+
from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm
116+
from torch_tensorrt.logging import TRT_LOGGER
117+
118+
runtime = trt.Runtime(TRT_LOGGER)
119+
refit_test_engine = runtime.deserialize_cuda_engine(
120+
interpreter_result.serialized_engine
121+
)
122+
try:
123+
_refit_single_trt_engine_with_gm(
124+
new_gm=module,
125+
old_engine=refit_test_engine,
126+
input_list=inputs,
127+
settings=settings,
128+
weight_name_map=interpreter_result.weight_name_map,
129+
)
130+
except AssertionError:
131+
# TRTInterpreterResult is a tuple, so we need to create a new one
132+
interpreter_result = TRTInterpreterResult(
133+
interpreter_result.serialized_engine,
134+
interpreter_result.input_names,
135+
interpreter_result.output_names,
136+
None,
137+
)
138+
logger.warning("Fast refit test failed. Removing the weight map caching.")
139+
110140
return interpreter_result
111141

112142

@@ -126,28 +156,6 @@ def convert_module(
126156
PythonTorchTensorRTModule or TorchTensorRTModule
127157
"""
128158
interpreter_result = interpret_module_to_result(module, inputs, settings)
129-
# Test fast refit:
130-
from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm
131-
from torch_tensorrt.logging import TRT_LOGGER
132-
133-
runtime = trt.Runtime(TRT_LOGGER)
134-
refit_test_engine = runtime.deserialize_cuda_engine(
135-
interpreter_result.serialized_engine
136-
)
137-
weight_name_map: Any = None
138-
# Do the test refit with cached map if make_refitable is enabled
139-
if settings.make_refitable:
140-
weight_name_map = interpreter_result.weight_name_map
141-
try:
142-
_refit_single_trt_engine_with_gm(
143-
new_gm=module,
144-
old_engine=refit_test_engine,
145-
input_list=inputs,
146-
settings=settings,
147-
weight_name_map=interpreter_result.weight_name_map,
148-
)
149-
except AssertionError:
150-
logger.warning("Fast refit test failed. Removing the weight map caching.")
151159

152160
rt_cls = PythonTorchTensorRTModule
153161

@@ -171,5 +179,5 @@ def convert_module(
171179
output_binding_names=list(interpreter_result.output_names),
172180
name=name,
173181
settings=settings,
174-
weight_name_map=weight_name_map,
182+
weight_name_map=interpreter_result.weight_name_map,
175183
)

0 commit comments

Comments
 (0)