Skip to content

Commit e5c2026

Browse files
committed
feat: Support torch_tensorrt.Input with FX backend
Signed-off-by: Dheeraj Peri <[email protected]> chore: Replace InputTensorSpec with Input Signed-off-by: Dheeraj Peri <[email protected]> feat: Allow torchtrt.Input support for FX backend Signed-off-by: Dheeraj Peri <[email protected]> refactor: Implement conversions from Input -> Pyt tensors, add Input utilities etc. Signed-off-by: Dheeraj Peri <[email protected]> chore: Use InputTensorSpec internally Signed-off-by: Dheeraj Peri <[email protected]> chore: Linter fixes Signed-off-by: Dheeraj Peri <[email protected]> chore: add ts_input.py file Signed-off-by: Dheeraj Peri <[email protected]> chore: Linter fixes Signed-off-by: Dheeraj Peri <[email protected]> chore: minor fixes Signed-off-by: Dheeraj Peri <[email protected]> chore: revert FX changes Signed-off-by: Dheeraj Peri <[email protected]> chore: Address Torchscript test case failures Signed-off-by: Dheeraj Peri <[email protected]> chore: remove device placement of input tensors Signed-off-by: Dheeraj Peri <[email protected]> chore: Linter fixes Signed-off-by: Dheeraj Peri <[email protected]> chore: refactor code Signed-off-by: Dheeraj Peri <[email protected]> chore: Remove max_batch_size and replace generate_input_specs calls Signed-off-by: Dheeraj Peri <[email protected]> chore: linter fixes Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 83923ee commit e5c2026

22 files changed

+502
-255
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,10 @@ TRTEngine::TRTEngine(
148148
}
149149

150150
TRTEngine::~TRTEngine() {
151-
rt.reset();
152151
trt_engine_profiler.reset();
153152
exec_ctx.reset();
154153
cuda_engine.reset();
154+
rt.reset();
155155
}
156156

