Skip to content

Commit 358c67a

Browse files
committed
Merge branch 'fix_where_operator' of https://github.com/vlad-perevezentsev/dpnp into fix_where_operator
2 parents 3926f89 + 16cce9b commit 358c67a

19 files changed

+444
-289
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ env:
1212
CHANNELS: '-c dppy/label/dev -c intel -c main --override-channels'
1313
TEST_SCOPE: >-
1414
test_arraycreation.py
15+
test_dot.py
1516
test_dparray.py
1617
test_fft.py
1718
test_linalg.py

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ enum class DPNPFuncName : size_t
377377
DPNP_FN_VANDER_EXT, /**< Used in numpy.vander() impl, requires extra parameters */
378378
DPNP_FN_VAR, /**< Used in numpy.var() impl */
379379
DPNP_FN_VAR_EXT, /**< Used in numpy.var() impl, requires extra parameters */
380-
DPNP_FN_WHERE_EXT, /**< Used in numpy.var() impl, requires extra parameters */
380+
DPNP_FN_WHERE_EXT, /**< Used in numpy.where() impl, requires extra parameters */
381381
DPNP_FN_ZEROS, /**< Used in numpy.zeros() impl */
382382
DPNP_FN_ZEROS_LIKE, /**< Used in numpy.zeros_like() impl */
383383
DPNP_FN_LAST, /**< The latest element of the enumeration */

dpnp/backend/kernels/dpnp_krnl_bitwise.cpp

Lines changed: 52 additions & 45 deletions
Large diffs are not rendered by default.

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 77 additions & 54 deletions
Large diffs are not rendered by default.

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2016-2020, Intel Corporation
2+
// Copyright (c) 2016-2023, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without
@@ -114,10 +114,10 @@ DPCTLSyclEventRef (*dpnp_around_ext_c)(DPCTLSyclQueueRef,
114114
const int,
115115
const DPCTLEventVectorRef) = dpnp_around_c<_DataType>;
116116

117-
template <typename _KernelNameSpecialization>
117+
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2>
118118
class dpnp_elemwise_absolute_c_kernel;
119119

120-
template <typename _DataType>
120+
template <typename _DataType_input, typename _DataType_output>
121121
DPCTLSyclEventRef dpnp_elemwise_absolute_c(DPCTLSyclQueueRef q_ref,
122122
const void* input1_in,
123123
void* result1,
@@ -137,43 +137,63 @@ DPCTLSyclEventRef dpnp_elemwise_absolute_c(DPCTLSyclQueueRef q_ref,
137137
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
138138
sycl::event event;
139139

140-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, input1_in, size);
141-
_DataType* array1 = input1_ptr.get_ptr();
142-
DPNPC_ptr_adapter<_DataType> result1_ptr(q_ref, result1, size, false, true);
143-
_DataType* result = result1_ptr.get_ptr();
140+
_DataType_input* array1 = static_cast<_DataType_input*>(const_cast<void*>(input1_in));
141+
_DataType_output* result = static_cast<_DataType_output*>(result1);
144142

145-
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
143+
if constexpr (is_any_v<_DataType_input, float, double, std::complex<float>, std::complex<double>>)
146144
{
147-
// https://docs.oneapi.com/versions/latest/onemkl/abs.html
148145
event = oneapi::mkl::vm::abs(q, size, array1, result);
149146
}
150147
else
151148
{
152-
sycl::range<1> gws(size);
153-
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
154-
const size_t idx = global_id[0];
149+
static_assert(is_any_v<_DataType_input, int32_t, int64_t>,
150+
"Integer types are only expected to pass in 'abs' kernel");
151+
static_assert(std::is_same_v<_DataType_input, _DataType_output>, "Result type must match a type of input data");
152+
153+
constexpr size_t lws = 64;
154+
constexpr unsigned int vec_sz = 8;
155+
constexpr sycl::access::address_space global_space = sycl::access::address_space::global_space;
156+
157+
auto gws_range = sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
158+
auto lws_range = sycl::range<1>(lws);
155159

156-
if (array1[idx] >= 0)
160+
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
161+
auto sg = nd_it.get_sub_group();
162+
const auto max_sg_size = sg.get_max_local_range()[0];
163+
const size_t start =
164+
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
165+
166+
if (start + static_cast<size_t>(vec_sz) * max_sg_size < size)
157167
{
158-
result[idx] = array1[idx];
168+
using input_ptrT = sycl::multi_ptr<_DataType_input, global_space>;
169+
using result_ptrT = sycl::multi_ptr<_DataType_output, global_space>;
170+
171+
sycl::vec<_DataType_input, vec_sz> data_vec = sg.load<vec_sz>(input_ptrT(&array1[start]));
172+
173+
// sycl::abs() returns unsigned integers only, so explicit casting to signed ones is required
174+
using result_absT = typename cl::sycl::detail::make_unsigned<_DataType_output>::type;
175+
sycl::vec<_DataType_output, vec_sz> res_vec =
176+
dpnp_vec_cast<_DataType_output, result_absT, vec_sz>(sycl::abs(data_vec));
177+
178+
sg.store<vec_sz>(result_ptrT(&result[start]), res_vec);
159179
}
160180
else
161181
{
162-
result[idx] = -1 * array1[idx];
182+
for (size_t k = start + sg.get_local_id()[0]; k < size; k += max_sg_size)
183+
{
184+
result[k] = std::abs(array1[k]);
185+
}
163186
}
164187
};
165188

166189
auto kernel_func = [&](sycl::handler& cgh) {
167-
cgh.parallel_for<class dpnp_elemwise_absolute_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
190+
cgh.parallel_for<class dpnp_elemwise_absolute_c_kernel<_DataType_input, _DataType_output>>(
191+
sycl::nd_range<1>(gws_range, lws_range), kernel_parallel_for_func);
168192
};
169-
170193
event = q.submit(kernel_func);
171194
}
172195

