Skip to content

Commit 139f99f

Browse files
authored
Merge pull request #82 from SymbolicML/parametric-expressions2
Improve ParametricExpressions
2 parents 3dbded5 + dbb2866 commit 139f99f

File tree

7 files changed

+110
-46
lines changed

7 files changed

+110
-46
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "0.18.0-alpha.1"
4+
version = "0.18.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/DynamicExpressions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ import .NodeModule:
6666
@reexport import .EvaluationHelpersModule
6767
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
6868
@reexport import .RandomModule: NodeSampler
69-
@reexport import .ExpressionModule: AbstractExpression, Expression, with_tree
69+
@reexport import .ExpressionModule:
70+
AbstractExpression, Expression, with_contents, with_metadata, get_contents, get_metadata
7071
import .ExpressionModule:
7172
get_tree, get_operators, get_variable_names, Metadata, default_node_type, node_type
7273
@reexport import .ParseModule: @parse_expression, parse_expression

src/Expression.jl

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ..NodeModule: AbstractExpressionNode, Node
66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
77
using ..UtilsModule: Undefined
88

9+
import ..NodeModule: copy_node, set_node!, count_nodes, tree_mapreduce, constructorof
910
import ..NodeUtilsModule:
1011
preserve_sharing,
1112
count_constants,
@@ -85,7 +86,7 @@ end
8586
end
8687

8788
node_type(::Union{E,Type{E}}) where {N,E<:AbstractExpression{<:Any,N}} = N
88-
@unstable default_node_type(::Type{<:AbstractExpression}) = Node
89+
@unstable default_node_type(_) = Node
8990
default_node_type(::Type{<:AbstractExpression{T}}) where {T} = Node{T}
9091

9192
########################################################
@@ -128,28 +129,52 @@ end
128129
function Base.copy(ex::AbstractExpression; break_sharing::Val=Val(false))
129130
return error("`copy` function must be implemented for $(typeof(ex)) types.")
130131
end
131-
function Base.hash(ex::AbstractExpression, h::UInt)
132-
return error("`hash` function must be implemented for $(typeof(ex)) types.")
133-
end
134-
function Base.:(==)(x::AbstractExpression, y::AbstractExpression)
135-
return error("`==` function must be implemented for $(typeof(x)) types.")
136-
end
137132
function get_constants(ex::AbstractExpression)
138133
return error("`get_constants` function must be implemented for $(typeof(ex)) types.")
139134
end
140135
function set_constants!(ex::AbstractExpression{T}, constants, refs) where {T}
141136
return error("`set_constants!` function must be implemented for $(typeof(ex)) types.")
142137
end
138+
function get_contents(ex::AbstractExpression)
139+
return error("`get_contents` function must be implemented for $(typeof(ex)) types.")
140+
end
141+
function get_metadata(ex::AbstractExpression)
142+
return error("`get_metadata` function must be implemented for $(typeof(ex)) types.")
143+
end
143144
########################################################
144145

145146
"""
146-
with_tree(ex::AbstractExpression, tree::AbstractExpressionNode)
147+
with_contents(ex::AbstractExpression, tree::AbstractExpressionNode)
148+
with_contents(ex::AbstractExpression, tree::AbstractExpression)
147149
148150
Create a new expression based on `ex` but with a different `tree`
149151
"""
150-
function with_tree(ex::AbstractExpression, tree)
151-
return constructorof(typeof(ex))(tree, ex.metadata)
152+
function with_contents(ex::AbstractExpression, tree::AbstractExpression)
153+
return with_contents(ex, get_contents(tree))
154+
end
155+
function with_contents(ex::AbstractExpression, tree)
156+
return constructorof(typeof(ex))(tree, get_metadata(ex))
157+
end
158+
function get_contents(ex::Expression)
159+
return ex.tree
160+
end
161+
162+
"""
163+
with_metadata(ex::AbstractExpression, metadata)
164+
with_metadata(ex::AbstractExpression; metadata...)
165+
166+
Create a new expression based on `ex` but with a different `metadata`.
167+
"""
168+
function with_metadata(ex::AbstractExpression; metadata...)
169+
return with_metadata(ex, Metadata((; metadata...)))
170+
end
171+
function with_metadata(ex::AbstractExpression, metadata::Metadata)
172+
return constructorof(typeof(ex))(get_contents(ex), metadata)
173+
end
174+
function get_metadata(ex::Expression)
175+
return ex.metadata
152176
end
177+
153178
function preserve_sharing(::Union{E,Type{E}}) where {T,N,E<:AbstractExpression{T,N}}
154179
return preserve_sharing(N)
155180
end
@@ -169,25 +194,17 @@ end
169194
function Base.copy(ex::Expression; break_sharing::Val=Val(false))
170195
return Expression(copy(ex.tree; break_sharing), copy(ex.metadata))
171196
end
172-
function Base.hash(ex::Expression, h::UInt)
173-
return hash(ex.tree, hash(ex.metadata, h))
197+
function Base.hash(ex::AbstractExpression, h::UInt)
198+
return hash(get_contents(ex), hash(get_metadata(ex), h))
174199
end
175-
176-
"""
177-
Base.:(==)(x::Expression, y::Expression)
178-
179-
Check equality of two expressions `x` and `y` by comparing their trees and metadata.
180-
"""
181-
function Base.:(==)(x::Expression, y::Expression)
182-
return x.tree == y.tree && x.metadata == y.metadata
200+
function Base.:(==)(x::AbstractExpression, y::AbstractExpression)
201+
return get_contents(x) == get_contents(y) && get_metadata(x) == get_metadata(y)
183202
end
184203

