Skip to content

Commit 68b1f39

Browse files
authored
Merge pull request #1331 from IntelPython/experimental/local_accessors
A sycl::local_accessor-like API for numba-dpex kernel
2 parents 202f460 + 2200c3c commit 68b1f39

File tree

15 files changed

+737
-29
lines changed

15 files changed

+737
-29
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
)
2323

2424

25-
def _get_flattened_member_count(ty):
26-
"""Return the number of fields in an instance of a given StructModel."""
25+
def get_flattened_member_count(ty):
26+
"""Returns the number of fields in an instance of a given StructModel."""
27+
2728
flattened_member_count = 0
2829
members = ty._members
2930
for member in members:
@@ -109,7 +110,7 @@ def flattened_field_count(self):
109110
"""
110111
Return the number of fields in an instance of a USMArrayDeviceModel.
111112
"""
112-
return _get_flattened_member_count(self)
113+
return get_flattened_member_count(self)
113114

114115

115116
class USMArrayHostModel(StructModel):
@@ -143,7 +144,7 @@ def __init__(self, dmm, fe_type):
143144
@property
144145
def flattened_field_count(self):
145146
"""Return the number of fields in an instance of a USMArrayHostModel."""
146-
return _get_flattened_member_count(self)
147+
return get_flattened_member_count(self)
147148

148149

149150
class SyclQueueModel(StructModel):
@@ -223,7 +224,7 @@ def __init__(self, dmm, fe_type):
223224
@property
224225
def flattened_field_count(self):
225226
"""Return the number of fields in an instance of a RangeModel."""
226-
return _get_flattened_member_count(self)
227+
return get_flattened_member_count(self)
227228

228229

229230
class NdRangeModel(StructModel):
@@ -246,7 +247,7 @@ def __init__(self, dmm, fe_type):
246247
@property
247248
def flattened_field_count(self):
248249
"""Return the number of fields in an instance of a NdRangeModel."""
249-
return _get_flattened_member_count(self)
250+
return get_flattened_member_count(self)
250251

251252

252253
def _init_data_model_manager() -> datamodel.DataModelManager:
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from numba.core import cgutils
6+
from numba.core.types import Type, UniTuple, intp
7+
from numba.extending import NativeValue, unbox
8+
from numba.np import numpy_support
9+
10+
from numba_dpex.core.types import USMNdArray
11+
from numba_dpex.utils import address_space as AddressSpace
12+
13+
14+
class DpctlMDLocalAccessorType(Type):
15+
"""numba-dpex internal type to represent a dpctl SyclInterface type
16+
`MDLocalAccessorTy`.
17+
"""
18+
19+
def __init__(self):
20+
super().__init__(name="DpctlMDLocalAccessor")
21+
22+
23+
class LocalAccessorType(USMNdArray):
24+
"""numba-dpex internal type to represent a Python object of
25+
:class:`numba_dpex.experimental.kernel_iface.LocalAccessor`.
26+
"""
27+
28+
def __init__(self, ndim, dtype):
29+
try:
30+
if isinstance(dtype, Type):
31+
parsed_dtype = dtype
32+
else:
33+
parsed_dtype = numpy_support.from_dtype(dtype)
34+
except NotImplementedError as exc:
35+
raise ValueError(f"Unsupported array dtype: {dtype}") from exc
36+
37+
type_name = (
38+
f"LocalAccessor(dtype={parsed_dtype}, ndim={ndim}, "
39+
f"address_space={AddressSpace.LOCAL})"
40+
)
41+
42+
super().__init__(
43+
ndim=ndim,
44+
layout="C",
45+
dtype=parsed_dtype,
46+
addrspace=AddressSpace.LOCAL,
47+
name=type_name,
48+
)
49+
50+
def cast_python_value(self, args):
51+
"""The helper function is not overloaded and using it on the
52+
LocalAccessorType throws a NotImplementedError.
53+
"""
54+
raise NotImplementedError
55+
56+
57+
@unbox(LocalAccessorType)
58+
def unbox_local_accessor(typ, obj, c): # pylint: disable=unused-argument
59+
"""Unboxes a Python LocalAccessor PyObject* into a numba-dpex internal
60+
representation.
61+
62+
A LocalAccessor object is represented internally in numba-dpex with the
63+
same data model as a numpy.ndarray. It is done as a LocalAccessor object
64+
serves only as a placeholder type when passed to ``call_kernel`` and the
65+
data buffer should never be accessed inside a host-side compiled function
66+
such as ``call_kernel``.
67+
68+
When a LocalAccessor object is passed as an argument to a kernel function
69+
it uses the USMArrayDeviceModel. Doing so allows numba-dpex to correctly
70+
generate the kernel signature passing in a pointer in the local address
71+
space.
72+
"""
73+
shape = c.pyapi.object_getattr_string(obj, "_shape")
74+
local_accessor = cgutils.create_struct_proxy(typ)(c.context, c.builder)
75+
76+
ty_unituple = UniTuple(intp, typ.ndim)
77+
ll_shape = c.unbox(ty_unituple, shape)
78+
local_accessor.shape = ll_shape.value
79+
80+
return NativeValue(
81+
c.builder.load(local_accessor._getpointer()),
82+
is_error=ll_shape.is_error,
83+
cleanup=ll_shape.cleanup,
84+
)

