Skip to content

Commit df8eb5f

Browse files
Merge pull request #1328 from IntelPython/fix-some-array-api-test-cases
2 parents a3c00bc + 3c0aeed commit df8eb5f

File tree

2 files changed

+82
-81
lines changed

2 files changed

+82
-81
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 59 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -290,78 +290,6 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
290290
_copy_same_shape(dst, src_same_shape)
291291

292292

293-
def copy(usm_ary, order="K"):
294-
"""copy(ary, order="K")
295-
296-
Creates a copy of given instance of :class:`dpctl.tensor.usm_ndarray`.
297-
298-
Args:
299-
ary (usm_ndarray):
300-
Input array.
301-
order ({"C", "F", "A", "K"}, optional):
302-
Controls the memory layout of the output array.
303-
Returns:
304-
usm_ndarray:
305-
A copy of the input array.
306-
307-
Memory layout of the copy is controlled by `order` keyword,
308-
following NumPy's conventions. The `order` keywords can be
309-
one of the following:
310-
311-
- "C": C-contiguous memory layout
312-
- "F": Fortran-contiguous memory layout
313-
- "A": Fortran-contiguous if the input array is also Fortran-contiguous,
314-
otherwise C-contiguous
315-
- "K": match the layout of `usm_ary` as closely as possible.
316-
317-
"""
318-
if not isinstance(usm_ary, dpt.usm_ndarray):
319-
return TypeError(
320-
f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}"
321-
)
322-
copy_order = "C"
323-
if order == "C":
324-
pass
325-
elif order == "F":
326-
copy_order = order
327-
elif order == "A":
328-
if usm_ary.flags.f_contiguous:
329-
copy_order = "F"
330-
elif order == "K":
331-
if usm_ary.flags.f_contiguous:
332-
copy_order = "F"
333-
else:
334-
raise ValueError(
335-
"Unrecognized value of the order keyword. "
336-
"Recognized values are 'A', 'C', 'F', or 'K'"
337-
)
338-
c_contig = usm_ary.flags.c_contiguous
339-
f_contig = usm_ary.flags.f_contiguous
340-
R = dpt.usm_ndarray(
341-
usm_ary.shape,
342-
dtype=usm_ary.dtype,
343-
buffer=usm_ary.usm_type,
344-
order=copy_order,
345-
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
346-
)
347-
if order == "K" and (not c_contig and not f_contig):
348-
original_strides = usm_ary.strides
349-
ind = sorted(
350-
range(usm_ary.ndim),
351-
key=lambda i: abs(original_strides[i]),
352-
reverse=True,
353-
)
354-
new_strides = tuple(R.strides[ind[i]] for i in ind)
355-
R = dpt.usm_ndarray(
356-
usm_ary.shape,
357-
dtype=usm_ary.dtype,
358-
buffer=R.usm_data,
359-
strides=new_strides,
360-
)
361-
_copy_same_shape(R, usm_ary)
362-
return R
363-
364-
365293
def _empty_like_orderK(X, dt, usm_type=None, dev=None):
366294
"""Returns empty array like `x`, using order='K'
367295
@@ -452,6 +380,65 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
452380
return dpt.permute_dims(R, inv_perm)
453381

454382

383+
def copy(usm_ary, order="K"):
384+
"""copy(ary, order="K")
385+
386+
Creates a copy of given instance of :class:`dpctl.tensor.usm_ndarray`.
387+
388+
Args:
389+
ary (usm_ndarray):
390+
Input array.
391+
order ({"C", "F", "A", "K"}, optional):
392+
Controls the memory layout of the output array.
393+
Returns:
394+
usm_ndarray:
395+
A copy of the input array.
396+
397+
Memory layout of the copy is controlled by `order` keyword,
398+
following NumPy's conventions. The `order` keywords can be
399+
one of the following:
400+
401+
- "C": C-contiguous memory layout
402+
- "F": Fortran-contiguous memory layout
403+
- "A": Fortran-contiguous if the input array is also Fortran-contiguous,
404+
otherwise C-contiguous
405+
- "K": match the layout of `usm_ary` as closely as possible.
406+
407+
"""
408+
if not isinstance(usm_ary, dpt.usm_ndarray):
409+
return TypeError(
410+
f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}"
411+
)
412+
copy_order = "C"
413+
if order == "C":
414+
pass
415+
elif order == "F":
416+
copy_order = order
417+
elif order == "A":
418+
if usm_ary.flags.f_contiguous:
419+
copy_order = "F"
420+
elif order == "K":
421+
if usm_ary.flags.f_contiguous:
422+
copy_order = "F"
423+
else:
424+
raise ValueError(
425+
"Unrecognized value of the order keyword. "
426+
"Recognized values are 'A', 'C', 'F', or 'K'"
427+
)
428+
if order == "K":
429+
R = _empty_like_orderK(usm_ary, usm_ary.dtype)
430+
else:
431+
R = dpt.usm_ndarray(
432+
usm_ary.shape,
433+
dtype=usm_ary.dtype,
434+
buffer=usm_ary.usm_type,
435+
order=copy_order,
436+
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
437+
)
438+
_copy_same_shape(R, usm_ary)
439+
return R
440+
441+
455442
def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
456443
""" astype(array, new_dtype, order="K", casting="unsafe", \
457444
copy=True)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <CL/sycl.hpp>
2929
#include <cstddef>
3030
#include <cstdint>
31+
#include <limits>
3132
#include <type_traits>
3233

3334
#include "utils/offset_utils.hpp"
@@ -55,16 +56,12 @@ using dpctl::tensor::type_utils::vec_cast;
5556

5657
template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
5758
{
58-
using supports_sg_loadstore = typename std::negation<
59-
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
60-
using supports_vec = typename std::negation<
61-
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
59+
using supports_sg_loadstore = std::true_type;
60+
using supports_vec = std::true_type;
6261

6362
resT operator()(const argT1 &in1, const argT2 &in2)
6463
{
65-
resT max = std::max<resT>(in1, in2);
66-
resT min = std::min<resT>(in1, in2);
67-
return max + std::log1p(std::exp(min - max));
64+
return impl<resT>(in1, in2);
6865
}
6966

7067
template <int vec_sz>
@@ -76,12 +73,29 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
7673

7774
#pragma unroll
7875
for (int i = 0; i < vec_sz; ++i) {
79-
resT max = std::max<resT>(in1[i], in2[i]);
80-
res[i] = max + std::log1p(std::exp(std::abs(diff[i])));
76+
res[i] = impl<resT>(in1[i], in2[i]);
8177
}
8278

8379
return res;
8480
}
81+
82+
private:
83+
template <typename T> T impl(T const &in1, T const &in2)
84+
{
85+
T max = std::max<T>(in1, in2);
86+
if (std::isnan(max)) {
87+
return std::numeric_limits<T>::quiet_NaN();
88+
}
89+
else {
90+
if (std::isinf(max)) {
91+
// if both args are -inf, and hence max is -inf
92+
// the result is -inf as well
93+
return max;
94+
}
95+
}
96+
T min = std::min<T>(in1, in2);
97+
return max + std::log1p(std::exp(min - max));
98+
}
8599
};
86100

87101
template <typename argT1,

0 commit comments

Comments
 (0)