185204
# Overload all methods on AbstractExpressionNode that return an aggregation, or can
186205
# return an entire tree. Methods that only return the nodes are *not* overloaded, so
187206
# that the user must use the low-level interface.
188207

189-
import ..NodeModule: copy_node, set_node!, count_nodes, tree_mapreduce, constructorof
190-
191208
#! format: off
192209
@unstable constructorof(::Type{E}) where {E<:AbstractExpression} = Base.typename(E).wrapper
193210
@unstable constructorof(::Type{<:Expression}) = Expression

src/Interfaces.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ using ..ExpressionModule:
4343
get_tree,
4444
get_operators,
4545
get_variable_names,
46-
with_tree,
46+
get_contents,
47+
get_metadata,
48+
with_contents,
49+
with_metadata,
4750
default_node_type
4851
using ..ParametricExpressionModule: ParametricExpression, ParametricNode
4952

@@ -52,6 +55,14 @@ using ..ParametricExpressionModule: ParametricExpression, ParametricNode
5255
###############################################################################
5356

5457
## mandatory
58+
function _check_get_contents(ex::AbstractExpression)
59+
new_ex = with_contents(ex, get_contents(ex))
60+
return new_ex == ex && new_ex isa typeof(ex)
61+
end
62+
function _check_get_metadata(ex::AbstractExpression)
63+
new_ex = with_metadata(ex, get_metadata(ex))
64+
return new_ex == ex && new_ex isa typeof(ex)
65+
end
5566
function _check_get_tree(ex::AbstractExpression{T,N}) where {T,N}
5667
return get_tree(ex) isa N
5768
end
@@ -67,6 +78,15 @@ function _check_copy(ex::AbstractExpression)
6778
# TODO: Could include checks for aliasing here
6879
return preserves
6980
end
81+
function _check_with_contents(ex::AbstractExpression)
82+
new_ex = with_contents(ex, get_contents(ex))
83+
new_ex2 = with_contents(ex, ex)
84+
return new_ex == ex && new_ex isa typeof(ex) && new_ex2 == ex && new_ex2 isa typeof(ex)
85+
end
86+
function _check_with_metadata(ex::AbstractExpression)
87+
new_ex = with_metadata(ex, get_metadata(ex))
88+
return new_ex == ex && new_ex isa typeof(ex)
89+
end
7090

