Skip to content

Commit 7c2e390

Browse files
jumerckxgithub-actions[bot]wsmoses
authored
linearize aliased kernel args (#504)
* Add NoStopTracedTrack mode and use to handle aliased inputs * aliasing test * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix test * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: William S. Moses <[email protected]>
1 parent 6ddc890 commit 7c2e390

File tree

3 files changed

+72
-13
lines changed

3 files changed

+72
-13
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
443443
seen = Reactant.OrderedIdDict()
444444
prev = Any[func.f, args...]
445445
kernelargsym = gensym("kernelarg")
446-
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack)
446+
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.NoStopTracedTrack)
447447
wrapper_tys = MLIR.IR.Type[]
448448
for arg in values(seen)
449449
if !(arg isa TracedRArray || arg isa TracedRNumber)
@@ -536,16 +536,18 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
536536
if !(arg isa TracedRArray || arg isa TracedRNumber)
537537
continue
538538
end
539-
for p in Reactant.TracedUtils.get_paths(arg)
539+
540+
paths = Reactant.TracedUtils.get_paths(arg)
541+
542+
arg = arg.mlir_data
543+
arg = Reactant.TracedUtils.transpose_val(arg)
544+
push!(restys, MLIR.IR.type(arg))
545+
push!(mlir_args, arg)
546+
547+
for p in paths
540548
if p[1] !== kernelargsym
541549
continue
542550
end
543-
544-
arg = arg.mlir_data
545-
arg = Reactant.TracedUtils.transpose_val(arg)
546-
push!(restys, MLIR.IR.type(arg))
547-
push!(mlir_args, arg)
548-
549551
# Get the allocation corresponding to which arg we're doing
550552
alloc = allocs[p[2]][1]
551553

@@ -580,9 +582,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
580582
),
581583
),
582584
)
583-
584-
argidx += 1
585585
end
586+
argidx += 1
586587
end
587588

588589
MLIR.IR.block!(wrapbody) do

src/Tracing.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
TracedToConcrete = 3
55
ArrayToConcrete = 4
66
TracedSetPath = 5
7+
NoStopTracedTrack = 6
78
end
89

910
for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber)
@@ -249,7 +250,7 @@ function traced_type(
249250
@inline base_typec(TV::TT) where {TT<:DataType} =
250251
(T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
251252
return base_typec(T)
252-
elseif mode == TracedTrack || mode == TracedSetPath
253+
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
253254
return T
254255
else
255256
throw("Abstract RArray $T cannot be made concrete in mode $mode")
@@ -261,7 +262,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}, track_numbers) where {T<:Trac
261262
throw("TracedRNG cannot be traced")
262263
elseif mode == TracedToConcrete
263264
return ConcreteRNG
264-
elseif mode == TracedTrack || mode == TracedSetPath
265+
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
265266
return T
266267
else
267268
throw("Unsupported mode: $mode")
@@ -329,7 +330,7 @@ function make_tracer(
329330
track_numbers=(),
330331
kwargs...,
331332
) where {RT}
332-
if haskey(seen, prev)
333+
if mode != NoStopTracedTrack && haskey(seen, prev)
333334
return seen[prev]
334335
end
335336
TT = traced_type(RT, (), Val(mode), track_numbers)
@@ -460,6 +461,13 @@ function make_tracer(
460461
end
461462
return prev
462463
end
464+
if mode == NoStopTracedTrack
465+
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
466+
if !haskey(seen, prev)
467+
seen[prev] = prev # don't return!
468+
end
469+
return prev
470+
end
463471
if mode == TracedSetPath
464472
if haskey(seen, prev)
465473
return seen[prev]
@@ -506,6 +514,13 @@ function make_tracer(
506514
end
507515
return prev
508516
end
517+
if mode == NoStopTracedTrack
518+
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
519+
if !haskey(seen, prev)
520+
seen[prev] = prev # don't return!
521+
end
522+
return prev
523+
end
509524
if mode == TracedSetPath
510525
if haskey(seen, prev)
511526
return seen[prev]
@@ -546,6 +561,13 @@ function make_tracer(
546561
end
547562
return prev
548563
end
564+
if mode == NoStopTracedTrack
565+
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
566+
if !haskey(seen, prev)
567+
seen[prev] = prev # don't return!
568+
end
569+
return prev
570+
end
549571
if mode == TracedSetPath
550572
haskey(seen, prev) && return seen[prev]
551573
res = MissingTracedValue((path,))

test/integration/cuda.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,40 @@ tuplef2(a) = @cuda threads = 1 tuplef2!((5, a))
115115
@code_hlo optimize = :before_kernel tuplef2(A)
116116
end
117117
end
118+
A = ConcreteRArray(fill(1))
119+
if CUDA.functional()
120+
@jit tuplef2(A)
121+
@test all(Array(A) .≈ 5)
122+
else
123+
@code_hlo optimize = :before_kernel tuplef2(A)
124+
end
125+
end
126+
127+
# TODO this same code fails if we use a 0-d array...?
128+
# maybe weird cuda things
129+
function aliased!(tup)
130+
x, y = tup
131+
x[2][1] *= y[2][1]
132+
return nothing
133+
end
134+
135+
function aliased(s)
136+
tup = (s, s)
137+
@cuda threads = 1 aliased!(tup)
138+
return nothing
139+
end
140+
141+
@static if !Sys.isapple()
142+
@testset "Aliasing arguments" begin
143+
a = ConcreteRArray([3])
144+
145+
s = (10, a)
146+
147+
if CUDA.functional()
148+
@jit aliased((s, s))
149+
@test all(Array(a) == 9)
150+
else
151+
@code_hlo optimize = :before_kernel aliased(s)
152+
end
153+
end
118154
end

0 commit comments

Comments
 (0)