173-
input1_ptr.depends_on(event);
174-
result1_ptr.depends_on(event);
175196
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
176-
177197
return DPCTLEvent_Copy(event_ref);
178198
}
179199

@@ -182,28 +202,24 @@ void dpnp_elemwise_absolute_c(const void* input1_in, void* result1, size_t size)
182202
{
183203
DPCTLSyclQueueRef q_ref = reinterpret_cast<DPCTLSyclQueueRef>(&DPNP_QUEUE);
184204
DPCTLEventVectorRef dep_event_vec_ref = nullptr;
185-
DPCTLSyclEventRef event_ref = dpnp_elemwise_absolute_c<_DataType>(q_ref,
186-
input1_in,
187-
result1,
188-
size,
189-
dep_event_vec_ref);
205+
DPCTLSyclEventRef event_ref = dpnp_elemwise_absolute_c<_DataType, _DataType>(q_ref,
206+
input1_in,
207+
result1,
208+
size,
209+
dep_event_vec_ref);
190210
DPCTLEvent_WaitAndThrow(event_ref);
211+
DPCTLEvent_Delete(event_ref);
191212
}
192213

193214
template <typename _DataType>
194215
void (*dpnp_elemwise_absolute_default_c)(const void*, void*, size_t) = dpnp_elemwise_absolute_c<_DataType>;
195216

196-
template <typename _DataType>
217+
template <typename _DataType_input, typename _DataType_output = _DataType_input>
197218
DPCTLSyclEventRef (*dpnp_elemwise_absolute_ext_c)(DPCTLSyclQueueRef,
198219
const void*,
199220
void*,
200221
size_t,
201-
const DPCTLEventVectorRef) = dpnp_elemwise_absolute_c<_DataType>;
202-
203-
// template void dpnp_elemwise_absolute_c<double>(void* array1_in, void* result1, size_t size);
204-
// template void dpnp_elemwise_absolute_c<float>(void* array1_in, void* result1, size_t size);
205-
// template void dpnp_elemwise_absolute_c<long>(void* array1_in, void* result1, size_t size);
206-
// template void dpnp_elemwise_absolute_c<int>(void* array1_in, void* result1, size_t size);
222+
const DPCTLEventVectorRef) = dpnp_elemwise_absolute_c<_DataType_input, _DataType_output>;
207223

