Skip to content

Commit 22faab8

Browse files
committed
refactor: Require output types to be provided to TRTInterpreter
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e6da7e4 commit 22faab8

22 files changed

+109
-80
lines changed

py/torch_tensorrt/csrc/tensorrt_classes.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,23 +235,23 @@ std::string Device::to_str() {
235235

236236
std::string to_str(EngineCapability value) {
237237
switch (value) {
238-
case EngineCapability::kSAFE_GPU:
239-
return "Safe GPU";
240-
case EngineCapability::kSAFE_DLA:
241-
return "Safe DLA";
242-
case EngineCapability::kDEFAULT:
238+
case EngineCapability::kDLA_STANDALONE:
239+
return "DLA Standalone";
240+
case EngineCapability::kSAFETY:
241+
return "Safety";
242+
case EngineCapability::kSTANDARD:
243243
default:
244-
return "Default";
244+
return "Standard";
245245
}
246246
}
247247

248248
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
249249
switch (value) {
250-
case EngineCapability::kSAFE_DLA:
250+
case EngineCapability::kDLA_STANDALONE:
251251
return TRT_ENGINE_CAPABILITY_DLA_STANDALONE;
252-
case EngineCapability::kSAFE_GPU:
252+
case EngineCapability::kSAFETY:
253253
return TRT_ENGINE_CAPABILITY_SAFETY;
254-
case EngineCapability::kDEFAULT:
254+
case EngineCapability::kSTANDARD:
255255
default:
256256
return TRT_ENGINE_CAPABILITY_STANDARD;
257257
}

py/torch_tensorrt/csrc/tensorrt_classes.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ struct TorchFallback : torch::CustomClassHolder {
114114
};
115115

116116
enum class EngineCapability : int8_t {
117-
kDEFAULT,
118-
kSAFE_GPU,
119-
kSAFE_DLA,
117+
kSTANDARD,
118+
kSAFETY,
119+
kDLA_STANDALONE,
120120
};
121121

122122
std::string to_str(EngineCapability value);
@@ -160,7 +160,7 @@ struct CompileSpec : torch::CustomClassHolder {
160160
ADD_FIELD_GET_SET(sparse_weights, bool);
161161
ADD_FIELD_GET_SET(refit, bool);
162162
ADD_FIELD_GET_SET(debug, bool);
163-
ADD_ENUM_GET_SET(capability, EngineCapability, static_cast<int64_t>(EngineCapability::kSAFE_DLA));
163+
ADD_ENUM_GET_SET(capability, EngineCapability, static_cast<int64_t>(EngineCapability::kSTANDARD));
164164
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
165165
ADD_FIELD_GET_SET(workspace_size, int64_t);
166166
ADD_FIELD_GET_SET(dla_sram_size, int64_t);
@@ -184,7 +184,7 @@ struct CompileSpec : torch::CustomClassHolder {
184184
bool allow_shape_tensors = false;
185185
Device device;
186186
TorchFallback torch_fallback;
187-
EngineCapability capability = EngineCapability::kDEFAULT;
187+
EngineCapability capability = EngineCapability::kSTANDARD;
188188
int64_t num_avg_timing_iters = 1;
189189
int64_t workspace_size = 0;
190190
int64_t dla_sram_size = 1048576;

py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,9 @@ PYBIND11_MODULE(_C, m) {
261261
m,
262262
"EngineCapability",
263263
"Enum to specify engine capability settings (selections of kernels to meet safety requirements)")
264-
.value("safe_gpu", EngineCapability::kSAFE_GPU, "Use safety GPU kernels only")
265-
.value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only")
266-
.value("default", EngineCapability::kDEFAULT, "Use default behavior");
264+
.value("SAFETY", EngineCapability::kSAFETY, "Use safe kernels only")
265+
.value("DLA_STANDALONE", EngineCapability::kDLA_STANDALONE, "Use DLA kernels only")
266+
.value("STANDARD", EngineCapability::kSTANDARD, "Use default behavior");
267267

268268
py::enum_<TensorFormat>(m, "TensorFormat", "Enum to specifiy the memory layout of tensors")
269269
.value("contiguous", TensorFormat::kContiguous, "Contiguous memory layout (NCHW / Linear)")

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def __init__(
105105
[dtype._from(o) for o in output_dtypes] if output_dtypes else None
106106
)
107107

108+
_LOGGER.debug(f"Graph to be compiled to TensorRT: {self.module.graph}")
109+
108110
def validate_conversion(self) -> Set[str]:
109111
missing_converters: Set[str] = set()
110112

@@ -121,6 +123,18 @@ def validate_conversion(self) -> Set[str]:
121123

122124
return missing_converters
123125

126+
@staticmethod
127+
def _args_str(args: List[Any]) -> str:
128+
args_ = [
129+
(
130+
f"ITensor {a.name} (shape: {a.shape}, dtype: {a.dtype})"
131+
if isinstance(a, trt.ITensor)
132+
else a
133+
)
134+
for a in args
135+
]
136+
return str(tuple(args_))
137+
124138
@staticmethod
125139
def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool:
126140
return enabled_precisions.issubset(_defaults.SUPPORTED_KERNEL_PRECISIONS)
@@ -359,10 +373,14 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
359373
f"Unable to access shape spec for input: {target} (got: {current_input})"
360374
)
361375

376+
trt_input_dtype = current_input.dtype.to(trt.DataType, use_default=True)
377+
_LOGGER.debug(
378+
f"Adding input to in-progress INetwork: {target} (shape={shape}, dtype={trt_input_dtype})"
379+
)
362380
return self.ctx.net.add_input(
363381
name=target,
364382
shape=tuple(shape),
365-
dtype=current_input.dtype.to(trt.DataType, use_default=True),
383+
dtype=trt_input_dtype,
366384
)
367385

368386
def call_module(
@@ -381,6 +399,9 @@ def call_module(
381399
converter, calling_convention = converter_packet
382400

383401
assert self._cur_node_name is not None
402+
_LOGGER.debug(
403+
f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})"
404+
)
384405
if calling_convention is CallingConvention.LEGACY:
385406
return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name)
386407
else:
@@ -397,6 +418,9 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
397418
converter, calling_convention = converter_packet
398419

399420
assert self._cur_node_name is not None
421+
_LOGGER.debug(
422+
f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})"
423+
)
400424
if calling_convention is CallingConvention.LEGACY:
401425
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
402426
else:
@@ -428,6 +452,9 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
428452
converter, calling_convention = converter_packet
429453

