Skip to content

Faster ChainRules implementation #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 29 additions & 36 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,52 +33,45 @@ function CRC.rrule(
tree::AbstractExpressionNode,
X::AbstractMatrix,
operators::OperatorEnum;
turbo=Val(false),
bumper=Val(false),
kws...,
)
primal, complete = eval_tree_array(tree, X, operators; turbo, bumper)
primal, complete = eval_tree_array(tree, X, operators; kws...)

if !complete
primal .= NaN
end

# TODO: Preferable to use the primal in the pullback somehow
function pullback((dY, _))
dtree = let X = X, dY = dY, tree = tree, operators = operators
@thunk(
let
_, gradient, complete2 = eval_grad_tree_array(
tree, X, operators; variable=Val(false)
)
if !complete2
gradient .= NaN
end
return (primal, complete), EvalPullback(tree, X, operators)
end

# Wrap in struct rather than closure to ensure variables are boxed
struct EvalPullback{N,A,O} <: Function
tree::N
X::A
operators::O
end

NodeTangent(
tree,
sum(j -> gradient[:, j] * dY[j], eachindex(dY, axes(gradient, 2))),
)
end
)
end
dX = let X = X, dY = dY, tree = tree, operators = operators
@thunk(
let
_, gradient2, complete3 = eval_grad_tree_array(
tree, X, operators; variable=Val(true)
)
if !complete3
gradient2 .= NaN
end
# TODO: Preferable to use the primal in the pullback somehow
function (e::EvalPullback)((dY, _))
_, dX_constants_dY, complete = eval_grad_tree_array(
e.tree, e.X, e.operators; variable=Val(:both)
)

gradient2 .* reshape(dY, 1, length(dY))
end
)
end
return (NoTangent(), dtree, dX, NoTangent())
if !complete
dX_constants_dY .= NaN
end

return (primal, complete), pullback
nfeatures = size(e.X, 1)
dX_dY = @view dX_constants_dY[1:nfeatures, :]
dconstants_dY = @view dX_constants_dY[(nfeatures + 1):end, :]

dtree = NodeTangent(
e.tree, sum(j -> dconstants_dY[:, j] * dY[j], eachindex(dY, axes(dconstants_dY, 2)))
)

dX = dX_dY .* reshape(dY, 1, length(dY))

return (NoTangent(), dtree, dX, NoTangent())
end

end
1 change: 1 addition & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ import .ParseModule: parse_leaf

@stable default_mode = "disable" begin
include("Interfaces.jl")
include("NonDifferentiableDeclarations.jl")
include("PatchMethods.jl")
end

Expand Down
95 changes: 53 additions & 42 deletions src/EvaluateDerivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,33 @@ function eval_grad_tree_array(
variable::Union{Bool,Val}=Val(false),
turbo::Union{Bool,Val}=Val(false),
) where {T<:Number}
n_gradients = if isa(variable, Val{true}) || (isa(variable, Bool) && variable)
variable_mode = isa(variable, Val{true}) || (isa(variable, Bool) && variable)
constant_mode = isa(variable, Val{false}) || (isa(variable, Bool) && !variable)
both_mode = isa(variable, Val{:both})

n_gradients = if variable_mode
size(cX, 1)::Int
else
elseif constant_mode
count_constants(tree)::Int
elseif both_mode
size(cX, 1) + count_constants(tree)
end
result = if isa(variable, Val{true}) || (variable isa Bool && variable)

result = if variable_mode
eval_grad_tree_array(tree, n_gradients, nothing, cX, operators, Val(true))
else
elseif constant_mode
index_tree = index_constants(tree)
eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(false))
end
eval_grad_tree_array(
tree, n_gradients, index_tree, cX, operators, Val(false)
)
elseif both_mode
# features come first because we can use size(cX, 1) to skip them
index_tree = index_constants(tree)
eval_grad_tree_array(
tree, n_gradients, index_tree, cX, operators, Val(:both)
)
end::ResultOk2

return (result.x, result.dx, result.ok)
end

