11
11
12
12
import dpctl
13
13
from llvmlite import ir as llvmir
14
- from numba .core import types
14
+ from numba .core import cgutils , types
15
15
from numba .core .cpu import CPUContext
16
16
17
17
from numba_dpex import utils
18
18
from numba_dpex .core .types import USMNdArray
19
- from numba_dpex .core .types .kernel_api .local_accessor import LocalAccessorType
19
+ from numba_dpex .core .types .kernel_api .local_accessor import (
20
+ DpctlMDLocalAccessorType ,
21
+ LocalAccessorType ,
22
+ )
20
23
from numba_dpex .dpctl_iface ._helpers import numba_type_to_dpctl_typenum
21
24
22
25
@@ -120,40 +123,6 @@ def print_kernel_arg_list(self) -> None:
120
123
for karg in args_list :
121
124
print (f" { karg .llvm_val } of typeid { karg .typeid } " )
122
125
123
- def _allocate_local_accessor_metadata_struct (self ):
124
- """Allocates a struct into the current function to store the metadata
125
- that should be passed to libsyclinterface to allocate a
126
- sycl::local_accessor object. The constructor of the sycl::local_accessor
127
- class is: local_accessor<Ty, Ndim>(range<Ndims> r).
128
-
129
- For this reason, the struct is allocated as:
130
-
131
- LOCAL_ACCESSOR_MDSTRUCT_TYPE = llvmir.LiteralStructType(
132
- [
133
- llvmir.IntType(64), # Ndim (0..3]
134
- llvmir.IntType(32), # typeid
135
- llvmir.IntType(64), # Dim0 extent
136
- llvmir.IntType(64), # Dim1 extent or NULL
137
- llvmir.IntType(64), # Dim2 extent or NULL
138
- ]
139
- )
140
- """
141
- local_accessor_mdstruct_type = llvmir .LiteralStructType (
142
- [
143
- llvmir .IntType (64 ),
144
- llvmir .IntType (32 ),
145
- llvmir .IntType (64 ),
146
- llvmir .IntType (64 ),
147
- llvmir .IntType (64 ),
148
- ]
149
- )
150
-
151
- struct_ref = None
152
- with self ._builder .goto_entry_block ():
153
- struct_ref = self ._builder .alloca (typ = local_accessor_mdstruct_type )
154
-
155
- return struct_ref
156
-
157
126
def _build_arg (self , llvm_val , numba_type ):
158
127
"""Returns a KernelArg to be passed to a DPCTLQueue_Submit call.
159
128
@@ -250,7 +219,7 @@ def _store_val_into_struct(self, struct_ref, index, val):
250
219
)
251
220
252
221
def _build_local_accessor_metadata_arg (
253
- self , llvm_val , arg_type , data_attr_ty
222
+ self , llvm_val , arg_type : LocalAccessorType , data_attr_ty
254
223
):
255
224
"""Handles the special case of building the kernel argument for the data
256
225
attribute of a kernel_api.LocalAccessor object.
@@ -267,91 +236,27 @@ def _build_local_accessor_metadata_arg(
267
236
handle proper device memory allocation.
268
237
"""
269
238
270
- kernel_data_model = self ._kernel_dmm .lookup (arg_type )
271
- host_data_model = self ._context .data_model_manager .lookup (arg_type )
272
- shape_member = kernel_data_model .get_member_fe_type ("shape" )
273
- shape_member_pos = host_data_model .get_field_position ("shape" )
274
- ndim = shape_member .count
275
-
276
- mdstruct_ref = self ._allocate_local_accessor_metadata_struct ()
239
+ ndim = arg_type .ndim
277
240
278
- # Store the number of dimensions in the local accessor
279
- self ._store_val_into_struct (
280
- mdstruct_ref ,
281
- index = 0 ,
282
- val = self ._context .get_constant (types .int64 , ndim ),
283
- )
284
- # Get the underlying dtype of the data (a CPointer) attribute of a
285
- # local_accessor object
286
- self ._store_val_into_struct (
287
- mdstruct_ref ,
288
- index = 1 ,
289
- val = numba_type_to_dpctl_typenum (self ._context , data_attr_ty .dtype ),
290
- )
291
- # Extract and store the shape values from array into mdstruct
292
- shape_attr = self ._builder .gep (
293
- llvm_val ,
294
- [
295
- self ._context .get_constant (types .int32 , 0 ),
296
- self ._context .get_constant (types .int32 , shape_member_pos ),
297
- ],
298
- )
299
- # Store the extent of the 1st dimension of the local accessor
300
- dim0_shape_ext = self ._builder .gep (
301
- shape_attr ,
302
- [
303
- self ._context .get_constant (types .int32 , 0 ),
304
- self ._context .get_constant (types .int32 , 0 ),
305
- ],
241
+ md_proxy = cgutils .create_struct_proxy (DpctlMDLocalAccessorType ())(
242
+ self ._context ,
243
+ self ._builder ,
306
244
)
307
- self ._store_val_into_struct (
308
- mdstruct_ref ,
309
- index = 2 ,
310
- val = self ._builder .load (dim0_shape_ext ),
245
+ la_proxy = cgutils .create_struct_proxy (arg_type )(
246
+ self ._context , self ._builder , value = self ._builder .load (llvm_val )
311
247
)
312
248
313
- if ndim == 2 :
314
- dim1_shape_ext = self ._builder .gep (
315
- shape_attr ,
316
- [
317
- self ._context .get_constant (types .int32 , 0 ),
318
- self ._context .get_constant (types .int32 , 1 ),
319
- ],
320
- )
321
- self ._store_val_into_struct (
322
- mdstruct_ref ,
323
- index = 3 ,
324
- val = self ._builder .load (dim1_shape_ext ),
325
- )
326
- else :
327
- self ._store_val_into_struct (
328
- mdstruct_ref ,
329
- index = 3 ,
330
- val = self ._context .get_constant (types .int64 , 1 ),
331
- )
332
-
333
- if ndim == 3 :
334
- dim2_shape_ext = self ._builder .gep (
335
- shape_attr ,
336
- [
337
- self ._context .get_constant (types .int32 , 0 ),
338
- self ._context .get_constant (types .int32 , 2 ),
339
- ],
340
- )
341
- self ._store_val_into_struct (
342
- mdstruct_ref ,
343
- index = 4 ,
344
- val = self ._builder .load (dim2_shape_ext ),
345
- )
346
- else :
347
- self ._store_val_into_struct (
348
- mdstruct_ref ,
349
- index = 4 ,
350
- val = self ._context .get_constant (types .int64 , 1 ),
351
- )
249
+ md_proxy .ndim = self ._context .get_constant (types .int64 , ndim )
250
+ md_proxy .dpctl_type_id = numba_type_to_dpctl_typenum (
251
+ self ._context , data_attr_ty .dtype
252
+ )
253
+ for i , val in enumerate (
254
+ cgutils .unpack_tuple (self ._builder , la_proxy .shape )
255
+ ):
256
+ setattr (md_proxy , f"dim{ i } " , val )
352
257
353
258
return self ._build_arg (
354
- llvm_val = mdstruct_ref ,
259
+ llvm_val = md_proxy . _getpointer () ,
355
260
numba_type = LocalAccessorType (
356
261
ndim , dpctl .tensor .dtype (data_attr_ty .dtype .name )
357
262
),
0 commit comments