18
18
# cython: language_level=3
19
19
# cython: linetrace=True
20
20
21
+ from libcpp cimport bool as cpp_bool
22
+
21
23
from dpctl.tensor._usmarray cimport (
22
24
USM_ARRAY_C_CONTIGUOUS,
23
25
USM_ARRAY_F_CONTIGUOUS,
24
26
USM_ARRAY_WRITEABLE,
27
+ usm_ndarray,
25
28
)
26
29
27
30
28
- class Flags :
31
+ cdef cpp_bool _check_bit(int flag, int mask):
32
+ return (flag & mask) == mask
33
+
29
34
30
- def __init__ (self , arr , flags ):
35
+ cdef class Flags:
36
+ """ Helper class to represent flags of :class:`dpctl.tensor.usm_ndarray`."""
37
+ cdef int flags_
38
+ cdef usm_ndarray arr_
39
+
40
+ def __cinit__ (self , usm_ndarray arr , int flags ):
31
41
self .arr_ = arr
32
42
self .flags_ = flags
33
43
@@ -37,32 +47,29 @@ class Flags:
37
47
38
48
@property
39
49
def c_contiguous (self ):
40
- return ((self .flags_ & USM_ARRAY_C_CONTIGUOUS)
41
- == USM_ARRAY_C_CONTIGUOUS)
50
+ return _check_bit(self .flags_, USM_ARRAY_C_CONTIGUOUS)
42
51
43
52
@property
44
53
def f_contiguous (self ):
45
- return ((self .flags_ & USM_ARRAY_F_CONTIGUOUS)
46
- == USM_ARRAY_F_CONTIGUOUS)
54
+ return _check_bit(self .flags_, USM_ARRAY_F_CONTIGUOUS)
47
55
48
56
@property
49
57
def writable (self ):
50
- return False if ((self .flags & USM_ARRAY_WRITEABLE)
51
- == USM_ARRAY_WRITEABLE) else True
58
+ return _check_bit(self .flags_, USM_ARRAY_WRITEABLE)
52
59
53
60
@property
54
61
def forc (self ):
55
- return True if ((( self .flags_ & USM_ARRAY_F_CONTIGUOUS)
56
- == USM_ARRAY_F_CONTIGUOUS )
57
- or (( self .flags_ & USM_ARRAY_C_CONTIGUOUS )
58
- == USM_ARRAY_C_CONTIGUOUS)) else False
62
+ return (
63
+ _check_bit( self .flags_, USM_ARRAY_C_CONTIGUOUS )
64
+ or _check_bit( self .flags_, USM_ARRAY_F_CONTIGUOUS )
65
+ )
59
66
60
67
@property
61
68
def fnc (self ):
62
- return True if ((( self .flags_ & USM_ARRAY_F_CONTIGUOUS)
63
- == USM_ARRAY_F_CONTIGUOUS )
64
- and not (( self .flags_ & USM_ARRAY_C_CONTIGUOUS )
65
- == USM_ARRAY_C_CONTIGUOUS)) else False
69
+ return (
70
+ _check_bit( self .flags_, USM_ARRAY_C_CONTIGUOUS )
71
+ and _check_bit( self .flags_, USM_ARRAY_F_CONTIGUOUS )
72
+ )
66
73
67
74
@property
68
75
def contiguous (self ):
@@ -83,3 +90,13 @@ class Flags:
83
90
for name in " C_CONTIGUOUS" , " F_CONTIGUOUS" , " WRITABLE" :
84
91
out.append(" {} : {}" .format(name, self [name]))
85
92
return ' \n ' .join(out)
93
+
94
+ def __eq__ (self , other ):
95
+ cdef Flags other_
96
+ if isinstance (other, self .__class__):
97
+ other_ = < Flags> other
98
+ return self .flags_ == other_.flags_
99
+ elif isinstance (other, int ):
100
+ return self .flags_ == < int > other
101
+ else :
102
+ return False
0 commit comments