2
2
module ExpressionModule
3
3
4
4
using DispatchDoctor: @unstable
5
- using ChainRulesCore: @ignore_derivatives
5
+ using ChainRulesCore: CRC
6
6
7
7
using .. NodeModule: AbstractExpressionNode, Node
8
8
using .. OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
@@ -19,7 +19,12 @@ import ..NodeUtilsModule:
19
19
has_constants,
20
20
get_constants,
21
21
set_constants!
22
+ import .. EvaluateModule: eval_tree_array, differentiable_eval_tree_array
23
+ import .. EvaluateDerivativeModule: eval_grad_tree_array
24
+ import .. EvaluationHelpersModule: _grad_evaluator
25
+ import .. StringsModule: string_tree, print_tree
22
26
import .. ChainRulesModule: extract_gradient
27
+ import .. SimplifyModule: combine_operators, simplify_tree!
23
28
24
29
""" A wrapper for a named tuple to avoid piracy."""
25
30
struct Metadata{NT<: NamedTuple }
@@ -280,8 +285,6 @@ function extract_gradient(
280
285
return extract_gradient (gradient. tree, get_tree (ex))
281
286
end
282
287
283
- import .. StringsModule: string_tree, print_tree
284
-
285
288
function string_tree (
286
289
ex:: AbstractExpression ,
287
290
operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
@@ -311,8 +314,6 @@ function Base.show(io::IO, ::MIME"text/plain", ex::AbstractExpression)
311
314
return print (io, string_tree (ex))
312
315
end
313
316
314
- import .. EvaluateModule: eval_tree_array, differentiable_eval_tree_array
315
-
316
317
function max_feature (ex:: AbstractExpression )
317
318
return tree_mapreduce (
318
319
leaf -> leaf. constant ? zero (UInt16) : leaf. feature,
@@ -344,8 +345,6 @@ function eval_tree_array(
344
345
return eval_tree_array (get_tree (ex), cX, get_operators (ex, operators); kws... )
345
346
end
346
347
347
- import .. EvaluateDerivativeModule: eval_grad_tree_array
348
-
349
348
# skipped (not used much)
350
349
# - eval_diff_tree_array
351
350
# - differentiable_eval_tree_array
@@ -360,8 +359,6 @@ function eval_grad_tree_array(
360
359
return eval_grad_tree_array (get_tree (ex), cX, get_operators (ex, operators); kws... )
361
360
end
362
361
363
- import .. EvaluationHelpersModule: _grad_evaluator
364
-
365
362
function Base. adjoint (ex:: AbstractExpression )
366
363
return ((args... ; kws... ) -> _grad_evaluator (ex, args... ; kws... ))
367
364
end
@@ -382,6 +379,4 @@ function (ex::AbstractExpression)(
382
379
return get_tree (ex)(X, get_operators (ex, operators); kws... )
383
380
end
384
381
385
- import .. SimplifyModule: combine_operators, simplify_tree!
386
-
387
382
end
0 commit comments