@@ -34,7 +34,7 @@ bool check_convolution_backward_args(
34
34
bool transposed,
35
35
IntArrayRef output_padding,
36
36
int64_t groups,
37
- ET_UNUSED executorch::aten::ArrayRef<bool > output_mask,
37
+ executorch::aten::ArrayRef<bool > output_mask,
38
38
Tensor& grad_input,
39
39
Tensor& grad_weight,
40
40
Tensor& grad_bias) {
@@ -45,9 +45,18 @@ bool check_convolution_backward_args(
45
45
46
46
ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (weight, input));
47
47
ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_output, input));
48
- ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_input, input));
49
- ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_weight, input));
50
- ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_bias, input));
48
+
49
+ if (output_mask[0 ]) {
50
+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_input, input));
51
+ }
52
+
53
+ if (output_mask[1 ]) {
54
+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_weight, input));
55
+ }
56
+
57
+ if (output_mask[2 ]) {
58
+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_bias, input));
59
+ }
51
60
52
61
ET_LOG_MSG_AND_RETURN_IF_FALSE (
53
62
check_convolution_args (
@@ -267,19 +276,23 @@ std::tuple<Tensor&, Tensor&, Tensor&> convolution_backward_out(
267
276
InvalidArgument,
268
277
ret_val);
269
278
270
- ET_KERNEL_CHECK (
271
- ctx,
272
- resize_tensor (grad_input, input.sizes ()) == Error::Ok,
273
- InvalidArgument,
274
- ret_val);
279
+ if (output_mask[0 ]) {
280
+ ET_KERNEL_CHECK (
281
+ ctx,
282
+ resize_tensor (grad_input, input.sizes ()) == Error::Ok,
283
+ InvalidArgument,
284
+ ret_val);
285
+ }
275
286
276
- ET_KERNEL_CHECK (
277
- ctx,
278
- resize_tensor (grad_weight, weight.sizes ()) == Error::Ok,
279
- InvalidArgument,
280
- ret_val);
287
+ if (output_mask[1 ]) {
288
+ ET_KERNEL_CHECK (
289
+ ctx,
290
+ resize_tensor (grad_weight, weight.sizes ()) == Error::Ok,
291
+ InvalidArgument,
292
+ ret_val);
293
+ }
281
294
282
- if (bias_sizes_opt.has_value ()) {
295
+ if (bias_sizes_opt.has_value () && output_mask[ 2 ] ) {
283
296
ET_KERNEL_CHECK (
284
297
ctx,
285
298
resize_tensor (grad_bias, bias_sizes_opt.value ()) == Error::Ok,
0 commit comments