4
4
5
5
from collections .abc import Iterable
6
6
7
+ from llvmlite import ir as llvmir
8
+ from numba .core import cgutils , errors , types
9
+ from numba .core .datamodel import default_manager
10
+ from numba .extending import intrinsic , overload
11
+
7
12
8
13
class Range (tuple ):
9
14
"""A data structure to encapsulate a single kernel launch parameter.
@@ -18,6 +23,8 @@ class Range(tuple):
18
23
the behavior of `sycl::range`.
19
24
"""
20
25
26
+ UNDEFINED_DIMENSION = - 1
27
+
21
28
def __new__ (cls , dim0 , dim1 = None , dim2 = None ):
22
29
"""Constructs a 1, 2, or 3 dimensional range.
23
30
@@ -74,6 +81,50 @@ def size(self):
74
81
else :
75
82
return self [0 ]
76
83
84
+ @property
85
+ def ndim (self ) -> int :
86
+ """Returns the rank of a Range object.
87
+
88
+ Returns:
89
+ int: Number of dimensions in the Range object
90
+ """
91
+ return len (self )
92
+
93
+ @property
94
+ def dim0 (self ) -> int :
95
+ """Return the extent of the first dimension for the Range object.
96
+
97
+ Returns:
98
+ int: Extent of first dimension for the Range object
99
+ """
100
+ return self [0 ]
101
+
102
+ @property
103
+ def dim1 (self ) -> int :
104
+ """Return the extent of the second dimension for the Range object.
105
+
106
+ Returns:
107
+ int: Extent of second dimension for the Range object or -1 for 1D
108
+ Range
109
+ """
110
+ try :
111
+ return self [1 ]
112
+ except IndexError :
113
+ return Range .UNDEFINED_DIMENSION
114
+
115
+ @property
116
+ def dim2 (self ) -> int :
117
+ """Return the extent of the second dimension for the Range object.
118
+
119
+ Returns:
120
+ int: Extent of second dimension for the Range object or -1 for 1D or
121
+ 2D Range
122
+ """
123
+ try :
124
+ return self [2 ]
125
+ except IndexError :
126
+ return Range .UNDEFINED_DIMENSION
127
+
77
128
78
129
class NdRange :
79
130
"""A class to encapsulate all kernel launch parameters.
@@ -169,3 +220,170 @@ def __repr__(self):
169
220
str: str representation for NdRange class.
170
221
"""
171
222
return self .__str__ ()
223
+
224
+ def __eq__ (self , other ):
225
+ if isinstance (other , NdRange ):
226
+ return (
227
+ self .global_range == other .global_range
228
+ and self .local_range == other .local_range
229
+ )
230
+ else :
231
+ return False
232
+
233
+
234
+ @intrinsic
235
+ def _intrin_range_alloc (typingctx , ty_dim0 , ty_dim1 , ty_dim2 , ty_range ):
236
+ ty_retty = ty_range .instance_type
237
+ sig = ty_retty (
238
+ ty_dim0 ,
239
+ ty_dim1 ,
240
+ ty_dim2 ,
241
+ ty_range ,
242
+ )
243
+
244
+ def codegen (context , builder , sig , args ):
245
+ typ = sig .return_type
246
+ dim0 , dim1 , dim2 , _ = args
247
+ range_struct = cgutils .create_struct_proxy (typ )(context , builder )
248
+ range_struct .dim0 = dim0
249
+
250
+ if not isinstance (sig .args [1 ], types .NoneType ):
251
+ range_struct .dim1 = dim1
252
+ else :
253
+ range_struct .dim1 = llvmir .Constant (
254
+ llvmir .types .IntType (64 ), Range .UNDEFINED_DIMENSION
255
+ )
256
+
257
+ if not isinstance (sig .args [2 ], types .NoneType ):
258
+ range_struct .dim2 = dim2
259
+ else :
260
+ range_struct .dim2 = llvmir .Constant (
261
+ llvmir .types .IntType (64 ), Range .UNDEFINED_DIMENSION
262
+ )
263
+
264
+ range_struct .ndim = llvmir .Constant (llvmir .types .IntType (64 ), typ .ndim )
265
+
266
+ return range_struct ._getvalue ()
267
+
268
+ return sig , codegen
269
+
270
+
271
+ @intrinsic
272
+ def _intrin_ndrange_alloc (
273
+ typingctx , ty_global_range , ty_local_range , ty_ndrange
274
+ ):
275
+ ty_retty = ty_ndrange .instance_type
276
+ sig = ty_retty (
277
+ ty_global_range ,
278
+ ty_local_range ,
279
+ ty_ndrange ,
280
+ )
281
+ range_datamodel = default_manager .lookup (ty_global_range )
282
+
283
+ def codegen (context , builder , sig , args ):
284
+ typ = sig .return_type
285
+
286
+ global_range , local_range , _ = args
287
+ ndrange_struct = cgutils .create_struct_proxy (typ )(context , builder )
288
+ ndrange_struct .ndim = llvmir .Constant (
289
+ llvmir .types .IntType (64 ), typ .ndim
290
+ )
291
+ ndrange_struct .gdim0 = builder .extract_value (
292
+ global_range ,
293
+ range_datamodel .get_field_position ("dim0" ),
294
+ )
295
+ ndrange_struct .gdim1 = builder .extract_value (
296
+ global_range ,
297
+ range_datamodel .get_field_position ("dim1" ),
298
+ )
299
+ ndrange_struct .gdim2 = builder .extract_value (
300
+ global_range ,
301
+ range_datamodel .get_field_position ("dim2" ),
302
+ )
303
+ ndrange_struct .ldim0 = builder .extract_value (
304
+ local_range ,
305
+ range_datamodel .get_field_position ("dim0" ),
306
+ )
307
+ ndrange_struct .ldim1 = builder .extract_value (
308
+ local_range ,
309
+ range_datamodel .get_field_position ("dim1" ),
310
+ )
311
+ ndrange_struct .ldim2 = builder .extract_value (
312
+ local_range ,
313
+ range_datamodel .get_field_position ("dim2" ),
314
+ )
315
+
316
+ return ndrange_struct ._getvalue ()
317
+
318
+ return sig , codegen
319
+
320
+
321
+ @overload (Range )
322
+ def _ol_range_init (dim0 , dim1 = None , dim2 = None ):
323
+ """Numba overload of the Range constructor to make it usable inside an
324
+ njit and dpjit decorated function.
325
+
326
+ """
327
+ from numba_dpex .core .types import RangeType
328
+
329
+ ndims = 1
330
+ ty_optional_dims = (dim1 , dim2 )
331
+
332
+ # A Range should at least have the 0th dimension populated
333
+ if not isinstance (dim0 , types .Integer ):
334
+ raise errors .TypingError (
335
+ "Expected a Range's dimension should to be an Integer value, but "
336
+ "encountered " + dim0 .name
337
+ )
338
+
339
+ for ty_dim in ty_optional_dims :
340
+ if isinstance (ty_dim , types .Integer ):
341
+ ndims += 1
342
+ elif ty_dim is not None :
343
+ raise errors .TypingError (
344
+ "Expected a Range's dimension to be an Integer value, "
345
+ f"but { type (ty_dim )} was provided."
346
+ )
347
+
348
+ ret_ty = RangeType (ndims )
349
+
350
+ def impl (dim0 , dim1 = None , dim2 = None ):
351
+ return _intrin_range_alloc (dim0 , dim1 , dim2 , ret_ty )
352
+
353
+ return impl
354
+
355
+
356
+ @overload (NdRange )
357
+ def _ol_ndrange_init (global_range , local_range ):
358
+ """Numba overload of the NdRange constructor to make it usable inside an
359
+ njit and dpjit decorated function.
360
+
361
+ """
362
+ from numba_dpex .core .exceptions import UnmatchedNumberOfRangeDimsError
363
+ from numba_dpex .core .types import NdRangeType , RangeType
364
+
365
+ if not isinstance (global_range , RangeType ):
366
+ raise errors .TypingError (
367
+ "Only global range values specified as a Range are "
368
+ "supported inside dpjit"
369
+ )
370
+
371
+ if not isinstance (local_range , RangeType ):
372
+ raise errors .TypingError (
373
+ "Only local range values specified as a Range are "
374
+ "supported inside dpjit"
375
+ )
376
+
377
+ if not global_range .ndim == local_range .ndim :
378
+ raise UnmatchedNumberOfRangeDimsError (
379
+ kernel_name = "" ,
380
+ global_ndims = global_range .ndim ,
381
+ local_ndims = local_range .ndim ,
382
+ )
383
+
384
+ ret_ty = NdRangeType (global_range .ndim )
385
+
386
+ def impl (global_range , local_range ):
387
+ return _intrin_ndrange_alloc (global_range , local_range , ret_ty )
388
+
389
+ return impl
0 commit comments