Skip to content

Commit 992b080

Browse files
Format code (#509)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 95341ff commit 992b080

File tree

9 files changed

+194
-141
lines changed

9 files changed

+194
-141
lines changed

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ for file in [
2828
"Gpu.jl",
2929
"Affine.jl",
3030
"TPU.jl",
31-
"Triton.jl"
31+
"Triton.jl",
3232
]
3333
build_file(joinpath(src_dir, "mlir", "Dialects", file))
3434
end

ext/ReactantCUDAExt.jl

Lines changed: 95 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ struct LLVMFunc{F,tt}
225225
entry::String
226226
end
227227

228-
function Base.getproperty(f::LLVMFunc{F, tt}, sym::Symbol) where {F, tt}
228+
function Base.getproperty(f::LLVMFunc{F,tt}, sym::Symbol) where {F,tt}
229229
if sym === :fun
230230
f
231231
else
@@ -235,8 +235,14 @@ end
235235

236236
# TODO in the future we may want to avoid doing a second cufunction compilation
237237
# for computing the thread/block count (or potentially do it ourselves).
238-
@noinline function CUDA.launch_configuration(f::LLVMFunc{F, tt}; shmem::Union{Integer, Base.Callable}=0, max_threads::Integer=0) where {F, tt}
239-
CUDA.launch_configuration(Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun; shmem, max_threads)
238+
@noinline function CUDA.launch_configuration(
239+
f::LLVMFunc{F,tt}; shmem::Union{Integer,Base.Callable}=0, max_threads::Integer=0
240+
) where {F,tt}
241+
return CUDA.launch_configuration(
242+
Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun;
243+
shmem,
244+
max_threads,
245+
)
240246
end
241247

242248
const GPUCompiler = CUDA.GPUCompiler
@@ -282,7 +288,12 @@ function compile(job)
282288
entry = GPUCompiler.JuliaContext() do ctx
283289
mod, meta = GPUCompiler.compile(
284290
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true
285-
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
291+
:llvm,
292+
job;
293+
optimize=false,
294+
cleanup=false,
295+
validate=false,
296+
libraries=false,
286297
# :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false
287298
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
288299
)
@@ -357,19 +368,21 @@ function link(job, compiled)
357368
end
358369

359370
function to_bytes(x)
360-
sz = sizeof(x)
361-
ref = Ref(x)
362-
GC.@preserve ref begin
363-
ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref))
364-
vec = Vector{UInt8}(undef, sz)
365-
for i in 1:sz
366-
@inbounds vec[i] = Base.unsafe_load(ptr, i)
367-
end
368-
vec
369-
end
370-
end
371-
372-
function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...)
371+
sz = sizeof(x)
372+
ref = Ref(x)
373+
GC.@preserve ref begin
374+
ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref))
375+
vec = Vector{UInt8}(undef, sz)
376+
for i in 1:sz
377+
@inbounds vec[i] = Base.unsafe_load(ptr, i)
378+
end
379+
vec
380+
end
381+
end
382+
383+
function Reactant.make_tracer(
384+
seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...
385+
)
373386
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray
374387
Reactant.make_tracer(seen, x, path, mode; kwargs...)
375388
return prev
@@ -388,7 +401,9 @@ function get_field_offset(T::Type, path)
388401
findfirst(==(field), fieldnames(current_type))
389402
end
390403
if field_idx === nothing
391-
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path")
404+
error(
405+
"Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path",
406+
)
392407
end
393408

394409
# Add the offset of this field
@@ -419,7 +434,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
419434
rarrays = TracedRArray[]
420435

