Skip to content

Commit afd725e

Browse files
committed
feat: Support torch_tensorrt.Input with FX backend
Signed-off-by: Dheeraj Peri <[email protected]> chore: Replace InputTensorSpec with Input Signed-off-by: Dheeraj Peri <[email protected]> feat: Allow torchtrt.Input support for FX backend Signed-off-by: Dheeraj Peri <[email protected]> refactor: Implement conversions from Input -> Pyt tensors, add Input utilities etc. Signed-off-by: Dheeraj Peri <[email protected]> chore: Use InputTensorSpec internally Signed-off-by: Dheeraj Peri <[email protected]> chore: Linter fixes Signed-off-by: Dheeraj Peri <[email protected]> chore: add ts_input.py file Signed-off-by: Dheeraj Peri <[email protected]> chore: Linter fixes Signed-off-by: Dheeraj Peri <[email protected]> chore: minor fixes Signed-off-by: Dheeraj Peri <[email protected]> chore: revert FX changes Signed-off-by: Dheeraj Peri <[email protected]> chore: Address Torchscript test case failures Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 83923ee commit afd725e

17 files changed

+306
-153
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,10 @@ TRTEngine::TRTEngine(
148148
}
149149

150150
TRTEngine::~TRTEngine() {
151-
rt.reset();
152151
trt_engine_profiler.reset();
153152
exec_ctx.reset();
154153
cuda_engine.reset();
154+
rt.reset();
155155
}
156156