430454
assert self._cur_node_name is not None
455+
_LOGGER.debug(
456+
f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})"
457+
)
431458
if calling_convention is CallingConvention.LEGACY:
432459
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
433460
else:
@@ -485,8 +512,10 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
485512
output.dtype = trt.DataType.BOOL
486513
elif self.output_dtypes is not None:
487514
output.dtype = self.output_dtypes[i].to(trt.DataType)
488-
elif self.output_fp16 and output.dtype == trt.DataType.FLOAT:
489-
output.dtype = trt.DataType.HALF
515+
490516
self._output_names.append(name)
517+
_LOGGER.debug(
518+
f"Marking output {name} (shape: {output.shape}, dtype: {output.dtype})"
519+
)
491520

492521
return list(outputs)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import io
44
import logging
5-
from typing import Sequence
5+
from typing import List, Sequence
66

77
import torch
8+
from torch_tensorrt._Device import Device
89
from torch_tensorrt._enums import dtype
910
from torch_tensorrt._features import ENABLED_FEATURES
1011
from torch_tensorrt._Input import Input
@@ -21,20 +22,14 @@
2122
logger = logging.getLogger(__name__)
2223

2324

24-
def interpret_module_to_result(
25+
def infer_module_output_dtypes(
2526
module: torch.fx.GraphModule,
2627
inputs: Sequence[Input],
27-
settings: CompilationSettings = CompilationSettings(),
28-
) -> TRTInterpreterResult:
29-
"""Interpret an FX module to a TRTInterpreterResult
30-
Args:
31-
module: FX GraphModule to interpret
32-
inputs: Sequence of Tensors representing inputs to the module
33-
settings: Compilation settings
34-
Returns:
35-
TRTInterpreterResult
36-
"""
37-
torch_inputs = get_torch_inputs(inputs, settings.device)
28+
device: Device,
29+
truncate_long_and_double: bool = False,
30+
) -> List[dtype]:
31+
torch_inputs = get_torch_inputs(inputs, device)
32+
module = module.to(device.to(torch.device))
3833
module_outputs = module(*torch_inputs)
3934