421436
fname = func.entry
422-
437+
423438
wrapper_tys = MLIR.IR.Type[]
424439
ctx = MLIR.IR.context()
425440
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1))
@@ -436,19 +451,23 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
436451
end
437452
push!(wrapper_tys, cullvm_ty)
438453
end
439-
454+
440455
sym_name = String(gensym("call_$fname"))
441456
mod = MLIR.IR.mmodule()
442-
CConv=MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel))
457+
CConv = MLIR.IR.Attribute(
458+
MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel)
459+
)
443460
voidty = MLIR.IR.Type(MLIR.API.mlirLLVMVoidTypeGet(ctx))
444-
wrapftype = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGet(voidty, length(wrapper_tys), wrapper_tys, false))
461+
wrapftype = MLIR.IR.Type(
462+
MLIR.API.mlirLLVMFunctionTypeGet(voidty, length(wrapper_tys), wrapper_tys, false)
463+
)
445464
wrapfunc = MLIR.IR.block!(MLIR.IR.body(mod)) do
446465
return MLIR.Dialects.llvm.func(;
447466
sym_name,
448467
sym_visibility=MLIR.IR.Attribute("private"),
449468
function_type=wrapftype,
450469
body=MLIR.IR.Region(),
451-
CConv
470+
CConv,
452471
)
453472
end
454473
wrapbody = MLIR.IR.Block(wrapper_tys, [MLIR.IR.Location() for _ in wrapper_tys])
@@ -459,11 +478,17 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
459478

460479
symtab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod))
461480
gpufunc = MLIR.IR.lookup(symtab, fname)
462-
MLIR.IR.attr!(gpufunc, "CConv", MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC)))
463-
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))
481+
MLIR.IR.attr!(
482+
gpufunc,
483+
"CConv",
484+
MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC)),
485+
)
486+
gpu_function_type = MLIR.IR.Type(
487+
Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type")
488+
)
464489

465490
trueidx = 1
466-
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[]
491+
allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type},Nothing}[]
467492

468493
llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))
469494
i8 = MLIR.IR.Type(UInt8)
@@ -476,18 +501,34 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
476501

477502
# TODO check for only integer and explicitly non cutraced types
478503
MLIR.IR.block!(wrapbody) do
479-
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1))
504+
argty = MLIR.IR.Type(
505+
MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx - 1)
506+
)
480507
trueidx += 1
481-
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
482-
alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr), 1)
508+
c1 = MLIR.IR.result(
509+
MLIR.Dialects.llvm.mlir_constant(;
510+
res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)
511+
),
512+
1,
513+
)
514+
alloc = MLIR.IR.result(
515+
MLIR.Dialects.llvm.alloca(
516+
c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr
517+
),
518+
1,
519+
)
483520
push!(allocs, (alloc, argty))
484521

485522
sz = sizeof(a)
486523
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
487-
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
524+
cdata = MLIR.IR.result(
525+
MLIR.Dialects.llvm.mlir_constant(;
526+
res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))
527+
),
528+
1,
529+
)
488530
MLIR.Dialects.llvm.store(cdata, alloc)
489531
end
490-
491532
end
492533

493534
argidx = 1
@@ -499,21 +540,30 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
499540
if p[1] !== kernelargsym
500541
continue
501542
end
502-
543+
503544
arg = arg.mlir_data
504545
arg = Reactant.TracedUtils.transpose_val(arg)
505546
push!(restys, MLIR.IR.type(arg))
506547
push!(mlir_args, arg)
507-
548+
508549
# Get the allocation corresponding to which arg we're doing
509550
alloc = allocs[p[2]][1]
510551

511552
# we need to now compute the offset in bytes of the path
512553
julia_arg = allargs[p[2]]
513-
554+
514555
offset = get_field_offset(typeof(julia_arg), p[3:end])
515556
MLIR.IR.block!(wrapbody) do
516-
ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, MLIR.IR.Value[], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute([Int32(offset)])), 1)
557+
ptr = MLIR.IR.result(
558+
MLIR.Dialects.llvm.getelementptr(
559+
alloc,
560+
MLIR.IR.Value[];
561+
res=llvmptr,
562+
elem_type=i8,
563+
rawConstantIndices=MLIR.IR.Attribute([Int32(offset)]),
564+
),
565+
1,
566+
)
517567
MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr)
518568
end
519569

@@ -530,11 +580,11 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
530580
),
531581
),
532582
)
533-
583+
534584
argidx += 1
535585
end
536586
end
537-
587+
538588
MLIR.IR.block!(wrapbody) do
539589
for arg in allocs
540590
if arg === nothing
@@ -544,7 +594,12 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
544594
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
545595
push!(wrapargs, argres)
546596
end
547-
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
597+
MLIR.Dialects.llvm.call(
598+
wrapargs,
599+
MLIR.IR.Value[];
600+
callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)),
601+
op_bundle_sizes=MLIR.IR.Attribute(Int32[]),
602+
)
548603
MLIR.Dialects.llvm.return_(nothing)
549604
end
550605

