Skip to content

Commit 21a6aaa

Browse files
Merge branch 'master' into add-dtypes-and-infos
2 parents a3f8fbd + 8cbed99 commit 21a6aaa

File tree

9 files changed

+271
-28
lines changed

9 files changed

+271
-28
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ per-file-ignores =
2525
dpctl/program/_program.pyx: E999, E225, E226, E227
2626
dpctl/tensor/_usmarray.pyx: E999, E225, E226, E227
2727
dpctl/tensor/_dlpack.pyx: E999, E225, E226, E227
28+
dpctl/tensor/_flags.pyx: E999, E225, E226, E227
2829
dpctl/tensor/numpy_usm_shared.py: F821
2930
dpctl/tests/_cython_api.pyx: E999, E225, E227, E402
3031
dpctl/utils/_compute_follows_data.pyx: E999, E225, E227

dpctl/memory/_memory.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ cdef public api class MemoryUSMHost(_Memory) [object PyMemoryUSMHostObject,
7575
pass
7676

7777

78-
cdef public class MemoryUSMDevice(_Memory) [object PyMemoryUSMDeviceObject,
78+
cdef public api class MemoryUSMDevice(_Memory) [object PyMemoryUSMDeviceObject,
7979
type PyMemoryUSMDeviceType]:
8080
pass

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
full,
3434
full_like,
3535
linspace,
36+
meshgrid,
3637
ones,
3738
ones_like,
3839
tril,
@@ -127,4 +128,5 @@
127128
"finfo",
128129
"can_cast",
129130
"result_type",
131+
"meshgrid",
130132
]

dpctl/tensor/_copy_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,18 +261,18 @@ def copy(usm_ary, order="K"):
261261
elif order == "F":
262262
copy_order = order
263263
elif order == "A":
264-
if usm_ary.flags & 2:
264+
if usm_ary.flags.f_contiguous:
265265
copy_order = "F"
266266
elif order == "K":
267-
if usm_ary.flags & 2:
267+
if usm_ary.flags.f_contiguous:
268268
copy_order = "F"
269269
else:
270270
raise ValueError(
271271
"Unrecognized value of the order keyword. "
272272
"Recognized values are 'A', 'C', 'F', or 'K'"
273273
)
274-
c_contig = usm_ary.flags & 1
275-
f_contig = usm_ary.flags & 2
274+
c_contig = usm_ary.flags.c_contiguous
275+
f_contig = usm_ary.flags.f_contiguous
276276
R = dpt.usm_ndarray(
277277
usm_ary.shape,
278278
dtype=usm_ary.dtype,
@@ -325,8 +325,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
325325
ary_dtype, newdtype, casting
326326
)
327327
)
328-
c_contig = usm_ary.flags & 1
329-
f_contig = usm_ary.flags & 2
328+
c_contig = usm_ary.flags.c_contiguous
329+
f_contig = usm_ary.flags.f_contiguous
330330
needs_copy = copy or not (ary_dtype == target_dtype)
331331
if not needs_copy and (order != "K"):
332332
needs_copy = (c_contig and order not in ["A", "C"]) or (
@@ -339,10 +339,10 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
339339
elif order == "F":
340340
copy_order = order
341341
elif order == "A":
342-
if usm_ary.flags & 2:
342+
if usm_ary.flags.f_contiguous:
343343
copy_order = "F"
344344
elif order == "K":
345-
if usm_ary.flags & 2:
345+
if usm_ary.flags.f_contiguous:
346346
copy_order = "F"
347347
else:
348348
raise ValueError(

dpctl/tensor/_ctors.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def _asarray_from_usm_ndarray(
133133
# sycl_queue is unchanged
134134
can_zero_copy = can_zero_copy and copy_q is usm_ndary.sycl_queue
135135
# order is unchanged
136-
c_contig = usm_ndary.flags & 1
137-
f_contig = usm_ndary.flags & 2
138-
fc_contig = usm_ndary.flags & 3
136+
c_contig = usm_ndary.flags.c_contiguous
137+
f_contig = usm_ndary.flags.f_contiguous
138+
fc_contig = usm_ndary.flags.forc
139139
if can_zero_copy:
140140
if order == "C" and c_contig:
141141
pass
@@ -1130,7 +1130,7 @@ def tril(X, k=0):
11301130
k = operator.index(k)
11311131

11321132
# F_CONTIGUOUS = 2
1133-
order = "F" if (X.flags & 2) else "C"
1133+
order = "F" if (X.flags.f_contiguous) else "C"
11341134

11351135
shape = X.shape
11361136
nd = X.ndim
@@ -1171,7 +1171,7 @@ def triu(X, k=0):
11711171
k = operator.index(k)
11721172

11731173
# F_CONTIGUOUS = 2
1174-
order = "F" if (X.flags & 2) else "C"
1174+
order = "F" if (X.flags.f_contiguous) else "C"
11751175

11761176
shape = X.shape
11771177
nd = X.ndim
@@ -1198,3 +1198,61 @@ def triu(X, k=0):
11981198
hev.wait()
11991199

12001200
return res
1201+
1202+
1203+
def meshgrid(*arrays, indexing="xy"):
1204+
1205+
"""
1206+
meshgrid(*arrays, indexing="xy") -> list[usm_ndarray]
1207+
1208+
Creates list of `usm_ndarray` coordinate matrices from vectors.
1209+
1210+
Args:
1211+
arrays: arbitrary number of one-dimensional `USM_ndarray` objects.
1212+
If vectors are not of the same data type,
1213+
or are not one-dimensional, raises `ValueError.`
1214+
indexing: Cartesian (`xy`) or matrix (`ij`) indexing of output.
1215+
For a set of `n` vectors with lengths N0, N1, N2, ...
1216+
Cartesian indexing results in arrays of shape
1217+
(N1, N0, N2, ...)
1218+
matrix indexing results in arrays of shape
1219+
(n0, N1, N2, ...)
1220+
Default: `xy`.
1221+
"""
1222+
ref_dt = None
1223+
ref_unset = True
1224+
for array in arrays:
1225+
if not isinstance(array, dpt.usm_ndarray):
1226+
raise TypeError(
1227+
f"Expected instance of dpt.usm_ndarray, got {type(array)}."
1228+
)
1229+
if array.ndim != 1:
1230+
raise ValueError("All arrays must be one-dimensional.")
1231+
if ref_unset:
1232+
ref_unset = False
1233+
ref_dt = array.dtype
1234+
else:
1235+
if not ref_dt == array.dtype:
1236+
raise ValueError(
1237+
"All arrays must be of the same numeric data type."
1238+
)
1239+
if indexing not in ["xy", "ij"]:
1240+
raise ValueError(
1241+
"Unrecognized indexing keyword value, expecting 'xy' or 'ij.'"
1242+
)
1243+
n = len(arrays)
1244+
sh = (-1,) + (1,) * (n - 1)
1245+
1246+
res = []
1247+
if n > 1 and indexing == "xy":
1248+
res.append(dpt.reshape(arrays[0], (1, -1) + sh[2:], copy=True))
1249+
res.append(dpt.reshape(arrays[1], sh, copy=True))
1250+
arrays, sh = arrays[2:], sh[-2:] + sh[:-2]
1251+
1252+
for array in arrays:
1253+
res.append(dpt.reshape(array, sh, copy=True))
1254+
sh = sh[-1:] + sh[:-1]
1255+
1256+
output = dpt.broadcast_arrays(*res)
1257+
1258+
return output

dpctl/tensor/_flags.pyx

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# distutils: language = c++
18+
# cython: language_level=3
19+
# cython: linetrace=True
20+
21+
from libcpp cimport bool as cpp_bool
22+
23+
from dpctl.tensor._usmarray cimport (
24+
USM_ARRAY_C_CONTIGUOUS,
25+
USM_ARRAY_F_CONTIGUOUS,
26+
USM_ARRAY_WRITEABLE,
27+
usm_ndarray,
28+
)
29+
30+
31+
cdef cpp_bool _check_bit(int flag, int mask):
32+
return (flag & mask) == mask
33+
34+
35+
cdef class Flags:
36+
"""Helper class to represent flags of :class:`dpctl.tensor.usm_ndarray`."""
37+
cdef int flags_
38+
cdef usm_ndarray arr_
39+
40+
def __cinit__(self, usm_ndarray arr, int flags):
41+
self.arr_ = arr
42+
self.flags_ = flags
43+
44+
@property
45+
def flags(self):
46+
return self.flags_
47+
48+
@property
49+
def c_contiguous(self):
50+
return _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
51+
52+
@property
53+
def f_contiguous(self):
54+
return _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
55+
56+
@property
57+
def writable(self):
58+
return _check_bit(self.flags_, USM_ARRAY_WRITEABLE)
59+
60+
@property
61+
def fc(self):
62+
return (
63+
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
64+
and _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
65+
)
66+
67+
@property
68+
def forc(self):
69+
return (
70+
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
71+
or _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
72+
)
73+
74+
@property
75+
def fnc(self):
76+
return (
77+
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
78+
and not _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
79+
)
80+
81+
@property
82+
def contiguous(self):
83+
return self.forc
84+
85+
def __getitem__(self, name):
86+
if name in ["C_CONTIGUOUS", "C"]:
87+
return self.c_contiguous
88+
elif name in ["F_CONTIGUOUS", "F"]:
89+
return self.f_contiguous
90+
elif name == "WRITABLE":
91+
return self.writable
92+
elif name == "FC":
93+
return self.fc
94+
elif name == "CONTIGUOUS":
95+
return self.forc
96+
97+
def __repr__(self):
98+
out = []
99+
for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE":
100+
out.append(" {} : {}".format(name, self[name]))
101+
return '\n'.join(out)
102+
103+
def __eq__(self, other):
104+
cdef Flags other_
105+
if isinstance(other, self.__class__):
106+
other_ = <Flags>other
107+
return self.flags_ == other_.flags_
108+
elif isinstance(other, int):
109+
return self.flags_ == <int>other
110+
else:
111+
return False

dpctl/tensor/_reshape.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def reshape(X, newshape, order="C", copy=None):
8686
raise TypeError
8787
if not isinstance(newshape, (list, tuple)):
8888
newshape = (newshape,)
89-
if order not in ["C", "F"]:
89+
if order in "cfCF":
90+
order = order.upper()
91+
else:
9092
raise ValueError(
9193
f"Keyword 'order' not recognized. Expecting 'C' or 'F', got {order}"
9294
)

dpctl/tensor/_usmarray.pyx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
3333
cimport dpctl as c_dpctl
3434
cimport dpctl.memory as c_dpmem
3535
cimport dpctl.tensor._dlpack as c_dlpack
36+
import dpctl.tensor._flags as _flags
3637

3738
include "_stride_utils.pxi"
3839
include "_types.pxi"
@@ -503,9 +504,9 @@ cdef class usm_ndarray:
503504
@property
504505
def flags(self):
505506
"""
506-
Currently returns integer whose bits correspond to the flags.
507+
Returns dpctl.tensor._flags object.
507508
"""
508-
return self.flags_
509+
return _flags.Flags(self, self.flags_)
509510

510511
@property
511512
def usm_type(self):
@@ -663,7 +664,7 @@ cdef class usm_ndarray:
663664
strides=self.strides,
664665
offset=self.get_offset()
665666
)
666-
res.flags_ = self.flags
667+
res.flags_ = self.flags.flags
667668
return res
668669
else:
669670
nbytes = self.usm_data.nbytes
@@ -678,7 +679,7 @@ cdef class usm_ndarray:
678679
strides=self.strides,
679680
offset=self.get_offset()
680681
)
681-
res.flags_ = self.flags
682+
res.flags_ = self.flags.flags
682683
return res
683684

684685
def _set_namespace(self, mod):

0 commit comments

Comments
 (0)