Skip to content

Commit 0b1345d

Browse files
committed
dpnp.subtract() doesn't work properly with a scalar
1 parent e26c3f1 commit 0b1345d

File tree

10 files changed

+148
-133
lines changed

10 files changed

+148
-133
lines changed

dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ MACRO_2ARG_3TYPES_OP(dpnp_power_c,
184184

185185
MACRO_2ARG_3TYPES_OP(dpnp_subtract_c,
186186
input1_elem - input2_elem,
187-
nullptr,
188-
std::false_type,
187+
sycl::sub_sat(x1, x2),
188+
MACRO_UNPACK_TYPES(int, long),
189189
oneapi::mkl::vm::sub,
190-
MACRO_UNPACK_TYPES(float, double))
190+
MACRO_UNPACK_TYPES(float, double, std::complex<float>, std::complex<double>))
191191

192192
#undef MACRO_2ARG_3TYPES_OP

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,12 @@ static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
11931193
func_type_map_t::find_type<FT1>,
11941194
func_type_map_t::find_type<FTs>>}),
11951195
...);
1196+
((fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][FT1][FTs] =
1197+
{populate_func_types<FT1, FTs>(),
1198+
(void*)dpnp_subtract_c_ext<func_type_map_t::find_type<populate_func_types<FT1, FTs>()>,
1199+
func_type_map_t::find_type<FT1>,
1200+
func_type_map_t::find_type<FTs>>}),
1201+
...);
11961202
}
11971203

11981204
template <DPNPFuncType... FTs>
@@ -1878,39 +1884,6 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
18781884
fmap[DPNPFuncName::DPNP_FN_SUBTRACT][eft_DBL][eft_DBL] = {
18791885
eft_DBL, (void*)dpnp_subtract_c_default<double, double, double>};
18801886

1881-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_INT][eft_INT] = {
1882-
eft_INT, (void*)dpnp_subtract_c_ext<int32_t, int32_t, int32_t>};
1883-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_INT][eft_LNG] = {
1884-
eft_LNG, (void*)dpnp_subtract_c_ext<int64_t, int32_t, int64_t>};
1885-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_INT][eft_FLT] = {
1886-
eft_DBL, (void*)dpnp_subtract_c_ext<double, int32_t, float>};
1887-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_INT][eft_DBL] = {
1888-
eft_DBL, (void*)dpnp_subtract_c_ext<double, int32_t, double>};
1889-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_LNG][eft_INT] = {
1890-
eft_LNG, (void*)dpnp_subtract_c_ext<int64_t, int64_t, int32_t>};
1891-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_LNG][eft_LNG] = {
1892-
eft_LNG, (void*)dpnp_subtract_c_ext<int64_t, int64_t, int64_t>};
1893-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_LNG][eft_FLT] = {
1894-
eft_DBL, (void*)dpnp_subtract_c_ext<double, int64_t, float>};
1895-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_LNG][eft_DBL] = {
1896-
eft_DBL, (void*)dpnp_subtract_c_ext<double, int64_t, double>};
1897-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_FLT][eft_INT] = {
1898-
eft_DBL, (void*)dpnp_subtract_c_ext<double, float, int32_t>};
1899-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_FLT][eft_LNG] = {
1900-
eft_DBL, (void*)dpnp_subtract_c_ext<double, float, int64_t>};
1901-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_FLT][eft_FLT] = {
1902-
eft_FLT, (void*)dpnp_subtract_c_ext<float, float, float>};
1903-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_FLT][eft_DBL] = {
1904-
eft_DBL, (void*)dpnp_subtract_c_ext<double, float, double>};
1905-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_DBL][eft_INT] = {
1906-
eft_DBL, (void*)dpnp_subtract_c_ext<double, double, int32_t>};
1907-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_DBL][eft_LNG] = {
1908-
eft_DBL, (void*)dpnp_subtract_c_ext<double, double, int64_t>};
1909-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_DBL][eft_FLT] = {
1910-
eft_DBL, (void*)dpnp_subtract_c_ext<double, double, float>};
1911-
fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][eft_DBL][eft_DBL] = {
1912-
eft_DBL, (void*)dpnp_subtract_c_ext<double, double, double>};
1913-
19141887
func_map_elemwise_2arg_3type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL, eft_C64, eft_C128>(fmap);
19151888

19161889
return;

