Skip to content

Commit ec3f8e3

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Made Flags cdef class, used typed members, and typed functions, added __eq__
Support for `__eq__` allows to compare two instances of Flags objects and compare Flags object to an integer.
1 parent 2d20e1e commit ec3f8e3

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

dpctl/tensor/_flags.pyx

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,26 @@
1818
# cython: language_level=3
1919
# cython: linetrace=True
2020

21+
from libcpp cimport bool as cpp_bool
22+
2123
from dpctl.tensor._usmarray cimport (
2224
USM_ARRAY_C_CONTIGUOUS,
2325
USM_ARRAY_F_CONTIGUOUS,
2426
USM_ARRAY_WRITEABLE,
27+
usm_ndarray,
2528
)
2629

2730

28-
class Flags:
31+
cdef cpp_bool _check_bit(int flag, int mask):
32+
return (flag & mask) == mask
33+
2934

30-
def __init__(self, arr, flags):
35+
cdef class Flags:
36+
"""Helper class to represent flags of :class:`dpctl.tensor.usm_ndarray`."""
37+
cdef int flags_
38+
cdef usm_ndarray arr_
39+
40+
def __cinit__(self, usm_ndarray arr, int flags):
3141
self.arr_ = arr
3242
self.flags_ = flags
3343

@@ -37,32 +47,29 @@ class Flags:
3747

3848
@property
3949
def c_contiguous(self):
40-
return ((self.flags_ & USM_ARRAY_C_CONTIGUOUS)
41-
== USM_ARRAY_C_CONTIGUOUS)
50+
return _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
4251

4352
@property
4453
def f_contiguous(self):
45-
return ((self.flags_ & USM_ARRAY_F_CONTIGUOUS)
46-
== USM_ARRAY_F_CONTIGUOUS)
54+
return _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
4755

4856
@property
4957
def writable(self):
50-
return False if ((self.flags & USM_ARRAY_WRITEABLE)
51-
== USM_ARRAY_WRITEABLE) else True
58+
return _check_bit(self.flags_, USM_ARRAY_WRITEABLE)
5259

5360
@property
5461
def forc(self):
55-
return True if (((self.flags_ & USM_ARRAY_F_CONTIGUOUS)
56-
== USM_ARRAY_F_CONTIGUOUS)
57-
or ((self.flags_ & USM_ARRAY_C_CONTIGUOUS)
58-
== USM_ARRAY_C_CONTIGUOUS)) else False
62+
return (
63+
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
64+
or _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
65+
)
5966

6067
@property
6168
def fnc(self):
62-
return True if (((self.flags_ & USM_ARRAY_F_CONTIGUOUS)
63-
== USM_ARRAY_F_CONTIGUOUS)
64-
and not ((self.flags_ & USM_ARRAY_C_CONTIGUOUS)
65-
== USM_ARRAY_C_CONTIGUOUS)) else False
69+
return (
70+
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
71+
and _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
72+
)
6673

6774
@property
6875
def contiguous(self):
@@ -83,3 +90,13 @@ class Flags:
8390
for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE":
8491
out.append(" {} : {}".format(name, self[name]))
8592
return '\n'.join(out)
93+
94+
def __eq__(self, other):
95+
cdef Flags other_
96+
if isinstance(other, self.__class__):
97+
other_ = <Flags>other
98+
return self.flags_ == other_.flags_
99+
elif isinstance(other, int):
100+
return self.flags_ == <int>other
101+
else:
102+
return False

0 commit comments

Comments
 (0)