Skip to content

Implements dpctl.tensor.repeat, dpctl.tensor.tile #1381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ endif()
set(python_module_name _tensor_impl)
pybind11_add_module(${python_module_name} MODULE
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
Expand All @@ -49,6 +50,7 @@ pybind11_add_module(${python_module_name} MODULE
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
)
set(_clang_prefix "")
if (WIN32)
Expand Down
4 changes: 4 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@
iinfo,
moveaxis,
permute_dims,
repeat,
result_type,
roll,
squeeze,
stack,
swapaxes,
tile,
unstack,
)
from dpctl.tensor._print import (
Expand Down Expand Up @@ -305,4 +307,6 @@
"tanh",
"trunc",
"allclose",
"repeat",
"tile",
]
303 changes: 301 additions & 2 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
# limitations under the License.


import itertools
import operator
from itertools import chain, repeat

import numpy as np
from numpy import AxisError
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple

import dpctl
Expand Down Expand Up @@ -133,7 +134,9 @@ def _broadcast_shape_impl(shapes):
diff = biggest - nds[i]
if diff > 0:
ty = type(shapes[i])
shapes[i] = ty(chain(repeat(1, diff), shapes[i]))
shapes[i] = ty(
itertools.chain(itertools.repeat(1, diff), shapes[i])
)
common_shape = []
for axis in range(biggest):
lengths = [s[axis] for s in shapes]
Expand Down Expand Up @@ -918,6 +921,302 @@ def swapaxes(X, axis1, axis2):
return dpt.permute_dims(X, tuple(ind))


def repeat(x, repeats, axis=None):
"""repeat(x, repeats, axis=None)

Repeat elements of an array.

Args:
x (usm_ndarray): input array

repeat (Union[int, Tuple[int, ...]]):
The number of repetitions for each element.
`repeats` is broadcasted to fit the shape of the given axis.

axis (Optional[int]):
The axis along which to repeat values. The `axis` is required
if input array has more than one dimension.

Returns:
usm_narray:
Array with repeated elements.
The returned array must have the same data type as `x`,
is created on the same device as `x` and has the same USM
allocation type as `x`.

Raises:
AxisError: if `axis` value is invalid.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")

x_ndim = x.ndim
if axis is None:
if x_ndim > 1:
raise ValueError(
f"`axis` cannot be `None` for array of dimension {x_ndim}"
)
axis = 0

x_shape = x.shape
if x_ndim > 0:
axis = normalize_axis_index(operator.index(axis), x_ndim)
axis_size = x_shape[axis]
else:
if axis != 0:
AxisError("`axis` must be `0` for input of dimension `0`")
axis_size = x.size

scalar = False
if isinstance(repeats, int):
if repeats < 0:
raise ValueError("`repeats` must be a positive integer")
usm_type = x.usm_type
exec_q = x.sycl_queue
scalar = True
elif isinstance(repeats, dpt.usm_ndarray):
if repeats.ndim > 1:
raise ValueError(
"`repeats` array must be 0- or 1-dimensional, got"
"{repeats.ndim}"
)
exec_q = dpctl.utils.get_execution_queue(
(x.sycl_queue, repeats.sycl_queue)
)
if exec_q is None:
raise dputils.ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
usm_type = dpctl.utils.get_coerced_usm_type(
(
x.usm_type,
repeats.usm_type,
)
)
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
if not dpt.can_cast(repeats.dtype, dpt.int64, casting="same_kind"):
raise TypeError(
f"`repeats` data type `{repeats.dtype}` cannot be cast to "
"`int64` according to the casting rule ''safe.''"
)
if repeats.size == 1:
scalar = True
# bring the single element to the host
repeats = int(repeats)
if repeats < 0:
raise ValueError("`repeats` elements must be positive")
else:
if repeats.size != axis_size:
raise ValueError(
"`repeats` array must be broadcastable to the size of "
"the repeated axis"
)
if not dpt.all(repeats >= 0):
raise ValueError("`repeats` elements must be positive")

elif isinstance(repeats, tuple):
usm_type = x.usm_type
exec_q = x.sycl_queue

len_reps = len(repeats)
if len_reps != axis_size:
raise ValueError(
"`repeats` tuple must have the same length as the repeated "
"axis"
)
elif len_reps == 1:
repeats = repeats[0]
if repeats < 0:
raise ValueError("`repeats` elements must be positive")
scalar = True
else:
repeats = dpt.asarray(
repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q
)
if not dpt.all(repeats >= 0):
raise ValueError("`repeats` elements must be positive")
else:
raise TypeError(
"Expected int, tuple, or `usm_ndarray` for second argument,"
f"got {type(repeats)}"
)

if axis_size == 0:
return dpt.empty(x_shape, dtype=x.dtype, sycl_queue=exec_q)

if scalar:
res_axis_size = repeats * axis_size
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
res = dpt.empty(
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
)
if res_axis_size > 0:
ht_rep_ev, _ = ti._repeat_by_scalar(
src=x,
dst=res,
reps=repeats,
axis=axis,
sycl_queue=exec_q,
)
ht_rep_ev.wait()
else:
if repeats.dtype != dpt.int64:
rep_buf = dpt.empty(
repeats.shape,
dtype=dpt.int64,
usm_type=usm_type,
sycl_queue=exec_q,
)
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=repeats, dst=rep_buf, sycl_queue=exec_q
)
cumsum = dpt.empty(
(axis_size,),
dtype=dpt.int64,
usm_type=usm_type,
sycl_queue=exec_q,
)
# _cumsum_1d synchronizes so `depends` ends here safely
res_axis_size = ti._cumsum_1d(
rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev]
)
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
res = dpt.empty(
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
)
if res_axis_size > 0:
ht_rep_ev, _ = ti._repeat_by_sequence(
src=x,
dst=res,
reps=rep_buf,
cumsum=cumsum,
axis=axis,
sycl_queue=exec_q,
)
ht_rep_ev.wait()
ht_copy_ev.wait()
else:
cumsum = dpt.empty(
(axis_size,),
dtype=dpt.int64,
usm_type=usm_type,
sycl_queue=exec_q,
)
# _cumsum_1d synchronizes so `depends` ends here safely
res_axis_size = ti._cumsum_1d(repeats, cumsum, sycl_queue=exec_q)
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
res = dpt.empty(
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
)
if res_axis_size > 0:
ht_rep_ev, _ = ti._repeat_by_sequence(
src=x,
dst=res,
reps=repeats,
cumsum=cumsum,
axis=axis,
sycl_queue=exec_q,
)
ht_rep_ev.wait()
return res


def tile(x, repetitions):
"""tile(x, repetitions)

