Skip to content

Commit cd5332f

Browse files
author
Diptorup Deb
authored
Merge pull request #1148 from IntelPython/feature/range_ndrange_type_support
Adds Range and NdRange as supported types in numba_dpex.dpjit.
2 parents 4334448 + 7e44902 commit cd5332f

File tree

9 files changed

+685
-0
lines changed

9 files changed

+685
-0
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
DpctlSyclEvent,
1515
DpctlSyclQueue,
1616
DpnpNdArray,
17+
NdRangeType,
18+
RangeType,
1719
USMNdArray,
1820
)
1921

@@ -195,6 +197,39 @@ def __init__(self, dmm, fe_type):
195197
super(SyclEventModel, self).__init__(dmm, fe_type, members)
196198

197199

200+
class RangeModel(StructModel):
201+
"""The native data model for a
202+
numba_dpex.core.kernel_interface.indexers.Range PyObject.
203+
"""
204+
205+
def __init__(self, dmm, fe_type):
206+
members = [
207+
("ndim", types.int64),
208+
("dim0", types.int64),
209+
("dim1", types.int64),
210+
("dim2", types.int64),
211+
]
212+
super(RangeModel, self).__init__(dmm, fe_type, members)
213+
214+
215+
class NdRangeModel(StructModel):
216+
"""The native data model for a
217+
numba_dpex.core.kernel_interface.indexers.NdRange PyObject.
218+
"""
219+
220+
def __init__(self, dmm, fe_type):
221+
members = [
222+
("ndim", types.int64),
223+
("gdim0", types.int64),
224+
("gdim1", types.int64),
225+
("gdim2", types.int64),
226+
("ldim0", types.int64),
227+
("ldim1", types.int64),
228+
("ldim2", types.int64),
229+
]
230+
super(NdRangeModel, self).__init__(dmm, fe_type, members)
231+
232+
198233
def _init_data_model_manager() -> datamodel.DataModelManager:
199234
"""Initializes a DpexKernelTarget-specific data model manager.
200235
@@ -249,3 +284,8 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
249284

250285
# Register the DpctlSyclEvent type
251286
register_model(DpctlSyclEvent)(SyclEventModel)
287+
# Register the RangeType type
288+
register_model(RangeType)(RangeModel)
289+
290+
# Register the NdRangeType type
291+
register_model(NdRangeType)(NdRangeModel)

numba_dpex/core/kernel_interface/indexers.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
from collections.abc import Iterable
66

7+
from llvmlite import ir as llvmir
8+
from numba.core import cgutils, errors, types
9+
from numba.core.datamodel import default_manager
10+
from numba.extending import intrinsic, overload
11+
712

813
class Range(tuple):
914
"""A data structure to encapsulate a single kernel launch parameter.
@@ -18,6 +23,8 @@ class Range(tuple):
1823
the behavior of `sycl::range`.
1924
"""
2025

26+
UNDEFINED_DIMENSION = -1
27+
2128
def __new__(cls, dim0, dim1=None, dim2=None):
2229
"""Constructs a 1, 2, or 3 dimensional range.
2330
@@ -74,6 +81,50 @@ def size(self):
7481
else:
7582
return self[0]
7683

84+
@property
85+
def ndim(self) -> int:
86+
"""Returns the rank of a Range object.
87+
88+
Returns:
89+
int: Number of dimensions in the Range object
90+
"""
91+
return len(self)
92+
93+
@property
94+
def dim0(self) -> int:
95+
"""Return the extent of the first dimension for the Range object.
96+
97+
Returns:
98+
int: Extent of first dimension for the Range object
99+
"""
100+
return self[0]
101+
102+
@property
103+
def dim1(self) -> int:
104+
"""Return the extent of the second dimension for the Range object.
105+
106+
Returns:
107+
int: Extent of second dimension for the Range object or -1 for 1D
108+
Range
109+
"""
110+
try:
111+
return self[1]
112+
except IndexError:
113+
return Range.UNDEFINED_DIMENSION
114+
115+
@property
116+
def dim2(self) -> int:
117+
"""Return the extent of the second dimension for the Range object.
118+
119+
Returns:
120+
int: Extent of second dimension for the Range object or -1 for 1D or
121+
2D Range
122+
"""
123+
try:
124+
return self[2]
125+
except IndexError:
126+
return Range.UNDEFINED_DIMENSION
127+
77128