157157
void TRTEngine::disable_profiling() {

examples/fx/lower_example.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torchvision
7+
import torch_tensorrt
78
from torch_tensorrt.fx import compile
89
from torch_tensorrt.fx.utils import LowerPrecision
910

@@ -98,13 +99,17 @@ def benchmark(
9899

99100
model = model.cuda().eval()
100101
inputs = [x.cuda() for x in inputs]
101-
102+
# inputs = [torch_tensorrt.Input(shape=(128, 3, 224, 224), dtype=torch.float32)]
103+
# inputs = [torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
104+
# opt_shape=[8, 3, 224, 224],
105+
# max_shape=[32, 3, 224, 224],
106+
# dtype=torch.float32)]
102107
# benchmark base configuration
103108
conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)
104109

105110
configurations = [
106111
# Baseline
107-
replace(conf, name="CUDA Eager", trt=False),
112+
# replace(conf, name="CUDA Eager", trt=False),
108113
# FP32
109114
replace(
110115
conf,
@@ -115,14 +120,14 @@ def benchmark(
115120
accuracy_rtol=1e-3,
116121
),
117122
# FP16
118-
replace(
119-
conf,
120-
name="TRT FP16 Eager",
121-
trt=True,
122-
jit=False,
123-
fp16=True,
124-
accuracy_rtol=1e-2,
125-
),
123+
# replace(
124+
# conf,
125+
# name="TRT FP16 Eager",
126+
# trt=True,
127+
# jit=False,
128+
# fp16=True,
129+
# accuracy_rtol=1e-2,
130+
# ),
126131
]
127132

128133
results = [
@@ -189,8 +194,12 @@ def run_configuration_benchmark(
189194
max_batch_size=conf.batch_size,
190195
lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,
191196
explicit_batch_dimension=True,
197+
dynamic_batch=False,
198+
)
199+
random_inputs = [torch.randn((128, 3, 224, 224), dtype=torch.float32).cuda()]
200+
time = benchmark_torch_function(
201+
conf.batch_iter, lambda: lowered_module(*random_inputs)
192202
)
193-
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
194203
else:
195204
print("Lowering with JIT is not available!", "red")
196205

examples/fx/lower_example_aten.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torchvision
7+
import torch_tensorrt
78
from torch_tensorrt.fx import compile
89
from torch_tensorrt.fx.utils import LowerPrecision
910

@@ -97,21 +98,25 @@ def benchmark(
9798
"""
9899

99100
model = model.cuda().eval()
100-
inputs = [x.cuda() for x in inputs]
101-
101+
# inputs = [x.cuda() for x in inputs]
102+
inputs = [torch_tensorrt.Input(shape=(128, 3, 224, 224), dtype=torch.float32)]
103+
# inputs = [torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
104+
# opt_shape=[8, 3, 224, 224],
105+
# max_shape=[32, 3, 224, 224],
106+
# dtype=torch.float32)]
102107
# benchmark base configuration
103108
conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)
104109

105110
configurations = [
106111
# Baseline
107-
replace(conf, name="CUDA Eager", trt=False),
112+
# replace(conf, name="CUDA Eager", trt=False),
108113
# FP16
109114
replace(
110115
conf,
111-
name="TRT FP16 Eager",
116+
name="TRT FP32 Eager",
112117
trt=True,
113118
jit=False,
114-
fp16=True,
119+
fp16=False,
115120
accuracy_rtol=1e-2,
116121
),
117122
]
@@ -182,7 +187,10 @@ def run_configuration_benchmark(
182187
explicit_batch_dimension=True,
183188
is_aten=True,
184189
)
185-
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
190+
random_inputs = [torch.randn((128, 3, 224, 224), dtype=torch.float32).cuda()]
191+
time = benchmark_torch_function(
192+
conf.batch_iter, lambda: lowered_module(*random_inputs)
193+
)
186194
else:
187195
print("Lowering with JIT is not available!", "red")
188196

py/setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def run(self):
350350
if FX_ONLY:
351351
ext_modules = None
352352
packages = [
353+
"torch_tensorrt",
353354
"torch_tensorrt.fx",
354355
"torch_tensorrt.fx.converters",
355356
"torch_tensorrt.fx.passes",
@@ -358,6 +359,7 @@ def run(self):
358359
"torch_tensorrt.fx.tracer.dispatch_tracer",
359360
]
360361
package_dir = {
362+
"torch_tensorrt": "torch_tensorrt/",
361363
"torch_tensorrt.fx": "torch_tensorrt/fx",
362364
"torch_tensorrt.fx.converters": "torch_tensorrt/fx/converters",
363365
"torch_tensorrt.fx.passes": "torch_tensorrt/fx/passes",
@@ -437,7 +439,9 @@ def run(self):
437439
"bin/*",
438440
"BUILD",
439441
"WORKSPACE",
440-
],
442+
]
443+
if not FX_ONLY
444+
else ["_Input.py"]
441445
},
442446
exclude_package_data={
443447
"": ["*.cpp"],

py/torch_tensorrt/_Input.py

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55

66
from torch_tensorrt import _enums
7-
from torch_tensorrt import _C
87

98

109
class Input(object):
@@ -41,6 +40,7 @@ class _ShapeMode(Enum):
4140
DOMAIN_OFFSET = 2.0
4241
low_tensor_domain_incl = 0.0
4342
high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET
43+
torch_dtype = None
4444

4545
def __init__(self, *args, **kwargs):
4646
"""__init__ Method for torch_tensorrt.Input
@@ -138,6 +138,9 @@ def __init__(self, *args, **kwargs):
138138
)
139139

140140
if "dtype" in kwargs:
141+
if isinstance(kwargs["dtype"], torch.dtype):
142+
self.torch_dtype = kwargs["dtype"]
143+
141144
self.dtype = Input._parse_dtype(kwargs["dtype"])
142145
self._explicit_set_dtype = True
143146

@@ -173,59 +176,6 @@ def __str__(self) -> str:
173176
else:
174177
raise RuntimeError("Unknown input shape mode")
175178

176-
def _to_internal(self) -> _C.Input:
177-
internal_in = _C.Input()
178-
if self.shape_mode == Input._ShapeMode.DYNAMIC:
179-
if not Input._supported_input_size_type(self.shape["min_shape"]):
180-
raise TypeError(
181-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
182-
+ str(type(self.shape["min_shape"]))
183-
+ " for min_shape"
184-
)
185-
else:
186-
internal_in.min = self.shape["min_shape"]
187-
188-
if not Input._supported_input_size_type(self.shape["opt_shape"]):
189-
raise TypeError(
190-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
191-
+ str(type(self.shape["opt_shape"]))
192-
+ " for opt_shape"
193-
)
194-
else:
195-
internal_in.opt = self.shape["opt_shape"]
196-
197-
if not Input._supported_input_size_type(self.shape["max_shape"]):
198-
raise TypeError(
199-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
200-
+ str(type(self.shape["max_shape"]))
201-
+ " for max_shape"
202-
)
203-
else:
204-
internal_in.max = self.shape["max_shape"]
205-
internal_in.input_is_dynamic = True
206-
else:
207-
if not Input._supported_input_size_type(self.shape):
208-
raise TypeError(
209-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
210-
+ str(type(self.shape))
211-
+ " for shape"
212-
)
213-
else:
214-
internal_in.opt = self.shape
215-
internal_in.input_is_dynamic = False
216-
217-
if self.dtype != _enums.dtype.unknown:
218-
self._explicit_set_dtype = True
219-
else:
220-
self._explicit_set_dtype = False
221-
222-
internal_in.dtype = Input._parse_dtype(self.dtype)
223-
internal_in._explicit_set_dtype = self._explicit_set_dtype
224-
internal_in.format = Input._parse_format(self.format)
225-
226-
internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain)
227-
return internal_in
228-
229179
@staticmethod
230180
def _supported_input_size_type(input_size: Any) -> bool:
231181
if isinstance(input_size, torch.Size):
@@ -304,6 +254,7 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple:
304254
Input.low_tensor_domain_incl,
305255
Input.high_tensor_domain_excl,
306256
)
257+
307258
elif len(domain) == 2:
308259
domain_lo, domain_hi = domain
309260

@@ -416,8 +367,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor
416367
)
417368

