Skip to content

Commit 246c4fe

Browse files
authored
Merge pull request #90 from SymbolicML/faster-gradients
Faster ChainRules implementation
2 parents d7d8802 + c9eaedf commit 246c4fe

File tree

7 files changed

+132
-92
lines changed

7 files changed

+132
-92
lines changed

src/ChainRules.jl

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,52 +33,45 @@ function CRC.rrule(
3333
tree::AbstractExpressionNode,
3434
X::AbstractMatrix,
3535
operators::OperatorEnum;
36-
turbo=Val(false),
37-
bumper=Val(false),
36+
kws...,
3837
)
39-
primal, complete = eval_tree_array(tree, X, operators; turbo, bumper)
38+
primal, complete = eval_tree_array(tree, X, operators; kws...)
4039

4140
if !complete
4241
primal .= NaN
4342
end
4443

45-
# TODO: Preferable to use the primal in the pullback somehow
46-
function pullback((dY, _))
47-
dtree = let X = X, dY = dY, tree = tree, operators = operators
48-
@thunk(
49-
let
50-
_, gradient, complete2 = eval_grad_tree_array(
51-
tree, X, operators; variable=Val(false)
52-
)
53-
if !complete2
54-
gradient .= NaN
55-
end
44+
return (primal, complete), EvalPullback(tree, X, operators)
45+
end
46+
47+
# Wrap in struct rather than closure to ensure variables are boxed
48+
struct EvalPullback{N,A,O} <: Function
49+
tree::N
50+
X::A
51+
operators::O
52+
end
5653

57-
NodeTangent(
58-
tree,
59-
sum(j -> gradient[:, j] * dY[j], eachindex(dY, axes(gradient, 2))),
60-
)
61-
end
62-
)
63-
end
64-
dX = let X = X, dY = dY, tree = tree, operators = operators
65-
@thunk(
66-
let
67-
_, gradient2, complete3 = eval_grad_tree_array(
68-
tree, X, operators; variable=Val(true)
69-
)
70-
if !complete3
71-
gradient2 .= NaN
72-
end
54+
# TODO: Preferable to use the primal in the pullback somehow
55+
function (e::EvalPullback)((dY, _))
56+
_, dX_constants_dY, complete = eval_grad_tree_array(
57+
e.tree, e.X, e.operators; variable=Val(:both)
58+
)
7359

74-
gradient2 .* reshape(dY, 1, length(dY))
75-
end
76-
)
77-
end
78-
return (NoTangent(), dtree, dX, NoTangent())
60+
if !complete
61+
dX_constants_dY .= NaN
7962
end
8063

81-
return (primal, complete), pullback
64+
nfeatures = size(e.X, 1)
65+
dX_dY = @view dX_constants_dY[1:nfeatures, :]
66+
dconstants_dY = @view dX_constants_dY[(nfeatures + 1):end, :]
67+
68+
dtree = NodeTangent(
69+
e.tree, sum(j -> dconstants_dY[:, j] * dY[j], eachindex(dY, axes(dconstants_dY, 2)))
70+
)
71+
72+
dX = dX_dY .* reshape(dY, 1, length(dY))
73+
74+
return (NoTangent(), dtree, dX, NoTangent())
8275
end
8376

8477
end

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ import .ParseModule: parse_leaf
7676

7777
@stable default_mode = "disable" begin
7878
include("Interfaces.jl")
79+
include("NonDifferentiableDeclarations.jl")
7980
include("PatchMethods.jl")
8081
end
8182

src/EvaluateDerivative.jl

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,33 @@ function eval_grad_tree_array(
206206
variable::Union{Bool,Val}=Val(false),
207207
turbo::Union{Bool,Val}=Val(false),
208208
) where {T<:Number}
209-
n_gradients = if isa(variable, Val{true}) || (isa(variable, Bool) && variable)
209+
variable_mode = isa(variable, Val{true}) || (isa(variable, Bool) && variable)
210+
constant_mode = isa(variable, Val{false}) || (isa(variable, Bool) && !variable)
211+
both_mode = isa(variable, Val{:both})
212+
213+
n_gradients = if variable_mode
210214
size(cX, 1)::Int
211-
else
215+
elseif constant_mode
212216
count_constants(tree)::Int
217+
elseif both_mode
218+
size(cX, 1) + count_constants(tree)
213219
end
214-
result = if isa(variable, Val{true}) || (variable isa Bool && variable)
220+
221+
result = if variable_mode
215222
eval_grad_tree_array(tree, n_gradients, nothing, cX, operators, Val(true))
216-
else
223+
elseif constant_mode
217224
index_tree = index_constants(tree)
218-
eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(false))
219-
end
225+
eval_grad_tree_array(
226+
tree, n_gradients, index_tree, cX, operators, Val(false)
227+
)
228+
elseif both_mode
229+
# features come first because we can use size(cX, 1) to skip them
230+
index_tree = index_constants(tree)
231+
eval_grad_tree_array(
232+
tree, n_gradients, index_tree, cX, operators, Val(:both)
233+
)
234+
end::ResultOk2
235+
220236
return (result.x, result.dx, result.ok)
221237
end
222238