7191
## optional
7292
function _check_count_nodes(ex::AbstractExpression)
@@ -116,10 +136,14 @@ end
116136
#! format: off
117137
ei_components = (
118138
mandatory = (
139+
get_contents = "extracts the runtime contents of an expression" => _check_get_contents,
140+
get_metadata = "extracts the runtime metadata of an expression" => _check_get_metadata,
119141
get_tree = "extracts the expression tree from [`AbstractExpression`](@ref)" => _check_get_tree,
120142
get_operators = "returns the operators used in the expression (or pass `operators` explicitly to override)" => _check_get_operators,
121143
get_variable_names = "returns the variable names used in the expression (or pass `variable_names` explicitly to override)" => _check_get_variable_names,
122144
copy = "returns a copy of the expression" => _check_copy,
145+
with_contents = "returns the expression with different tree" => _check_with_contents,
146+
with_metadata = "returns the expression with different metadata" => _check_with_metadata,
123147
),
124148
optional = (
125149
count_nodes = "counts the number of nodes in the expression tree" => _check_count_nodes,

src/ParametricExpression.jl

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ import ..EvaluateModule: eval_tree_array
1919
import ..EvaluateDerivativeModule: eval_grad_tree_array
2020
import ..EvaluationHelpersModule: _grad_evaluator
2121
import ..ExpressionModule:
22-
get_tree, get_operators, get_variable_names, max_feature, default_node_type
22+
get_contents,
23+
get_metadata,
24+
get_tree,
25+
get_operators,
26+
get_variable_names,
27+
max_feature,
28+
default_node_type
2329
import ..ParseModule: parse_leaf
2430

2531
"""A type of expression node that also stores a parameter index"""
@@ -48,9 +54,13 @@ end
4854
"""
4955
ParametricExpression{T,N<:ParametricNode{T},D<:NamedTuple} <: AbstractExpression{T,N}
5056
51-
An expression to store parameters for a tree
57+
(Experimental) An expression to store parameters for a tree
5258
"""
53-
struct ParametricExpression{T,N<:ParametricNode{T},D<:NamedTuple} <: AbstractExpression{T,N}
59+
struct ParametricExpression{
60+
T,
61+
N<:ParametricNode{T},
62+
D<:NamedTuple{(:operators, :variable_names, :parameters, :parameter_names)},
63+
} <: AbstractExpression{T,N}
5464
tree::N
5565
metadata::Metadata{D}
5666

@@ -65,8 +75,9 @@ function ParametricExpression(
6575
parameters::AbstractMatrix{T2},
6676
parameter_names,
6777
) where {T1,T2}
68-
@assert (isempty(parameters) && isnothing(parameter_names)) ||
69-
size(parameters, 1) == length(parameter_names)
78+
if !isnothing(parameter_names)
79+
@assert size(parameters, 1) == length(parameter_names)
80+
end
7081
T = promote_type(T1, T2)
7182
t = T === T1 ? tree : convert(ParametricNode{T}, tree)
7283
m = Metadata((;
@@ -127,9 +138,9 @@ end
127138
###############################################################################
128139
# Abstract expression interface ###############################################
129140
###############################################################################
130-
function get_tree(ex::ParametricExpression)
131-
return ex.tree
132-
end
141+
get_contents(ex::ParametricExpression) = ex.tree
142+
get_metadata(ex::ParametricExpression) = ex.metadata
143+
get_tree(ex::ParametricExpression) = ex.tree
133144
function get_operators(ex::ParametricExpression, operators=nothing)
134145
return operators === nothing ? ex.metadata.operators : operators
135146
end
@@ -147,12 +158,6 @@ function Base.copy(ex::ParametricExpression; break_sharing::Val=Val(false))
147158
parameter_names=_copy_with_nothing(ex.metadata.parameter_names),
148159
)
149160
end
150-
function Base.hash(ex::ParametricExpression, h::UInt)
151-
return hash(ex.tree, hash(ex.metadata, h))
152-
end
153-
function Base.:(==)(x::ParametricExpression, y::ParametricExpression)
154-
return x.tree == y.tree && x.metadata == y.metadata
155-
end
156161
###############################################################################
157162

158163
###############################################################################
@@ -283,10 +288,16 @@ function string_tree(
283288
UInt16(0)
284289
end
285290
end
291+
_parameter_names = ex.metadata.parameter_names
292+
parameter_names = if _parameter_names === nothing
293+
["p$(i)" for i in 1:num_params]
294+
else
295+
_parameter_names
296+
end
286297
variable_names3 = if variable_names2 === nothing
287-
vcat(["p$(i)" for i in 1:num_params], ["x$(i)" for i in 1:max_feature])
298+
vcat(parameter_names, ["x$(i)" for i in 1:max_feature])
288299
else
289-
vcat(ex.metadata.parameter_names, variable_names2)
300+
vcat(parameter_names, variable_names2)
290301
end
291302
@assert length(variable_names3) >= num_params + max_feature
292303
return string_tree(

test/test_expressions.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,14 @@ end
168168
@test has_constants(ex) == false
169169
end
170170

171-
@testitem "Expression with_tree" begin
171+
@testitem "Expression with_contents" begin
172172
using DynamicExpressions
173173

174174
ex = @parse_expression(x1 + 1.5, binary_operators = [+, *], variable_names = ["x1"])
175175
ex2 = @parse_expression(x1 + 3.0, binary_operators = [+], variable_names = ["x1"])
176176

177-
t2 = DynamicExpressions.get_tree(ex2)
178-
ex_modified = DynamicExpressions.with_tree(ex, t2)
177+
t2 = DynamicExpressions.get_contents(ex2)
178+
ex_modified = DynamicExpressions.with_contents(ex, t2)
179179
@test DynamicExpressions.get_tree(ex_modified) == t2
180180
end
181181

test/test_multi_expression.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
trees::TREES
1111
metadata::Metadata{D}
1212

13+
function MultiScalarExpression(trees::NamedTuple, metadata::Metadata{D}) where {D}
14+
example_tree = first(values(trees))
15+
N = typeof(example_tree)
16+
T = eltype(example_tree)
17+
return new{T,N,typeof(trees),D}(trees, metadata)
18+
end
19+
1320
"""
1421
Create a multi-expression expression type.
1522
@@ -54,8 +61,6 @@
5461
)
5562
@test_throws "`get_tree` function must be implemented for" DE.get_tree(multi_ex)
5663
@test_throws "`copy` function must be implemented for" copy(multi_ex)
57-
@test_throws "`hash` function must be implemented for" hash(multi_ex, UInt(0))
58-
@test_throws "`==` function must be implemented for" multi_ex == multi_ex
5964
@test_throws "`get_constants` function must be implemented for" get_constants(
6065
multi_ex
6166
)
@@ -65,6 +70,12 @@
6570
end
6671

6772
tree_factory(f::F, trees) where {F} = f(; trees...)
73+
function DE.get_contents(ex::MultiScalarExpression)
74+
return ex.trees
75+
end
76+
function DE.get_metadata(ex::MultiScalarExpression)
77+
return ex.metadata
78+
end
6879
function DE.get_tree(ex::MultiScalarExpression{T,N}) where {T,N}
6980
fused_expression = parse_expression(
7081
tree_factory(ex.metadata.tree_factory, ex.trees)::Expr;

0 commit comments

Comments
 (0)