208224
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
209225
DPCTLSyclEventRef dpnp_cross_c(DPCTLSyclQueueRef q_ref,
@@ -1085,10 +1101,12 @@ void func_map_init_mathematical(func_map_t& fmap)
10851101
(void*)dpnp_elemwise_absolute_ext_c<int32_t>};
10861102
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_LNG][eft_LNG] = {eft_LNG,
10871103
(void*)dpnp_elemwise_absolute_ext_c<int64_t>};
1088-
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1089-
(void*)dpnp_elemwise_absolute_ext_c<float>};
1090-
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1091-
(void*)dpnp_elemwise_absolute_ext_c<double>};
1104+
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_elemwise_absolute_ext_c<float>};
1105+
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_elemwise_absolute_ext_c<double>};
1106+
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C64][eft_C64] = {
1107+
eft_FLT, (void*)dpnp_elemwise_absolute_ext_c<std::complex<float>, float>};
1108+
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C128][eft_C128] = {
1109+
eft_DBL, (void*)dpnp_elemwise_absolute_ext_c<std::complex<double>, double>};
10921110

10931111
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_around_default_c<int32_t>};
10941112
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_around_default_c<int64_t>};

dpnp/backend/src/dpnp_fptr.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ struct is_any : std::disjunction<std::is_same<T, Ts>...> {};
163163
template <typename T, typename... Ts>
164164
struct are_same : std::conjunction<std::is_same<T, Ts>...> {};
165165

166+
/**
167+
* A template constant to check if type T matces any type from Ts.
168+
*/
169+
template <typename T, typename... Ts>
170+
constexpr auto is_any_v = is_any<T, Ts...>::value;
171+
166172
/**
167173
* A template constat to check if both types T1 and T2 match every type from Ts sequence.
168174
*/

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# cython: language_level=3
22
# -*- coding: utf-8 -*-
33
# *****************************************************************************
4-
# Copyright (c) 2016-2020, Intel Corporation
4+
# Copyright (c) 2016-2023, Intel Corporation
55
# All rights reserved.
66
#
77
# Redistribution and use in source and binary forms, with or without
@@ -65,8 +65,9 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_matmul_t)(c_dpctl.DPCTLSyclQue
6565
const shape_elem_type *, const shape_elem_type * ,
6666
const c_dpctl.DPCTLEventVectorRef)
6767

68-
cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp_descriptor in_array2):
69-
68+
cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1,
69+
utils.dpnp_descriptor in_array2,
70+
utils.dpnp_descriptor out=None):
7071
cdef shape_type_c shape1, shape2
7172

7273
shape1 = in_array1.shape
@@ -78,6 +79,7 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp
7879

7980
# get the FPTR data structure
8081
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_DOT_EXT, param1_type, param2_type)
82+
cdef utils.dpnp_descriptor result
8183

8284
ndim1 = in_array1.ndim
8385
ndim2 = in_array2.ndim
@@ -89,7 +91,7 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp
8991
elif ndim1 == 1 and ndim2 == 1:
9092
result_shape = ()
9193
elif ndim1 == 1: # ndim2 > 1
92-
result_shape = shape2[:-1]
94+
result_shape = shape2[::-2] if ndim2 == 2 else shape2[::2]
9395
elif ndim2 == 1: # ndim1 > 1
9496
result_shape = shape1[:-1]
9597
else:
@@ -101,13 +103,24 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp
101103

102104
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(in_array1, in_array2)
103105

104-
# create result array with type given by FPTR data
105-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
106-
kernel_data.return_type,
107-
None,
108-
device=result_sycl_device,
109-
usm_type=result_usm_type,
110-
sycl_queue=result_sycl_queue)
106+
if out is None:
107+
# create result array with type given by FPTR data
108+
result = utils.create_output_descriptor(result_shape,
109+
kernel_data.return_type,
110+
None,
111+
device=result_sycl_device,
112+
usm_type=result_usm_type,
113+
sycl_queue=result_sycl_queue)
114+
else:
115+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
116+
if out.dtype != result_type:
117+
utils.checker_throw_value_error('dot', 'out.dtype', out.dtype, result_type)
118+
if out.shape != result_shape:
119+
utils.checker_throw_value_error('dot', 'out.shape', out.shape, result_shape)
120+
121+
result = out
122+
123+
utils.get_common_usm_allocation(in_array1, result) # check USM allocation is common
111124

