@@ -225,7 +225,7 @@ struct LLVMFunc{F,tt}
225
225
entry:: String
226
226
end
227
227
228
- function Base. getproperty (f:: LLVMFunc{F, tt} , sym:: Symbol ) where {F, tt}
228
+ function Base. getproperty (f:: LLVMFunc{F,tt} , sym:: Symbol ) where {F,tt}
229
229
if sym === :fun
230
230
f
231
231
else
235
235
236
236
# TODO in the future we may want to avoid doing a second cufunction compilation
237
237
# for computing the thread/block count (or potentially do it ourselves).
238
- @noinline function CUDA. launch_configuration (f:: LLVMFunc{F, tt} ; shmem:: Union{Integer, Base.Callable} = 0 , max_threads:: Integer = 0 ) where {F, tt}
239
- CUDA. launch_configuration (Base. inferencebarrier (CUDA. cufunction)(f. f, Tuple{tt. parameters[2 : end ]. .. }). fun; shmem, max_threads)
238
+ @noinline function CUDA. launch_configuration (
239
+ f:: LLVMFunc{F,tt} ; shmem:: Union{Integer,Base.Callable} = 0 , max_threads:: Integer = 0
240
+ ) where {F,tt}
241
+ return CUDA. launch_configuration (
242
+ Base. inferencebarrier (CUDA. cufunction)(f. f, Tuple{tt. parameters[2 : end ]. .. }). fun;
243
+ shmem,
244
+ max_threads,
245
+ )
240
246
end
241
247
242
248
const GPUCompiler = CUDA. GPUCompiler
@@ -282,7 +288,12 @@ function compile(job)
282
288
entry = GPUCompiler. JuliaContext () do ctx
283
289
mod, meta = GPUCompiler. compile (
284
290
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true
285
- :llvm , job; optimize= false , cleanup= false , validate= false , libraries= false
291
+ :llvm ,
292
+ job;
293
+ optimize= false ,
294
+ cleanup= false ,
295
+ validate= false ,
296
+ libraries= false ,
286
297
# :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false
287
298
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
288
299
)
@@ -357,19 +368,21 @@ function link(job, compiled)
357
368
end
358
369
359
370
function to_bytes (x)
360
- sz = sizeof (x)
361
- ref = Ref (x)
362
- GC. @preserve ref begin
363
- ptr = Base. reinterpret (Ptr{UInt8}, Base. unsafe_convert (Ptr{Cvoid}, ref))
364
- vec = Vector {UInt8} (undef, sz)
365
- for i in 1 : sz
366
- @inbounds vec[i] = Base. unsafe_load (ptr, i)
367
- end
368
- vec
369
- end
370
- end
371
-
372
- function Reactant. make_tracer (seen, @nospecialize (prev:: CuTracedArray ), @nospecialize (path), mode; kwargs... )
371
+ sz = sizeof (x)
372
+ ref = Ref (x)
373
+ GC. @preserve ref begin
374
+ ptr = Base. reinterpret (Ptr{UInt8}, Base. unsafe_convert (Ptr{Cvoid}, ref))
375
+ vec = Vector {UInt8} (undef, sz)
376
+ for i in 1 : sz
377
+ @inbounds vec[i] = Base. unsafe_load (ptr, i)
378
+ end
379
+ vec
380
+ end
381
+ end
382
+
383
+ function Reactant. make_tracer (
384
+ seen, @nospecialize (prev:: CuTracedArray ), @nospecialize (path), mode; kwargs...
385
+ )
373
386
x = Base. unsafe_pointer_to_objref (Base. reinterpret (Ptr{Cvoid}, prev. ptr)):: TracedRArray
374
387
Reactant. make_tracer (seen, x, path, mode; kwargs... )
375
388
return prev
@@ -388,7 +401,9 @@ function get_field_offset(T::Type, path)
388
401
findfirst (== (field), fieldnames (current_type))
389
402
end
390
403
if field_idx === nothing
391
- error (" Field $field not found in type $current_type , fieldnames=$(fieldnames (current_type)) T=$T path=$path " )
404
+ error (
405
+ " Field $field not found in type $current_type , fieldnames=$(fieldnames (current_type)) T=$T path=$path " ,
406
+ )
392
407
end
393
408
394
409
# Add the offset of this field
@@ -419,7 +434,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
419
434
rarrays = TracedRArray[]
420
435
421
436
fname = func. entry
422
-
437
+
423
438
wrapper_tys = MLIR. IR. Type[]
424
439
ctx = MLIR. IR. context ()
425
440
cullvm_ty = MLIR. IR. Type (MLIR. API. mlirLLVMPointerTypeGet (ctx, 1 ))
@@ -436,19 +451,23 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
436
451
end
437
452
push! (wrapper_tys, cullvm_ty)
438
453
end
439
-
454
+
440
455
sym_name = String (gensym (" call_$fname " ))
441
456
mod = MLIR. IR. mmodule ()
442
- CConv= MLIR. IR. Attribute (MLIR. API. mlirLLVMCConvAttrGet (ctx, MLIR. API. MlirLLVMCConvPTX_Kernel))
457
+ CConv = MLIR. IR. Attribute (
458
+ MLIR. API. mlirLLVMCConvAttrGet (ctx, MLIR. API. MlirLLVMCConvPTX_Kernel)
459
+ )
443
460
voidty = MLIR. IR. Type (MLIR. API. mlirLLVMVoidTypeGet (ctx))
444
- wrapftype = MLIR. IR. Type (MLIR. API. mlirLLVMFunctionTypeGet (voidty, length (wrapper_tys), wrapper_tys, false ))
461
+ wrapftype = MLIR. IR. Type (
462
+ MLIR. API. mlirLLVMFunctionTypeGet (voidty, length (wrapper_tys), wrapper_tys, false )
463
+ )
445
464
wrapfunc = MLIR. IR. block! (MLIR. IR. body (mod)) do
446
465
return MLIR. Dialects. llvm. func (;
447
466
sym_name,
448
467
sym_visibility= MLIR. IR. Attribute (" private" ),
449
468
function_type= wrapftype,
450
469
body= MLIR. IR. Region (),
451
- CConv
470
+ CConv,
452
471
)
453
472
end
454
473
wrapbody = MLIR. IR. Block (wrapper_tys, [MLIR. IR. Location () for _ in wrapper_tys])
@@ -459,11 +478,17 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
459
478
460
479
symtab = MLIR. IR. SymbolTable (MLIR. IR. Operation (mod))
461
480
gpufunc = MLIR. IR. lookup (symtab, fname)
462
- MLIR. IR. attr! (gpufunc, " CConv" , MLIR. IR. Attribute (MLIR. API. mlirLLVMCConvAttrGet (ctx, MLIR. API. MlirLLVMCConvC)))
463
- gpu_function_type = MLIR. IR. Type (Reactant. TracedUtils. get_attribute_by_name (gpufunc, " function_type" ))
481
+ MLIR. IR. attr! (
482
+ gpufunc,
483
+ " CConv" ,
484
+ MLIR. IR. Attribute (MLIR. API. mlirLLVMCConvAttrGet (ctx, MLIR. API. MlirLLVMCConvC)),
485
+ )
486
+ gpu_function_type = MLIR. IR. Type (
487
+ Reactant. TracedUtils. get_attribute_by_name (gpufunc, " function_type" )
488
+ )
464
489
465
490
trueidx = 1
466
- allocs = Union{Tuple{MLIR. IR. Value, MLIR. IR. Type}, Nothing}[]
491
+ allocs = Union{Tuple{MLIR. IR. Value,MLIR. IR. Type},Nothing}[]
467
492
468
493
llvmptr = MLIR. IR. Type (MLIR. API. mlirLLVMPointerTypeGet (ctx, 0 ))
469
494
i8 = MLIR. IR. Type (UInt8)
@@ -476,18 +501,34 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
476
501
477
502
# TODO check for only integer and explicitly non cutraced types
478
503
MLIR. IR. block! (wrapbody) do
479
- argty = MLIR. IR. Type (MLIR. API. mlirLLVMFunctionTypeGetInput (gpu_function_type, trueidx- 1 ))
504
+ argty = MLIR. IR. Type (
505
+ MLIR. API. mlirLLVMFunctionTypeGetInput (gpu_function_type, trueidx - 1 )
506
+ )
480
507
trueidx += 1
481
- c1 = MLIR. IR. result (MLIR. Dialects. llvm. mlir_constant (; res= MLIR. IR. Type (Int64), value= MLIR. IR. Attribute (1 )), 1 )
482
- alloc = MLIR. IR. result (MLIR. Dialects. llvm. alloca (c1; elem_type= MLIR. IR. Attribute (argty), res= llvmptr), 1 )
508
+ c1 = MLIR. IR. result (
509
+ MLIR. Dialects. llvm. mlir_constant (;
510
+ res= MLIR. IR. Type (Int64), value= MLIR. IR. Attribute (1 )
511
+ ),
512
+ 1 ,
513
+ )
514
+ alloc = MLIR. IR. result (
515
+ MLIR. Dialects. llvm. alloca (
516
+ c1; elem_type= MLIR. IR. Attribute (argty), res= llvmptr
517
+ ),
518
+ 1 ,
519
+ )
483
520
push! (allocs, (alloc, argty))
484
521
485
522
sz = sizeof (a)
486
523
array_ty = MLIR. IR. Type (MLIR. API. mlirLLVMArrayTypeGet (MLIR. IR. Type (Int8), sz))
487
- cdata = MLIR. IR. result (MLIR. Dialects. llvm. mlir_constant (; res= array_ty, value= MLIR. IR. DenseElementsAttribute (to_bytes (a))), 1 )
524
+ cdata = MLIR. IR. result (
525
+ MLIR. Dialects. llvm. mlir_constant (;
526
+ res= array_ty, value= MLIR. IR. DenseElementsAttribute (to_bytes (a))
527
+ ),
528
+ 1 ,
529
+ )
488
530
MLIR. Dialects. llvm. store (cdata, alloc)
489
531
end
490
-
491
532
end
492
533
493
534
argidx = 1
@@ -499,21 +540,30 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
499
540
if p[1 ] != = kernelargsym
500
541
continue
501
542
end
502
-
543
+
503
544
arg = arg. mlir_data
504
545
arg = Reactant. TracedUtils. transpose_val (arg)
505
546
push! (restys, MLIR. IR. type (arg))
506
547
push! (mlir_args, arg)
507
-
548
+
508
549
# Get the allocation corresponding to which arg we're doing
509
550
alloc = allocs[p[2 ]][1 ]
510
551
511
552
# we need to now compute the offset in bytes of the path
512
553
julia_arg = allargs[p[2 ]]
513
-
554
+
514
555
offset = get_field_offset (typeof (julia_arg), p[3 : end ])
515
556
MLIR. IR. block! (wrapbody) do
516
- ptr = MLIR. IR. result (MLIR. Dialects. llvm. getelementptr (alloc, MLIR. IR. Value[], res= llvmptr, elem_type= i8, rawConstantIndices= MLIR. IR. Attribute ([Int32 (offset)])), 1 )
557
+ ptr = MLIR. IR. result (
558
+ MLIR. Dialects. llvm. getelementptr (
559
+ alloc,
560
+ MLIR. IR. Value[];
561
+ res= llvmptr,
562
+ elem_type= i8,
563
+ rawConstantIndices= MLIR. IR. Attribute ([Int32 (offset)]),
564
+ ),
565
+ 1 ,
566
+ )
517
567
MLIR. Dialects. llvm. store (MLIR. IR. argument (wrapbody, argidx), ptr)
518
568
end
519
569
@@ -530,11 +580,11 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
530
580
),
531
581
),
532
582
)
533
-
583
+
534
584
argidx += 1
535
585
end
536
586
end
537
-
587
+
538
588
MLIR. IR. block! (wrapbody) do
539
589
for arg in allocs
540
590
if arg === nothing
@@ -544,7 +594,12 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
544
594
argres = MLIR. IR. result (MLIR. Dialects. llvm. load (alloc; res= argty), 1 )
545
595
push! (wrapargs, argres)
546
596
end
547
- MLIR. Dialects. llvm. call (wrapargs, MLIR. IR. Value[]; callee= MLIR. IR. FlatSymbolRefAttribute (Base. String (fname)), op_bundle_sizes= MLIR. IR. Attribute (Int32[]))
597
+ MLIR. Dialects. llvm. call (
598
+ wrapargs,
599
+ MLIR. IR. Value[];
600
+ callee= MLIR. IR. FlatSymbolRefAttribute (Base. String (fname)),
601
+ op_bundle_sizes= MLIR. IR. Attribute (Int32[]),
602
+ )
548
603
MLIR. Dialects. llvm. return_ (nothing )
549
604
end
550
605
@@ -565,7 +620,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
565
620
mlir_args;
566
621
result_0= restys,
567
622
fn= MLIR. IR. FlatSymbolRefAttribute (sym_name),
568
- output_operand_aliases= MLIR. IR. Attribute (output_operand_aliases)
623
+ output_operand_aliases= MLIR. IR. Attribute (output_operand_aliases),
569
624
)
570
625
571
626
argidx = 1
@@ -574,7 +629,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
574
629
continue
575
630
end
576
631
arg. mlir_data = Reactant. TracedUtils. transpose_val (MLIR. IR. result (call, argidx))
577
- argidx+= 1
632
+ argidx += 1
578
633
end
579
634
end
580
635
0 commit comments