dpnp/dpnp_array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ def __rmul__(self, other):
270270
# '__rpow__',
271271
# '__rrshift__',
272272
# '__rshift__',
273-
# '__rsub__',
273+
274+
def __rsub__(self, other):
275+
return dpnp.subtract(other, self)
274276

275277
def __rtruediv__(self, other):
276278
return dpnp.true_divide(other, self)

dpnp/dpnp_iface_mathematical.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def add(x1,
215215
if x1_desc and x2_desc:
216216
return dpnp_add(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
217217

218-
return call_origin(numpy.add, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
218+
return call_origin(numpy.add, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
219219

220220

221221
def around(x1, decimals=0, out=None):
@@ -1145,7 +1145,7 @@ def multiply(x1,
11451145
if x1_desc and x2_desc:
11461146
return dpnp_multiply(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
11471147

1148-
return call_origin(numpy.multiply, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
1148+
return call_origin(numpy.multiply, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
11491149

11501150

11511151
def nancumprod(x1, **kwargs):
@@ -1520,60 +1520,67 @@ def sign(x1, **kwargs):
15201520
return call_origin(numpy.sign, x1, **kwargs)
15211521

15221522

1523-
def subtract(x1, x2, dtype=None, out=None, where=True, **kwargs):
1523+
def subtract(x1,
1524+
x2,
1525+
/,
1526+
out=None,
1527+
*,
1528+
where=True,
1529+
dtype=None,
1530+
subok=True,
1531+
**kwargs):
15241532
"""
15251533
Subtract arguments, element-wise.
15261534
15271535
For full documentation refer to :obj:`numpy.subtract`.
15281536
1537+
Returns
1538+
-------
1539+
y : dpnp.ndarray
1540+
The difference of `x1` and `x2`, element-wise.
1541+
15291542
Limitations
15301543
-----------
1531-
Parameters ``x1`` and ``x2`` are supported as either :obj:`dpnp.ndarray` or scalar.
1532-
Parameters ``dtype``, ``out`` and ``where`` are supported with their default values.
1544+
Parameters `x1` and `x2` are supported as either :class:`dpnp.ndarray` or scalar,
1545+
but not both (at least either `x1` or `x2` should be as :class:`dpnp.ndarray`).
1546+
Parameters `out`, `where`, `dtype` and `subok` are supported with their default values.
15331547
Keyword arguments ``kwargs`` are currently unsupported.
1534-
Otherwise the functions will be executed sequentially on CPU.
1548+
Otherwise the function will be executed sequentially on CPU.
15351549
Input array data types are limited by supported DPNP :ref:`Data types`.
15361550
15371551
Example
15381552
-------
1539-
>>> import dpnp as np
1540-
>>> result = np.subtract(np.array([4, 3]), np.array([2, 7]))
1541-
>>> [x for x in result]
1553+
>>> import dpnp as dp
1554+
>>> result = dp.subtract(dp.array([4, 3]), dp.array([2, 7]))
1555+
>>> print(result)
15421556
[2, -4]
15431557
15441558
"""
15451559

1546-
x1_is_scalar = dpnp.isscalar(x1)
1547-
x2_is_scalar = dpnp.isscalar(x2)
1548-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
1549-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)
1560+
if out is not None:
1561+
pass
1562+
elif where is not True:
1563+
pass
1564+
elif dtype is not None:
1565+
pass
1566+
elif subok is not True:
1567+
pass
1568+
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
1569+
# at least either x1 or x2 has to be an array
1570+
pass
1571+
else:
1572+
# get a common queue to copy data from the host into a device if any input is scalar
1573+
queue = get_common_allocation_queue([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else None
15501574

1551-
if x1_desc and x2_desc and not kwargs:
1552-
if not x1_desc and not x1_is_scalar:
1553-
pass
1554-
elif not x2_desc and not x2_is_scalar:
1555-
pass
1556-
elif x1_is_scalar and x2_is_scalar:
1557-
pass
1558-
elif x1_desc and x1_desc.ndim == 0:
1559-
pass
1560-
elif x1_desc and x1_desc.dtype == dpnp.bool:
1561-
pass
1562-
elif x2_desc and x2_desc.ndim == 0:
1563-
pass
1564-
elif x2_desc and x2_desc.dtype == dpnp.bool:
1565-
pass
1566-
elif dtype is not None:
1567-
pass
1568-
elif out is not None:
1569-
pass
1570-
elif not where:
1571-
pass
1572-
else:
1573-
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) if out is not None else None
1574-
return dpnp_subtract(x1_desc, x2_desc, dtype=dtype, out=out_desc, where=where).get_pyobj()
1575+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False, alloc_queue=queue)
1576+
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False, alloc_queue=queue)
1577+
if x1_desc and x2_desc:
1578+
if x1_desc.dtype == x2_desc.dtype == dpnp.bool:
1579+
raise TypeError("DPNP boolean subtract, the `-` operator, is not supported, "
1580+
"use the bitwise_xor, the `^` operator, or the logical_xor function instead.")
1581+
return dpnp_subtract(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
15751582

1576-
return call_origin(numpy.subtract, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
1583+
return call_origin(numpy.subtract, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
15771584

15781585

15791586
def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=True):

tests/skipped_tests.tbl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_asar
389389
tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_ascontiguousarray_on_noncontiguous_array
390390
tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_asfortranarray_cuda_array_zero_dim
391391
tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_asfortranarray_cuda_array_zero_dim_dtype
392-
tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_fromfile
392+
393393
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_0_{copy=False, indexing='xy', sparse=False}::test_meshgrid0
394394
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_0_{copy=False, indexing='xy', sparse=False}::test_meshgrid1
395395
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_0_{copy=False, indexing='xy', sparse=False}::test_meshgrid2
@@ -773,7 +773,6 @@ tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNu
773773
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_6_{name='subtract', nargs=2}::test_raises_with_numpy_input
774774
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_8_{name='floor_divide', nargs=2}::test_raises_with_numpy_input
775775

776-
tests/third_party/cupy/math_tests/test_arithmetic.py::TestBoolSubtract_param_3_{shape=(), xp=dpnp}::test_bool_subtract
777776
tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp
778777
tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp2
779778
tests/third_party/cupy/math_tests/test_floating.py::TestFloating::test_copysign_float

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-conjugate-data2]
1818
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-copy-data3]
1919
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-cumprod-data4]
2020
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-cumsum-data5]
21-
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-diff-data6]
2221
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-ediff1d-data7]
2322
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-fabs-data8]
2423
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-floor-data9]
@@ -29,11 +28,9 @@ tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-conjugate-data2]
2928
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-copy-data3]
3029
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-cumprod-data4]
3130
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-cumsum-data5]
32-
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-diff-data6]
3331
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-ediff1d-data7]
3432
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-fabs-data8]
3533
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-floor-data9]
36-
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-gradient-data10]
3734
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-nancumprod-data11]
3835
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-nancumsum-data12]
3936
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-nanprod-data13]
@@ -554,7 +551,6 @@ tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_asar
554551
tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_ascontiguousarray_on_noncontiguous_array
555552
tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_asfortranarray_cuda_array_zero_dim
556553
tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_asfortranarray_cuda_array_zero_dim_dtype
557-
tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_fromfile
558554

