Skip to content

Commit bdde52e

Browse files
committed
refactor(//py): Have the python api default input types to PyTorch like
behavior unless user explicitly overrides Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 08b4942 commit bdde52e

File tree

5 files changed

+17
-1
lines changed

5 files changed

+17
-1
lines changed

py/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def run(self):
181181
include_dirs=[
182182
dir_path + "trtorch/csrc",
183183
dir_path + "/../",
184-
dir_path + "/../bazel-trtorch-testing/external/tensorrt/include",
184+
dir_path + "/../bazel-TRTorch/external/tensorrt/include",
185185
],
186186
extra_compile_args=[
187187
"-Wno-deprecated",

py/trtorch/Input.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class _ShapeMode(Enum):
3030
shape_mode = None
3131
shape = None
3232
dtype = _types.dtype.float32
33+
_explicit_set_dtype = False
3334
format = _types.TensorFormat.contiguous
3435

3536
def __init__(self, *args, **kwargs):
@@ -105,6 +106,7 @@ def __init__(self, *args, **kwargs):
105106

106107
if "dtype" in kwargs:
107108
self.dtype = Input._parse_dtype(kwargs["dtype"])
109+
self._explicit_set_dtype = True
108110

109111
if "format" in kwargs:
110112
self.format = Input._parse_format(kwargs["format"])
@@ -128,6 +130,7 @@ def _to_internal(self) -> trtorch._C.Input:
128130
internal_in.opt = self.shape
129131
internal_in.input_is_dynamic = False
130132
internal_in.dtype = self.dtype
133+
internal_in._explicit_set_dtype = self._explicit_set_dtype
131134
internal_in.format = self.format
132135
return internal_in
133136

py/trtorch/_compile_spec.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,16 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
169169

170170
if "enabled_precisions" in compile_spec:
171171
info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"])
172+
# We want default behavior to match PyTorch, so in the case the user did not explicitly set the dtype for inputs they
173+
# will follow PyTorch convetions
174+
for i in info.inputs:
175+
if not i._explicit_set_dtype:
176+
if _types.dtype.int8 in info.enabled_precisions:
177+
i.dtype = _types.dtype.float32
178+
elif _types.dtype.half in info.enabled_precisions:
179+
i.dtype = _types.dtype.float16
180+
else:
181+
i.dtype = _types.dtype.float32
172182

173183
if "calibrator" in compile_spec:
174184
info.ptq_calibrator = compile_spec["calibrator"]

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ struct Input : torch::CustomClassHolder {
4141
std::vector<int64_t> max;
4242

4343
bool input_is_dynamic;
44+
bool explicit_set_dtype;
4445
DataType dtype;
4546
TensorFormat format;
4647

4748
ADD_FIELD_GET_SET(min, std::vector<int64_t>);
4849
ADD_FIELD_GET_SET(opt, std::vector<int64_t>);
4950
ADD_FIELD_GET_SET(max, std::vector<int64_t>);
5051
ADD_FIELD_GET_SET(input_is_dynamic, bool);
52+
ADD_FIELD_GET_SET(explicit_set_dtype, bool);
5153
ADD_ENUM_GET_SET(dtype, DataType, static_cast<int64_t>(DataType::kBool));
5254
ADD_ENUM_GET_SET(format, TensorFormat, static_cast<int64_t>(TensorFormat::kContiguous));
5355

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ PYBIND11_MODULE(_C, m) {
170170
.def_readwrite("opt", &Input::opt)
171171
.def_readwrite("max", &Input::max)
172172
.def_readwrite("input_is_dynamic", &Input::input_is_dynamic)
173+
.def_readwrite("_explicit_set_dtype", &Input::explicit_set_dtype)
173174
.def_readwrite("dtype", &Input::dtype)
174175
.def_readwrite("format", &Input::format);
175176

0 commit comments

Comments
 (0)