Skip to content

Commit 73a7aac

Browse files
peri044apbose
authored andcommitted
fix: Fix CI issues due to unintended fake tensor creation in torch.compile tests
1 parent b2637ea commit 73a7aac

File tree

2 files changed

+43
-35
lines changed

2 files changed

+43
-35
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tensorrt as trt
1212
import torch
1313
from torch._subclasses.fake_tensor import FakeTensor
14+
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1415
from torch_tensorrt._Device import Device
1516
from torch_tensorrt._enums import dtype
1617
from torch_tensorrt._features import ENABLED_FEATURES
@@ -243,48 +244,54 @@ def prepare_inputs(
243244
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
244245
disable_memory_format_check: bool = False,
245246
) -> Any:
246-
if inputs is None:
247-
return None
248-
249-
elif isinstance(inputs, Input):
250-
return inputs
247+
"""
248+
We take a nested group of torch.Tensors or scalars and convert them into torchtrt.Input's
249+
"""
250+
# Any tensors created inside this call will be FakeTensors if it's inside a torch.compile session
251+
# So, we disable fake mode temporarily.
252+
with unset_fake_temporarily():
253+
if inputs is None:
254+
return None
251255

252-
elif isinstance(inputs, (torch.Tensor, int, float, bool)):
253-
return Input.from_tensor(
254-
torch.tensor(inputs),
255-
disable_memory_format_check=disable_memory_format_check,
256-
)
256+
elif isinstance(inputs, Input):
257+
return inputs
257258

258-
elif isinstance(inputs, (list, tuple)):
259-
torchtrt_input_list = []
260-
for input_obj in inputs:
261-
torchtrt_input = prepare_inputs(
262-
input_obj, disable_memory_format_check=disable_memory_format_check
259+
elif isinstance(inputs, (torch.Tensor, int, float, bool)):
260+
return Input.from_tensor(
261+
torch.tensor(inputs),
262+
disable_memory_format_check=disable_memory_format_check,
263263
)
264-
torchtrt_input_list.append(torchtrt_input)
265-
266-
return (
267-
torchtrt_input_list
268-
if isinstance(inputs, list)
269-
else tuple(torchtrt_input_list)
270-
)
271264

272-
elif isinstance(inputs, dict):
273-
torchtrt_inputs_dict: Dict[Any, Any] = dict()
265+
elif isinstance(inputs, (list, tuple)):
266+
torchtrt_input_list = []
267+
for input_obj in inputs:
268+
torchtrt_input = prepare_inputs(
269+
input_obj, disable_memory_format_check=disable_memory_format_check
270+
)
271+
torchtrt_input_list.append(torchtrt_input)
274272

275-
for key, input_obj in inputs.items():
276-
torchtrt_input = prepare_inputs(
277-
input_obj, disable_memory_format_check=disable_memory_format_check
273+
return (
274+
torchtrt_input_list
275+
if isinstance(inputs, list)
276+
else tuple(torchtrt_input_list)
278277
)
279-
torchtrt_inputs_dict[key] = torchtrt_input
280278

281-
return torchtrt_inputs_dict
279+
elif isinstance(inputs, dict):
280+
torchtrt_inputs_dict: Dict[Any, Any] = dict()
282281

283-
else:
284-
raise ValueError(
285-
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
286-
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
287-
)
282+
for key, input_obj in inputs.items():
283+
torchtrt_input = prepare_inputs(
284+
input_obj, disable_memory_format_check=disable_memory_format_check
285+
)
286+
torchtrt_inputs_dict[key] = torchtrt_input
287+
288+
return torchtrt_inputs_dict
289+
290+
else:
291+
raise ValueError(
292+
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
293+
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
294+
)
288295

289296

290297
def parse_complex_tensor_structs(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
import torch_tensorrt
33
from parameterized import parameterized
4-
from testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
54
from torch.testing._internal.common_utils import TestCase, run_tests
65

6+
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
7+
78

89
class TestLowering(TestCase):
910
def test_lowering_inplace_op(self):

0 commit comments

Comments
 (0)