Expand All @@ -226,11 +242,9 @@ function eval_grad_tree_array(
index_tree::Union{NodeIndex,Nothing},
cX::AbstractMatrix{T},
operators::OperatorEnum,
::Val{variable},
)::ResultOk2 where {T<:Number,variable}
result = _eval_grad_tree_array(
tree, n_gradients, index_tree, cX, operators, Val(variable)
)
::Val{mode},
)::ResultOk2 where {T<:Number,mode}
result = _eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(mode))
!result.ok && return result
return ResultOk2(
result.x, result.dx, !(is_bad_array(result.x) || is_bad_array(result.dx))
Expand Down Expand Up @@ -260,30 +274,18 @@ end
index_tree::Union{NodeIndex,Nothing},
cX::AbstractMatrix{T},
operators::OperatorEnum,
::Val{variable},
)::ResultOk2 where {T<:Number,variable}
::Val{mode},
)::ResultOk2 where {T<:Number,mode}
nuna = get_nuna(operators)
nbin = get_nbin(operators)
deg1_branch_skeleton = quote
grad_deg1_eval(
tree,
n_gradients,
index_tree,
cX,
operators.unaops[i],
operators,
Val(variable),
tree, n_gradients, index_tree, cX, operators.unaops[i], operators, Val(mode)
)
end
deg2_branch_skeleton = quote
grad_deg2_eval(
tree,
n_gradients,
index_tree,
cX,
operators.binops[i],
operators,
Val(variable),
tree, n_gradients, index_tree, cX, operators.binops[i], operators, Val(mode)
)
end
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
Expand All @@ -310,7 +312,7 @@ end
end
quote
if tree.degree == 0
grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(variable))
grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(mode))
elseif tree.degree == 1
$deg1_branch
else
Expand All @@ -324,8 +326,8 @@ function grad_deg0_eval(
n_gradients,
index_tree::Union{NodeIndex,Nothing},
cX::AbstractMatrix{T},
::Val{variable},
)::ResultOk2 where {T<:Number,variable}
::Val{mode},
)::ResultOk2 where {T<:Number,mode}
const_part = deg0_eval(tree, cX).x

zero_mat = if isa(cX, Array)
Expand All @@ -334,17 +336,26 @@ function grad_deg0_eval(
hcat([fill_similar(zero(T), cX, axes(cX, 2)) for _ in 1:n_gradients]...)'
end

if variable == tree.constant
if (mode isa Bool && mode == tree.constant)
# No gradients at this leaf node
return ResultOk2(const_part, zero_mat, true)
end

index = if variable
tree.feature
else
index = if (mode isa Bool && mode)
tree.feature::UInt16
elseif (mode isa Bool && !mode)
(index_tree === nothing ? zero(UInt16) : index_tree.val::UInt16)
elseif mode == :both
index_tree::NodeIndex
if tree.constant
index_tree.val::UInt16 + UInt16(size(cX, 1))
else
tree.feature::UInt16
end
end

derivative_part = zero_mat
derivative_part[index, :] .= one(T)
fill!(@view(derivative_part[index, :]), one(T))
return ResultOk2(const_part, derivative_part, true)
end

Expand All @@ -355,15 +366,15 @@ function grad_deg1_eval(
cX::AbstractMatrix{T},
op::F,
operators::OperatorEnum,
::Val{variable},
)::ResultOk2 where {T<:Number,F,variable}
::Val{mode},
)::ResultOk2 where {T<:Number,F,mode}
result = eval_grad_tree_array(
tree.l,
n_gradients,
index_tree === nothing ? index_tree : index_tree.l,
cX,
operators,
Val(variable),
Val(mode),
)
!result.ok && return result