Repeat an input array `x` along each axis a number of times given by
`repetitions`.

For `N` = len(`repetitions`) and `M` = len(`x.shape`):
- if `M < N`, `x` will have `N - M` new axes prepended to its shape
- if `M > N`, `repetitions` will have `M - N` new axes 1 prepended to it

Args:
x (usm_ndarray): input array

repetitions (Union[int, Tuple[int, ...]]):
The number of repetitions for each dimension.

Returns:
usm_narray:
Array with tiled elements.
The returned array must have the same data type as `x`,
is created on the same device as `x` and has the same USM
allocation type as `x`.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")

if not isinstance(repetitions, tuple):
if isinstance(repetitions, int):
repetitions = (repetitions,)
else:
raise TypeError(
f"Expected tuple or integer type, got {type(repetitions)}."
)

# case of scalar
if x.size == 1:
if not repetitions:
# handle empty tuple
repetitions = (1,)
return dpt.full(
repetitions,
x,
dtype=x.dtype,
usm_type=x.usm_type,
sycl_queue=x.sycl_queue,
)
rep_dims = len(repetitions)
x_dims = x.ndim
if rep_dims < x_dims:
repetitions = (x_dims - rep_dims) * (1,) + repetitions
elif x_dims < rep_dims:
x = dpt.reshape(x, (rep_dims - x_dims) * (1,) + x.shape)
res_shape = tuple(map(lambda sh, rep: sh * rep, x.shape, repetitions))
# case of empty input
if x.size == 0:
return dpt.empty(
res_shape, x.dtype, usm_type=x.usm_type, sycl_queue=x.sycl_queue
)
in_sh = x.shape
if res_shape == in_sh:
return dpt.copy(x)
expanded_sh = []
broadcast_sh = []
out_sz = 1
for i in range(len(res_shape)):
out_sz *= res_shape[i]
reps, sh = repetitions[i], in_sh[i]
if reps == 1:
# dimension will be unchanged
broadcast_sh.append(sh)
expanded_sh.append(sh)
elif sh == 1:
# dimension will be broadcast
broadcast_sh.append(reps)
expanded_sh.append(sh)
else:
broadcast_sh.extend([reps, sh])
expanded_sh.extend([1, sh])
exec_q = x.sycl_queue
res = dpt.empty((out_sz,), x.dtype, usm_type=x.usm_type, sycl_queue=exec_q)
# no need to copy data for empty output
if out_sz > 0:
x = dpt.broadcast_to(
# this reshape should never copy
dpt.reshape(x, expanded_sh),
broadcast_sh,
)
# copy broadcast input into flat array
hev, _ = ti._copy_usm_ndarray_for_reshape(
src=x, dst=res, sycl_queue=exec_q
)
hev.wait()
return dpt.reshape(res, res_shape)


def _supported_dtype(dtypes):
for dtype in dtypes:
if dtype.char not in "?bBhHiIlLqQefdFD":
Expand Down
Loading