Skip to content

Commit 421b270

Browse files
Used new native extension modules
1 parent f74eae0 commit 421b270

File tree

4 files changed

+20
-17
lines changed

4 files changed

+20
-17
lines changed

dpctl/tensor/_clip.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import dpctl
1818
import dpctl.tensor as dpt
19+
import dpctl.tensor._tensor_elementwise_impl as tei
1920
import dpctl.tensor._tensor_impl as ti
2021
from dpctl.tensor._copy_utils import (
2122
_empty_like_orderK,
@@ -429,9 +430,9 @@ def clip(x, min=None, max=None, out=None, order="K"):
429430
"only one of `min` and `max` is permitted to be `None`"
430431
)
431432
elif max is None:
432-
return _clip_none(x, min, out, order, ti._maximum)
433+
return _clip_none(x, min, out, order, tei._maximum)
433434
elif min is None:
434-
return _clip_none(x, max, out, order, ti._minimum)
435+
return _clip_none(x, max, out, order, tei._minimum)
435436
else:
436437
q1, x_usm_type = x.sycl_queue, x.usm_type
437438
q2, min_usm_type = _get_queue_usm_type(min)

dpctl/tensor/_elementwise_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import dpctl.tensor._tensor_impl as ti
17+
import dpctl.tensor._tensor_elementwise_impl as ti
1818

1919
from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc
2020
from ._type_utils import _acceptance_fn_divide

dpctl/tensor/_reduction.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dpctl
2020
import dpctl.tensor as dpt
2121
import dpctl.tensor._tensor_impl as ti
22+
import dpctl.tensor._tensor_reductions_impl as tri
2223

2324
from ._type_utils import _to_device_supported_dtype
2425

@@ -220,8 +221,8 @@ def sum(x, axis=None, dtype=None, keepdims=False):
220221
axis,
221222
dtype,
222223
keepdims,
223-
ti._sum_over_axis,
224-
ti._sum_over_axis_dtype_supported,
224+
tri._sum_over_axis,
225+
tri._sum_over_axis_dtype_supported,
225226
_default_reduction_dtype,
226227
_identity=0,
227228
)
@@ -281,8 +282,8 @@ def prod(x, axis=None, dtype=None, keepdims=False):
281282
axis,
282283
dtype,
283284
keepdims,
284-
ti._prod_over_axis,
285-
ti._prod_over_axis_dtype_supported,
285+
tri._prod_over_axis,
286+
tri._prod_over_axis_dtype_supported,
286287
_default_reduction_dtype,
287288
_identity=1,
288289
)
@@ -335,8 +336,8 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
335336
axis,
336337
dtype,
337338
keepdims,
338-
ti._logsumexp_over_axis,
339-
lambda inp_dt, res_dt, *_: ti._logsumexp_over_axis_dtype_supported(
339+
tri._logsumexp_over_axis,
340+
lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported(
340341
inp_dt, res_dt
341342
),
342343
_default_reduction_dtype_fp_types,
@@ -391,8 +392,8 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
391392
axis,
392393
dtype,
393394
keepdims,
394-
ti._hypot_over_axis,
395-
lambda inp_dt, res_dt, *_: ti._hypot_over_axis_dtype_supported(
395+
tri._hypot_over_axis,
396+
lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported(
396397
inp_dt, res_dt
397398
),
398399
_default_reduction_dtype_fp_types,
@@ -468,7 +469,7 @@ def max(x, axis=None, keepdims=False):
468469
entire array, a zero-dimensional array is returned. The returned
469470
array has the same data type as `x`.
470471
"""
471-
return _comparison_over_axis(x, axis, keepdims, ti._max_over_axis)
472+
return _comparison_over_axis(x, axis, keepdims, tri._max_over_axis)
472473

473474

474475
def min(x, axis=None, keepdims=False):
@@ -496,7 +497,7 @@ def min(x, axis=None, keepdims=False):
496497
entire array, a zero-dimensional array is returned. The returned
497498
array has the same data type as `x`.
498499
"""
499-
return _comparison_over_axis(x, axis, keepdims, ti._min_over_axis)
500+
return _comparison_over_axis(x, axis, keepdims, tri._min_over_axis)
500501

501502

502503
def _search_over_axis(x, axis, keepdims, _reduction_fn):
@@ -577,7 +578,7 @@ def argmax(x, axis=None, keepdims=False):
577578
zero-dimensional array is returned. The returned array has the
578579
default array index data type for the device of `x`.
579580
"""
580-
return _search_over_axis(x, axis, keepdims, ti._argmax_over_axis)
581+
return _search_over_axis(x, axis, keepdims, tri._argmax_over_axis)
581582

582583

583584
def argmin(x, axis=None, keepdims=False):
@@ -609,4 +610,4 @@ def argmin(x, axis=None, keepdims=False):
609610
zero-dimensional array is returned. The returned array has the
610611
default array index data type for the device of `x`.
611612
"""
612-
return _search_over_axis(x, axis, keepdims, ti._argmin_over_axis)
613+
return _search_over_axis(x, axis, keepdims, tri._argmin_over_axis)

dpctl/tensor/_utility_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import dpctl
44
import dpctl.tensor as dpt
55
import dpctl.tensor._tensor_impl as ti
6+
import dpctl.tensor._tensor_reductions_impl as tri
67

78

89
def _boolean_reduction(x, axis, keepdims, func):
@@ -94,7 +95,7 @@ def all(x, axis=None, keepdims=False):
9495
An array with a data type of `bool`
9596
containing the results of the logical AND reduction.
9697
"""
97-
return _boolean_reduction(x, axis, keepdims, ti._all)
98+
return _boolean_reduction(x, axis, keepdims, tri._all)
9899

99100

100101
def any(x, axis=None, keepdims=False):
@@ -122,4 +123,4 @@ def any(x, axis=None, keepdims=False):
122123
An array with a data type of `bool`
123124
containing the results of the logical OR reduction.
124125
"""
125-
return _boolean_reduction(x, axis, keepdims, ti._any)
126+
return _boolean_reduction(x, axis, keepdims, tri._any)

0 commit comments

Comments
 (0)