Skip to content

Commit c086ae6

Browse files
committed
Update LocalAccessor host model that contains only shape
1 parent fe8d1d2 commit c086ae6

File tree

5 files changed

+125
-34
lines changed

5 files changed

+125
-34
lines changed

numba_dpex/core/types/kernel_api/local_accessor.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from numba.core.pythonapi import unbox
6-
from numba.core.types import Array, Type
5+
from numba.core import cgutils
6+
from numba.core.types import Type, UniTuple, intp
7+
from numba.extending import NativeValue, unbox
78
from numba.np import numpy_support
89

910
from numba_dpex.core.types import USMNdArray
@@ -69,7 +70,15 @@ def unbox_local_accessor(typ, obj, c): # pylint: disable=unused-argument
6970
generate the kernel signature passing in a pointer in the local address
7071
space.
7172
"""
72-
73-
nparrobj = c.pyapi.object_getattr_string(obj, "_data")
74-
nparrtype = Array(typ.dtype, typ.ndim, typ.layout, readonly=False)
75-
return c.unbox(nparrtype, nparrobj)
73+
shape = c.pyapi.object_getattr_string(obj, "_shape")
74+
local_accessor = cgutils.create_struct_proxy(typ)(c.context, c.builder)
75+
76+
ty_unituple = UniTuple(intp, typ.ndim)
77+
ll_shape = c.unbox(ty_unituple, shape)
78+
local_accessor.shape = ll_shape.value
79+
80+
return NativeValue(
81+
c.builder.load(local_accessor._getpointer()),
82+
is_error=ll_shape.is_error,
83+
cleanup=ll_shape.cleanup,
84+
)

numba_dpex/core/utils/kernel_flattened_args_builder.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
object.
88
"""
99

10+
from functools import reduce
11+
from math import ceil
1012
from typing import NamedTuple
1113

