@@ -134,9 +134,15 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
134
134
135
135
auto eps = static_cast <float >(args[7 ].unwrapToDouble (1e-5f ));
136
136
137
- auto scales = args[1 ].unwrapToTensor (at::ones (shape[1 ], options)).cpu ().contiguous ();
138
- auto bias = args[2 ].unwrapToTensor (at::zeros (shape[1 ], options)).cpu ().contiguous ();
139
137
138
+ auto scales = at::ones (shape[1 ], options);
139
+ if (!args[1 ].IValue ()->isNone ()) {
140
+ scales = args[1 ].unwrapToTensor (at::ones (shape[1 ], options)).cpu ().contiguous ();
141
+ }
142
+ auto bias = at::zeros (shape[1 ], options);
143
+ if (!args[2 ].IValue ()->isNone ()){
144
+ bias = args[2 ].unwrapToTensor (at::zeros (shape[1 ], options)).cpu ().contiguous ();
145
+ }
140
146
// track_running_stats=True
141
147
if (!args[3 ].IValue ()->isNone () || !args[4 ].IValue ()->isNone ()) {
142
148
auto running_mean = args[3 ].unwrapToTensor ();
@@ -154,6 +160,8 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
154
160
return true ;
155
161
}
156
162
163
+ // Not sure this actually does something since the cudnn_enabled is from the PyTorch context.
164
+ // We need cuDNN either way to run this converter
157
165
auto cudnn_enabled = static_cast <bool >(args[8 ].unwrapToBool (false ));
158
166
if (!cudnn_enabled) {
159
167
LOG_DEBUG (
@@ -162,7 +170,7 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
162
170
so for some functionalities, users need to install correct \
163
171
cuDNN version by themselves. Please see our support matrix \
164
172
here: https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html." );
165
- return false ;
173
+ // return false;
166
174
}
167
175
168
176
const int relu = 0 ;
0 commit comments