Expand All @@ -389,15 +400,15 @@ function grad_deg2_eval(
cX::AbstractMatrix{T},
op::F,
operators::OperatorEnum,
::Val{variable},
)::ResultOk2 where {T<:Number,F,variable}
::Val{mode},
)::ResultOk2 where {T<:Number,F,mode}
result_l = eval_grad_tree_array(
tree.l,
n_gradients,
index_tree === nothing ? index_tree : index_tree.l,
cX,
operators,
Val(variable),
Val(mode),
)
!result_l.ok && return result_l
result_r = eval_grad_tree_array(
Expand All @@ -406,7 +417,7 @@ function grad_deg2_eval(
index_tree === nothing ? index_tree : index_tree.r,
cX,
operators,
Val(variable),
Val(mode),
)
!result_r.ok && return result_r

Expand Down
18 changes: 7 additions & 11 deletions src/Expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
module ExpressionModule

using DispatchDoctor: @unstable

using ..NodeModule: AbstractExpressionNode, Node
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
using ..UtilsModule: Undefined
Expand All @@ -17,7 +18,12 @@ import ..NodeUtilsModule:
has_constants,
get_constants,
set_constants!
import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array
import ..EvaluateDerivativeModule: eval_grad_tree_array
import ..EvaluationHelpersModule: _grad_evaluator
import ..StringsModule: string_tree, print_tree
import ..ChainRulesModule: extract_gradient
import ..SimplifyModule: combine_operators, simplify_tree!

"""A wrapper for a named tuple to avoid piracy."""
struct Metadata{NT<:NamedTuple}
Expand Down Expand Up @@ -65,7 +71,7 @@ expression tree (like `Node`) along with associated metadata for evaluation and

- `tree::N`: The root node of the raw expression tree.
- `metadata::Metadata{D}`: A named tuple of settings for the expression,
such as the operators and variable names.
such as the operators and variable names.

# Constructors

Expand Down Expand Up @@ -278,8 +284,6 @@ function extract_gradient(
return extract_gradient(gradient.tree, get_tree(ex))
end

import ..StringsModule: string_tree, print_tree

function string_tree(
ex::AbstractExpression,
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
Expand Down Expand Up @@ -309,8 +313,6 @@ function Base.show(io::IO, ::MIME"text/plain", ex::AbstractExpression)
return print(io, string_tree(ex))
end

import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array

function max_feature(ex::AbstractExpression)
return tree_mapreduce(
leaf -> leaf.constant ? zero(UInt16) : leaf.feature,
Expand Down Expand Up @@ -342,8 +344,6 @@ function eval_tree_array(
return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
end

import ..EvaluateDerivativeModule: eval_grad_tree_array

# skipped (not used much)
# - eval_diff_tree_array
# - differentiable_eval_tree_array
Expand All @@ -358,8 +358,6 @@ function eval_grad_tree_array(
return eval_grad_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
end

import ..EvaluationHelpersModule: _grad_evaluator

function Base.adjoint(ex::AbstractExpression)
return ((args...; kws...) -> _grad_evaluator(ex, args...; kws...))
end
Expand All @@ -380,6 +378,4 @@ function (ex::AbstractExpression)(
return get_tree(ex)(X, get_operators(ex, operators); kws...)
end

import ..SimplifyModule: combine_operators, simplify_tree!

end
18 changes: 18 additions & 0 deletions src/NonDifferentiableDeclarations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module NonDifferentiableDeclarationsModule

using ChainRulesCore: @non_differentiable
import ..OperatorEnumModule: AbstractOperatorEnum
import ..NodeModule: AbstractExpressionNode, AbstractNode
import ..NodeUtilsModule: tree_mapreduce
import ..ExpressionModule:
AbstractExpression, get_operators, get_variable_names, _validate_input

#! format: off
@non_differentiable tree_mapreduce(f::Function, op::Function, tree::AbstractNode, result_type::Type)
@non_differentiable tree_mapreduce(f::Function, f_branch::Function, op::Function, tree::AbstractNode, result_type::Type)
@non_differentiable get_operators(ex::Union{AbstractExpression,AbstractExpressionNode}, operators::Union{AbstractOperatorEnum,Nothing})
@non_differentiable get_variable_names(ex::AbstractExpression, variable_names::Union{AbstractVector{<:AbstractString},Nothing})
@non_differentiable _validate_input(ex::AbstractExpression, X, operators::Union{AbstractOperatorEnum,Nothing})
#! format: on

end
6 changes: 4 additions & 2 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,10 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
)
end

empty_old_operators_idx = findfirst(x -> first(x.args) == :empty_old_operators, kws)
internal_idx = findfirst(x -> first(x.args) == :internal, kws)
empty_old_operators_idx = findfirst(
x -> hasproperty(x, :args) && first(x.args) == :empty_old_operators, kws
)
internal_idx = findfirst(x -> hasproperty(x, :args) && first(x.args) == :internal, kws)

empty_old_operators = if empty_old_operators_idx !== nothing
@assert kws[empty_old_operators_idx].head == :(=)
Expand Down
Loading
Loading