@@ -226,11 +242,9 @@ function eval_grad_tree_array(
226242
index_tree::Union{NodeIndex,Nothing},
227243
cX::AbstractMatrix{T},
228244
operators::OperatorEnum,
229-
::Val{variable},
230-
)::ResultOk2 where {T<:Number,variable}
231-
result = _eval_grad_tree_array(
232-
tree, n_gradients, index_tree, cX, operators, Val(variable)
233-
)
245+
::Val{mode},
246+
)::ResultOk2 where {T<:Number,mode}
247+
result = _eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(mode))
234248
!result.ok && return result
235249
return ResultOk2(
236250
result.x, result.dx, !(is_bad_array(result.x) || is_bad_array(result.dx))
@@ -260,30 +274,18 @@ end
260274
index_tree::Union{NodeIndex,Nothing},
261275
cX::AbstractMatrix{T},
262276
operators::OperatorEnum,
263-
::Val{variable},
264-
)::ResultOk2 where {T<:Number,variable}
277+
::Val{mode},
278+
)::ResultOk2 where {T<:Number,mode}
265279
nuna = get_nuna(operators)
266280
nbin = get_nbin(operators)
267281
deg1_branch_skeleton = quote
268282
grad_deg1_eval(
269-
tree,
270-
n_gradients,
271-
index_tree,
272-
cX,
273-
operators.unaops[i],
274-
operators,
275-
Val(variable),
283+
tree, n_gradients, index_tree, cX, operators.unaops[i], operators, Val(mode)
276284
)
277285
end
278286
deg2_branch_skeleton = quote
279287
grad_deg2_eval(
280-
tree,
281-
n_gradients,
282-
index_tree,
283-
cX,
284-
operators.binops[i],
285-
operators,
286-
Val(variable),
288+
tree, n_gradients, index_tree, cX, operators.binops[i], operators, Val(mode)
287289
)
288290
end
289291
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
@@ -310,7 +312,7 @@ end
310312
end
311313
quote
312314
if tree.degree == 0
313-
grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(variable))
315+
grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(mode))
314316
elseif tree.degree == 1
315317
$deg1_branch
316318
else
@@ -324,8 +326,8 @@ function grad_deg0_eval(
324326
n_gradients,
325327
index_tree::Union{NodeIndex,Nothing},
326328
cX::AbstractMatrix{T},
327-
::Val{variable},
328-
)::ResultOk2 where {T<:Number,variable}
329+
::Val{mode},
330+
)::ResultOk2 where {T<:Number,mode}
329331
const_part = deg0_eval(tree, cX).x
330332

331333
zero_mat = if isa(cX, Array)
@@ -334,17 +336,26 @@ function grad_deg0_eval(
334336
hcat([fill_similar(zero(T), cX, axes(cX, 2)) for _ in 1:n_gradients]...)'
335337
end
336338

337-
if variable == tree.constant
339+
if (mode isa Bool && mode == tree.constant)
340+
# No gradients at this leaf node
338341
return ResultOk2(const_part, zero_mat, true)
339342
end
340343

341-
index = if variable
342-
tree.feature
343-
else
344+
index = if (mode isa Bool && mode)
345+
tree.feature::UInt16
346+
elseif (mode isa Bool && !mode)
344347
(index_tree === nothing ? zero(UInt16) : index_tree.val::UInt16)
348+
elseif mode == :both
349+
index_tree::NodeIndex
350+
if tree.constant
351+
index_tree.val::UInt16 + UInt16(size(cX, 1))
352+
else
353+
tree.feature::UInt16
354+
end
345355
end
356+
346357
derivative_part = zero_mat
347-
derivative_part[index, :] .= one(T)
358+
fill!(@view(derivative_part[index, :]), one(T))
348359
return ResultOk2(const_part, derivative_part, true)
349360
end
350361

@@ -355,15 +366,15 @@ function grad_deg1_eval(
355366
cX::AbstractMatrix{T},
356367
op::F,
357368
operators::OperatorEnum,
358-
::Val{variable},
359-
)::ResultOk2 where {T<:Number,F,variable}
369+
::Val{mode},
370+
)::ResultOk2 where {T<:Number,F,mode}
360371
result = eval_grad_tree_array(
361372
tree.l,
362373
n_gradients,
363374
index_tree === nothing ? index_tree : index_tree.l,
364375
cX,
365376
operators,
366-
Val(variable),
377+
Val(mode),
367378
)
368379
!result.ok && return result
369380

