Skip to content

Commit b436b48

Browse files
Generalize precompilation support (#534)
* Generalize precompilation support * CUDA precompilation * Update ReactantCUDAExt.jl * Update ReactantCUDAExt.jl * fixup * fix * Update ReactantCUDAExt.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 4b9434e commit b436b48

File tree

2 files changed

+55
-16
lines changed

2 files changed

+55
-16
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,9 @@ function compile(job)
328328
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
329329
)
330330

331-
GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job))
331+
if !Reactant.precompiling()
332+
GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job))
333+
end
332334
entryname = LLVM.name(meta.entry)
333335

334336
GPUCompiler.optimize_module!(job, mod)
@@ -788,4 +790,31 @@ function __init__()
788790
return nothing
789791
end
790792

793+
@static if !Sys.isapple() && Sys.ARCH != :aarch64
794+
Reactant.PrecompileTools.@setup_workload begin
795+
Reactant.initialize_dialect()
796+
client = Reactant.XLA.CPUClient(; checkcount=false)
797+
Reactant.PrecompileTools.@compile_workload begin
798+
@static if Reactant.precompilation_supported()
799+
function square_kernel!(x)
800+
i = CUDA.threadIdx().x
801+
x[i] *= x[i]
802+
return nothing
803+
end
804+
805+
function square!(x)
806+
CUDA.@cuda blocks = 1 threads = length(x) square_kernel!(x)
807+
return nothing
808+
end
809+
y = Reactant.ConcreteRArray([2.0]; client)
810+
Reactant.Compiler.compile_mlir(square!, (y,); optimize=false)
811+
end
812+
end
813+
Reactant.XLA.free_client(client)
814+
client.client = C_NULL
815+
Reactant.deinitialize_dialect()
816+
Reactant.clear_oc_cache()
817+
end
818+
end
819+
791820
end # module ReactantCUDAExt

src/Precompile.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using PrecompileTools
12
using PrecompileTools: @setup_workload, @compile_workload
23

34
function infer_sig(sig)
@@ -34,15 +35,33 @@ function infer_sig(sig)
3435
end
3536
end
3637

38+
function clear_oc_cache()
39+
# Opaque closures capture the worldage of their compilation and thus are not relocatable
40+
# Therefore we explicitly purge all OC's we have created here
41+
for v in oc_capture_vec
42+
if v isa Base.RefValue
43+
p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v))
44+
Base.atomic_pointerset(p, C_NULL, :monotonic)
45+
else
46+
empty!(v)
47+
end
48+
end
49+
end
50+
51+
# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947
52+
function precompilation_supported()
53+
return VERSION >= v"1.11" || VERSION >= v"1.10.8"
54+
end
55+
56+
function precompiling()
57+
return (@ccall jl_generating_output()::Cint) == 1
58+
end
59+
3760
@setup_workload begin
3861
initialize_dialect()
3962
client = XLA.CPUClient(; checkcount=false)
4063
@compile_workload begin
41-
# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947
42-
@static if VERSION < v"1.11"
43-
else
44-
# infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}})
45-
# infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}})
64+
@static if precompilation_supported()
4665
x = ConcreteRNumber(2.0; client)
4766
Reactant.compile(sin, (x,); client)
4867

@@ -53,14 +72,5 @@ end
5372
XLA.free_client(client)
5473
client.client = C_NULL
5574
deinitialize_dialect()
56-
# Opaque closures capture the worldage of their compilation and thus are not relocatable
57-
# Therefore we explicitly purge all OC's we have created here
58-
for v in oc_capture_vec
59-
if v isa Base.RefValue
60-
p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v))
61-
Base.atomic_pointerset(p, C_NULL, :monotonic)
62-
else
63-
empty!(v)
64-
end
65-
end
75+
clear_oc_cache()
6676
end

0 commit comments

Comments
 (0)