@@ -565,7 +620,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
565620
mlir_args;
566621
result_0=restys,
567622
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
568-
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases)
623+
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases),
569624
)
570625

571626
argidx = 1
@@ -574,7 +629,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
574629
continue
575630
end
576631
arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx))
577-
argidx+=1
632+
argidx += 1
578633
end
579634
end
580635

src/Compiler.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,12 @@ function optimization_passes(; no_nan::Bool=false)
293293
)
294294
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
295295
return join(
296-
["inline{default-pipeline=canonicalize max-iterations=4}", "libdevice-funcs-raise", func_passes], ','
296+
[
297+
"inline{default-pipeline=canonicalize max-iterations=4}",
298+
"libdevice-funcs-raise",
299+
func_passes,
300+
],
301+
',',
297302
)
298303
end
299304

src/mlir/Dialects/Nvvm.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,15 @@ function barrier(
7878
attributes = NamedAttribute[]
7979
!isnothing(barrierId) && push!(operands, barrierId)
8080
!isnothing(numberOfThreads) && push!(operands, numberOfThreads)
81-
push!(
82-
attributes,
83-
operandsegmentsizes([
84-
if (barrierId == nothing)
85-
0
86-
elseif 1(numberOfThreads == nothing)
87-
0
88-
else
89-
1
90-
end
91-
]),
92-
)
81+
push!(attributes, operandsegmentsizes([
82+
if (barrierId == nothing)
83+
0
84+
elseif 1(numberOfThreads == nothing)
85+
0
86+
else
87+
1
88+
end,
89+
]))
9390

9491
return create_operation(
9592
"nvvm.barrier",

src/mlir/Dialects/TPU.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -902,18 +902,17 @@ function sem_signal(
902902
attributes = NamedAttribute[]
903903
!isnothing(device_id) && push!(operands, device_id)
904904
!isnothing(core_id) && push!(operands, core_id)
905-
push!(
906-
attributes,
907-
operandsegmentsizes([
908-
1, 1, if (device_id == nothing)
909-
0
910-
elseif 1(core_id == nothing)
911-
0
912-
else
913-
1
914-
end
915-
]),
916-
)
905+
push!(attributes, operandsegmentsizes([
906+
1,
907+
1,
908+
if (device_id == nothing)
909+
0
910+
elseif 1(core_id == nothing)
911+
0
912+
else
913+
1
914+
end,
915+
]))
917916
!isnothing(core_type) && push!(attributes, namedattribute("core_type", core_type))
918917

919918
return create_operation(

src/mlir/Dialects/Triton.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -482,18 +482,18 @@ function dot_scaled(
482482
]
483483
!isnothing(lhs_scale) && push!(operands, lhs_scale)
484484
!isnothing(rhs_scale) && push!(operands, rhs_scale)
485-
push!(
486-
attributes,
487-
operandsegmentsizes([
488-
1, 1, 1, if (lhs_scale == nothing)
489-
0
490-
elseif 1(rhs_scale == nothing)
491-
0
492-
else
493-
1
494-
end
495-
]),
496-
)
485+
push!(attributes, operandsegmentsizes([
486+
1,
487+
1,
488+
1,
489+
if (lhs_scale == nothing)
490+
0
491+
elseif 1(rhs_scale == nothing)
492+
0
493+
else
494+
1
495+
end,
496+
]))
497497

498498
return create_operation(
499499
"tt.dot_scaled",
@@ -949,16 +949,16 @@ function load(
949949
attributes = NamedAttribute[]
950950
!isnothing(mask) && push!(operands, mask)
951951
!isnothing(other) && push!(operands, other)
952-
push!(
953-
attributes,
954-
operandsegmentsizes([1, if (mask == nothing)
952+
push!(attributes, operandsegmentsizes([
953+
1,
954+
if (mask == nothing)
955955
0
956956
elseif 1(other == nothing)
957957
0
958958
else
959959
1
960-
end]),
961-
)
960+
end,
961+
]))
962962
!isnothing(result) && push!(op_ty_results, result)
963963
!isnothing(boundaryCheck) &&
964964
push!(attributes, namedattribute("boundaryCheck", boundaryCheck))

0 commit comments

Comments
 (0)