Skip to content

Commit 2f006b9

Browse files
author
Diptorup Deb
authored
Merge pull request #1118 from IntelPython/fix/dpnp_data_model
Use different data models for DpnpNdArray Type for kernel and dpjit targets
2 parents f4487df + 12c5c23 commit 2f006b9

File tree

5 files changed

+78
-81
lines changed

5 files changed

+78
-81
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,30 @@ def __init__(self, dmm, fe_type):
5757
]
5858
super(USMArrayModel, self).__init__(dmm, fe_type, members)
5959

60+
@property
61+
def flattened_field_count(self):
62+
"""Return the number of fields in an instance of a USMArrayModel."""
63+
flattened_member_count = 0
64+
members = self._members
65+
for member in members:
66+
if isinstance(member, types.UniTuple):
67+
flattened_member_count += member.count
68+
elif isinstance(
69+
member,
70+
(
71+
types.scalars.Integer,
72+
types.misc.PyObject,
73+
types.misc.RawPointer,
74+
types.misc.CPointer,
75+
types.misc.MemInfoPointer,
76+
),
77+
):
78+
flattened_member_count += 1
79+
else:
80+
raise UnreachableError
81+
82+
return flattened_member_count
83+
6084

6185
class DpnpNdArrayModel(StructModel):
6286
"""Data model for the DpnpNdArray type.
@@ -138,35 +162,54 @@ def __init__(self, dmm, fe_type):
138162
super(SyclQueueModel, self).__init__(dmm, fe_type, members)
139163

140164

141-
def _init_data_model_manager():
165+
def _init_data_model_manager() -> datamodel.DataModelManager:
166+
"""Initializes a DpexKernelTarget-specific data model manager.
167+
168+
SPIRV kernel functions for certain types of devices require an explicit
169+
address space qualifier for pointers. For OpenCL HD Graphics
170+
devices, defining a kernel function (spir_kernel calling convention) with
171+
pointer arguments that have no address space qualifier causes a run time
172+
crash. For this reason, numba-dpex defines two separate data
173+
models: USMArrayModel and DpnpNdArrayModel. When a dpnp.ndarray object is
174+
passed as an argument to a ``numba_dpex.kernel`` decorated function it uses
175+
the USMArrayModel and when passed to a ``numba_dpex.dpjit`` decorated
176+
function it uses the DpnpNdArrayModel. The difference is due to the fact
177+
that inside a ``dpjit`` decorated function a dpnp.ndarray object can be
178+
passed to any other regular function.
179+
180+
Returns:
181+
DataModelManager: A numba-dpex DpexKernelTarget-specific data model
182+
manager
183+
"""
142184
dmm = datamodel.default_manager.copy()
143185
dmm.register(types.CPointer, GenericPointerModel)
144186
dmm.register(Array, USMArrayModel)
187+
188+
# Register the USMNdArray type to USMArrayModel in numba_dpex's data model
189+
# manager. The dpex_data_model_manager is used by the DpexKernelTarget
190+
dmm.register(USMNdArray, USMArrayModel)
191+
192+
# Register the DpnpNdArray type to USMArrayModel in numba_dpex's data model
193+
# manager. The dpex_data_model_manager is used by the DpexKernelTarget
194+
dmm.register(DpnpNdArray, USMArrayModel)
195+
196+
# Register the DpctlSyclQueue type to SyclQueueModel in numba_dpex's data
197+
# model manager. The dpex_data_model_manager is used by the DpexKernelTarget
198+
dmm.register(DpctlSyclQueue, SyclQueueModel)
199+
145200
return dmm
146201

147202

148203
dpex_data_model_manager = _init_data_model_manager()
149204

150-
# XXX A kernel function has the spir_kernel ABI and requires pointers to have an
151-
# address space attribute. For this reason, the UsmNdArray type uses dpex's
152-
# ArrayModel where the pointers are address space casted to have a SYCL-specific
153-
# address space value. The DpnpNdArray type can be used inside djit functions
154-
# as host function calls arguments, such as dpnp library calls. The DpnpNdArray
155-
# needs to use Numba's array model as its data model. Thus, from a Numba typing
156-
# perspective dpnp.ndarrays cannot be directly passed to a kernel. To get
157-
# around the limitation, the DpexKernelTypingContext does not resolve the type
158-
# of dpnp.array args to a kernel as DpnpNdArray type objects, but uses the
159-
# ``to_usm_ndarray`` utility function to convert them into a UsmNdArray type
160-
# object.
161-
162-
# Register the USMNdArray type with the dpex ArrayModel
205+
206+
# Register the USMNdArray type to USMArrayModel in numba's default data model
207+
# manager
163208
register_model(USMNdArray)(USMArrayModel)
164-
dpex_data_model_manager.register(USMNdArray, USMArrayModel)
165209

166-
# Register the DpnpNdArray type with the Numba ArrayModel
210+
# Register the DpnpNdArray type to DpnpNdArrayModel in numba's default data
211+
# model manager
167212
register_model(DpnpNdArray)(DpnpNdArrayModel)
168-
dpex_data_model_manager.register(DpnpNdArray, DpnpNdArrayModel)
169213

170214
# Register the DpctlSyclQueue type
171215
register_model(DpctlSyclQueue)(SyclQueueModel)
172-
dpex_data_model_manager.register(DpctlSyclQueue, SyclQueueModel)

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -70,30 +70,6 @@ def _compile_kernel_parfor(
7070
func_ir, kernel_name
7171
)
7272

73-
# A cast from DpnpNdArray type to USMNdArray is needed for all arguments of
74-
# DpnpNdArray type. Although, DpnpNdArray derives from USMNdArray the two
75-
# types use different data models. USMNdArray uses the
76-
# numba_dpex.core.datamodel.models.ArrayModel data model that defines all
77-
# CPointer type members in the GLOBAL address space. The DpnpNdArray uses
78-
# Numba's default ArrayModel that does not define pointers in any specific
79-
# address space. For OpenCL HD Graphics devices, defining a kernel function
80-
# (spir_kernel calling convention) with pointer arguments that have no
81-
# address space qualifier causes a run time crash. By casting the argument
82-
# type for parfor arguments from DpnpNdArray type to the USMNdArray type the
83-
# generated kernel always has an address space qualifier, avoiding the issue
84-
# on OpenCL HD graphics devices.
85-
86-
for i, argty in enumerate(argtypes):
87-
if isinstance(argty, DpnpNdArray):
88-
new_argty = USMNdArray(
89-
ndim=argty.ndim,
90-
layout=argty.layout,
91-
dtype=argty.dtype,
92-
usm_type=argty.usm_type,
93-
queue=argty.queue,
94-
)
95-
argtypes[i] = new_argty
96-
9773
# compile the kernel
9874
kernel.compile(
9975
args=argtypes,

numba_dpex/core/targets/kernel_target.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,6 @@ def resolve_argument_type(self, val):
7272
type=str(type(val)), value=val
7373
)
7474

75-
# A cast from DpnpNdArray type to USMNdArray is needed for all
76-
# arguments of DpnpNdArray type. Although, DpnpNdArray derives from
77-
# USMNdArray the two types use different data models. USMNdArray
78-
# uses the numba_dpex.core.datamodel.models.ArrayModel data model
79-
# that defines all CPointer type members in the GLOBAL address
80-
# space. The DpnpNdArray uses Numba's default ArrayModel that does
81-
# not define pointers in any specific address space. For OpenCL HD
82-
# Graphics devices, defining a kernel function (spir_kernel calling
83-
# convention) with pointer arguments that have no address space
84-
# qualifier causes a run time crash. By casting the argument type
85-
# for parfor arguments from DpnpNdArray type to the USMNdArray type
86-
# the generated kernel always has an address space qualifier,
87-
# avoiding the issue on OpenCL HD graphics devices.
88-
if isinstance(numba_type, DpnpNdArray):
89-
return USMNdArray(
90-
ndim=numba_type.ndim,
91-
layout=numba_type.layout,
92-
dtype=numba_type.dtype,
93-
usm_type=numba_type.usm_type,
94-
queue=numba_type.queue,
95-
)
96-
9775
except ValueError:
9876
# When an array-like kernel argument is not recognized by
9977
# numba-dpex, this additional check sees if the array-like object

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,12 +1075,13 @@ def getitem_arraynd_intp(context, builder, sig, args):
10751075
"""
10761076
ret = np_getitem_arraynd_intp(context, builder, sig, args)
10771077