4035
if not isinstance(module_outputs, (list, tuple)):
@@ -44,13 +39,36 @@ def interpret_module_to_result(
4439
# such as aten.sum - such outputs can be truncated
4540
output_dtypes = []
4641
for output in module_outputs:
47-
if settings.truncate_long_and_double and output.dtype == dtype.float64:
42+
if truncate_long_and_double and output.dtype == dtype.float64:
4843
output_dtypes.append(dtype.float32)
49-
elif settings.truncate_long_and_double and output.dtype == dtype.int64:
44+
elif truncate_long_and_double and output.dtype == dtype.int64:
5045
output_dtypes.append(dtype.int32)
5146
else:
5247
output_dtypes.append(dtype._from(output.dtype))
5348

49+
return output_dtypes
50+
51+
52+
def interpret_module_to_result(
53+
module: torch.fx.GraphModule,
54+
inputs: Sequence[Input],
55+
settings: CompilationSettings = CompilationSettings(),
56+
) -> TRTInterpreterResult:
57+
"""Interpret an FX module to a TRTInterpreterResult
58+
Args:
59+
module: FX GraphModule to interpret
60+
inputs: Sequence of Tensors representing inputs to the module
61+
settings: Compilation settings
62+
Returns:
63+
TRTInterpreterResult
64+
"""
65+
output_dtypes = infer_module_output_dtypes(
66+
module,
67+
inputs,
68+
settings.device,
69+
truncate_long_and_double=settings.truncate_long_and_double,
70+
)
71+
5472
interpreter = TRTInterpreter(
5573
module,
5674
inputs,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def get_node_name(node: torch.fx.Node) -> str:
4242
# like the node.meta['source_fn'] attr
4343
pass
4444

45-
_LOGGER.debug(f"Node meta name {node_name}")
4645
return node_name
4746

4847

tests/py/dynamo/conversion/harness.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
1515
from torch_tensorrt.dynamo.conversion import TRTInterpreter
16+
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
1617
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
1718
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
1819

@@ -71,7 +72,8 @@ def run_test(
7172
interpreter_result.output_names,
7273
)
7374

74-
ref_outputs = mod(*inputs)
75+
mod = mod.cuda()
76+
ref_outputs = mod(*cuda_inputs)
7577

7678
torch.cuda.synchronize()
7779
start_event = torch.cuda.Event(enable_timing=True)
@@ -147,7 +149,7 @@ def run_test_custom_compare_results(
147149
interpreter_result.output_names,
148150
)
149151
res_trt = trt_mod(*cuda_inputs).cpu()
150-
res_cpu = mod(*inputs)
152+
res_cpu = mod(*cuda_inputs).cpu()
151153
assert len(res_trt) == len(res_cpu)
152154
assert len(res_cpu) == len(comparators)
153155
for output_trt, output_cpu, comparator in zip(
@@ -211,7 +213,6 @@ def generate_graph(
211213
fx_module = torch.fx.symbolic_trace(mod)
212214
if enable_passes:
213215
fx_module = apply_lowering_passes(fx_module, original_inputs)
214-
_LOGGER.info(f"FX graph= {fx_module.graph}")
215216
return fx_module
216217

217218
def run_test(
@@ -222,7 +223,6 @@ def run_test(
222223
atol=1e-03,
223224
precision=dtype.f32,
224225
check_dtype=True,
225-
output_dtypes=None,
226226
use_dynamo_tracer=False,
227227
enable_passes=False,
228228
):
@@ -237,12 +237,24 @@ def run_test(
237237
# Previous instance of the interpreter auto-casted 64-bit inputs
238238
# We replicate this behavior here
239239
compilation_settings = CompilationSettings(
240-
enabled_precisions={dtype._from(precision)}, truncate_long_and_double=True
240+
enabled_precisions={dtype._from(precision)},
241+
truncate_long_and_double=True,
241242
)
242243

244+
input_specs = [Input.from_tensor(i) for i in inputs]
245+
246+
output_dtypes = None
247+
if check_dtype:
248+
output_dtypes = infer_module_output_dtypes(
249+
mod,
250+
input_specs,
251+
compilation_settings.device,
252+
truncate_long_and_double=compilation_settings.truncate_long_and_double,
253+
)
254+
243255
interp = TRTInterpreter(
244256
mod,
245-
Input.from_tensors(inputs),
257+
input_specs,
246258
output_dtypes=output_dtypes,
247259
compilation_settings=compilation_settings,
248260
)

tests/py/dynamo/conversion/test_abs_aten.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def forward(self, input):
4242
self.run_test(
4343
abs(),
4444
inputs,
45-
output_dtypes=[torch.int],
4645
)
4746

4847

0 commit comments

Comments
 (0)