Skip to content

Commit 832f855

Browse files
Dont error check unused tensors in convolution_backward.out
Differential Revision: D69211399 Pull Request resolved: #8285
1 parent b1d76c9 commit 832f855

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

kernels/portable/cpu/op_convolution_backward.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ bool check_convolution_backward_args(
3434
bool transposed,
3535
IntArrayRef output_padding,
3636
int64_t groups,
37-
ET_UNUSED executorch::aten::ArrayRef<bool> output_mask,
37+
executorch::aten::ArrayRef<bool> output_mask,
3838
Tensor& grad_input,
3939
Tensor& grad_weight,
4040
Tensor& grad_bias) {
@@ -45,9 +45,18 @@ bool check_convolution_backward_args(
4545

4646
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(weight, input));
4747
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_output, input));
48-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input));
49-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_weight, input));
50-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_bias, input));
48+
49+
if (output_mask[0]) {
50+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input));
51+
}
52+
53+
if (output_mask[1]) {
54+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_weight, input));
55+
}
56+
57+
if (output_mask[2]) {
58+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_bias, input));
59+
}
5160

5261
ET_LOG_MSG_AND_RETURN_IF_FALSE(
5362
check_convolution_args(
@@ -267,19 +276,23 @@ std::tuple<Tensor&, Tensor&, Tensor&> convolution_backward_out(
267276
InvalidArgument,
268277
ret_val);
269278

270-
ET_KERNEL_CHECK(
271-
ctx,
272-
resize_tensor(grad_input, input.sizes()) == Error::Ok,
273-
InvalidArgument,
274-
ret_val);
279+
if (output_mask[0]) {
280+
ET_KERNEL_CHECK(
281+
ctx,
282+
resize_tensor(grad_input, input.sizes()) == Error::Ok,
283+
InvalidArgument,
284+
ret_val);
285+
}
275286

276-
ET_KERNEL_CHECK(
277-
ctx,
278-
resize_tensor(grad_weight, weight.sizes()) == Error::Ok,
279-
InvalidArgument,
280-
ret_val);
287+
if (output_mask[1]) {
288+
ET_KERNEL_CHECK(
289+
ctx,
290+
resize_tensor(grad_weight, weight.sizes()) == Error::Ok,
291+
InvalidArgument,
292+
ret_val);
293+
}
281294

282-
if (bias_sizes_opt.has_value()) {
295+
if (bias_sizes_opt.has_value() && output_mask[2]) {
283296
ET_KERNEL_CHECK(
284297
ctx,
285298
resize_tensor(grad_bias, bias_sizes_opt.value()) == Error::Ok,

0 commit comments

Comments
 (0)