157157
void TRTEngine::disable_profiling() {

examples/fx/lower_example.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torchvision
7+
import torch_tensorrt
78
from torch_tensorrt.fx import compile
89
from torch_tensorrt.fx.utils import LowerPrecision
910

@@ -98,13 +99,17 @@ def benchmark(
9899

99100
model = model.cuda().eval()
100101
inputs = [x.cuda() for x in inputs]
101-
102+
# inputs = [torch_tensorrt.Input(shape=(128, 3, 224, 224), dtype=torch.float32)]
103+
# inputs = [torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
104+
# opt_shape=[8, 3, 224, 224],
105+
# max_shape=[32, 3, 224, 224],
106+
# dtype=torch.float32)]
102107
# benchmark base configuration
103108
conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)
104109

105110
configurations = [
106111
# Baseline
107-
replace(conf, name="CUDA Eager", trt=False),
112+
# replace(conf, name="CUDA Eager", trt=False),
108113
# FP32
109114
replace(
110115
conf,
@@ -115,14 +120,14 @@ def benchmark(
115120
accuracy_rtol=1e-3,
116121
),
117122
# FP16
118-
replace(
119-
conf,
120-
name="TRT FP16 Eager",
121-
trt=True,
122-
jit=False,
123-
fp16=True,
124-
accuracy_rtol=1e-2,
125-
),
123+
# replace(
124+
# conf,
125+
# name="TRT FP16 Eager",
126+
# trt=True,
127+
# jit=False,
128+
# fp16=True,
129+
# accuracy_rtol=1e-2,
130+
# ),
126131
]
127132

128133
results = [
@@ -189,8 +194,12 @@ def run_configuration_benchmark(
189194
max_batch_size=conf.batch_size,
190195
lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,
191196
explicit_batch_dimension=True,
197+
dynamic_batch=False,
198+
)
199+
random_inputs = [torch.randn((128, 3, 224, 224), dtype=torch.float32).cuda()]
200+
time = benchmark_torch_function(
201+
conf.batch_iter, lambda: lowered_module(*random_inputs)
192202
)
193-
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
194203
else:
195204
print("Lowering with JIT is not available!", "red")
196205

examples/fx/lower_example_aten.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torchvision
7+
import torch_tensorrt
78
from torch_tensorrt.fx import compile
89
from torch_tensorrt.fx.utils import LowerPrecision
910

@@ -97,21 +98,25 @@ def benchmark(
9798
"""
9899

99100
model = model.cuda().eval()
100-
inputs = [x.cuda() for x in inputs]
101-
101+
# inputs = [x.cuda() for x in inputs]
102+
inputs = [torch_tensorrt.Input(shape=(128, 3, 224, 224), dtype=torch.float32)]
103+
# inputs = [torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
104+
# opt_shape=[8, 3, 224, 224],
105+
# max_shape=[32, 3, 224, 224],
106+
# dtype=torch.float32)]
102107
# benchmark base configuration
103108
conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)
104109

105110
configurations = [
106111
# Baseline
107-
replace(conf, name="CUDA Eager", trt=False),
112+
# replace(conf, name="CUDA Eager", trt=False),
108113
# FP16
109114
replace(
110115
conf,
111-
name="TRT FP16 Eager",
116+
name="TRT FP32 Eager",
112117
trt=True,
113118
jit=False,
114-
fp16=True,
119+
fp16=False,
115120
accuracy_rtol=1e-2,
116121
),
117122
]
@@ -182,7 +187,10 @@ def run_configuration_benchmark(
182187
explicit_batch_dimension=True,
183188
is_aten=True,
184189
)
185-
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
190+
random_inputs = [torch.randn((128, 3, 224, 224), dtype=torch.float32).cuda()]
191+
time = benchmark_torch_function(
192+
conf.batch_iter, lambda: lowered_module(*random_inputs)
193+
)
186194
else:
187195
print("Lowering with JIT is not available!", "red")
188196

py/setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def run(self):
350350
if FX_ONLY:
351351
ext_modules = None
352352
packages = [
353+
"torch_tensorrt",
353354
"torch_tensorrt.fx",
354355
"torch_tensorrt.fx.converters",
355356
"torch_tensorrt.fx.passes",
@@ -358,6 +359,7 @@ def run(self):
358359
"torch_tensorrt.fx.tracer.dispatch_tracer",
359360
]
360361
package_dir = {
362+
"torch_tensorrt": "torch_tensorrt/",
361363
"torch_tensorrt.fx": "torch_tensorrt/fx",
362364
"torch_tensorrt.fx.converters": "torch_tensorrt/fx/converters",
363365
"torch_tensorrt.fx.passes": "torch_tensorrt/fx/passes",
@@ -437,7 +439,9 @@ def run(self):
437439
"bin/*",
438440
"BUILD",
439441
"WORKSPACE",
440-
],
442+
]
443+
if not FX_ONLY
444+
else ["_Input.py"]
441445
},
442446
exclude_package_data={
443447
"": ["*.cpp"],

py/torch_tensorrt/_Input.py

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55

66
from torch_tensorrt import _enums
7-
from torch_tensorrt import _C
87

98

109
class Input(object):
@@ -41,6 +40,7 @@ class _ShapeMode(Enum):
4140
DOMAIN_OFFSET = 2.0
4241
low_tensor_domain_incl = 0.0
4342
high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET
43+
torch_dtype = None
4444

4545
def __init__(self, *args, **kwargs):
4646
"""__init__ Method for torch_tensorrt.Input
@@ -138,6 +138,9 @@ def __init__(self, *args, **kwargs):
138138
)
139139

140140
if "dtype" in kwargs:
141+
if isinstance(kwargs["dtype"], torch.dtype):
142+
self.torch_dtype = kwargs["dtype"]
143+
141144
self.dtype = Input._parse_dtype(kwargs["dtype"])
142145
self._explicit_set_dtype = True
143146

@@ -173,59 +176,6 @@ def __str__(self) -> str:
173176
else:
174177
raise RuntimeError("Unknown input shape mode")
175178

176-
def _to_internal(self) -> _C.Input:
177-
internal_in = _C.Input()
178-
if self.shape_mode == Input._ShapeMode.DYNAMIC:
179-
if not Input._supported_input_size_type(self.shape["min_shape"]):
180-
raise TypeError(
181-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
182-
+ str(type(self.shape["min_shape"]))
183-
+ " for min_shape"
184-
)
185-
else:
186-
internal_in.min = self.shape["min_shape"]
187-
188-
if not Input._supported_input_size_type(self.shape["opt_shape"]):
189-
raise TypeError(
190-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
191-
+ str(type(self.shape["opt_shape"]))
192-
+ " for opt_shape"
193-
)
194-
else:
195-
internal_in.opt = self.shape["opt_shape"]
196-
197-
if not Input._supported_input_size_type(self.shape["max_shape"]):
198-
raise TypeError(
199-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
200-
+ str(type(self.shape["max_shape"]))
201-
+ " for max_shape"
202-
)
203-
else:
204-
internal_in.max = self.shape["max_shape"]
205-
internal_in.input_is_dynamic = True
206-
else:
207-
if not Input._supported_input_size_type(self.shape):
208-
raise TypeError(
209-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
210-
+ str(type(self.shape))
211-
+ " for shape"
212-
)
213-
else:
214-
internal_in.opt = self.shape
215-
internal_in.input_is_dynamic = False
216-
217-
if self.dtype != _enums.dtype.unknown:
218-
self._explicit_set_dtype = True
219-
else:
220-
self._explicit_set_dtype = False
221-
222-
internal_in.dtype = Input._parse_dtype(self.dtype)
223-
internal_in._explicit_set_dtype = self._explicit_set_dtype
224-
internal_in.format = Input._parse_format(self.format)
225-
226-
internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain)
227-
return internal_in
228-
229179
@staticmethod
230180
def _supported_input_size_type(input_size: Any) -> bool:
231181
if isinstance(input_size, torch.Size):
@@ -304,6 +254,7 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple:
304254
Input.low_tensor_domain_incl,
305255
Input.high_tensor_domain_excl,
306256
)
257+
307258
elif len(domain) == 2:
308259
domain_lo, domain_hi = domain
309260

@@ -416,8 +367,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor
416367
)
417368

