Skip to content

Commit 905ebc0

Browse files
authored
Merge pull request #83 from SymbolicML/parametric-expressions2
Fix some method ambiguities in `Expression`
2 parents 139f99f + db17350 commit 905ebc0

File tree

6 files changed

+80
-24
lines changed

6 files changed

+80
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "0.18.0"
4+
version = "0.18.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,18 @@ end
111111
function Base.convert(
112112
::typeof(SymbolicUtils.Symbolic),
113113
tree::Union{AbstractExpression,AbstractExpressionNode},
114-
operators::AbstractOperatorEnum;
114+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
115115
variable_names::Union{Array{String,1},Nothing}=nothing,
116116
index_functions::Bool=false,
117117
# Deprecated:
118118
varMap=nothing,
119119
)
120120
variable_names = deprecate_varmap(variable_names, varMap, :convert)
121121
return node_to_symbolic(
122-
tree, operators; variable_names=variable_names, index_functions=index_functions
122+
tree,
123+
get_operators(tree, operators);
124+
variable_names=variable_names,
125+
index_functions=index_functions,
123126
)
124127
end
125128

src/Expression.jl

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ or `cur_operators` if it is not `nothing`. If left as default,
101101
it requires `cur_operators` to not be `nothing`.
102102
`cur_operators` would typically be an `OperatorEnum`.
103103
"""
104-
function get_operators(ex::AbstractExpression, operators)
104+
function get_operators(
105+
ex::AbstractExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing
106+
)
105107
return error("`get_operators` function must be implemented for $(typeof(ex)) types.")
106108
end
107109

@@ -110,7 +112,10 @@ end
110112
111113
The same as `operators`, but for variable names.
112114
"""
113-
function get_variable_names(ex::AbstractExpression, variable_names)
115+
function get_variable_names(
116+
ex::AbstractExpression,
117+
variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing,
118+
)
114119
return error(
115120
"`get_variable_names` function must be implemented for $(typeof(ex)) types."
116121
)
@@ -179,10 +184,23 @@ function preserve_sharing(::Union{E,Type{E}}) where {T,N,E<:AbstractExpression{T
179184
return preserve_sharing(N)
180185
end
181186

182-
function get_operators(ex::Expression, operators=nothing)
187+
function get_operators(
188+
tree::AbstractExpressionNode, operators::Union{AbstractOperatorEnum,Nothing}=nothing
189+
)
190+
if operators === nothing
191+
throw(ArgumentError("`operators` must be provided for $(typeof(tree)) types."))
192+
else
193+
return operators
194+
end
195+
end
196+
function get_operators(
197+
ex::Expression, operators::Union{AbstractOperatorEnum,Nothing}=nothing
198+
)
183199
return operators === nothing ? ex.metadata.operators : operators
184200
end
185-
function get_variable_names(ex::Expression, variable_names=nothing)
201+
function get_variable_names(
202+
ex::Expression, variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing
203+
)
186204
return variable_names === nothing ? ex.metadata.variable_names : variable_names
187205
end
188206
function get_tree(ex::Expression)
@@ -249,7 +267,10 @@ end
249267
import ..StringsModule: string_tree, print_tree
250268

251269
function string_tree(
252-
ex::AbstractExpression, operators=nothing; variable_names=nothing, kws...
270+
ex::AbstractExpression,
271+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
272+
variable_names=nothing,
273+
kws...,
253274
)
254275
return string_tree(
255276
get_tree(ex),
@@ -260,7 +281,11 @@ function string_tree(
260281
end
261282
for io in ((), (:(io::IO),))
262283
@eval function print_tree(
263-
$(io...), ex::AbstractExpression, operators=nothing; variable_names=nothing, kws...
284+
$(io...),
285+
ex::AbstractExpression,
286+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
287+
variable_names=nothing,
288+
kws...,
264289
)
265290
return println($(io...), string_tree(ex, operators; variable_names, kws...))
266291
end
@@ -283,7 +308,9 @@ function max_feature(ex::AbstractExpression)
283308
)
284309
end
285310

286-
function _validate_input(ex::AbstractExpression, X, operators)
311+
function _validate_input(
312+
ex::AbstractExpression, X, operators::Union{AbstractOperatorEnum,Nothing}
313+
)
287314
if get_operators(ex, operators) isa OperatorEnum
288315
@assert X isa AbstractMatrix
289316
@assert max_feature(ex) <= size(X, 1)
@@ -292,7 +319,10 @@ function _validate_input(ex::AbstractExpression, X, operators)
292319
end
293320

294321
function eval_tree_array(
295-
ex::AbstractExpression, cX::AbstractMatrix, operators=nothing; kws...
322+
ex::AbstractExpression,
323+
cX::AbstractMatrix,
324+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
325+
kws...,
296326
)
297327
_validate_input(ex, cX, operators)
298328
return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
@@ -305,7 +335,10 @@ import ..EvaluateDerivativeModule: eval_grad_tree_array
305335
# - differentiable_eval_tree_array
306336

307337
function eval_grad_tree_array(
308-
ex::AbstractExpression, cX::AbstractMatrix, operators=nothing; kws...
338+
ex::AbstractExpression,
339+
cX::AbstractMatrix,
340+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
341+
kws...,
309342
)
310343
_validate_input(ex, cX, operators)
311344
return eval_grad_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
@@ -319,14 +352,16 @@ end
319352
function _grad_evaluator(
320353
ex::AbstractExpression,
321354
cX::AbstractMatrix,
322-
operators=nothing;
355+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
323356
variable=Val(true),
324357
kws...,
325358
)
326359
_validate_input(ex, cX, operators)
327360
return _grad_evaluator(get_tree(ex), cX, get_operators(ex, operators); variable, kws...)
328361
end
329-
function (ex::AbstractExpression)(X, operators=nothing; kws...)
362+
function (ex::AbstractExpression)(
363+
X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...
364+
)
330365
_validate_input(ex, X, operators)
331366
return get_tree(ex)(X, get_operators(ex, operators); kws...)
332367
end

src/ParametricExpression.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ParametricExpressionModule
22

33
using DispatchDoctor: @stable, @unstable
44

5+
using ..OperatorEnumModule: AbstractOperatorEnum
56
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
67
using ..ExpressionModule: AbstractExpression, Metadata
78

@@ -70,7 +71,7 @@ struct ParametricExpression{
7071
end
7172
function ParametricExpression(
7273
tree::ParametricNode{T1};
73-
operators,
74+
operators::Union{AbstractOperatorEnum,Nothing},
7475
variable_names,
7576
parameters::AbstractMatrix{T2},
7677
parameter_names,
@@ -141,10 +142,15 @@ end
141142
get_contents(ex::ParametricExpression) = ex.tree
142143
get_metadata(ex::ParametricExpression) = ex.metadata
143144
get_tree(ex::ParametricExpression) = ex.tree
144-
function get_operators(ex::ParametricExpression, operators=nothing)
145+
function get_operators(
146+
ex::ParametricExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing
147+
)
145148
return operators === nothing ? ex.metadata.operators : operators
146149
end
147-
function get_variable_names(ex::ParametricExpression, variable_names=nothing)
150+
function get_variable_names(
151+
ex::ParametricExpression,
152+
variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing,
153+
)
148154
return variable_names === nothing ? ex.metadata.variable_names : variable_names
149155
end
150156
@inline _copy_with_nothing(x) = copy(x)
@@ -232,15 +238,18 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
232238
)
233239
end
234240
#! format: off
235-
function (ex::ParametricExpression)(X::AbstractMatrix, operators=nothing; kws...)
241+
function (ex::ParametricExpression)(X::AbstractMatrix, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...)
236242
return eval_tree_array(ex, X, operators; kws...) # Will error
237243
end
238-
function eval_tree_array(::ParametricExpression{T}, ::AbstractMatrix{T}, operators=nothing; kws...) where {T}
244+
function eval_tree_array(::ParametricExpression{T}, ::AbstractMatrix{T}, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...) where {T}
239245
return error("Incorrect call. You must pass the `classes::Vector` argument when calling `eval_tree_array`.")
240246
end
241247
#! format: on
242248
function (ex::ParametricExpression)(
243-
X::AbstractMatrix{T}, classes::AbstractVector{<:Integer}, operators=nothing; kws...
249+
X::AbstractMatrix{T},
250+
classes::AbstractVector{<:Integer},
251+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
252+
kws...,
244253
) where {T}
245254
(output, flag) = eval_tree_array(ex, X, classes, operators; kws...) # Will error
246255
if !flag
@@ -252,7 +261,7 @@ function eval_tree_array(
252261
ex::ParametricExpression{T},
253262
X::AbstractMatrix{T},
254263
classes::AbstractVector{<:Integer},
255-
operators=nothing;
264+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
256265
kws...,
257266
) where {T}
258267
@assert length(classes) == size(X, 2)
@@ -270,7 +279,7 @@ function eval_tree_array(
270279
end
271280
function string_tree(
272281
ex::ParametricExpression,
273-
operators=nothing;
282+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
274283
variable_names=nothing,
275284
display_variable_names=nothing,
276285
X_sym_units=nothing,

test/test_expressions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,13 @@ end
226226

227227
@testitem "Miscellaneous expression calls" begin
228228
using DynamicExpressions
229+
using DynamicExpressions: get_tree, get_operators
229230

230231
ex = @parse_expression(x1 + 1.5, binary_operators = [+], variable_names = ["x1"])
231232
@test DynamicExpressions.ExpressionModule.node_type(ex) <: Node
232233

233234
@test !isempty(ex)
235+
236+
tree = get_tree(ex)
237+
@test_throws ArgumentError get_operators(tree, nothing)
234238
end

test/test_multi_expression.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,15 @@
8787
)::Expression{T,N}
8888
return fused_expression.tree
8989
end
90-
function DE.get_operators(ex::MultiScalarExpression, operators=nothing)
90+
function DE.get_operators(
91+
ex::MultiScalarExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing
92+
)
9193
return operators === nothing ? ex.metadata.operators : operators
9294
end
93-
function DE.get_variable_names(ex::MultiScalarExpression, variable_names=nothing)
95+
function DE.get_variable_names(
96+
ex::MultiScalarExpression,
97+
variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing,
98+
)
9499
return variable_names === nothing ? ex.metadata.variable_names : variable_names
95100
end
96101
function Base.copy(ex::MultiScalarExpression)

0 commit comments

Comments
 (0)