Skip to content

Commit 2518db5

Browse files
committed
fix: distingush engines based on compilation settings in addition to graph structure
Signed-off-by: Naren Dasan <[email protected]>
1 parent f84be56 commit 2518db5

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,34 @@ def __str__(self) -> str:
220220
def __repr__(self) -> str:
221221
return self.__str__()
222222

223+
@staticmethod
224+
def equivalent_spec(a: Input, b: Input) -> bool:
225+
if a.shape_mode != b.shape_mode:
226+
return False
227+
228+
if a.shape_mode == Input._ShapeMode.DYNAMIC:
229+
assert isinstance(a.shape, dict)
230+
assert isinstance(b.shape, dict)
231+
checks = [
232+
a.shape["min_shape"] == b.shape["min_shape"],
233+
a.shape["opt_shape"] == b.shape["opt_shape"],
234+
a.shape["max_shape"] == b.shape["max_shape"],
235+
a.dtype == b.dtype,
236+
a.format == b.format,
237+
a.low_tensor_domain_incl == b.low_tensor_domain_incl,
238+
a.high_tensor_domain_excl == b.high_tensor_domain_excl,
239+
]
240+
return all(checks)
241+
else:
242+
checks = [
243+
a.shape == b.shape,
244+
a.dtype == b.dtype,
245+
a.format == b.format,
246+
a.low_tensor_domain_incl == b.low_tensor_domain_incl,
247+
a.high_tensor_domain_excl == b.high_tensor_domain_excl,
248+
]
249+
return all(checks)
250+
223251
@staticmethod
224252
def _supported_input_size_type(input_size: Any) -> bool:
225253
if isinstance(input_size, torch.Size):

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def run(
545545
serialized_engine,
546546
self._input_names,
547547
self._output_names,
548-
engine_input_specs,
548+
cached_engine_input_specs,
549549
engine_compilation_settings,
550550
self.weight_name_map,
551551
) = cached_data
@@ -559,6 +559,16 @@ def run(
559559
setting_compatiblity
560560
), f"Attempted to refit a prebuilt engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"
561561

562+
for i, e in enumerate(
563+
[
564+
Input.equivalent_spec(c, i)
565+
for c, i in zip(cached_engine_input_specs, self.input_specs)
566+
]
567+
):
568+
assert (
569+
e
570+
), f"Found that cached engine was built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"
571+
562572
_LOGGER.info(
563573
"Found the cached engine that corresponds to this graph. It is directly loaded."
564574
)

0 commit comments

Comments
 (0)