Skip to content

Commit 448700d

Browse files
huydhnmalfet
andauthored
Fix NULL dereference in binary CPU ops (pytorch#115241)
* Fix NULL dereference in binary CPU ops (pytorch#115183) Targeted fix for pytorch#113037 A more fundamental one, where those functions are not even called for empty tensors are coming later Pull Request resolved: pytorch#115183 Approved by: https://github.com/drisspg, https://github.com/atalman, https://github.com/huydhn * Fix build after conflict resolution * Also include pytorch#113262 to pass the test --------- Co-authored-by: Nikita Shulga <[email protected]>
1 parent 5965649 commit 448700d

File tree

3 files changed

+62
-29
lines changed

3 files changed

+62
-29
lines changed

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ void mul_kernel(TensorIteratorBase& iter) {
101101
using comp_t = c10::complex<float>;
102102
return comp_t{a} * comp_t{b};
103103
});
104-
} else if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
104+
} else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
105105
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "mul_cpu_reduced_float", [&]() {
106106
using opmath_t = at::opmath_type<scalar_t>;
107107
opmath_t b = iter.original_scalar_value<opmath_t>(2);
@@ -125,7 +125,7 @@ void mul_kernel(TensorIteratorBase& iter) {
125125

126126
void div_true_kernel(TensorIteratorBase& iter) {
127127
const auto dtype = iter.common_dtype();
128-
if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
128+
if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
129129
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "div_cpu_reduced_float", [&]() {
130130
using opmath_t = at::opmath_type<scalar_t>;
131131
opmath_t b = iter.original_scalar_value<opmath_t>(2);
@@ -162,19 +162,28 @@ void div_trunc_kernel(TensorIteratorBase& iter) {
162162
return a / b;
163163
});
164164
});
165-
} else if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
166-
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "div_trunc_cpu_reduced_float", [&]() {
167-
using opmath_t = at::opmath_type<scalar_t>;
168-
opmath_t b = iter.original_scalar_value<opmath_t>(2);
169-
iter.remove_operand(2);
170-
cpu_kernel_vec(iter,
171-
[=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
172-
return std::trunc(static_cast<opmath_t>(a) / b);
173-
},
174-
[=](Vectorized<scalar_t> a) {
175-
return binary_op_scalar(a, b, [](const Vectorized<opmath_t>& x, const Vectorized<opmath_t>& y) { return (x / y).trunc(); });
165+
} else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
166+
AT_DISPATCH_REDUCED_FLOATING_TYPES(
167+
dtype, "div_trunc_cpu_reduced_float", [&]() {
168+
using opmath_t = at::opmath_type<scalar_t>;
169+
opmath_t b = iter.original_scalar_value<opmath_t>(2);
170+
iter.remove_operand(2);
171+
cpu_kernel_vec(
172+
iter,
173+
[=](scalar_t a)
174+
__ubsan_ignore_float_divide_by_zero__ -> scalar_t {
175+
return std::trunc(static_cast<opmath_t>(a) / b);
176+
},
177+
[=](Vectorized<scalar_t> a) {
178+
return binary_op_scalar(
179+
a,
180+
b,
181+
[](const Vectorized<opmath_t>& x,
182+
const Vectorized<opmath_t>& y) {
183+
return (x / y).trunc();
184+
});
185+
});
176186
});
177-
});
178187
} else {
179188
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, dtype, "div_trunc_cpu", [&]() {
180189
cpu_kernel_vec(iter,
@@ -223,20 +232,25 @@ void div_floor_kernel(TensorIteratorBase& iter) {
223232
});
224233
} else {
225234
// See NOTE: [Floor Division in Python]
226-
if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
227-
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "div_floor_cpu_reduced_float", [&]() {
228-
using opmath_t = at::opmath_type<scalar_t>;
229-
opmath_t b = iter.original_scalar_value<opmath_t>(2);
230-
iter.remove_operand(2);
231-
using vec_t = Vectorized<opmath_t>;
232-
cpu_kernel_vec(iter,
233-
[=](scalar_t a) -> scalar_t {
234-
return div_floor_floating(static_cast<opmath_t>(a), b);
235-
},
236-
[=](Vectorized<scalar_t> a) {
237-
return binary_op_scalar(a, b, [](const vec_t& x, const vec_t& y) { return div_floor_floating_vec(x, y); });
235+
if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
236+
AT_DISPATCH_REDUCED_FLOATING_TYPES(
237+
dtype, "div_floor_cpu_reduced_float", [&]() {
238+
using opmath_t = at::opmath_type<scalar_t>;
239+
opmath_t b = iter.original_scalar_value<opmath_t>(2);
240+
iter.remove_operand(2);
241+
using vec_t = Vectorized<opmath_t>;
242+
cpu_kernel_vec(
243+
iter,
244+
[=](scalar_t a) -> scalar_t {
245+
return div_floor_floating(static_cast<opmath_t>(a), b);
246+
},
247+
[=](Vectorized<scalar_t> a) {
248+
return binary_op_scalar(
249+
a, b, [](const vec_t& x, const vec_t& y) {
250+
return div_floor_floating_vec(x, y);
251+
});
252+
});
238253
});
239-
});
240254
} else {
241255
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, dtype, "div_floor_cpu", [&]() {
242256
using vec_t = Vectorized<scalar_t>;

test/test_foreach.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,16 +516,20 @@ def test_reduce_op(self, device, dtype, op, is_fastpath):
516516
sum(ref((ref_tensors,), ord=ord)).backward()
517517
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
518518

519-
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
519+
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
520520
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
521521
# TODO: enable empty list case
522-
for tensors in [[torch.randn([0])]]:
522+
for tensors in [[torch.randn([0], device=device, dtype=dtype)],
523+
[torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)]]:
523524
res = torch._foreach_add(tensors, 1)
524525
self.assertEqual(res, tensors)
525526

526527
torch._foreach_add_(tensors, 1)
527528
self.assertEqual(res, tensors)
528529

530+
# Regression test for https://github.com/pytorch/pytorch/issues/113156
531+
torch._foreach_mul_(tensors, 1)
532+
529533
@ops(
530534
filter(lambda op: not op.has_no_out_of_place, foreach_binary_op_db),
531535
dtypes=OpDTypes.supported,

test/test_numpy_interop.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,21 @@ def test_numpy_scalar_cmp(self, device, dtype):
472472
else:
473473
self.assertTrue(t == a)
474474

475+
@onlyCPU
476+
def test_empty_tensors_interop(self, device):
477+
x = torch.rand((), dtype=torch.float16)
478+
y = torch.tensor(np.random.rand(0), dtype=torch.float16)
479+
# Same can be achieved by running
480+
# y = torch.empty_strided((0,), (0,), dtype=torch.float16)
481+
482+
# Regression test for https://github.com/pytorch/pytorch/issues/115068
483+
self.assertEqual(torch.true_divide(x, y).shape, y.shape)
484+
# Regression test for https://github.com/pytorch/pytorch/issues/115066
485+
self.assertEqual(torch.mul(x, y).shape, y.shape)
486+
# Regression test for https://github.com/pytorch/pytorch/issues/113037
487+
self.assertEqual(torch.div(x, y, rounding_mode='floor').shape, y.shape)
488+
489+
475490
instantiate_device_type_tests(TestNumPyInterop, globals())
476491

477492
if __name__ == '__main__':

0 commit comments

Comments
 (0)