418369
if self.shape_mode == Input._ShapeMode.STATIC:
419-
return torch.randn(self.shape).to(dtype=self.dtype)
370+
return torch.randn(self.shape).to(
371+
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
372+
)
420373
else:
421374
return torch.randn(self.shape[optimization_profile_field]).to(
422-
dtype=self.dtype
375+
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
423376
)

py/torch_tensorrt/_compile.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,8 @@ def compile(
142142
return torch_tensorrt.fx.compile(
143143
module,
144144
inputs,
145-
lower_precision=lower_precision,
146-
max_batch_size=inputs[0].size(0),
147145
explicit_batch_dimension=True,
146+
lower_precision=lower_precision,
148147
dynamic_batch=False,
149148
**kwargs,
150149
)

py/torch_tensorrt/fx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
tensorrt_converter,
99
)
1010
from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa
11-
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
11+
from .input_tensor_spec import InputTensorSpec # noqa
1212
from .lower_setting import LowerSetting # noqa
1313
from .trt_module import TRTModule # noqa
1414
from .lower import compile # usort: skip #noqa

py/torch_tensorrt/fx/fx2trt.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def validate_conversion(self):
153153

154154
def run(
155155
self,
156-
max_batch_size=64,
157156
max_workspace_size=1 << 25,
158157
lower_precision=LowerPrecision.FP16,
159158
sparse_weights=False,
@@ -167,7 +166,6 @@ def run(
167166
"""
168167
Build TensorRT engine with some configs.
169168
Args:
170-
max_batch_size: set accordingly for maximum batch size you will use.
171169
max_workspace_size: set to the maximum size we can afford for temporary buffer
172170
lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
173171
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
@@ -207,7 +205,6 @@ def run(
207205
)
208206
build_engine_start_time = datetime.now()
209207

210-
self.builder.max_batch_size = max_batch_size
211208
builder_config = self.builder.create_builder_config()
212209
builder_config.max_workspace_size = max_workspace_size
213210

0 commit comments

Comments
 (0)