112125
cdef shape_type_c result_strides = utils.strides_to_vector(result.strides, result.shape)
113126
cdef shape_type_c in_array1_shape = in_array1.shape

dpnp/dpnp_array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,8 @@ def diagonal(input, offset=0, axis1=0, axis2=1):
592592

593593
return dpnp.diagonal(input, offset, axis1, axis2)
594594

595-
# 'dot',
595+
def dot(self, other, out=None):
596+
return dpnp.dot(self, other, out)
596597

597598
@property
598599
def dtype(self):

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@
4444
from dpnp.dpnp_algo import *
4545
from dpnp.dpnp_utils import *
4646
import dpnp
47-
import dpnp.config as config
4847

4948
import numpy
49+
import dpctl.tensor as dpt
5050

5151

5252
__all__ = [
@@ -62,18 +62,25 @@
6262
]
6363

6464

65-
def dot(x1, x2, **kwargs):
65+
def dot(x1, x2, out=None, **kwargs):
6666
"""
67-
Returns the dot product of `x1` and `x2`.
67+
Dot product of `x1` and `x2`.
6868
6969
For full documentation refer to :obj:`numpy.dot`.
7070
71+
Returns
72+
-------
73+
y : dpnp.ndarray
74+
Returns the dot product of `x1` and `x2`.
75+
If `out` is given, then it is returned.
76+
7177
Limitations
7278
-----------
73-
Parameters ``x1`` and ``x2`` are supported as :obj:`dpnp.ndarray` of the same type.
74-
Keyword arguments ``kwargs`` are currently unsupported.
75-
Otherwise the functions will be executed sequentially on CPU.
76-
Input array data types are limited by supported DPNP :ref:`Data types`.
79+
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
80+
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
81+
Keyword argument ``kwargs`` is currently unsupported.
82+
Otherwise the functions will be executed sequentially on CPU.
83+
Input array data types are limited by supported DPNP :ref:`Data types`.
7784
7885
See Also
7986
--------
@@ -82,31 +89,37 @@ def dot(x1, x2, **kwargs):
8289
8390
Examples
8491
--------
85-
>>> import dpnp as np
86-
>>> np.dot(3, 4)
87-
12
88-
>>> a = np.array([1, 2, 3])
89-
>>> b = np.array([1, 2, 3])
90-
>>> np.dot(a, b)
92+
>>> import dpnp as dp
93+
>>> a = dp.array([1, 2, 3])
94+
>>> b = dp.array([1, 2, 3])
95+
>>> dp.dot(a, b)
9196
14
9297
9398
"""
9499

95-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
96-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)
97-
if x1_desc and x2_desc and not kwargs:
98-
# TODO: remove fallback with scalars when muliply backend func will support strides
99-
if(x1_desc.ndim == 0 and x2_desc.strides is not None
100-
or x2_desc.ndim == 0 and x1_desc.strides is not None):
101-
pass
102-
elif (x1_desc.ndim >= 1 and x2_desc.ndim > 1 and x1_desc.shape[-1] != x2_desc.shape[-2]):
103-
pass
104-
elif (x1_desc.ndim > 0 and x2_desc.ndim == 1 and x1_desc.shape[-1] != x2_desc.shape[0]):
105-
pass
106-
else:
107-
return dpnp_dot(x1_desc, x2_desc).get_pyobj()
100+
if kwargs:
101+
pass
102+
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
103+
# at least either x1 or x2 has to be an array
104+
pass
105+
else:
106+
# get USM type and queue to copy scalar from the host memory into a USM allocation
107+
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)
108+
109+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
110+
alloc_usm_type=usm_type, alloc_queue=queue)
111+
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
112+
alloc_usm_type=usm_type, alloc_queue=queue)
113+
if x1_desc and x2_desc:
114+
if out is not None:
115+
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
116+
raise TypeError("return array must be of supported array type")
117+
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
118+
else:
119+
out_desc = None
120+
return dpnp_dot(x1_desc, x2_desc, out=out_desc).get_pyobj()
108121

109-
return call_origin(numpy.dot, x1, x2, **kwargs)
122+
return call_origin(numpy.dot, x1, x2, out=out, **kwargs)
110123

111124

112125
def einsum(*args, **kwargs):

0 commit comments

Comments
 (0)