Skip to content

Commit de4957b

Browse files
committed
fix(aten::instance_norm): Handle optional inputs in instance norm converter
1 parent 43831dc commit de4957b

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,15 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
134134

135135
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
136136

137-
auto scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous();
138-
auto bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous();
139137

138+
auto scales = at::ones(shape[1], options);
139+
if (!args[1].IValue()->isNone()) {
140+
scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous();
141+
}
142+
auto bias = at::zeros(shape[1], options);
143+
if (!args[2].IValue()->isNone()){
144+
bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous();
145+
}
140146
// track_running_stats=True
141147
if (!args[3].IValue()->isNone() || !args[4].IValue()->isNone()) {
142148
auto running_mean = args[3].unwrapToTensor();
@@ -154,6 +160,8 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
154160
return true;
155161
}
156162

163+
// Not sure this actually does something since the cudnn_enabled is from the PyTorch context.
164+
// We need cuDNN either way to run this converter
157165
auto cudnn_enabled = static_cast<bool>(args[8].unwrapToBool(false));
158166
if (!cudnn_enabled) {
159167
LOG_DEBUG(
@@ -162,7 +170,7 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
162170
so for some functionalities, users need to install correct \
163171
cuDNN version by themselves. Please see our support matrix \
164172
here: https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html.");
165-
return false;
173+
//return false;
166174
}
167175

168176
const int relu = 0;

core/util/prelude.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
// A collection of headers from util that will typically get included in most
44
// files
5+
#include <cstdint>
56
#include "core/util/Exception.h"
67
#include "core/util/build_info.h"
78
#include "core/util/jit_util.h"

tests/core/conversion/converters/test_instance_norm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ constexpr auto graph = R"IR(
1818
%running_mean.1 : Tensor?,
1919
%running_var.1 : Tensor?,
2020
%use_input_stats.1 : bool):
21-
%cudnn_enabled.1 : bool = prim::Constant[value=0]()
21+
%cudnn_enabled.1 : bool = prim::Constant[value=1]()
2222
%momentum.1 : float = prim::Constant[value=0.10000000000000001]()
2323
%eps.1 : float = prim::Constant[value=1.0000000000000001e-05]()
2424
%4 : Tensor = aten::instance_norm(%input.1,

0 commit comments

Comments
 (0)