numba_dpex/core/utils/kernel_flattened_args_builder.py

Lines changed: 132 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,21 @@
77
object.
88
"""
99

10+
from functools import reduce
11+
from math import ceil
1012
from typing import NamedTuple
1113

14+
import dpctl
1215
from llvmlite import ir as llvmir
13-
from numba.core import types
16+
from numba.core import cgutils, types
1417
from numba.core.cpu import CPUContext
1518

1619
from numba_dpex import utils
1720
from numba_dpex.core.types import USMNdArray
21+
from numba_dpex.core.types.kernel_api.local_accessor import (
22+
DpctlMDLocalAccessorType,
23+
LocalAccessorType,
24+
)
1825
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum
1926

2027

@@ -70,8 +77,14 @@ def add_argument(
7077
arg_type,
7178
arg_packed_llvm_val,
7279
):
73-
"""Add kernel argument that need to be flatten."""
74-
if isinstance(arg_type, USMNdArray):
80+
"""Add flattened representation of a kernel argument."""
81+
if isinstance(arg_type, LocalAccessorType):
82+
self._kernel_arg_list.extend(
83+
self._build_local_accessor_arg(
84+
arg_type, llvm_val=arg_packed_llvm_val
85+
)
86+
)
87+
elif isinstance(arg_type, USMNdArray):
7588
self._kernel_arg_list.extend(
7689
self._build_array_arg(
7790
arg_type, llvm_array_val=arg_packed_llvm_val
@@ -213,6 +226,121 @@ def _store_val_into_struct(self, struct_ref, index, val):
213226
),
214227
)
215228

229+
def _build_local_accessor_metadata_arg(
230+
self, llvm_val, arg_type: LocalAccessorType, data_attr_ty
231+
):
232+
"""Handles the special case of building the kernel argument for the data
233+
attribute of a kernel_api.LocalAccessor object.
234+
235+
A kernel_api.LocalAccessor conceptually represents a device-only memory
236+
allocation. The mock kernel_api.LocalAccessor uses a numpy.ndarray to
237+
represent the data allocation. The numpy.ndarray cannot be passed to the
238+
kernel and is ignored when building the kernel argument. Instead, a
239+
struct is allocated to store the metadata about the size of the device
240+
memory allocation and a reference to the struct is passed to the
241+
DPCTLQueue_Submit call. The DPCTLQueue_Submit then constructs a
242+
sycl::local_accessor object using the metadata and passes the
243+
sycl::local_accessor as the kernel argument, letting the DPC++ runtime
244+
handle proper device memory allocation.
245+
"""
246+
247+
ndim = arg_type.ndim
248+
249+
md_proxy = cgutils.create_struct_proxy(DpctlMDLocalAccessorType())(
250+
self._context,
251+
self._builder,
252+
)
253+
la_proxy = cgutils.create_struct_proxy(arg_type)(
254+
self._context, self._builder, value=self._builder.load(llvm_val)
255+
)
256+
257+
md_proxy.ndim = self._context.get_constant(types.int64, ndim)
258+
md_proxy.dpctl_type_id = numba_type_to_dpctl_typenum(
259+
self._context, data_attr_ty.dtype
260+
)
261+
for i, val in enumerate(
262+
cgutils.unpack_tuple(self._builder, la_proxy.shape)
263+
):
264+
setattr(md_proxy, f"dim{i}", val)
265+
266+
return self._build_arg(
267+
llvm_val=md_proxy._getpointer(),
268+
numba_type=LocalAccessorType(
269+
ndim, dpctl.tensor.dtype(data_attr_ty.dtype.name)
270+
),
271+
)
272+
273+
def _build_local_accessor_arg(self, arg_type: LocalAccessorType, llvm_val):
274+
"""Creates a list of kernel LLVM Values for an unpacked USMNdArray
275+
kernel argument from the local accessor.
276+
277+
Method generates UsmNdArray fields from local accessor type and value.
278+
"""
279+
# TODO: move extra values build on device side of codegen.
280+
ndim = arg_type.ndim
281+
la_proxy = cgutils.create_struct_proxy(arg_type)(
282+
self._context, self._builder, value=self._builder.load(llvm_val)
283+
)
284+
shape = cgutils.unpack_tuple(self._builder, la_proxy.shape)
285+
ll_size = reduce(self._builder.mul, shape)
286+
287+
size_ptr = cgutils.alloca_once_value(self._builder, ll_size)
288+
itemsize = self._context.get_constant(
289+
types.intp, ceil(arg_type.dtype.bitwidth / types.byte.bitwidth)
290+
)
291+
itemsize_ptr = cgutils.alloca_once_value(self._builder, itemsize)
292+
293+
kernel_arg_list = []
294+
295+
kernel_dm = self._kernel_dmm.lookup(arg_type)
296+
297+
kernel_arg_list.extend(
298+
self._build_arg(
299+
llvm_val=size_ptr,
300+
numba_type=kernel_dm.get_member_fe_type("nitems"),
301+
)
302+
)
303+
304+
# Argument itemsize
305+
kernel_arg_list.extend(
306+
self._build_arg(
307+
llvm_val=itemsize_ptr,
308+
numba_type=kernel_dm.get_member_fe_type("itemsize"),
309+
)
310+
)
311+
312+
# Argument data
313+
data_attr_ty = kernel_dm.get_member_fe_type("data")
314+
315+
kernel_arg_list.extend(
316+
self._build_local_accessor_metadata_arg(
317+
llvm_val=llvm_val,
318+
arg_type=arg_type,
319+
data_attr_ty=data_attr_ty,
320+
)
321+
)
322+
323+
# Arguments for shape
324+
for val in shape:
325+
shape_ptr = cgutils.alloca_once_value(self._builder, val)
326+
kernel_arg_list.extend(
327+
self._build_arg(
328+
llvm_val=shape_ptr,
329+
numba_type=types.int64,
330+
)
331+
)
332+
333+
# Arguments for strides
334+
for i in range(ndim):
335+
kernel_arg_list.extend(
336+
self._build_arg(
337+
llvm_val=itemsize_ptr,
338+
numba_type=types.int64,
339+
)
340+
)
341+
342+
return kernel_arg_list
343+
216344
def _build_array_arg(self, arg_type, llvm_array_val):
217345
"""Creates a list of LLVM Values for an unpacked USMNdArray kernel
218346
argument.
@@ -240,6 +368,7 @@ def _build_array_arg(self, arg_type, llvm_array_val):
240368
# Argument data
241369
data_attr_pos = host_data_model.get_field_position("data")
242370
data_attr_ty = kernel_data_model.get_member_fe_type("data")
371+
243372
kernel_arg_list.extend(
244373
self._build_collections_attr_arg(
245374
llvm_val=llvm_array_val,

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from numba_dpex.core.exceptions import UnreachableError
2222
from numba_dpex.core.runtime.context import DpexRTContext
2323
from numba_dpex.core.types import USMNdArray
24+
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
2425
from numba_dpex.core.types.kernel_api.ranges import NdRangeType, RangeType
2526
from numba_dpex.core.utils.kernel_flattened_args_builder import (
2627
KernelFlattenedArgsBuilder,
@@ -675,7 +676,9 @@ def get_queue_from_llvm_values(
675676
the queue from the first USMNdArray argument can be extracted.
676677
"""
677678
for arg_num, argty in enumerate(ty_kernel_args):
678-
if isinstance(argty, USMNdArray):
679+
if isinstance(argty, USMNdArray) and not isinstance(
680+
argty, LocalAccessorType
681+
):
679682
llvm_val = ll_kernel_args[arg_num]
680683
datamodel = ctx.data_model_manager.lookup(argty)
681684
sycl_queue_attr_pos = datamodel.get_field_position("sycl_queue")

numba_dpex/dpctl_iface/_helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numba.core import types
66

77
from numba_dpex import dpctl_sem_version
8+
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
89

910

1011
def numba_type_to_dpctl_typenum(context, ty):
@@ -34,6 +35,10 @@ def numba_type_to_dpctl_typenum(context, ty):
3435
return context.get_constant(
3536
types.int32, kargty.dpctl_void_ptr.value
3637
)
38+
elif isinstance(ty, LocalAccessorType):
39+
return context.get_constant(
40+
types.int32, kargty.dpctl_local_accessor.value
41+
)
3742
else:
3843
raise NotImplementedError
3944
else:
@@ -61,5 +66,9 @@ def numba_type_to_dpctl_typenum(context, ty):
6166
elif ty == types.voidptr or isinstance(ty, types.CPointer):
6267
# DPCTL_VOID_PTR
6368
return context.get_constant(types.int32, 15)
69+
elif isinstance(ty, LocalAccessorType):
70+
raise NotImplementedError(
71+
"LocalAccessor args for kernels requires dpctl 0.17 or greater."
72+
)
6473
else:
6574
raise NotImplementedError

numba_dpex/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from numba_dpex.core.boxing import *
1313
from numba_dpex.kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
1414

15+
from . import typeof
1516
from ._kernel_dpcpp_spirv_overloads import (
1617
_atomic_fence_overloads,
1718
_atomic_ref_overloads,

0 commit comments

Comments
 (0)