Skip to content

Commit fe8d1d2

Browse files
committed
Add DpctlMDLocalAccessorType
1 parent 3b93e8c commit fe8d1d2

File tree

3 files changed

+57
-117
lines changed

3 files changed

+57
-117
lines changed

numba_dpex/core/types/kernel_api/local_accessor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010
from numba_dpex.utils import address_space as AddressSpace
1111

1212

13+
class DpctlMDLocalAccessorType(Type):
14+
"""numba-dpex internal type to represent a dpctl SyclInterface type
15+
`MDLocalAccessorTy`.
16+
"""
17+
18+
def __init__(self):
19+
super().__init__(name="DpctlMDLocalAccessor")
20+
21+
1322
class LocalAccessorType(USMNdArray):
1423
"""numba-dpex internal type to represent a Python object of
1524
:class:`numba_dpex.experimental.kernel_iface.LocalAccessor`.

numba_dpex/core/utils/kernel_flattened_args_builder.py

Lines changed: 21 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111

1212
import dpctl
1313
from llvmlite import ir as llvmir
14-
from numba.core import types
14+
from numba.core import cgutils, types
1515
from numba.core.cpu import CPUContext
1616

1717
from numba_dpex import utils
1818
from numba_dpex.core.types import USMNdArray
19-
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
19+
from numba_dpex.core.types.kernel_api.local_accessor import (
20+
DpctlMDLocalAccessorType,
21+
LocalAccessorType,
22+
)
2023
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum
2124

2225

@@ -120,40 +123,6 @@ def print_kernel_arg_list(self) -> None:
120123
for karg in args_list:
121124
print(f" {karg.llvm_val} of typeid {karg.typeid}")
122125

123-
def _allocate_local_accessor_metadata_struct(self):
124-
"""Allocates a struct into the current function to store the metadata
125-
that should be passed to libsyclinterface to allocate a
126-
sycl::local_accessor object. The constructor of the sycl::local_accessor
127-
class is: local_accessor<Ty, Ndim>(range<Ndims> r).
128-
129-
For this reason, the struct is allocated as:
130-
131-
LOCAL_ACCESSOR_MDSTRUCT_TYPE = llvmir.LiteralStructType(
132-
[
133-
llvmir.IntType(64), # Ndim (0..3]
134-
llvmir.IntType(32), # typeid
135-
llvmir.IntType(64), # Dim0 extent
136-
llvmir.IntType(64), # Dim1 extent or NULL
137-
llvmir.IntType(64), # Dim2 extent or NULL
138-
]
139-
)
140-
"""
141-
local_accessor_mdstruct_type = llvmir.LiteralStructType(
142-
[
143-
llvmir.IntType(64),
144-
llvmir.IntType(32),
145-
llvmir.IntType(64),
146-
llvmir.IntType(64),
147-
llvmir.IntType(64),
148-
]
149-
)
150-
151-
struct_ref = None
152-
with self._builder.goto_entry_block():
153-
struct_ref = self._builder.alloca(typ=local_accessor_mdstruct_type)
154-
155-
return struct_ref
156-
157126
def _build_arg(self, llvm_val, numba_type):
158127
"""Returns a KernelArg to be passed to a DPCTLQueue_Submit call.
159128
@@ -250,7 +219,7 @@ def _store_val_into_struct(self, struct_ref, index, val):
250219
)
251220

252221
def _build_local_accessor_metadata_arg(
253-
self, llvm_val, arg_type, data_attr_ty
222+
self, llvm_val, arg_type: LocalAccessorType, data_attr_ty
254223
):
255224
"""Handles the special case of building the kernel argument for the data
256225
attribute of a kernel_api.LocalAccessor object.
@@ -267,91 +236,27 @@ def _build_local_accessor_metadata_arg(
267236
handle proper device memory allocation.
268237
"""
269238

270-
kernel_data_model = self._kernel_dmm.lookup(arg_type)
271-
host_data_model = self._context.data_model_manager.lookup(arg_type)
272-
shape_member = kernel_data_model.get_member_fe_type("shape")
273-
shape_member_pos = host_data_model.get_field_position("shape")
274-
ndim = shape_member.count
275-
276-
mdstruct_ref = self._allocate_local_accessor_metadata_struct()
239+
ndim = arg_type.ndim
277240

