Skip to content

Commit 30f5094

Browse files
peri044narendasan
authored andcommitted
feat: cherry-pick of Selectively enable different frontends (#2693) (#2761)
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> Co-authored-by: Naren Dasan <[email protected]>
1 parent a5079ad commit 30f5094

File tree

18 files changed

+56
-81
lines changed

18 files changed

+56
-81
lines changed

.github/workflows/build-test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ jobs:
264264
pre-script: ${{ matrix.pre-script }}
265265
script: |
266266
export USE_HOST_DEPS=1
267-
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
268267
pushd .
269268
cd tests/py/core
270269
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver

py/torch_tensorrt/_Device.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
else:
1010
from typing_extensions import Self
1111

12+
import tensorrt as trt
1213
import torch
1314
from torch_tensorrt._enums import DeviceType
1415
from torch_tensorrt._features import ENABLED_FEATURES
1516

16-
import tensorrt as trt
17-
1817

1918
class Device(object):
2019
"""

py/torch_tensorrt/_compile.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
import torch
99
import torch.fx
10-
import torch_tensorrt.dynamo
11-
import torch_tensorrt.ts
1210
from torch_tensorrt._enums import dtype
1311
from torch_tensorrt._features import ENABLED_FEATURES
1412
from torch_tensorrt._Input import Input
@@ -343,18 +341,8 @@ def convert_method_to_trt_engine(
343341
"convert_method_to_trt_engine call is not supported for ir=fx"
344342
)
345343
elif target_ir == _IRType.dynamo:
346-
# Prepare torch and torchtrt inputs
347-
from torch_tensorrt.dynamo.utils import prepare_inputs
348-
349-
if not isinstance(inputs, collections.abc.Sequence):
350-
inputs = [inputs]
351-
352-
# Export the module
353-
torchtrt_inputs = prepare_inputs(inputs)
354-
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
355-
356344
return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
357-
exp_program,
345+
module,
358346
inputs=inputs,
359347
enabled_precisions=enabled_precisions_set,
360348
**kwargs,

py/torch_tensorrt/_enums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _from(
107107
return dtype.f16
108108
elif t == trt.float32:
109109
return dtype.f32
110-
elif t == trt.bool:
110+
elif trt.__version__ >= "7.0" and t == trt.bool:
111111
return dtype.b
112112
else:
113113
raise TypeError(

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,7 @@
1010
from torch_tensorrt._Device import Device
1111
from torch_tensorrt._enums import EngineCapability, dtype
1212
from torch_tensorrt._Input import Input
13-
from torch_tensorrt.dynamo import partitioning
14-
from torch_tensorrt.dynamo._defaults import (
15-
DEBUG,
16-
DEVICE,
17-
DISABLE_TF32,
18-
DLA_GLOBAL_DRAM_SIZE,
19-
DLA_LOCAL_DRAM_SIZE,
20-
DLA_SRAM_SIZE,
21-
DRYRUN,
22-
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
23-
ENGINE_CAPABILITY,
24-
HARDWARE_COMPATIBLE,
25-
MAX_AUX_STREAMS,
26-
MIN_BLOCK_SIZE,
27-
NUM_AVG_TIMING_ITERS,
28-
OPTIMIZATION_LEVEL,
29-
PASS_THROUGH_BUILD_FAILURES,
30-
PRECISION,
31-
REFIT,
32-
REQUIRE_FULL_COMPILATION,
33-
SPARSE_WEIGHTS,
34-
TRUNCATE_LONG_AND_DOUBLE,
35-
USE_FAST_PARTITIONER,
36-
USE_PYTHON_RUNTIME,
37-
VERSION_COMPATIBLE,
38-
WORKSPACE_SIZE,
39-
)
13+
from torch_tensorrt.dynamo import _defaults, partitioning
4014
from torch_tensorrt.dynamo._DryRunTracker import (
4115
DryRunTracker,
4216
PerSubgraphData,
@@ -89,15 +63,15 @@ def compile(
8963
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
9064
torch_executed_ops: Optional[Collection[Target]] = None,
9165
torch_executed_modules: Optional[List[str]] = None,
92-
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
93-
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
94-
version_compatible: bool = VERSION_COMPATIBLE,
95-
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
96-
use_python_runtime: bool = USE_PYTHON_RUNTIME,
97-
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
98-
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
99-
dryrun: bool = DRYRUN,
100-
hardware_compatible: bool = HARDWARE_COMPATIBLE,
66+
pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES,
67+
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
68+
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
69+
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
70+
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
71+
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
72+
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
73+
dryrun: bool = _defaults.DRYRUN,
74+
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
10175
**kwargs: Any,
10276
) -> torch.fx.GraphModule:
10377
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
REQUIRE_FULL_COMPILATION = False
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
29+
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8}
2930

3031

3132
def default_device() -> Device:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def run(
313313
)
314314
timing_cache = self._create_timing_cache(builder_config, existing_cache)
315315

316-
engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
316+
engine = self.builder.build_engine(self.ctx.net, builder_config)
317317
assert engine
318318

319319
serialized_cache = (

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def infer_module_output_dtypes(
3838
# such as aten.sum - such outputs can be truncated
3939
output_dtypes = []
4040
for output in module_outputs:
41+
if not isinstance(output, torch.Tensor):
42+
output = torch.tensor(output)
4143
if truncate_long_and_double and output.dtype == dtype.float64:
4244
output_dtypes.append(dtype.float32)
4345
elif truncate_long_and_double and output.dtype == dtype.int64:

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Union
33

44
import numpy as np
5+
import tensorrt as trt
56
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt import _enums

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import tensorrt as trt
34
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt import _enums
@@ -9,8 +10,6 @@
910
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
1011
from torch_tensorrt.fx.types import TRTTensor
1112

12-
import tensorrt as trt
13-
1413

1514
def matrix_multiply(
1615
ctx: ConversionContext,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
_select_rt_device,
1616
multi_gpu_device_check,
1717
)
18-
from torch_tensorrt.logging import TRT_LOGGER
1918

2019
logger = logging.getLogger(__name__)
2120

@@ -65,19 +64,35 @@ def _initialize(self) -> None:
6564
) == (len(self.input_names) + len(self.output_names))
6665

6766
self.input_dtypes = [
68-
dtype._from(self.engine.get_tensor_dtype(input_name))
69-
for input_name in self.input_names
67+
dtype._from(self.engine.get_binding_dtype(idx))
68+
for idx in self.input_binding_indices_in_order
7069
]
7170
self.input_shapes = [
7271
self.engine.get_tensor_shape(input_name) for input_name in self.input_names
7372
]
7473
self.output_dtypes = [
75-
dtype._from(self.engine.get_tensor_dtype(output_name))
76-
for output_name in self.output_names
74+
dtype._from(self.engine.get_binding_dtype(idx))
75+
for idx in self.output_binding_indices_in_order
7776
]
7877
self.output_shapes = [
79-
self.engine.get_tensor_shape(output_name)
80-
for output_name in self.output_names
78+
(
79+
tuple(self.engine.get_binding_shape(idx))
80+
if self.engine.has_implicit_batch_dimension
81+
else tuple()
82+
)
83+
for idx in self.output_binding_indices_in_order
84+
]
85+
self.hidden_output_dtypes = [
86+
dtype._from(self.engine.get_binding_dtype(idx))
87+
for idx in self.hidden_output_binding_indices_in_order
88+
]
89+
self.hidden_output_shapes = [
90+
(
91+
tuple(self.engine.get_binding_shape(idx))
92+
if self.engine.has_implicit_batch_dimension
93+
else tuple()
94+
)
95+
for idx in self.hidden_output_binding_indices_in_order
8196
]
8297

8398
def _check_initialized(self) -> None:
@@ -219,11 +234,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
219234
bindings.append(output.data_ptr())
220235
outputs.append(output)
221236

222-
# Assign tensor address appropriately
223-
for idx in range(self.engine.num_io_tensors):
224-
self.context.set_tensor_address(
225-
self.engine.get_tensor_name(idx), bindings[idx]
226-
)
237+
for i, idx in enumerate(self.hidden_output_binding_indices_in_order):
238+
shape = tuple(self.context.get_binding_shape(idx))
239+
240+
output = torch.empty(
241+
size=shape,
242+
dtype=self.hidden_output_dtypes[i].to(torch.dtype),
243+
device=torch.cuda.current_device(),
244+
)
245+
bindings[idx] = output.data_ptr()
227246

228247
with (
229248
torch.autograd.profiler.record_function(

py/torch_tensorrt/logging.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import logging
22
from typing import Any
33

4-
from torch_tensorrt._features import ENABLED_FEATURES
5-
64
import tensorrt as trt
5+
from torch_tensorrt._features import ENABLED_FEATURES
76

87
logging.captureWarnings(True)
98
_LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]")

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from torch_tensorrt.ts._Input import TorchScriptInput
1414
from torch_tensorrt.ts.logging import Level, log
1515

16-
import tensorrt as trt
17-
1816

1917
def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input:
2018
clone = torch.classes.tensorrt._Input()

py/torch_tensorrt/ts/_enums.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401
2-
31
from tensorrt import DeviceType # noqa: F401
2+
from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401

tests/py/core/test_classes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
import unittest
33
from typing import Dict
44

5+
import tensorrt as trt
56
import torch
67
import torch_tensorrt as torchtrt
78
import torchvision.models as models
89
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule
910

10-
import tensorrt as trt
11-
1211

1312
class TestDevice(unittest.TestCase):
1413
def test_from_string_constructor(self):

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import unittest
2+
13
import torch
24
import torch_tensorrt
35
from torch.testing._internal.common_utils import TestCase, run_tests

tests/py/dynamo/runtime/test_hw_compat.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,6 @@ def forward(self, x):
7474
not torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8,
7575
"HW Compatibility is not supported on cards older than Ampere",
7676
)
77-
@unittest.skip(
78-
"Skipping this test because the hw_compat.ts can't be generated using torch nightly"
79-
)
8077
def test_hw_compat_3080_build(self):
8178
inputs = [torch.randn(1, 3, 224, 224).cuda()]
8279

tests/py/ts/ptq/test_ptq_trt_calibrator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import unittest
33

4+
import tensorrt as trt
45
import torch
56
import torch.nn as nn
67
import torch_tensorrt as torchtrt
@@ -9,8 +10,6 @@
910
from torch.nn import functional as F
1011
from torch_tensorrt.ts.logging import *
1112

12-
import tensorrt as trt
13-
1413

1514
def find_repo_root(max_depth=10):
1615
dir_path = os.path.dirname(os.path.realpath(__file__))

0 commit comments

Comments
 (0)