1078-
array_val = args[0]
1079-
array_ty = sig.args[0]
1080-
sycl_queue_attr_pos = dpex_dmm.lookup(array_ty).get_field_position(
1081-
"sycl_queue"
1082-
)
1083-
sycl_queue_attr = builder.extract_value(array_val, sycl_queue_attr_pos)
1084-
ret = builder.insert_value(ret, sycl_queue_attr, sycl_queue_attr_pos)
1078+
if isinstance(sig.return_type, DpnpNdArray):
1079+
array_val = args[0]
1080+
array_ty = sig.args[0]
1081+
sycl_queue_attr_pos = dpex_dmm.lookup(array_ty).get_field_position(
1082+
"sycl_queue"
1083+
)
1084+
sycl_queue_attr = builder.extract_value(array_val, sycl_queue_attr_pos)
1085+
ret = builder.insert_value(ret, sycl_queue_attr, sycl_queue_attr_pos)
10851086

10861087
return ret

numba_dpex/tests/core/types/DpnpNdArray/test_models.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,25 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from numba import types
6-
from numba.core.datamodel import models
6+
from numba.core.datamodel import default_manager, models
77

88
from numba_dpex.core.datamodel.models import (
99
DpnpNdArrayModel,
10+
USMArrayModel,
1011
dpex_data_model_manager,
1112
)
1213
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray
1314

1415

1516
def test_model_for_DpnpNdArray():
16-
"""Test that model is registered for DpnpNdArray instances.
17-
18-
The model for DpnpNdArray is dpex's ArrayModel.
19-
17+
"""Test the datamodel for DpnpNdArray that is registered with numba's
18+
default datamodel manager and numba_dpex's kernel data model manager.
2019
"""
21-
22-
model = dpex_data_model_manager.lookup(
23-
DpnpNdArray(ndim=1, dtype=types.float64, layout="C")
24-
)
25-
assert isinstance(model, DpnpNdArrayModel)
20+
dpnp_ndarray = DpnpNdArray(ndim=1, dtype=types.float64, layout="C")
21+
model = dpex_data_model_manager.lookup(dpnp_ndarray)
22+
assert isinstance(model, USMArrayModel)
23+
default_model = default_manager.lookup(dpnp_ndarray)
24+
assert isinstance(default_model, DpnpNdArrayModel)
2625

2726

2827
def test_dpnp_ndarray_Model():

0 commit comments

Comments
 (0)