Skip to content

Commit 717e11b

Browse files
authored
feat: Adding support for native int64 (#2789)
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e6f9aa2 commit 717e11b

28 files changed

+366
-98
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repos:
1616
- --fix=lf
1717
exclude: ^docs
1818
- repo: https://github.com/pre-commit/mirrors-clang-format
19-
rev: v18.1.1
19+
rev: v14.0.6
2020
hooks:
2121
- id: clang-format
2222
types_or: [c++, c, cuda]

core/runtime/register_jit_hooks.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ TORCH_LIBRARY(tensorrt, m) {
122122
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
123123
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
124124
});
125+
m.def("set_logging_level", [](int64_t level) -> void {
126+
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
127+
});
128+
m.def(
129+
"get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); });
125130
}
126131

127132
} // namespace

core/util/trt_util.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
292292
{at::kFloat, nvinfer1::DataType::kFLOAT},
293293
{at::kHalf, nvinfer1::DataType::kHALF},
294294
{at::kInt, nvinfer1::DataType::kINT32},
295-
{at::kLong, nvinfer1::DataType::kINT32},
295+
{at::kLong, nvinfer1::DataType::kINT64},
296296
{at::kChar, nvinfer1::DataType::kINT8},
297297
{at::kByte, nvinfer1::DataType::kINT8},
298298
{at::kBool, nvinfer1::DataType::kBOOL}};
@@ -304,6 +304,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
304304
{nvinfer1::DataType::kFLOAT, at::kFloat},
305305
{nvinfer1::DataType::kHALF, at::kHalf},
306306
{nvinfer1::DataType::kINT32, at::kInt},
307+
{nvinfer1::DataType::kINT64, at::kLong},
307308
{nvinfer1::DataType::kINT8, at::kChar},
308309
{nvinfer1::DataType::kBOOL, at::kBool},
309310
};

core/util/trt_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
5353
return stream << "Int8";
5454
case nvinfer1::DataType::kINT32:
5555
return stream << "Int32";
56+
case nvinfer1::DataType::kINT64:
57+
return stream << "Int64";
5658
case nvinfer1::DataType::kBOOL:
5759
return stream << "Bool";
5860
default:

py/torch_tensorrt/_enums.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from typing import Any, Optional, Type, Union
66

77
import numpy as np
8-
import tensorrt as trt
98
import torch
109
from torch_tensorrt._features import ENABLED_FEATURES
1110

11+
import tensorrt as trt
12+
1213