278-
# Store the number of dimensions in the local accessor
279-
self._store_val_into_struct(
280-
mdstruct_ref,
281-
index=0,
282-
val=self._context.get_constant(types.int64, ndim),
283-
)
284-
# Get the underlying dtype of the data (a CPointer) attribute of a
285-
# local_accessor object
286-
self._store_val_into_struct(
287-
mdstruct_ref,
288-
index=1,
289-
val=numba_type_to_dpctl_typenum(self._context, data_attr_ty.dtype),
290-
)
291-
# Extract and store the shape values from array into mdstruct
292-
shape_attr = self._builder.gep(
293-
llvm_val,
294-
[
295-
self._context.get_constant(types.int32, 0),
296-
self._context.get_constant(types.int32, shape_member_pos),
297-
],
298-
)
299-
# Store the extent of the 1st dimension of the local accessor
300-
dim0_shape_ext = self._builder.gep(
301-
shape_attr,
302-
[
303-
self._context.get_constant(types.int32, 0),
304-
self._context.get_constant(types.int32, 0),
305-
],
241+
md_proxy = cgutils.create_struct_proxy(DpctlMDLocalAccessorType())(
242+
self._context,
243+
self._builder,
306244
)
307-
self._store_val_into_struct(
308-
mdstruct_ref,
309-
index=2,
310-
val=self._builder.load(dim0_shape_ext),
245+
la_proxy = cgutils.create_struct_proxy(arg_type)(
246+
self._context, self._builder, value=self._builder.load(llvm_val)
311247
)
312248

313-
if ndim == 2:
314-
dim1_shape_ext = self._builder.gep(
315-
shape_attr,
316-
[
317-
self._context.get_constant(types.int32, 0),
318-
self._context.get_constant(types.int32, 1),
319-
],
320-
)
321-
self._store_val_into_struct(
322-
mdstruct_ref,
323-
index=3,
324-
val=self._builder.load(dim1_shape_ext),
325-
)
326-
else:
327-
self._store_val_into_struct(
328-
mdstruct_ref,
329-
index=3,
330-
val=self._context.get_constant(types.int64, 1),
331-
)
332-
333-
if ndim == 3:
334-
dim2_shape_ext = self._builder.gep(
335-
shape_attr,
336-
[
337-
self._context.get_constant(types.int32, 0),
338-
self._context.get_constant(types.int32, 2),
339-
],
340-
)
341-
self._store_val_into_struct(
342-
mdstruct_ref,
343-
index=4,
344-
val=self._builder.load(dim2_shape_ext),
345-
)
346-
else:
347-
self._store_val_into_struct(
348-
mdstruct_ref,
349-
index=4,
350-
val=self._context.get_constant(types.int64, 1),
351-
)
249+
md_proxy.ndim = self._context.get_constant(types.int64, ndim)
250+
md_proxy.dpctl_type_id = numba_type_to_dpctl_typenum(
251+
self._context, data_attr_ty.dtype
252+
)
253+
for i, val in enumerate(
254+
cgutils.unpack_tuple(self._builder, la_proxy.shape)
255+
):
256+
setattr(md_proxy, f"dim{i}", val)
352257

353258
return self._build_arg(
354-
llvm_val=mdstruct_ref,
259+
llvm_val=md_proxy._getpointer(),
355260
numba_type=LocalAccessorType(
356261
ndim, dpctl.tensor.dtype(data_attr_ty.dtype.name)
357262
),

numba_dpex/experimental/models.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
)
2020

2121
from ..core.types.kernel_api.atomic_ref import AtomicRefType
22-
from ..core.types.kernel_api.local_accessor import LocalAccessorType
22+
from ..core.types.kernel_api.local_accessor import (
23+
DpctlMDLocalAccessorType,
24+
LocalAccessorType,
25+
)
2326
from .types import KernelDispatcherType
2427

2528

@@ -45,6 +48,26 @@ def __init__(self, dmm, fe_type):
4548
super().__init__(dmm, fe_type, members)
4649

4750

51+
class DpctlMDLocalAccessorModel(StructModel):
52+
"""Data model to represent DpctlMDLocalAccessorType.
53+
54+
Must be the same structure as
55+
dpctl/syclinterface/dpctl_sycl_queue_interface.h::MDLocalAccessor.
56+
57+
Structure intended to be used only on host side of the kernel call.
58+
"""
59+
60+
def __init__(self, dmm, fe_type):
61+
members = [
62+
("ndim", types.size_t),
63+
("dpctl_type_id", types.int32),
64+
("dim0", types.size_t),
65+
("dim1", types.size_t),
66+
("dim2", types.size_t),
67+
]
68+
super().__init__(dmm, fe_type, members)
69+
70+
4871
def _init_exp_data_model_manager() -> DataModelManager:
4972
"""Initializes a DpexExpKernelTarget-specific data model manager.
5073
@@ -89,6 +112,9 @@ def _init_exp_data_model_manager() -> DataModelManager:
89112
# Register the NdItemType type
90113
register_model(NdItemType)(EmptyStructModel)
91114

115+
# Register the MDLocalAccessorType type
116+
register_model(DpctlMDLocalAccessorType)(DpctlMDLocalAccessorModel)
117+
92118
# The LocalAccessorType is registered with the EmptyStructModel in the default
93119
# data manager so that its attributes are not accessible inside dpjit.
94120
register_model(LocalAccessorType)(ArrayModel)

0 commit comments

Comments
 (0)