|
11 | 11 | import tensorrt as trt
|
12 | 12 | import torch
|
13 | 13 | from torch._subclasses.fake_tensor import FakeTensor
|
| 14 | +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily |
14 | 15 | from torch_tensorrt._Device import Device
|
15 | 16 | from torch_tensorrt._enums import dtype
|
16 | 17 | from torch_tensorrt._features import ENABLED_FEATURES
|
@@ -243,48 +244,54 @@ def prepare_inputs(
|
243 | 244 | inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
|
244 | 245 | disable_memory_format_check: bool = False,
|
245 | 246 | ) -> Any:
|
246 |
| - if inputs is None: |
247 |
| - return None |
248 |
| - |
249 |
| - elif isinstance(inputs, Input): |
250 |
| - return inputs |
| 247 | + """ |
| 248 | + We take a nested group of torch.Tensors or scalars and convert them into torchtrt.Input's |
| 249 | + """ |
| 250 | + # Any tensors created inside this call will be FakeTensors if it's inside a torch.compile session |
| 251 | + # So, we disable fake mode temporarily. |
| 252 | + with unset_fake_temporarily(): |
| 253 | + if inputs is None: |
| 254 | + return None |
251 | 255 |
|
252 |
| - elif isinstance(inputs, (torch.Tensor, int, float, bool)): |
253 |
| - return Input.from_tensor( |
254 |
| - torch.tensor(inputs), |
255 |
| - disable_memory_format_check=disable_memory_format_check, |
256 |
| - ) |
| 256 | + elif isinstance(inputs, Input): |
| 257 | + return inputs |
257 | 258 |
|
258 |
| - elif isinstance(inputs, (list, tuple)): |
259 |
| - torchtrt_input_list = [] |
260 |
| - for input_obj in inputs: |
261 |
| - torchtrt_input = prepare_inputs( |
262 |
| - input_obj, disable_memory_format_check=disable_memory_format_check |
| 259 | + elif isinstance(inputs, (torch.Tensor, int, float, bool)): |
| 260 | + return Input.from_tensor( |
| 261 | + torch.tensor(inputs), |
| 262 | + disable_memory_format_check=disable_memory_format_check, |
263 | 263 | )
|
264 |
| - torchtrt_input_list.append(torchtrt_input) |
265 |
| - |
266 |
| - return ( |
267 |
| - torchtrt_input_list |
268 |
| - if isinstance(inputs, list) |
269 |
| - else tuple(torchtrt_input_list) |
270 |
| - ) |
271 | 264 |
|
272 |
| - elif isinstance(inputs, dict): |
273 |
| - torchtrt_inputs_dict: Dict[Any, Any] = dict() |
| 265 | + elif isinstance(inputs, (list, tuple)): |
| 266 | + torchtrt_input_list = [] |
| 267 | + for input_obj in inputs: |
| 268 | + torchtrt_input = prepare_inputs( |
| 269 | + input_obj, disable_memory_format_check=disable_memory_format_check |
| 270 | + ) |
| 271 | + torchtrt_input_list.append(torchtrt_input) |
274 | 272 |
|
275 |
| - for key, input_obj in inputs.items(): |
276 |
| - torchtrt_input = prepare_inputs( |
277 |
| - input_obj, disable_memory_format_check=disable_memory_format_check |
| 273 | + return ( |
| 274 | + torchtrt_input_list |
| 275 | + if isinstance(inputs, list) |
| 276 | + else tuple(torchtrt_input_list) |
278 | 277 | )
|
279 |
| - torchtrt_inputs_dict[key] = torchtrt_input |
280 | 278 |
|
281 |
| - return torchtrt_inputs_dict |
| 279 | + elif isinstance(inputs, dict): |
| 280 | + torchtrt_inputs_dict: Dict[Any, Any] = dict() |
282 | 281 |
|
283 |
| - else: |
284 |
| - raise ValueError( |
285 |
| - f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " |
286 |
| - + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" |
287 |
| - ) |
| 282 | + for key, input_obj in inputs.items(): |
| 283 | + torchtrt_input = prepare_inputs( |
| 284 | + input_obj, disable_memory_format_check=disable_memory_format_check |
| 285 | + ) |
| 286 | + torchtrt_inputs_dict[key] = torchtrt_input |
| 287 | + |
| 288 | + return torchtrt_inputs_dict |
| 289 | + |
| 290 | + else: |
| 291 | + raise ValueError( |
| 292 | + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " |
| 293 | + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" |
| 294 | + ) |
288 | 295 |
|
289 | 296 |
|
290 | 297 | def parse_complex_tensor_structs(
|
|
0 commit comments