Skip to content

Commit 4b9434e

Browse files
authored
Split should_rewrite_ft for call and invoke expressions, and overlay Base._unique_dims (#505)
* Split `should_rewrite_ft` for `call` and `invoke` expressions * fix dispatch on `unique` * fix syntax problem * test * fix typo * Replace `should_rewrite_invoke` of `unique` for overlayed method on `_unique_dims` * remove previous solution
1 parent 2d83f13 commit 4b9434e

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

src/Overlay.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,12 @@ end
142142
end
143143
end
144144
end
145+
146+
## fixes #493
147+
@reactant_overlay @noinline function Base._unique_dims(A::AbstractArray, dims::Colon)
148+
if use_overlayed_version(A)
149+
error("Reactant doesn't have a `Base._unique_dims` with the current interpreter.")
150+
else
151+
Base.inferencebarrier(Base._unique_dims)(A, dims)
152+
end
153+
end

src/utils.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ function has_ancestor(query::Module, target::Module)
8989
end
9090
end
9191

92-
function should_rewrite_ft(@nospecialize(ft))
92+
function should_rewrite_call(@nospecialize(ft))
9393
# Don't rewrite builtin or intrinsics
9494
if ft <: Core.IntrinsicFunction || ft <: Core.Builtin
9595
return false
@@ -178,6 +178,9 @@ function should_rewrite_ft(@nospecialize(ft))
178178
return true
179179
end
180180

181+
# by default, same as `should_rewrite_call`
182+
should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) = should_rewrite_call(ft)
183+
181184
# Avoid recursively interpreting into methods we define explicitly
182185
# as overloads, which we assume should handle the entirety of the
183186
# translation (and if not they can use call_in_reactant).
@@ -242,7 +245,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
242245
end
243246
if ft == typeof(Core._apply_iterate)
244247
ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir))
245-
if Base.invokelatest(should_rewrite_ft, ft)
248+
if Base.invokelatest(should_rewrite_call, ft)
246249
if RT === Union{}
247250
rep = Expr(
248251
:call,
@@ -256,7 +259,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
256259
return true, rep, Any
257260
end
258261
end
259-
elseif Base.invokelatest(should_rewrite_ft, ft)
262+
elseif Base.invokelatest(should_rewrite_call, ft)
260263
if RT === Union{}
261264
rep = Expr(:call, call_with_reactant, MustThrowError(), inst.args...)
262265
return true, rep, Union{}
@@ -270,10 +273,13 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
270273
omi = inst.args[1]::Core.MethodInstance
271274
sig = omi.specTypes
272275
ft = sig.parameters[1]
276+
argsig = sig.parameters[2:end]
273277
if ft == typeof(Core.kwcall)
274278
ft = sig.parameters[3]
279+
argsig = sig.parameters[4:end]
275280
end
276-
if Base.invokelatest(should_rewrite_ft, ft) && !is_reactant_method(omi)
281+
argsig = Core.apply_type(Core.Tuple, argsig...)
282+
if Base.invokelatest(should_rewrite_invoke, ft, argsig) && !is_reactant_method(omi)
277283
method = omi.def::Core.Method
278284

279285
min_world = Ref{UInt}(typemin(UInt))

test/compile.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ end
128128
@test !occursin("add", repr(hlo))
129129
end
130130

131-
# While a bit specific, the following is used to check for a bug in `should_rewrite_ft`
131+
# While a bit specific, the following is used to check for a bug in `should_rewrite_call`
132132
function sinusoidal_embedding(
133133
x::AbstractArray{T,4}, min_freq, max_freq, embedding_dims::Int
134134
) where {T}
@@ -146,3 +146,9 @@ end
146146
x_ra = Reactant.to_rarray(rand(Float32, 1, 1, 1, 4))
147147
hlo = @code_hlo sinusoidal_embedding(x_ra, 0.1, 10.0, 4)
148148
end
149+
150+
# test #493
151+
@testset "unique(::Vector{Symbol}) (#493)" begin
152+
x = [:a, :b, :a]
153+
@test @jit(unique(x)) == [:a, :b]
154+
end

0 commit comments

Comments
 (0)