Skip to content

Commit 8c33cbb

Browse files
committed
dpnp.divide() doesn't work properly with a scalar
1 parent 0b1345d commit 8c33cbb

File tree

13 files changed

+227
-135
lines changed

13 files changed

+227
-135
lines changed

dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ MACRO_2ARG_3TYPES_OP(dpnp_copysign_c,
132132

133133
MACRO_2ARG_3TYPES_OP(dpnp_divide_c,
134134
input1_elem / input2_elem,
135-
nullptr,
136-
std::false_type,
135+
sycl::native::divide(x1, x2),
136+
MACRO_UNPACK_TYPES(float, double),
137137
oneapi::mkl::vm::div,
138-
MACRO_UNPACK_TYPES(float, double))
138+
MACRO_UNPACK_TYPES(float, double, std::complex<float>, std::complex<double>))
139139

140140
MACRO_2ARG_3TYPES_OP(dpnp_fmod_c,
141141
sycl::fmod((double)input1_elem, (double)input2_elem),

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,26 @@ size_t operator-(DPNPFuncType lhs, DPNPFuncType rhs);
419419
*/
420420
typedef struct DPNPFuncData
421421
{
422-
DPNPFuncType return_type; /**< return type identifier which expected by the @ref ptr function */
423-
void* ptr; /**< C++ backend function pointer */
422+
DPNPFuncData(const DPNPFuncType gen_type, void* gen_ptr, const DPNPFuncType type_no_fp64, void* ptr_no_fp64)
423+
: return_type(gen_type)
424+
, ptr(gen_ptr)
425+
, return_type_no_fp64(type_no_fp64)
426+
, ptr_no_fp64(ptr_no_fp64)
427+
{
428+
}
429+
DPNPFuncData(const DPNPFuncType gen_type, void* gen_ptr)
430+
: DPNPFuncData(gen_type, gen_ptr, DPNPFuncType::DPNP_FT_NONE, nullptr)
431+
{
432+
}
433+
DPNPFuncData()
434+
: DPNPFuncData(DPNPFuncType::DPNP_FT_NONE, nullptr)
435+
{
436+
}
437+
438+
DPNPFuncType return_type; /**< return type identifier which expected by the @ref ptr function */
439+
void* ptr; /**< C++ backend function pointer */
440+
DPNPFuncType return_type_no_fp64; /**< alternative return type identifier when no fp64 support by device */
441+
void* ptr_no_fp64; /**< alternative C++ backend function pointer when no fp64 support by device */
424442
} DPNPFuncData_t;
425443

426444
/**

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,47 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
11781178

11791179
#include <dpnp_gen_2arg_3type_tbl.hpp>
11801180

1181+
template <DPNPFuncType FT1, DPNPFuncType FT2, typename has_fp64 = std::true_type>
1182+
static constexpr DPNPFuncType get_divide_res_type()
1183+
{
1184+
constexpr auto widest_type = populate_func_types<FT1, FT2>();
1185+
constexpr auto shortes_type = (widest_type == FT1) ? FT2 : FT1;
1186+
1187+
if constexpr (widest_type == DPNPFuncType::DPNP_FT_CMPLX128 || widest_type == DPNPFuncType::DPNP_FT_DOUBLE)
1188+
{
1189+
return widest_type;
1190+
}
1191+
else if constexpr (widest_type == DPNPFuncType::DPNP_FT_CMPLX64)
1192+
{
1193+
if constexpr (shortes_type == DPNPFuncType::DPNP_FT_DOUBLE)
1194+
{
1195+
return DPNPFuncType::DPNP_FT_CMPLX128;
1196+
}
1197+
else if constexpr (has_fp64::value &&
1198+
(shortes_type == DPNPFuncType::DPNP_FT_INT || shortes_type == DPNPFuncType::DPNP_FT_LONG))
1199+
{
1200+
return DPNPFuncType::DPNP_FT_CMPLX128;
1201+
}
1202+
}
1203+
else if constexpr (widest_type == DPNPFuncType::DPNP_FT_FLOAT)
1204+
{
1205+
if constexpr (has_fp64::value &&
1206+
(shortes_type == DPNPFuncType::DPNP_FT_INT || shortes_type == DPNPFuncType::DPNP_FT_LONG))
1207+
{
1208+
return DPNPFuncType::DPNP_FT_DOUBLE;
1209+
}
1210+
}
1211+
else if constexpr (has_fp64::value)
1212+
{
1213+
return DPNPFuncType::DPNP_FT_DOUBLE;
1214+
}
1215+
else
1216+
{
1217+
return DPNPFuncType::DPNP_FT_FLOAT;
1218+
}
1219+
return widest_type;
1220+
}
1221+
11811222
template <DPNPFuncType FT1, DPNPFuncType... FTs>
11821223
static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
11831224
{
@@ -1199,6 +1240,16 @@ static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
11991240
func_type_map_t::find_type<FT1>,
12001241
func_type_map_t::find_type<FTs>>}),
12011242
...);
1243+
((fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][FT1][FTs] =
1244+
{get_divide_res_type<FT1, FTs>(),
1245+
(void*)dpnp_divide_c_ext<func_type_map_t::find_type<get_divide_res_type<FT1, FTs>()>,
1246+
func_type_map_t::find_type<FT1>,
1247+
func_type_map_t::find_type<FTs>>,
1248+
get_divide_res_type<FT1, FTs, std::false_type>(),
1249+
(void*)dpnp_divide_c_ext<func_type_map_t::find_type<get_divide_res_type<FT1, FTs, std::false_type>()>,
1250+
func_type_map_t::find_type<FT1>,
1251+
func_type_map_t::find_type<FTs>>}),
1252+
...);
12021253
}
12031254

12041255
template <DPNPFuncType... FTs>
@@ -1407,39 +1458,6 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
14071458
fmap[DPNPFuncName::DPNP_FN_DIVIDE][eft_DBL][eft_DBL] = {eft_DBL,
14081459
(void*)dpnp_divide_c_default<double, double, double>};
14091460

1410-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_INT] = {eft_DBL,
1411-
(void*)dpnp_divide_c_ext<double, int32_t, int32_t>};
1412-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_LNG] = {eft_DBL,
1413-
(void*)dpnp_divide_c_ext<double, int32_t, int64_t>};
1414-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_FLT] = {eft_DBL,
1415-
(void*)dpnp_divide_c_ext<double, int32_t, float>};
1416-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_DBL] = {eft_DBL,
1417-
(void*)dpnp_divide_c_ext<double, int32_t, double>};
1418-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_INT] = {eft_DBL,
1419-
(void*)dpnp_divide_c_ext<double, int64_t, int32_t>};
1420-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_LNG] = {eft_DBL,
1421-
(void*)dpnp_divide_c_ext<double, int64_t, int64_t>};
1422-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_FLT] = {eft_DBL,
1423-
(void*)dpnp_divide_c_ext<double, int64_t, float>};
1424-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_DBL] = {eft_DBL,
1425-
(void*)dpnp_divide_c_ext<double, int64_t, double>};
1426-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_INT] = {eft_DBL,
1427-
(void*)dpnp_divide_c_ext<double, float, int32_t>};
1428-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_LNG] = {eft_DBL,
1429-
(void*)dpnp_divide_c_ext<double, float, int64_t>};
1430-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1431-
(void*)dpnp_divide_c_ext<float, float, float>};
1432-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_DBL] = {eft_DBL,
1433-
(void*)dpnp_divide_c_ext<double, float, double>};
1434-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_INT] = {eft_DBL,
1435-
(void*)dpnp_divide_c_ext<double, double, int32_t>};
1436-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_LNG] = {eft_DBL,
1437-
(void*)dpnp_divide_c_ext<double, double, int64_t>};
1438-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_FLT] = {eft_DBL,
1439-
(void*)dpnp_divide_c_ext<double, double, float>};
1440-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1441-
(void*)dpnp_divide_c_ext<double, double, double>};
1442-
14431461
fmap[DPNPFuncName::DPNP_FN_FMOD][eft_INT][eft_INT] = {eft_INT,
14441462
(void*)dpnp_fmod_c_default<int32_t, int32_t, int32_t>};
14451463
fmap[DPNPFuncName::DPNP_FN_FMOD][eft_INT][eft_LNG] = {eft_LNG,

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ cdef extern from "dpnp_iface_fptr.hpp":
374374
struct DPNPFuncData:
375375
DPNPFuncType return_type
376376
void * ptr
377+
DPNPFuncType return_type_no_fp64
378+
void *ptr_no_fp64
377379

378380
DPNPFuncData get_dpnp_function_ptr(DPNPFuncName name, DPNPFuncType first_type, DPNPFuncType second_type) except +
379381

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,6 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
481481
# get the FPTR data structure
482482
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, x1_c_type, x2_c_type)
483483

484-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
485-
486484
# Create result array
487485
cdef shape_type_c x1_shape = x1_obj.shape
488486

@@ -495,6 +493,15 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
495493

496494
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj)
497495

496+
# get FPTR function and result type
497+
cdef fptr_2in_1out_strides_t func = NULL
498+
if fptr_name != DPNP_FN_DIVIDE_EXT or result_sycl_device.has_aspect_fp64:
499+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
500+
func = < fptr_2in_1out_strides_t > kernel_data.ptr
501+
else:
502+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type_no_fp64)
503+
func = < fptr_2in_1out_strides_t > kernel_data.ptr_no_fp64
504+
498505
if out is None:
499506
""" Create result array with type given by FPTR data """
500507
result = utils.create_output_descriptor(result_shape,
@@ -517,11 +524,10 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
517524

518525
result_obj = result.get_array()
519526

520-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_obj.sycl_queue
527+
cdef c_dpctl.SyclQueue q = < c_dpctl.SyclQueue > result_obj.sycl_queue
521528
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
522529

523530
""" Call FPTR function """
524-
cdef fptr_2in_1out_strides_t func = <fptr_2in_1out_strides_t > kernel_data.ptr
525531
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
526532
result.get_data(),
527533
result.size,

dpnp/dpnp_iface_mathematical.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -544,55 +544,64 @@ def diff(x1, n=1, axis=-1, prepend=numpy._NoValue, append=numpy._NoValue):
544544
return call_origin(numpy.diff, x1, n=n, axis=axis, prepend=prepend, append=append)
545545

546546

547-
def divide(x1, x2, dtype=None, out=None, where=True, **kwargs):
547+
def divide(x1,
548+
x2,
549+
/,
550+
out=None,
551+
*,
552+
where=True,
553+
dtype=None,
554+
subok=True,
555+
**kwargs):
548556
"""
549557
Divide arguments element-wise.
550558
551559
For full documentation refer to :obj:`numpy.divide`.
552560
561+
Returns
562+
-------
563+
y : dpnp.ndarray
564+
The quotient ``x1/x2``, element-wise.
565+
553566
Limitations
554567
-----------
555-
Parameters ``x1`` and ``x2`` are supported as either :obj:`dpnp.ndarray` or scalar.
556-
Parameters ``dtype``, ``out`` and ``where`` are supported with their default values.
568+
Parameters `x1` and `x2` are supported as either :class:`dpnp.ndarray` or scalar,
569+
but not both (at least either `x1` or `x2` should be as :class:`dpnp.ndarray`).
570+
Parameters `out`, `where`, `dtype` and `subok` are supported with their default values.
557571
Keyword arguments ``kwargs`` are currently unsupported.
558-
Otherwise the functions will be executed sequentially on CPU.
572+
Otherwise the function will be executed sequentially on CPU.
559573
Input array data types are limited by supported DPNP :ref:`Data types`.
560574
561575
Examples
562576
--------
563-
>>> import dpnp as np
564-
>>> result = np.divide(np.array([1, -2, 6, -9]), np.array([-2, -2, -2, -2]))
565-
>>> [x for x in result]
577+
>>> import dpnp as dp
578+
>>> result = dp.divide(dp.array([1, -2, 6, -9]), dp.array([-2, -2, -2, -2]))
579+
>>> print(result)
566580
[-0.5, 1.0, -3.0, 4.5]
567581
568582
"""
569583

570-
x1_is_scalar = dpnp.isscalar(x1)
571-
x2_is_scalar = dpnp.isscalar(x2)
572-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
573-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)
584+
if out is not None:
585+
pass
586+
elif where is not True:
587+
pass
588+
elif dtype is not None:
589+
pass
590+
elif subok is not True:
591+
pass
592+
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
593+
# at least either x1 or x2 has to be an array
594+
pass
595+
else:
596+
# get a common queue to copy data from the host into a device if any input is scalar
597+
queue = get_common_allocation_queue([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else None
574598

575-
if x1_desc and x2_desc and not kwargs:
576-
if not x1_desc and not x1_is_scalar:
577-
pass
578-
elif not x2_desc and not x2_is_scalar:
579-
pass
580-
elif x1_is_scalar and x2_is_scalar:
581-
pass
582-
elif x1_desc and x1_desc.ndim == 0:
583-
pass
584-
elif x2_desc and x2_desc.ndim == 0:
585-
pass
586-
elif dtype is not None:
587-
pass
588-
elif out is not None:
589-
pass
590-
elif not where:
591-
pass
592-
else:
599+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False, alloc_queue=queue)
600+
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False, alloc_queue=queue)
601+
if x1_desc and x2_desc:
593602
return dpnp_divide(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
594603

595-
return call_origin(numpy.divide, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
604+
return call_origin(numpy.divide, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
596605

597606

598607
def ediff1d(x1, to_end=None, to_begin=None):

tests/conftest.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# *****************************************************************************
3-
# Copyright (c) 2016-2020, Intel Corporation
3+
# Copyright (c) 2016-2023, Intel Corporation
44
# All rights reserved.
55
#
66
# Redistribution and use in source and binary forms, with or without
@@ -77,3 +77,22 @@ def pytest_collection_modifyitems(config, items):
7777
@pytest.fixture
7878
def allow_fall_back_on_numpy(monkeypatch):
7979
monkeypatch.setattr(dpnp.config, '__DPNP_RAISE_EXCEPION_ON_NUMPY_FALLBACK__', 0)
80+
81+
@pytest.fixture
82+
def suppress_divide_numpy_warnings():
83+
# divide: treatment for division by zero (infinite result obtained from finite numbers)
84+
old_settings = numpy.seterr(divide='ignore')
85+
yield
86+
numpy.seterr(**old_settings) # reset to default
87+
88+
@pytest.fixture
89+
def suppress_invalid_numpy_warnings():
90+
# invalid: treatment for invalid floating-point operation
91+
# (result is not an expressible number, typically indicates that a NaN was produced)
92+
old_settings = numpy.seterr(invalid='ignore')
93+
yield
94+
numpy.seterr(**old_settings) # reset to default
95+
96+
@pytest.fixture
97+
def suppress_divide_invalid_numpy_warnings(suppress_divide_numpy_warnings, suppress_invalid_numpy_warnings):
98+
yield

tests/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_all_dtypes(no_bool=False,
3232
dtypes.append(dpnp.complex64)
3333
if dev.has_aspect_fp64:
3434
dtypes.append(dpnp.complex128)
35-
35+
3636
# add None value to validate a default dtype
3737
if not no_none:
3838
dtypes.append(None)

0 commit comments

Comments
 (0)