@@ -389,15 +400,15 @@ function grad_deg2_eval(
389400
cX::AbstractMatrix{T},
390401
op::F,
391402
operators::OperatorEnum,
392-
::Val{variable},
393-
)::ResultOk2 where {T<:Number,F,variable}
403+
::Val{mode},
404+
)::ResultOk2 where {T<:Number,F,mode}
394405
result_l = eval_grad_tree_array(
395406
tree.l,
396407
n_gradients,
397408
index_tree === nothing ? index_tree : index_tree.l,
398409
cX,
399410
operators,
400-
Val(variable),
411+
Val(mode),
401412
)
402413
!result_l.ok && return result_l
403414
result_r = eval_grad_tree_array(
@@ -406,7 +417,7 @@ function grad_deg2_eval(
406417
index_tree === nothing ? index_tree : index_tree.r,
407418
cX,
408419
operators,
409-
Val(variable),
420+
Val(mode),
410421
)
411422
!result_r.ok && return result_r
412423

src/Expression.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
module ExpressionModule
33

44
using DispatchDoctor: @unstable
5+
56
using ..NodeModule: AbstractExpressionNode, Node
67
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
78
using ..UtilsModule: Undefined
@@ -17,7 +18,12 @@ import ..NodeUtilsModule:
1718
has_constants,
1819
get_constants,
1920
set_constants!
21+
import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array
22+
import ..EvaluateDerivativeModule: eval_grad_tree_array
23+
import ..EvaluationHelpersModule: _grad_evaluator
24+
import ..StringsModule: string_tree, print_tree
2025
import ..ChainRulesModule: extract_gradient
26+
import ..SimplifyModule: combine_operators, simplify_tree!
2127

2228
"""A wrapper for a named tuple to avoid piracy."""
2329
struct Metadata{NT<:NamedTuple}
@@ -65,7 +71,7 @@ expression tree (like `Node`) along with associated metadata for evaluation and
6571
6672
- `tree::N`: The root node of the raw expression tree.
6773
- `metadata::Metadata{D}`: A named tuple of settings for the expression,
68-
such as the operators and variable names.
74+
such as the operators and variable names.
6975
7076
# Constructors
7177
@@ -278,8 +284,6 @@ function extract_gradient(
278284
return extract_gradient(gradient.tree, get_tree(ex))
279285
end
280286

281-
import ..StringsModule: string_tree, print_tree
282-
283287
function string_tree(
284288
ex::AbstractExpression,
285289
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
@@ -309,8 +313,6 @@ function Base.show(io::IO, ::MIME"text/plain", ex::AbstractExpression)
309313
return print(io, string_tree(ex))
310314
end
311315

312-
import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array
313-
314316
function max_feature(ex::AbstractExpression)
315317
return tree_mapreduce(
316318
leaf -> leaf.constant ? zero(UInt16) : leaf.feature,
@@ -342,8 +344,6 @@ function eval_tree_array(
342344
return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
343345
end
344346

345-
import ..EvaluateDerivativeModule: eval_grad_tree_array
346-
347347
# skipped (not used much)
348348
# - eval_diff_tree_array
349349
# - differentiable_eval_tree_array
@@ -358,8 +358,6 @@ function eval_grad_tree_array(
358358
return eval_grad_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
359359
end
360360

361-
import ..EvaluationHelpersModule: _grad_evaluator
362-
363361
function Base.adjoint(ex::AbstractExpression)
364362
return ((args...; kws...) -> _grad_evaluator(ex, args...; kws...))
365363
end
@@ -380,6 +378,4 @@ function (ex::AbstractExpression)(
380378
return get_tree(ex)(X, get_operators(ex, operators); kws...)
381379
end
382380

383-
import ..SimplifyModule: combine_operators, simplify_tree!
384-
385381
end

src/NonDifferentiableDeclarations.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module NonDifferentiableDeclarationsModule
2+
3+
using ChainRulesCore: @non_differentiable
4+
import ..OperatorEnumModule: AbstractOperatorEnum
5+
import ..NodeModule: AbstractExpressionNode, AbstractNode
6+
import ..NodeUtilsModule: tree_mapreduce
7+
import ..ExpressionModule:
8+
AbstractExpression, get_operators, get_variable_names, _validate_input
9+
10+
#! format: off
11+
@non_differentiable tree_mapreduce(f::Function, op::Function, tree::AbstractNode, result_type::Type)
12+
@non_differentiable tree_mapreduce(f::Function, f_branch::Function, op::Function, tree::AbstractNode, result_type::Type)
13+
@non_differentiable get_operators(ex::Union{AbstractExpression,AbstractExpressionNode}, operators::Union{AbstractOperatorEnum,Nothing})
14+
@non_differentiable get_variable_names(ex::AbstractExpression, variable_names::Union{AbstractVector{<:AbstractString},Nothing})
15+
@non_differentiable _validate_input(ex::AbstractExpression, X, operators::Union{AbstractOperatorEnum,Nothing})
16+
#! format: on
17+
18+
end

src/OperatorEnumConstruction.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,10 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
228228
)
229229
end
230230

231-
empty_old_operators_idx = findfirst(x -> first(x.args) == :empty_old_operators, kws)
232-
internal_idx = findfirst(x -> first(x.args) == :internal, kws)
231+
empty_old_operators_idx = findfirst(
232+
x -> hasproperty(x, :args) && first(x.args) == :empty_old_operators, kws
233+
)
234+
internal_idx = findfirst(x -> hasproperty(x, :args) && first(x.args) == :internal, kws)
233235

234236
empty_old_operators = if empty_old_operators_idx !== nothing
235237
@assert kws[empty_old_operators_idx].head == :(=)

0 commit comments

Comments
 (0)