Skip to content

Commit f140479

Browse files
committed
Add optimizer Reduction.
1 parent e4c5210 commit f140479

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

src/OMEinsumContractionOrders.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ include("treesa.jl")
4242
# tree width method
4343
include("treewidth.jl")
4444

45+
# preprocessor
46+
include("reduction.jl")
47+
4548
# simplification passes
4649
include("simplify.jl")
4750

src/interfaces.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,7 @@ function _optimize_code(code, size_dict, optimizer::TreeSA)
5555
sc_weight=optimizer.sc_weight, rw_weight=optimizer.rw_weight, initializer=optimizer.initializer,
5656
greedy_method=optimizer.greedy_config, fixed_slices=optimizer.fixed_slices)
5757
end
58+
function _optimize_code(code, size_dict, optimizer::Reduction)
59+
_optimizer = Treewidth(; alg=ReductionEA(optimizer.inner))
60+
_optimize_code(code, size_dict, _optimizer)
61+
end

src/reduction.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
struct Reduction{O <: CodeOptimizer} <: CodeOptimizer
2+
inner::O
3+
end
4+
5+
struct ReductionEA{O <: CodeOptimizer} <: CliqueTrees.EliminationAlgorithm
6+
inner::O
7+
end
8+
9+
#=
10+
function CliqueTrees.permutation(weights::AbstractVector, graph, alg::ReductionEA)
11+
# reduce graph
12+
width = CliqueTrees.lowerbound(weights, graph)
13+
weights, graph, stack, index, width = CliqueTrees.saferules(weights, graph, width)
14+
15+
# feed reduced graph back to OMEinsumContractionOrders
16+
code = EinCode(maximal_cliques(CliqueTrees.Graph(graph)), Int[])
17+
sizes = Dict{Int, Int}(v => round(Int, 2^weights[v]) for v in CliqueTrees.vertices(graph))
18+
opt = optimize_code(code, sizes, alg.inner).eins
19+
20+
# compute ordering
21+
for v in eincode2order(opt)
22+
append!(stack, CliqueTrees.neighbors(index, v))
23+
end
24+
25+
return stack, invperm(stack)
26+
end
27+
=#
28+
29+
function CliqueTrees.permutation(weights::AbstractVector, graph, alg::ReductionEA)
30+
# reduce graph
31+
kernel, stack, label, width = CliqueTrees.pr4(graph, CliqueTrees.lowerbound(graph))
32+
33+
# feed reduced graph back to OMEinsumContractionOrders
34+
code = EinCode(maximal_cliques(kernel), Int[])
35+
sizes = Dict{Int, Int}(v => round(Int, 2^weights[label[v]]) for v in CliqueTrees.vertices(kernel))
36+
opt = optimize_code(code, sizes, alg.inner).eins
37+
38+
# compute ordering
39+
append!(stack, label[eincode2order(opt)])
40+
return stack, invperm(stack)
41+
end
42+
43+
function eincode2order(code::NestedEinsum{L}) where {L}
44+
elimination_order = Vector{L}()
45+
isleaf(code) && return elimination_order
46+
47+
for node in PostOrderDFS(code)
48+
if !(node isa LeafString)
49+
for id in setdiff(vcat(getixsv(node.eins)...), getiyv(node.eins))
50+
push!(elimination_order, id)
51+
end
52+
end
53+
end
54+
55+
return elimination_order
56+
end

0 commit comments

Comments
 (0)