1314
class dtype(Enum):
1415
"""Enum to set supported dtypes in the compiler"""
@@ -103,6 +104,8 @@ def _from(
103104
return dtype.i8
104105
elif t == trt.int32:
105106
return dtype.i32
107+
elif t == trt.int64:
108+
return dtype.i64
106109
elif t == trt.float16:
107110
return dtype.f16
108111
elif t == trt.float32:
@@ -227,6 +230,8 @@ def to(
227230
return trt.DataType.INT8
228231
elif self == dtype.i32:
229232
return trt.DataType.INT32
233+
elif self == dtype.i64:
234+
return trt.DataType.INT64
230235
elif self == dtype.f16:
231236
return trt.DataType.HALF
232237
elif self == dtype.f32:

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import warnings
56
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
67

78
import torch
@@ -22,7 +23,7 @@
2223
UnsupportedOperatorException,
2324
convert_module,
2425
interpret_module_to_result,
25-
repair_long_or_double_inputs,
26+
repair_double_inputs,
2627
)
2728
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
2829
DYNAMO_CONVERTERS as CONVERTERS,
@@ -58,7 +59,7 @@ def compile(
5859
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
5960
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
6061
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
61-
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
62+
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
6263
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
6364
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
6465
torch_executed_ops: Optional[Collection[Target]] = None,
@@ -74,7 +75,7 @@ def compile(
7475
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
7576
**kwargs: Any,
7677
) -> torch.fx.GraphModule:
77-
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
78+
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
7879
7980
Takes a existing TorchScript module and a set of settings to configure the compiler
8081
and will convert methods to JIT Graphs which call equivalent TensorRT engines
@@ -115,7 +116,7 @@ def compile(
115116
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
116117
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
117118
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
118-
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
119+
truncate_double (bool): Truncate weights provided in double (float64) to float32
119120
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
120121
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
121122
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
@@ -138,6 +139,19 @@ def compile(
138139
if debug:
139140
set_log_level(logger.parent, logging.DEBUG)
140141

142+
if "truncate_long_and_double" in kwargs.keys():
143+
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
144+
raise ValueError(
145+
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
146+
)
147+
else:
148+
truncate_double = kwargs["truncate_long_and_double"]
149+
warnings.warn(
150+
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
151+
DeprecationWarning,
152+
stacklevel=2,
153+
)
154+
141155
engine_capability = EngineCapability._from(engine_capability)
142156

143157
if torch_executed_modules is not None and torch_executed_modules:
@@ -185,7 +199,7 @@ def compile(
185199
"version_compatible": version_compatible,
186200
"optimization_level": optimization_level,
187201
"use_python_runtime": use_python_runtime,
188-
"truncate_long_and_double": truncate_long_and_double,
202+
"truncate_double": truncate_double,
189203
"use_fast_partitioner": use_fast_partitioner,
190204
"num_avg_timing_iters": num_avg_timing_iters,
191205
"enable_experimental_decompositions": enable_experimental_decompositions,
@@ -349,8 +363,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
349363

350364
assert submodule_inputs is not None
351365
# Handle long/double inputs if requested by the user
352-
if settings.truncate_long_and_double:
353-
submodule_inputs = repair_long_or_double_inputs(
366+
if settings.truncate_double:
367+
submodule_inputs = repair_double_inputs(
354368
partitioned_module,
355369
submodule,
356370
submodule_inputs,
@@ -423,7 +437,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
423437

424438
def convert_module_to_trt_engine(
425439
exported_program: ExportedProgram,
426-
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
440+
inputs: Tuple[Any, ...],
441+
*,
427442
enabled_precisions: (
428443
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
429444
) = _defaults.ENABLED_PRECISIONS,
@@ -436,7 +451,7 @@ def convert_module_to_trt_engine(
436451
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
437452
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
438453
use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME,
439-
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
454+
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
440455
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
441456
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
442457
device: Device = Device._current_device(),
@@ -451,6 +466,7 @@ def convert_module_to_trt_engine(
451466
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
452467
calibrator: object = None,
453468
allow_shape_tensors: bool = False,
469+
**kwargs: Any,
454470
) -> bytes:
455471
"""Convert an ExportedProgram to a serialized TensorRT engine
456472
@@ -488,7 +504,7 @@ def convert_module_to_trt_engine(
488504
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
489505
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
490506
argument as None
491-
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
507+
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
492508
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
493509
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
494510
or only a selected subset of them
@@ -512,6 +528,19 @@ def convert_module_to_trt_engine(
512528
if debug:
513529
set_log_level(logger.parent, logging.DEBUG)
514530

531+
if "truncate_long_and_double" in kwargs.keys():
532+
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
533+
raise ValueError(
534+
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
535+
)
536+
else:
537+
truncate_double = kwargs["truncate_long_and_double"]
538+
warnings.warn(
539+
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
540+
DeprecationWarning,
541+
stacklevel=2,
542+
)
543+
515544
input_list = list(inputs) if inputs is not None else []
516545
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
517546
# Prepare torch_trt inputs
@@ -531,7 +560,7 @@ def convert_module_to_trt_engine(
531560
"version_compatible": version_compatible,
532561
"optimization_level": optimization_level,
533562
"use_python_runtime": use_python_runtime,
534-
"truncate_long_and_double": truncate_long_and_double,
563+
"truncate_double": truncate_double,
535564
"use_fast_partitioner": use_fast_partitioner,
536565
"enable_experimental_decompositions": enable_experimental_decompositions,
537566
"device": device,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
VERSION_COMPATIBLE = False
1919
OPTIMIZATION_LEVEL = None
2020
SPARSE_WEIGHTS = False
21-
TRUNCATE_LONG_AND_DOUBLE = False
21+
TRUNCATE_DOUBLE = False
2222
USE_PYTHON_RUNTIME = False
2323
USE_FAST_PARTITIONER = True
2424
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
REFIT,
2424
REQUIRE_FULL_COMPILATION,
2525
SPARSE_WEIGHTS,
26-
TRUNCATE_LONG_AND_DOUBLE,
26+
TRUNCATE_DOUBLE,
2727
USE_FAST_PARTITIONER,
2828
USE_PYTHON_RUNTIME,
2929
VERSION_COMPATIBLE,
@@ -50,7 +50,7 @@ class CompilationSettings:
5050
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
5151
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
5252
argument as None
53-
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
53+
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
5454
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
5555
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
5656
or only a selected subset of them
@@ -81,7 +81,7 @@ class CompilationSettings:
8181
version_compatible: bool = VERSION_COMPATIBLE
8282
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
8383
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
84-
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
84+
truncate_double: bool = TRUNCATE_DOUBLE
8585
use_fast_partitioner: bool = USE_FAST_PARTITIONER
8686
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
8787
device: Device = field(default_factory=default_device)

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,20 @@
2121
from torch.fx.node import Argument, Node, Target, _get_qualified_name
2222
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
2323
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS
24-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
24+
25+
import tensorrt as trt
2526

2627
logger = logging.getLogger(__name__)
2728

2829
LegacyConverterImplSignature = Callable[
2930
[
30-
TRTNetwork,
31+
trt.INetworkDefinition,
3132
Target,
3233
Tuple[Argument, ...],
3334
Dict[str, Argument],
3435
str,
3536
],
36-
Union[TRTTensor, Sequence[TRTTensor]],
37+
Union[trt.ITensor, Sequence[trt.ITensor]],
3738
]
3839

3940
DynamoConverterImplSignature = Callable[
@@ -44,7 +45,7 @@
4445
Dict[str, Argument],
4546
str,
4647
],
47-
Union[TRTTensor, Sequence[TRTTensor]],
48+
Union[trt.ITensor, Sequence[trt.ITensor]],
4849
]
4950

5051
ConverterImplSignature = Union[

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7-
import tensorrt as trt
87
import torch
98
import torch.fx
109
from torch.fx.node import _get_qualified_name
@@ -26,6 +25,7 @@
2625
from torch_tensorrt.fx.observer import Observer
2726
from torch_tensorrt.logging import TRT_LOGGER
2827

28+
import tensorrt as trt
2929
from packaging import version
3030

3131
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
498498
)
499499

500500
for i, output in enumerate(outputs):
501+
name = f"output{i}"
502+
503+
output_dtype = dtype.unknown
501504
if any(
502505
op_name in output.name.split("_")
503506
for op_name in (
@@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
514517
"any",
515518
)
516519
):
517-
output_bool = True
518-
else:
519-
output_bool = False
520-
name = f"output{i}"
521-
output.name = name
522-
self.ctx.net.mark_output(output)
523-
if output_bool:
524-
output.dtype = trt.DataType.BOOL
520+
output_dtype = dtype.b
525521
elif self.output_dtypes is not None:
526-
output.dtype = self.output_dtypes[i].to(trt.DataType)
522+
if self.output_dtypes[i] == dtype.i64:
523+
output = self.ctx.net.add_cast(
524+
output, dtype.i64.to(trt.DataType)
525+
).get_output(0)
526+
output_dtype = dtype.i64
527+
else:
528+
output_dtype = self.output_dtypes[i]
529+
530+
self.ctx.net.mark_output(output)
531+
if output_dtype is not dtype.unknown:
532+
output.dtype = output_dtype.to(trt.DataType, use_default=True)
533+
output.name = name
527534

528535
self._output_names.append(name)
529536
_LOGGER.debug(

py/torch_tensorrt/dynamo/conversion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403
55
from ._TRTInterpreter import * # noqa: F403
6-
from .truncate_long_and_double import repair_long_or_double_inputs
6+
from .truncate_double import repair_double_inputs

0 commit comments

Comments
 (0)