Skip to content

Commit fc4e764

Browse files
committed
update refit usage
1 parent e48dc98 commit fc4e764

File tree

4 files changed

+24
-30
lines changed

4 files changed

+24
-30
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ def compile(
236236
logger.debug("Lowered Input graph: " + str(gm.graph))
237237

238238
if cache_built_engines or reuse_cached_engines:
239+
assert (
240+
make_refitable
241+
), "Engine caching requires make_refitable to be set to True"
239242
if custom_engine_cache is None:
240243
custom_engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size)
241244

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,24 @@ def run(
544544
_LOGGER.info(
545545
"Found the cached engine that corresponds to this graph. It is directly loaded."
546546
)
547+
548+
from torch_tensorrt.dynamo._refit import (
549+
_refit_single_trt_engine_with_gm,
550+
)
551+
552+
runtime = trt.Runtime(TRT_LOGGER)
553+
engine = runtime.deserialize_cuda_engine(serialized_engine)
554+
555+
_refit_single_trt_engine_with_gm(
556+
new_gm=self.module,
557+
old_engine=engine,
558+
input_list=self.input_specs,
559+
settings=self.compilation_settings,
560+
weight_name_map=weight_name_map,
561+
)
562+
563+
serialized_engine = bytes(engine.serialize())
564+
547565
return TRTInterpreterResult(
548566
serialized_engine,
549567
self._input_names,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -107,36 +107,6 @@ def interpret_module_to_result(
107107
compilation_settings=settings,
108108
)
109109
interpreter_result = interpreter.run()
110-
111-
if settings.make_refitable:
112-
# Run fast refit even if it's the first compilation.
113-
# This is to ensure that the weight name map is correct for future refits.
114-
# If the fast refit fails, remove the weight name map.
115-
from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm
116-
from torch_tensorrt.logging import TRT_LOGGER
117-
118-
runtime = trt.Runtime(TRT_LOGGER)
119-
refit_test_engine = runtime.deserialize_cuda_engine(
120-
interpreter_result.serialized_engine
121-
)
122-
try:
123-
_refit_single_trt_engine_with_gm(
124-
new_gm=module,
125-
old_engine=refit_test_engine,
126-
input_list=inputs,
127-
settings=settings,
128-
weight_name_map=interpreter_result.weight_name_map,
129-
)
130-
except AssertionError:
131-
# TRTInterpreterResult is a tuple, so we need to create a new one
132-
interpreter_result = TRTInterpreterResult(
133-
interpreter_result.serialized_engine,
134-
interpreter_result.input_names,
135-
interpreter_result.output_names,
136-
None,
137-
)
138-
logger.warning("Fast refit test failed. Removing the weight map caching.")
139-
140110
return interpreter_result
141111

142112

py/torch_tensorrt/dynamo/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,9 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
361361
# If cache_built_engines and reuse_cached_engines are True but custom_engine_cache is not provided,
362362
# then create a default disk engine cache
363363
if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"):
364+
assert kwargs.get(
365+
"make_refitable"
366+
), "Engine caching requires make_refitable to be set to True"
364367
if settings.custom_engine_cache is None:
365368
from torch_tensorrt.dynamo._engine_caching import DiskEngineCache
366369

0 commit comments

Comments
 (0)