559555
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_0_{copy=False, indexing='xy', sparse=False}::test_meshgrid0
560556
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_0_{copy=False, indexing='xy', sparse=False}::test_meshgrid1

tests/test_arraycreation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_frombuffer(dtype):
109109

110110

111111
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
112-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
112+
@pytest.mark.parametrize("dtype", get_all_dtypes())
113113
def test_fromfile(dtype):
114114
with tempfile.TemporaryFile() as fh:
115115
fh.write(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08")
@@ -275,6 +275,7 @@ def test_tri_default_dtype():
275275
'[[1, 2], [3, 4]]',
276276
'[[0, 1, 2], [3, 4, 5], [6, 7, 8]]',
277277
'[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]'])
278+
# TODO: add fixture 'allow_fall_back_on_numpy' and remove operator.index()
278279
def test_tril(m, k):
279280
a = numpy.array(m)
280281
ia = dpnp.array(a)
@@ -295,6 +296,7 @@ def test_tril(m, k):
295296
'[[1, 2], [3, 4]]',
296297
'[[0, 1, 2], [3, 4, 5], [6, 7, 8]]',
297298
'[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]'])
299+
# TODO: add fixture 'allow_fall_back_on_numpy' and remove operator.index()
298300
def test_triu(m, k):
299301
a = numpy.array(m)
300302
ia = dpnp.array(a)

0 commit comments

Comments
 (0)