@@ -57,6 +57,30 @@ def __init__(self, dmm, fe_type):
57
57
]
58
58
super (USMArrayModel , self ).__init__ (dmm , fe_type , members )
59
59
60
+ @property
61
+ def flattened_field_count (self ):
62
+ """Return the number of fields in an instance of a USMArrayModel."""
63
+ flattened_member_count = 0
64
+ members = self ._members
65
+ for member in members :
66
+ if isinstance (member , types .UniTuple ):
67
+ flattened_member_count += member .count
68
+ elif isinstance (
69
+ member ,
70
+ (
71
+ types .scalars .Integer ,
72
+ types .misc .PyObject ,
73
+ types .misc .RawPointer ,
74
+ types .misc .CPointer ,
75
+ types .misc .MemInfoPointer ,
76
+ ),
77
+ ):
78
+ flattened_member_count += 1
79
+ else :
80
+ raise UnreachableError
81
+
82
+ return flattened_member_count
83
+
60
84
61
85
class DpnpNdArrayModel (StructModel ):
62
86
"""Data model for the DpnpNdArray type.
@@ -138,35 +162,54 @@ def __init__(self, dmm, fe_type):
138
162
super (SyclQueueModel , self ).__init__ (dmm , fe_type , members )
139
163
140
164
141
- def _init_data_model_manager ():
165
+ def _init_data_model_manager () -> datamodel .DataModelManager :
166
+ """Initializes a DpexKernelTarget-specific data model manager.
167
+
168
+ SPIRV kernel functions for certain types of devices require an explicit
169
+ address space qualifier for pointers. For OpenCL HD Graphics
170
+ devices, defining a kernel function (spir_kernel calling convention) with
171
+ pointer arguments that have no address space qualifier causes a run time
172
+ crash. For this reason, numba-dpex defines two separate data
173
+ models: USMArrayModel and DpnpNdArrayModel. When a dpnp.ndarray object is
174
+ passed as an argument to a ``numba_dpex.kernel`` decorated function it uses
175
+ the USMArrayModel and when passed to a ``numba_dpex.dpjit`` decorated
176
+ function it uses the DpnpNdArrayModel. The difference is due to the fact
177
+ that inside a ``dpjit`` decorated function a dpnp.ndarray object can be
178
+ passed to any other regular function.
179
+
180
+ Returns:
181
+ DataModelManager: A numba-dpex DpexKernelTarget-specific data model
182
+ manager
183
+ """
142
184
dmm = datamodel .default_manager .copy ()
143
185
dmm .register (types .CPointer , GenericPointerModel )
144
186
dmm .register (Array , USMArrayModel )
187
+
188
+ # Register the USMNdArray type to USMArrayModel in numba_dpex's data model
189
+ # manager. The dpex_data_model_manager is used by the DpexKernelTarget
190
+ dmm .register (USMNdArray , USMArrayModel )
191
+
192
+ # Register the DpnpNdArray type to USMArrayModel in numba_dpex's data model
193
+ # manager. The dpex_data_model_manager is used by the DpexKernelTarget
194
+ dmm .register (DpnpNdArray , USMArrayModel )
195
+
196
+ # Register the DpctlSyclQueue type to SyclQueueModel in numba_dpex's data
197
+ # model manager. The dpex_data_model_manager is used by the DpexKernelTarget
198
+ dmm .register (DpctlSyclQueue , SyclQueueModel )
199
+
145
200
return dmm
146
201
147
202
148
203
dpex_data_model_manager = _init_data_model_manager ()
149
204
150
- # XXX A kernel function has the spir_kernel ABI and requires pointers to have an
151
- # address space attribute. For this reason, the UsmNdArray type uses dpex's
152
- # ArrayModel where the pointers are address space casted to have a SYCL-specific
153
- # address space value. The DpnpNdArray type can be used inside djit functions
154
- # as host function calls arguments, such as dpnp library calls. The DpnpNdArray
155
- # needs to use Numba's array model as its data model. Thus, from a Numba typing
156
- # perspective dpnp.ndarrays cannot be directly passed to a kernel. To get
157
- # around the limitation, the DpexKernelTypingContext does not resolve the type
158
- # of dpnp.array args to a kernel as DpnpNdArray type objects, but uses the
159
- # ``to_usm_ndarray`` utility function to convert them into a UsmNdArray type
160
- # object.
161
-
162
- # Register the USMNdArray type with the dpex ArrayModel
205
+
206
+ # Register the USMNdArray type to USMArrayModel in numba's default data model
207
+ # manager
163
208
register_model (USMNdArray )(USMArrayModel )
164
- dpex_data_model_manager .register (USMNdArray , USMArrayModel )
165
209
166
- # Register the DpnpNdArray type with the Numba ArrayModel
210
+ # Register the DpnpNdArray type to DpnpNdArrayModel in numba's default data
211
+ # model manager
167
212
register_model (DpnpNdArray )(DpnpNdArrayModel )
168
- dpex_data_model_manager .register (DpnpNdArray , DpnpNdArrayModel )
169
213
170
214
# Register the DpctlSyclQueue type
171
215
register_model (DpctlSyclQueue )(SyclQueueModel )
172
- dpex_data_model_manager .register (DpctlSyclQueue , SyclQueueModel )
0 commit comments