@@ -281,9 +281,13 @@ function compile(job)
281
281
# TODO : on 1.9, this actually creates a context. cache those.
282
282
entry = GPUCompiler. JuliaContext () do ctx
283
283
mod, meta = GPUCompiler. compile (
284
+ # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true
284
285
:llvm , job; optimize= false , cleanup= false , validate= false , libraries= false
286
+ # :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false
287
+ # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
285
288
)
286
289
290
+ GPUCompiler. link_library! (mod, GPUCompiler. load_runtime (job))
287
291
entryname = LLVM. name (meta. entry)
288
292
289
293
GPUCompiler. optimize_module! (job, mod)
@@ -319,6 +323,8 @@ function compile(job)
319
323
end
320
324
end
321
325
326
+ # GPUCompiler.check_ir(job, mod)
327
+
322
328
LLVM. strip_debuginfo! (mod)
323
329
modstr = string (mod)
324
330
@@ -363,6 +369,38 @@ function to_bytes(x)
363
369
end
364
370
end
365
371
372
+ function Reactant. make_tracer (seen, @nospecialize (prev:: CuTracedArray ), @nospecialize (path), mode; kwargs... )
373
+ x = Base. unsafe_pointer_to_objref (Base. reinterpret (Ptr{Cvoid}, prev. ptr)):: TracedRArray
374
+ Reactant. make_tracer (seen, x, path, mode; kwargs... )
375
+ return prev
376
+ end
377
+
378
+ function get_field_offset (T:: Type , path)
379
+ offset = 0
380
+ current_type = T
381
+
382
+ for field in path
383
+ # Get the field index
384
+ field_idx = if field isa Integer
385
+ field
386
+ else
387
+ @assert field isa Symbol
388
+ findfirst (== (field), fieldnames (current_type))
389
+ end
390
+ if field_idx === nothing
391
+ error (" Field $field not found in type $current_type , fieldnames=$(fieldnames (current_type)) T=$T path=$path " )
392
+ end
393
+
394
+ # Add the offset of this field
395
+ offset += fieldoffset (current_type, field_idx)
396
+
397
+ # Update current_type to the field's type for next iteration
398
+ current_type = fieldtype (current_type, field_idx)
399
+ end
400
+
401
+ return offset
402
+ end
403
+
366
404
Reactant. @reactant_overlay @noinline function (func:: LLVMFunc{F,tt} )(
367
405
args... ;
368
406
convert= Val (false ),
@@ -384,20 +422,19 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
384
422
385
423
wrapper_tys = MLIR. IR. Type[]
386
424
ctx = MLIR. IR. context ()
387
- cullvm_ty = MLIR. IR. Type (MLIR. API. mlirLLVMArrayTypeGet (MLIR. API. mlirLLVMPointerTypeGet (ctx, 1 ), 1 ))
388
- for (i, a) in Tuple{Int, Any}[(0 , func. f), enumerate (args)... ]
389
- if sizeof (a) == 0
425
+ cullvm_ty = MLIR. IR. Type (MLIR. API. mlirLLVMPointerTypeGet (ctx, 1 ))
426
+
427
+ # linearize kernel arguments
428
+ seen = Reactant. OrderedIdDict ()
429
+ prev = Any[func. f, args... ]
430
+ kernelargsym = gensym (" kernelarg" )
431
+ Reactant. make_tracer (seen, prev, (kernelargsym,), Reactant. TracedTrack)
432
+ wrapper_tys = MLIR. IR. Type[]
433
+ for arg in values (seen)
434
+ if ! (arg isa TracedRArray || arg isa TracedRNumber)
390
435
continue
391
436
end
392
- if a isa CuTracedArray
393
- a =
394
- Base. unsafe_pointer_to_objref (Base. reinterpret (Ptr{Cvoid}, a. ptr)):: TracedRArray
395
- end
396
- if a isa TracedRArray || a isa TracedRNumber
397
- push! (wrapper_tys, cullvm_ty)
398
- continue
399
- end
400
- # Per below we assume we can inline all other types directly in
437
+ push! (wrapper_tys, cullvm_ty)
401
438
end
402
439
403
440
sym_name = String (gensym (" call_$fname " ))
@@ -426,20 +463,60 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
426
463
gpu_function_type = MLIR. IR. Type (Reactant. TracedUtils. get_attribute_by_name (gpufunc, " function_type" ))
427
464
428
465
trueidx = 1
429
- for (i, a) in Tuple{Int, Any}[(0 , func. f), enumerate (args)... ]
466
+ allocs = Union{Tuple{MLIR. IR. Value, MLIR. IR. Type}, Nothing}[]
467
+
468
+ llvmptr = MLIR. IR. Type (MLIR. API. mlirLLVMPointerTypeGet (ctx, 0 ))
469
+ i8 = MLIR. IR. Type (UInt8)
470
+ allargs = [func. f, args... ]
471
+ for a in allargs
430
472
if sizeof (a) == 0
473
+ push! (allocs, nothing )
431
474
continue
432
475
end
433
- if a isa CuTracedArray
434
- a =
435
- Base. unsafe_pointer_to_objref (Base. reinterpret (Ptr{Cvoid}, a. ptr)):: TracedRArray
476
+
477
+ # TODO check for only integer and explicitly non cutraced types
478
+ MLIR. IR. block! (wrapbody) do
479
+ argty = MLIR. IR. Type (MLIR. API. mlirLLVMFunctionTypeGetInput (gpu_function_type, trueidx- 1 ))
480
+ trueidx += 1
481
+ c1 = MLIR. IR. result (MLIR. Dialects. llvm. mlir_constant (; res= MLIR. IR. Type (Int64), value= MLIR. IR. Attribute (1 )), 1 )
482
+ alloc = MLIR. IR. result (MLIR. Dialects. llvm. alloca (c1; elem_type= MLIR. IR. Attribute (argty), res= llvmptr), 1 )
483
+ push! (allocs, (alloc, argty))
484
+
485
+ sz = sizeof (a)
486
+ array_ty = MLIR. IR. Type (MLIR. API. mlirLLVMArrayTypeGet (MLIR. IR. Type (Int8), sz))
487
+ cdata = MLIR. IR. result (MLIR. Dialects. llvm. mlir_constant (; res= array_ty, value= MLIR. IR. DenseElementsAttribute (to_bytes (a))), 1 )
488
+ MLIR. Dialects. llvm. store (cdata, alloc)
436
489
end
437
- if a isa TracedRArray || a isa TracedRNumber
438
- push! (rarrays, a)
439
- arg = a. mlir_data
490
+
491
+ end
492
+
493
+ argidx = 1
494
+ for arg in values (seen)
495
+ if ! (arg isa TracedRArray || arg isa TracedRNumber)
496
+ continue
497
+ end
498
+ for p in Reactant. TracedUtils. get_paths (arg)
499
+ if p[1 ] != = kernelargsym
500
+ continue
501
+ end
502
+
503
+ arg = arg. mlir_data
440
504
arg = Reactant. TracedUtils. transpose_val (arg)
441
505
push! (restys, MLIR. IR. type (arg))
442
506
push! (mlir_args, arg)
507
+
508
+ # Get the allocation corresponding to which arg we're doing
509
+ alloc = allocs[p[2 ]][1 ]
510
+
511
+ # we need to now compute the offset in bytes of the path
512
+ julia_arg = allargs[p[2 ]]
513
+
514
+ offset = get_field_offset (typeof (julia_arg), p[3 : end ])
515
+ MLIR. IR. block! (wrapbody) do
516
+ ptr = MLIR. IR. result (MLIR. Dialects. llvm. getelementptr (alloc, MLIR. IR. Value[], res= llvmptr, elem_type= i8, rawConstantIndices= MLIR. IR. Attribute ([Int32 (offset)])), 1 )
517
+ MLIR. Dialects. llvm. store (MLIR. IR. argument (wrapbody, argidx), ptr)
518
+ end
519
+
443
520
push! (
444
521
aliases,
445
522
MLIR. IR. Attribute (
@@ -453,30 +530,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
453
530
),
454
531
),
455
532
)
456
- push! (wrapargs, MLIR . IR . argument (wrapbody, argidx))
533
+
457
534
argidx += 1
458
- trueidx += 1
459
- continue
460
- end
461
-
462
- # TODO check for only integer and explicitly non cutraced types
463
- @show " Warning: using fallback for kernel argument type conversion for argument of type $(Core. Typeof (a)) , if this contains a CuTracedArray this will segfault"
464
- MLIR. IR. block! (wrapbody) do
465
- argty = MLIR. IR. Type (MLIR. API. mlirLLVMFunctionTypeGetInput (gpu_function_type, trueidx- 1 ))
466
- trueidx += 1
467
- c1 = MLIR. IR. result (MLIR. Dialects. llvm. mlir_constant (; res= MLIR. IR. Type (Int64), value= MLIR. IR. Attribute (1 )), 1 )
468
- alloc = MLIR. IR. result (MLIR. Dialects. llvm. alloca (c1; elem_type= MLIR. IR. Attribute (argty), res= MLIR. IR. Type (MLIR. API. mlirLLVMPointerTypeGet (ctx, 0 ))), 1 )
469
-
470
- sz = sizeof (a)
471
- array_ty = MLIR. IR. Type (MLIR. API. mlirLLVMArrayTypeGet (MLIR. IR. Type (Int8), sz))
472
- cdata = MLIR. IR. result (MLIR. Dialects. llvm. mlir_constant (; res= array_ty, value= MLIR. IR. DenseElementsAttribute (to_bytes (a))), 1 )
473
- MLIR. Dialects. llvm. store (cdata, alloc)
474
- argres = MLIR. IR. result (MLIR. Dialects. llvm. load (alloc; res= argty), 1 )
475
- push! (wrapargs, argres)
476
535
end
477
536
end
478
537
479
538
MLIR. IR. block! (wrapbody) do
539
+ for arg in allocs
540
+ if arg === nothing
541
+ continue
542
+ end
543
+ alloc, argty = arg
544
+ argres = MLIR. IR. result (MLIR. Dialects. llvm. load (alloc; res= argty), 1 )
545
+ push! (wrapargs, argres)
546
+ end
480
547
MLIR. Dialects. llvm. call (wrapargs, MLIR. IR. Value[]; callee= MLIR. IR. FlatSymbolRefAttribute (Base. String (fname)), op_bundle_sizes= MLIR. IR. Attribute (Int32[]))
481
548
MLIR. Dialects. llvm. return_ (nothing )
482
549
end
@@ -500,8 +567,14 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
500
567
fn= MLIR. IR. FlatSymbolRefAttribute (sym_name),
501
568
output_operand_aliases= MLIR. IR. Attribute (output_operand_aliases)
502
569
)
503
- for (i, res) in enumerate (rarrays)
504
- res. mlir_data = Reactant. TracedUtils. transpose_val (MLIR. IR. result (call, i))
570
+
571
+ argidx = 1
572
+ for arg in values (seen)
573
+ if ! (arg isa TracedRArray || arg isa TracedRNumber)
574
+ continue
575
+ end
576
+ arg. mlir_data = Reactant. TracedUtils. transpose_val (MLIR. IR. result (call, argidx))
577
+ argidx+= 1
505
578
end
506
579
end
507
580
@@ -546,6 +619,12 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
546
619
return Core. Typeof (res)(f, res. entry)
547
620
end
548
621
622
+ function Reactant. traced_type (
623
+ :: Type{A} , seen:: ST , :: Val{mode} , track_numbers
624
+ ) where {A<: CuTracedArray ,ST,mode}
625
+ return A
626
+ end
627
+
549
628
function Reactant. traced_type (
550
629
:: Type{A} , seen:: ST , :: Val{mode} , track_numbers
551
630
) where {T,N,A<: CUDA.CuArray{T,N} ,ST,mode}
0 commit comments