Skip to content

Commit 1e7cf1d

Browse files
committed
fix: JET errors in non differentiable errors
1 parent 92960e9 commit 1e7cf1d

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/NonDifferentiableDeclarations.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
module NonDifferentiableDeclarationsModule
22

33
using ChainRulesCore: @non_differentiable
4+
import ..OperatorEnumModule: AbstractOperatorEnum
5+
import ..NodeModule: AbstractExpressionNode, AbstractNode
46
import ..NodeUtilsModule: tree_mapreduce
5-
import ..ExpressionModule: get_operators, get_variable_names, _validate_input
7+
import ..ExpressionModule:
8+
AbstractExpression, get_operators, get_variable_names, _validate_input
69

7-
@non_differentiable tree_mapreduce(args...)
8-
@non_differentiable get_operators(ex, operators)
9-
@non_differentiable get_variable_names(ex, variable_names)
10-
@non_differentiable _validate_input(ex, X, operators)
10+
#! format: off
11+
@non_differentiable tree_mapreduce(f::Function, op::Function, tree::AbstractNode, result_type::Type)
12+
@non_differentiable tree_mapreduce(f::Function, f_branch::Function, op::Function, tree::AbstractNode, result_type::Type)
13+
@non_differentiable get_operators(ex::Union{AbstractExpression,AbstractExpressionNode}, operators::Union{AbstractOperatorEnum,Nothing})
14+
@non_differentiable get_variable_names(ex::AbstractExpression, variable_names::Union{AbstractVector{<:AbstractString},Nothing})
15+
@non_differentiable _validate_input(ex::AbstractExpression, X, operators::Union{AbstractOperatorEnum,Nothing})
16+
#! format: on
1117

1218
end

0 commit comments

Comments
 (0)