78129
class NdRange:
79130
"""A class to encapsulate all kernel launch parameters.
@@ -169,3 +220,170 @@ def __repr__(self):
169220
str: str representation for NdRange class.
170221
"""
171222
return self.__str__()
223+
224+
def __eq__(self, other):
225+
if isinstance(other, NdRange):
226+
return (
227+
self.global_range == other.global_range
228+
and self.local_range == other.local_range
229+
)
230+
else:
231+
return False
232+
233+
234+
@intrinsic
235+
def _intrin_range_alloc(typingctx, ty_dim0, ty_dim1, ty_dim2, ty_range):
236+
ty_retty = ty_range.instance_type
237+
sig = ty_retty(
238+
ty_dim0,
239+
ty_dim1,
240+
ty_dim2,
241+
ty_range,
242+
)
243+
244+
def codegen(context, builder, sig, args):
245+
typ = sig.return_type
246+
dim0, dim1, dim2, _ = args
247+
range_struct = cgutils.create_struct_proxy(typ)(context, builder)
248+
range_struct.dim0 = dim0
249+
250+
if not isinstance(sig.args[1], types.NoneType):
251+
range_struct.dim1 = dim1
252+
else:
253+
range_struct.dim1 = llvmir.Constant(
254+
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
255+
)
256+
257+
if not isinstance(sig.args[2], types.NoneType):
258+
range_struct.dim2 = dim2
259+
else:
260+
range_struct.dim2 = llvmir.Constant(
261+
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
262+
)
263+
264+
range_struct.ndim = llvmir.Constant(llvmir.types.IntType(64), typ.ndim)
265+
266+
return range_struct._getvalue()
267+
268+
return sig, codegen
269+
270+
271+
@intrinsic
272+
def _intrin_ndrange_alloc(
273+
typingctx, ty_global_range, ty_local_range, ty_ndrange
274+
):
275+
ty_retty = ty_ndrange.instance_type
276+
sig = ty_retty(
277+
ty_global_range,
278+
ty_local_range,
279+
ty_ndrange,
280+
)
281+
range_datamodel = default_manager.lookup(ty_global_range)
282+
283+
def codegen(context, builder, sig, args):
284+
typ = sig.return_type
285+
286+
global_range, local_range, _ = args
287+
ndrange_struct = cgutils.create_struct_proxy(typ)(context, builder)
288+
ndrange_struct.ndim = llvmir.Constant(
289+
llvmir.types.IntType(64), typ.ndim
290+
)
291+
ndrange_struct.gdim0 = builder.extract_value(
292+
global_range,
293+
range_datamodel.get_field_position("dim0"),
294+
)
295+
ndrange_struct.gdim1 = builder.extract_value(
296+
global_range,
297+
range_datamodel.get_field_position("dim1"),
298+
)
299+
ndrange_struct.gdim2 = builder.extract_value(
300+
global_range,
301+
range_datamodel.get_field_position("dim2"),
302+
)
303+
ndrange_struct.ldim0 = builder.extract_value(
304+
local_range,
305+
range_datamodel.get_field_position("dim0"),
306+
)
307+
ndrange_struct.ldim1 = builder.extract_value(
308+
local_range,
309+
range_datamodel.get_field_position("dim1"),
310+
)
311+
ndrange_struct.ldim2 = builder.extract_value(
312+
local_range,
313+
range_datamodel.get_field_position("dim2"),
314+
)
315+
316+
return ndrange_struct._getvalue()
317+
318+
return sig, codegen
319+
320+
321+
@overload(Range)
322+
def _ol_range_init(dim0, dim1=None, dim2=None):
323+
"""Numba overload of the Range constructor to make it usable inside an
324+
njit and dpjit decorated function.
325+
326+
"""
327+
from numba_dpex.core.types import RangeType
328+
329+
ndims = 1
330+
ty_optional_dims = (dim1, dim2)
331+
332+
# A Range should at least have the 0th dimension populated
333+
if not isinstance(dim0, types.Integer):
334+
raise errors.TypingError(
335+
"Expected a Range's dimension should to be an Integer value, but "
336+
"encountered " + dim0.name
337+
)
338+
339+
for ty_dim in ty_optional_dims:
340+
if isinstance(ty_dim, types.Integer):
341+
ndims += 1
342+
elif ty_dim is not None:
343+
raise errors.TypingError(
344+
"Expected a Range's dimension to be an Integer value, "
345+
f"but {type(ty_dim)} was provided."
346+
)
347+
348+
ret_ty = RangeType(ndims)
349+
350+
def impl(dim0, dim1=None, dim2=None):
351+
return _intrin_range_alloc(dim0, dim1, dim2, ret_ty)
352+
353+
return impl
354+
355+
356+
@overload(NdRange)
357+
def _ol_ndrange_init(global_range, local_range):
358+
"""Numba overload of the NdRange constructor to make it usable inside an
359+
njit and dpjit decorated function.
360+
361+
"""
362+
from numba_dpex.core.exceptions import UnmatchedNumberOfRangeDimsError
363+
from numba_dpex.core.types import NdRangeType, RangeType
364+
365+
if not isinstance(global_range, RangeType):
366+
raise errors.TypingError(
367+
"Only global range values specified as a Range are "
368+
"supported inside dpjit"
369+
)
370+
371+
if not isinstance(local_range, RangeType):
372+
raise errors.TypingError(
373+
"Only local range values specified as a Range are "
374+
"supported inside dpjit"
375+
)
376+
377+
if not global_range.ndim == local_range.ndim:
378+
raise UnmatchedNumberOfRangeDimsError(
379+
kernel_name="",
380+
global_ndims=global_range.ndim,
381+
local_ndims=local_range.ndim,
382+
)
383+
384+
ret_ty = NdRangeType(global_range.ndim)
385+
386+
def impl(global_range, local_range):
387+
return _intrin_ndrange_alloc(global_range, local_range, ret_ty)
388+
389+
return impl

numba_dpex/core/types/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
uint64,
2727
void,
2828
)
29+
from .range_types import NdRangeType, RangeType
2930
from .usm_ndarray_type import USMNdArray
3031

3132
usm_ndarray = USMNdArray
@@ -35,6 +36,8 @@
3536
"DpctlSyclQueue",
3637
"DpctlSyclEvent",
3738
"DpnpNdArray",
39+
"RangeType",
40+
"NdRangeType",
3841
"USMNdArray",
3942
"none",
4043
"boolean",

0 commit comments

Comments
 (0)