Skip to content

Commit fdec1ee

Browse files
Cleanups and fixes
1 parent 2715670 commit fdec1ee

File tree

3 files changed

+53
-97
lines changed

3 files changed

+53
-97
lines changed

dpnp/backend/extensions/sycl_ext/dispatcher_utils.hpp

Lines changed: 27 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,30 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
128
#include <array>
229
#include <tuple>
330

@@ -387,86 +414,6 @@ struct UsmArrayMatcher
387414
}
388415
};
389416

390-
// int typenum_to_lookup_id(int typenum) const
391-
// {
392-
// using typenum_t = ::dpctl::tensor::type_dispatch::typenum_t;
393-
// auto const &api = ::dpctl::detail::dpctl_capi::get();
394-
395-
// if (typenum == api.UAR_DOUBLE_) {
396-
// return static_cast<int>(typenum_t::DOUBLE);
397-
// }
398-
// else if (typenum == api.UAR_INT64_) {
399-
// return static_cast<int>(typenum_t::INT64);
400-
// }
401-
// else if (typenum == api.UAR_INT32_) {
402-
// return static_cast<int>(typenum_t::INT32);
403-
// }
404-
// else if (typenum == api.UAR_BOOL_) {
405-
// return static_cast<int>(typenum_t::BOOL);
406-
// }
407-
// else if (typenum == api.UAR_CDOUBLE_) {
408-
// return static_cast<int>(typenum_t::CDOUBLE);
409-
// }
410-
// else if (typenum == api.UAR_FLOAT_) {
411-
// return static_cast<int>(typenum_t::FLOAT);
412-
// }
413-
// else if (typenum == api.UAR_INT16_) {
414-
// return static_cast<int>(typenum_t::INT16);
415-
// }
416-
// else if (typenum == api.UAR_INT8_) {
417-
// return static_cast<int>(typenum_t::INT8);
418-
// }
419-
// else if (typenum == api.UAR_UINT64_) {
420-
// return static_cast<int>(typenum_t::UINT64);
421-
// }
422-
// else if (typenum == api.UAR_UINT32_) {
423-
// return static_cast<int>(typenum_t::UINT32);
424-
// }
425-
// else if (typenum == api.UAR_UINT16_) {
426-
// return static_cast<int>(typenum_t::UINT16);
427-
// }
428-
// else if (typenum == api.UAR_UINT8_) {
429-
// return static_cast<int>(typenum_t::UINT8);
430-
// }
431-
// else if (typenum == api.UAR_CFLOAT_) {
432-
// return static_cast<int>(typenum_t::CFLOAT);
433-
// }
434-
// else if (typenum == api.UAR_HALF_) {
435-
// return static_cast<int>(typenum_t::HALF);
436-
// }
437-
// else if (typenum == api.UAR_INT_ || typenum == api.UAR_UINT_) {
438-
// switch (sizeof(int)) {
439-
// case sizeof(std::int32_t):
440-
// return ((typenum == api.UAR_INT_)
441-
// ? static_cast<int>(typenum_t::INT32)
442-
// : static_cast<int>(typenum_t::UINT32));
443-
// case sizeof(std::int64_t):
444-
// return ((typenum == api.UAR_INT_)
445-
// ? static_cast<int>(typenum_t::INT64)
446-
// : static_cast<int>(typenum_t::UINT64));
447-
// default:
448-
// throw_unrecognized_typenum_error(typenum);
449-
// }
450-
// }
451-
// else if (typenum == api.UAR_LONGLONG_ || typenum == api.UAR_ULONGLONG_)
452-
// {
453-
// switch (sizeof(long long)) {
454-
// case sizeof(std::int64_t):
455-
// return ((typenum == api.UAR_LONGLONG_)
456-
// ? static_cast<int>(typenum_t::INT64)
457-
// : static_cast<int>(typenum_t::UINT64));
458-
// default:
459-
// throw_unrecognized_typenum_error(typenum);
460-
// }
461-
// }
462-
// else {
463-
// throw_unrecognized_typenum_error(typenum);
464-
// }
465-
// // return code signalling error, should never be reached
466-
// assert(false);
467-
// return -1;
468-
// }
469-
470417
template <>
471418
bool UsmArrayMatcher::match<int8_t>(const dpctl::tensor::usm_ndarray &arr)
472419
{

dpnp/backend/extensions/sycl_ext/sum_mean.hpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,23 @@ bool check_limitations(const dpctl::tensor::usm_ndarray &in,
5656
return false;
5757
}
5858

59-
// auto local_mem_size = in.get_sycl_device().local_mem_size();
60-
// auto out_full_size = out.get_size()*out.get_elemsize();
61-
// if (out_full_size > local_mem_size)
62-
// {
63-
// if (throw_on_fail)
64-
// throw py::value_error("Resulting array exceeds local memroy size
65-
// " + std::to_string(local_mem_size));
66-
67-
// return false;
68-
// }
59+
auto device = in.get_queue().get_device();
60+
auto local_mem_size = device.get_info<sycl::info::device::local_mem_size>();
61+
size_t out_full_size = out.get_size() * out.get_elemsize();
62+
if (out_full_size > local_mem_size) {
63+
if (throw_on_fail)
64+
throw py::value_error("Resulting array exceeds local memroy size" +
65+
std::to_string(local_mem_size));
66+
67+
return false;
68+
}
69+
70+
if (out.get_elemsize() == 64 and not device.has(sycl::aspect::atomic64)) {
71+
if (throw_on_fail)
72+
throw py::value_error("64-bit atomics are not supported");
73+
74+
return false;
75+
}
6976

7077
return true;
7178
}

dpnp/dpnp_iface_mathematical.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import dpctl.tensor as dpt
4444
import numpy
45+
from numpy.core.numeric import normalize_axis_tuple
4546

4647
import dpnp
4748
from dpnp.dpnp_array import dpnp_array
@@ -1814,12 +1815,11 @@ def sum(
18141815
18151816
"""
18161817

1817-
if not isinstance(axis, (tuple, list)):
1818-
axis = (axis,)
1818+
if axis is not None:
1819+
if not isinstance(axis, (tuple, list)):
1820+
axis = (axis,)
18191821

1820-
from numpy.core.numeric import normalize_axis_tuple
1821-
1822-
axis = normalize_axis_tuple(axis, x.ndim, "axis")
1822+
axis = normalize_axis_tuple(axis, x.ndim, "axis")
18231823

18241824
if out is not None:
18251825
pass
@@ -1828,7 +1828,7 @@ def sum(
18281828
elif where is not True:
18291829
pass
18301830
else:
1831-
if (axis == (0,) and len(x.shape) == 2):
1831+
if axis == (0,) and len(x.shape) == 2:
18321832
from dpctl.tensor._reduction import _default_reduction_dtype
18331833

18341834
from dpnp.backend.extensions.sycl_ext import _sycl_ext_impl
@@ -1837,7 +1837,9 @@ def sum(
18371837

18381838
queue = input.sycl_queue
18391839
out_dtype = _default_reduction_dtype(input.dtype, queue)
1840-
output = dpt.empty(input.shape[1], dtype=out_dtype, sycl_queue=queue)
1840+
output = dpt.empty(
1841+
input.shape[1], dtype=out_dtype, sycl_queue=queue
1842+
)
18411843

18421844
get_sum = _sycl_ext_impl._get_sum_over_axis_0
18431845
sum = get_sum(input, output)

0 commit comments

Comments
 (0)