Skip to content

Commit f144045

Browse files
authored
Fix Zygote issues with FunctionTransform (#152)
* Fix Zygote issues for functional transform * Resolve tests * Clean up * Avoid splatting, use reduce instead * Remove rrules for reduce(hcat)
1 parent 0df9e83 commit f144045

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/transform/functiontransform.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,21 @@ end
1616
(t::FunctionTransform)(x) = t.f(x)
1717

1818
_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
19-
_map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
20-
_map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
19+
20+
21+
function _map(t::FunctionTransform, x::ColVecs)
22+
vals = map(axes(x.X, 2)) do i
23+
t.f(view(x.X, :, i))
24+
end
25+
return ColVecs(reduce(hcat, vals))
26+
end
27+
28+
function _map(t::FunctionTransform, x::RowVecs)
29+
vals = map(axes(x.X, 1)) do i
30+
t.f(view(x.X, i, :))
31+
end
32+
return RowVecs(reduce(hcat, vals)')
33+
end
2134

2235
duplicate(t::FunctionTransform,f) = FunctionTransform(f)
2336

test/transform/functiontransform.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,5 @@
2828

2929
@test repr(FunctionTransform(sin)) == "Function Transform: $(sin)"
3030
f(a, x) = sin.(a .* x)
31-
test_ADs(x->transform(SEKernel(), FunctionTransform(y->f(x, y))), randn(rng, 3), ADs = [:ForwardDiff, :ReverseDiff])
32-
@test_broken "Zygote is failing"
31+
test_ADs(x->transform(SEKernel(), FunctionTransform(y->f(x, y))), randn(rng, 3))
3332
end

0 commit comments

Comments
 (0)