Skip to content

Commit 6ddc890

Browse files
authored
fix: generalize broadcast_in_dims for setindex (#518)
* fix: generalize broadcast_in_dims for setindex * test: writing with less dims
1 parent ca98c17 commit 6ddc890

File tree

4 files changed

+67
-44
lines changed

4 files changed

+67
-44
lines changed

src/Ops.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -936,24 +936,24 @@ end
936936
end
937937

938938
# broadcast ops
939-
# function broadcast_in_dim(
940-
# x::TracedRArray{T,N},
941-
# dims::Vector{Int};
942-
# location=mlir_stacktrace(
943-
# "broadcast_in_dim", @__FILE__, @__LINE__
944-
# ),
945-
# ) where {T,N}
946-
# rsize = restype = MLIR.IR.TensorType([...], mlir_type(T)) # mlir_type(TracedRArray{T,N}, size(x))
947-
# res = MLIR.IR.result(
948-
# stablehlo.broadcast_in_dim(
949-
# x.mlir_data;
950-
# result_0=restype,
951-
# broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims),
952-
# location,
953-
# ),
954-
# )
955-
# return TracedRArray{T,N}((), res, size(x))
956-
# end
939+
function broadcast_in_dim(
940+
x::TracedRArray{T,N},
941+
dims::Vector{Int},
942+
result_size::Vector{Int};
943+
location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__),
944+
) where {T,N}
945+
@assert length(dims) == N
946+
947+
res = MLIR.IR.result(
948+
stablehlo.broadcast_in_dim(
949+
x.mlir_data;
950+
result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)),
951+
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1),
952+
location,
953+
),
954+
)
955+
return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size))
956+
end
957957

958958
@noinline function sort(
959959
x::TracedRArray{T,N};

src/TracedRArray.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,21 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
218218
return v
219219
end
220220

221-
v = TracedUtils.broadcast_to_size(v, length.(indices))
222-
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
221+
if v isa Number
222+
v = TracedUtils.broadcast_to_size(v, length.(indices))
223+
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
224+
else
225+
v = TracedUtils.promote_to(TracedRArray{T,ndims(v)}, v)
226+
non_integer_indices = [!(idx isa Integer) for idx in indices]
227+
broadcast_dims = findall(non_integer_indices)
228+
if length(broadcast_dims) == N
229+
v = TracedUtils.broadcast_to_size(v, length.(indices))
230+
else
231+
v = Ops.broadcast_in_dim(
232+
materialize_traced_array(v), broadcast_dims, Int64.(length.(indices))
233+
)
234+
end
235+
end
223236

224237
indices = [
225238
(

src/TracedUtils.jl

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -496,30 +496,7 @@ function broadcast_to_size(arg::Broadcast.Extruded, rsize)
496496
end
497497

498498
@noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T}
499-
dims = collect(Int64, 0:(length(size(x)) - 1))
500-
501-
if length(size(MLIR.IR.type(get_mlir_data(x)))) != length(dims)
502-
@show x
503-
@show arg
504-
@show rsize
505-
@show rsize2
506-
@show dims
507-
end
508-
@assert length(size(MLIR.IR.type(get_mlir_data(x)))) == length(dims)
509-
mlirty = MLIR.IR.type(get_mlir_data(x))
510-
511-
return TracedRArray{T,Int(length(rsize))}(
512-
(),
513-
MLIR.IR.result(
514-
MLIR.Dialects.stablehlo.broadcast_in_dim(
515-
get_mlir_data(x);
516-
result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)),
517-
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims),
518-
),
519-
1,
520-
),
521-
collect(rsize),
522-
)
499+
return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize))
523500
end
524501

525502
end

test/basic.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,39 @@ end
422422
# get_view_compiled = @compile get_view(x_concrete)
423423
end
424424

425+
function write_with_broadcast1!(x, y)
426+
x[1, :, :] .= reshape(y, 4, 3)
427+
return x
428+
end
429+
function write_with_broadcast2!(x, y)
430+
x[:, 1, :] .= view(y, :, 1:3)
431+
return x
432+
end
433+
434+
@testset "write_with_broadcast" begin
435+
x_ra = Reactant.to_rarray(zeros(3, 4, 3))
436+
y_ra = Reactant.to_rarray(rand(3, 4))
437+
438+
res = @jit write_with_broadcast1!(x_ra, y_ra)
439+
440+
@test res.data === x_ra.data
441+
442+
res = Array(res)
443+
y = Array(y_ra)
444+
@test res[1, :, :] reshape(y, 4, 3)
445+
446+
x_ra = Reactant.to_rarray(zeros(3, 4, 3))
447+
y_ra = Reactant.to_rarray(rand(3, 4))
448+
449+
res = @jit write_with_broadcast2!(x_ra, y_ra)
450+
451+
@test res.data === x_ra.data
452+
453+
res = Array(res)
454+
y = Array(y_ra)
455+
@test res[:, 1, :] view(y, :, 1:3)
456+
end
457+
425458
function masking(x)
426459
y = similar(x)
427460
y[1:2, :] .= 0

0 commit comments

Comments
 (0)