1214
import dpctl
@@ -76,7 +78,13 @@ def add_argument(
7678
arg_packed_llvm_val,
7779
):
7880
"""Add flattened representation of a kernel argument."""
79-
if isinstance(arg_type, USMNdArray):
81+
if isinstance(arg_type, LocalAccessorType):
82+
self._kernel_arg_list.extend(
83+
self._build_local_accessor_arg(
84+
arg_type, llvm_val=arg_packed_llvm_val
85+
)
86+
)
87+
elif isinstance(arg_type, USMNdArray):
8088
self._kernel_arg_list.extend(
8189
self._build_array_arg(
8290
arg_type, llvm_array_val=arg_packed_llvm_val
@@ -262,6 +270,77 @@ def _build_local_accessor_metadata_arg(
262270
),
263271
)
264272

273+
def _build_local_accessor_arg(self, arg_type: LocalAccessorType, llvm_val):
274+
"""Creates a list of kernel LLVM Values for an unpacked USMNdArray
275+
kernel argument from the local accessor.
276+
277+
Method generates UsmNdArray fields from local accessor type and value.
278+
"""
279+
# TODO: move extra values build on device side of codegen.
280+
ndim = arg_type.ndim
281+
la_proxy = cgutils.create_struct_proxy(arg_type)(
282+
self._context, self._builder, value=self._builder.load(llvm_val)
283+
)
284+
shape = cgutils.unpack_tuple(self._builder, la_proxy.shape)
285+
ll_size = reduce(self._builder.mul, shape)
286+
287+
size_ptr = cgutils.alloca_once_value(self._builder, ll_size)
288+
itemsize = self._context.get_constant(
289+
types.intp, ceil(arg_type.dtype.bitwidth / types.byte.bitwidth)
290+
)
291+
itemsize_ptr = cgutils.alloca_once_value(self._builder, itemsize)
292+
293+
kernel_arg_list = []
294+
295+
kernel_dm = self._kernel_dmm.lookup(arg_type)
296+
297+
kernel_arg_list.extend(
298+
self._build_arg(
299+
llvm_val=size_ptr,
300+
numba_type=kernel_dm.get_member_fe_type("nitems"),
301+
)
302+
)
303+
304+
# Argument itemsize
305+
kernel_arg_list.extend(
306+
self._build_arg(
307+
llvm_val=itemsize_ptr,
308+
numba_type=kernel_dm.get_member_fe_type("itemsize"),
309+
)
310+
)
311+
312+
# Argument data
313+
data_attr_ty = kernel_dm.get_member_fe_type("data")
314+
315+
kernel_arg_list.extend(
316+
self._build_local_accessor_metadata_arg(
317+
llvm_val=llvm_val,
318+
arg_type=arg_type,
319+
data_attr_ty=data_attr_ty,
320+
)
321+
)
322+
323+
# Arguments for shape
324+
for val in shape:
325+
shape_ptr = cgutils.alloca_once_value(self._builder, val)
326+
kernel_arg_list.extend(
327+
self._build_arg(
328+
llvm_val=shape_ptr,
329+
numba_type=types.int64,
330+
)
331+
)
332+
333+
# Arguments for strides
334+
for i in range(ndim):
335+
kernel_arg_list.extend(
336+
self._build_arg(
337+
llvm_val=itemsize_ptr,
338+
numba_type=types.int64,
339+
)
340+
)
341+
342+
return kernel_arg_list
343+
265344
def _build_array_arg(self, arg_type, llvm_array_val):
266345
"""Creates a list of LLVM Values for an unpacked USMNdArray kernel
267346
argument.
@@ -290,22 +369,13 @@ def _build_array_arg(self, arg_type, llvm_array_val):
290369
data_attr_pos = host_data_model.get_field_position("data")
291370
data_attr_ty = kernel_data_model.get_member_fe_type("data")
292371

293-
if isinstance(arg_type, LocalAccessorType):
294-
kernel_arg_list.extend(
295-
self._build_local_accessor_metadata_arg(
296-
llvm_val=llvm_array_val,
297-
arg_type=arg_type,
298-
data_attr_ty=data_attr_ty,
299-
)
300-
)
301-
else:
302-
kernel_arg_list.extend(
303-
self._build_collections_attr_arg(
304-
llvm_val=llvm_array_val,
305-
attr_index=data_attr_pos,
306-
attr_type=data_attr_ty,
307-
)
372+
kernel_arg_list.extend(
373+
self._build_collections_attr_arg(
374+
llvm_val=llvm_array_val,
375+
attr_index=data_attr_pos,
376+
attr_type=data_attr_ty,
308377
)
378+
)
309379
# Arguments for shape
310380
kernel_arg_list.extend(
311381
self._build_unituple_member_arg(

numba_dpex/experimental/models.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
from numba.core import types
1010
from numba.core.datamodel import DataModelManager, models
11-
from numba.core.datamodel.models import ArrayModel, StructModel
11+
from numba.core.datamodel.models import StructModel
1212
from numba.core.extending import register_model
1313

1414
import numba_dpex.core.datamodel.models as dpex_core_models
15+
from numba_dpex.core.datamodel.models import USMArrayDeviceModel
1516
from numba_dpex.core.types.kernel_api.index_space_ids import (
1617
GroupType,
1718
ItemType,
@@ -68,6 +69,17 @@ def __init__(self, dmm, fe_type):
6869
super().__init__(dmm, fe_type, members)
6970

7071

72+
class LocalAccessorModel(StructModel):
73+
"""Data model for the LocalAccessor type when used in a host-only function."""
74+
75+
def __init__(self, dmm, fe_type):
76+
ndim = fe_type.ndim
77+
members = [
78+
("shape", types.UniTuple(types.intp, ndim)),
79+
]
80+
super().__init__(dmm, fe_type, members)
81+
82+
7183
def _init_exp_data_model_manager() -> DataModelManager:
7284
"""Initializes a DpexExpKernelTarget-specific data model manager.
7385
@@ -84,7 +96,8 @@ def _init_exp_data_model_manager() -> DataModelManager:
8496
# Register the types and data model in the DpexExpTargetContext
8597
dmm.register(AtomicRefType, AtomicRefModel)
8698

87-
dmm.register(LocalAccessorType, dpex_core_models.USMArrayDeviceModel)
99+
# Register the LocalAccessorType type
100+
dmm.register(LocalAccessorType, USMArrayDeviceModel)
88101

89102
# Register the GroupType type
90103
dmm.register(GroupType, EmptyStructModel)
@@ -115,6 +128,5 @@ def _init_exp_data_model_manager() -> DataModelManager:
115128
# Register the MDLocalAccessorType type
116129
register_model(DpctlMDLocalAccessorType)(DpctlMDLocalAccessorModel)
117130

118-
# The LocalAccessorType is registered with the EmptyStructModel in the default
119-
# data manager so that its attributes are not accessible inside dpjit.
120-
register_model(LocalAccessorType)(ArrayModel)
131+
# Register the LocalAccessorType type
132+
register_model(LocalAccessorType)(LocalAccessorModel)

numba_dpex/experimental/typeof.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,4 @@ def typeof_local_accessor(val: LocalAccessor, c) -> LocalAccessorType:
9797
Returns: LocalAccessorType object corresponding to the LocalAccessor object.
9898
"""
9999
# pylint: disable=protected-access
100-
return LocalAccessorType(ndim=val._data.ndim, dtype=val._data.dtype)
100+
return LocalAccessorType(ndim=len(val._shape), dtype=val._dtype)

numba_dpex/kernel_api/local_accessor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,17 @@ def __init__(self, shape, dtype) -> None:
3535
if hasattr(shape, "tolist"):
3636
fn = getattr(shape, "tolist")
3737
if callable(fn):
38-
self._shape = shape.tolist()
38+
self._shape = tuple(shape.tolist())
3939
else:
4040
try:
41-
self._shape = [
42-
shape,
43-
]
41+
self._shape = (shape,)
4442
except Exception as e:
4543
raise TypeError(
4644
"Argument shape must a non-negative integer, "
4745
"or a list/tuple of such integers."
4846
) from e
4947
else:
50-
self._shape = list(shape)
48+
self._shape = tuple(shape)
5149

5250
# Make sure shape is made up a supported types
5351
if not self._verify_positive_integral_list(self._shape):
@@ -118,7 +116,9 @@ class is designed in a way to not have any data container backing up the
118116
"""
119117

120118
def __init__(self, local_accessor: LocalAccessor):
121-
self._data = local_accessor._data
119+
self._data = numpy.empty(
120+
local_accessor._shape, dtype=local_accessor._dtype
121+
)
122122

123123
def __getitem__(self, idx_obj):
124124
"""Returns the value stored at the position represented by idx_obj in

0 commit comments

Comments
 (0)