418369
if self.shape_mode == Input._ShapeMode.STATIC:
419-
return torch.randn(self.shape).to(dtype=self.dtype)
370+
return torch.randn(self.shape).to(
371+
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
372+
)
420373
else:
421374
return torch.randn(self.shape[optimization_profile_field]).to(
422-
dtype=self.dtype
375+
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
423376
)

py/torch_tensorrt/fx/input_tensor_spec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .types import Shape, ShapeRange
66
from .utils import get_dynamic_dims
7+
from torch_tensorrt._Input import Input
78

89

910
def generate_input_specs(inputs, lower_setting, additional_inputs=None):

py/torch_tensorrt/fx/lower.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.fx as fx
99
import torch.nn as nn
10+
import torch_tensorrt
1011
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
1112
from torch.fx.passes.splitter_base import SplitResult
1213

@@ -22,8 +23,9 @@
2223
from .utils import LowerPrecision
2324

2425
logger = logging.getLogger(__name__)
26+
from torch_tensorrt._Input import Input
2527

26-
Input = Sequence[Any]
28+
# Input = Sequence[Any]
2729

2830

2931
def compile(
@@ -302,6 +304,7 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
302304
conversion_fn = fp16_conversion_fn
303305

304306
inputs = tuple(conversion_fn(x) for x in inputs)
307+
305308
if lower_setting.is_aten:
306309
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(
307310
inputs, additional_inputs

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn
55
from torch.fx.passes.pass_manager import PassManager
66

7-
from .input_tensor_spec import InputTensorSpec
7+
from torch_tensorrt._Input import Input
88
from .passes.lower_basic_pass import fuse_permute_linear, fuse_permute_matmul
99
from .utils import LowerPrecision
1010

@@ -76,7 +76,7 @@ class LowerSetting(LowerSettingBasic):
7676
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
7777
"""
7878

79-
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
79+
input_specs: List[Input] = dc.field(default_factory=list)
8080
explicit_batch_dimension: bool = True
8181
explicit_precision: bool = False
8282
max_workspace_size: int = 1 << 30

0 commit comments

Comments
 (0)