|
| 1 | +struct Reduction{O <: CodeOptimizer} <: CodeOptimizer |
| 2 | + inner::O |
| 3 | +end |
| 4 | + |
| 5 | +struct ReductionEA{O <: CodeOptimizer} <: CliqueTrees.EliminationAlgorithm |
| 6 | + inner::O |
| 7 | +end |
| 8 | + |
| 9 | +#= |
| 10 | +function CliqueTrees.permutation(weights::AbstractVector, graph, alg::ReductionEA) |
| 11 | + # reduce graph |
| 12 | + width = CliqueTrees.lowerbound(weights, graph) |
| 13 | + weights, graph, stack, index, width = CliqueTrees.saferules(weights, graph, width) |
| 14 | + |
| 15 | + # feed reduced graph back to OMEinsumContractionOrders |
| 16 | + code = EinCode(maximal_cliques(CliqueTrees.Graph(graph)), Int[]) |
| 17 | + sizes = Dict{Int, Int}(v => round(Int, 2^weights[v]) for v in CliqueTrees.vertices(graph)) |
| 18 | + opt = optimize_code(code, sizes, alg.inner).eins |
| 19 | +
|
| 20 | + # compute ordering |
| 21 | + for v in eincode2order(opt) |
| 22 | + append!(stack, CliqueTrees.neighbors(index, v)) |
| 23 | + end |
| 24 | + |
| 25 | + return stack, invperm(stack) |
| 26 | +end |
| 27 | +=# |
| 28 | + |
| 29 | +function CliqueTrees.permutation(weights::AbstractVector, graph, alg::ReductionEA) |
| 30 | + # reduce graph |
| 31 | + kernel, stack, label, width = CliqueTrees.pr4(graph, CliqueTrees.lowerbound(graph)) |
| 32 | + |
| 33 | + # feed reduced graph back to OMEinsumContractionOrders |
| 34 | + code = EinCode(maximal_cliques(kernel), Int[]) |
| 35 | + sizes = Dict{Int, Int}(v => round(Int, 2^weights[label[v]]) for v in CliqueTrees.vertices(kernel)) |
| 36 | + opt = optimize_code(code, sizes, alg.inner).eins |
| 37 | + |
| 38 | + # compute ordering |
| 39 | + append!(stack, label[eincode2order(opt)]) |
| 40 | + return stack, invperm(stack) |
| 41 | +end |
| 42 | + |
| 43 | +function eincode2order(code::NestedEinsum{L}) where {L} |
| 44 | + elimination_order = Vector{L}() |
| 45 | + isleaf(code) && return elimination_order |
| 46 | + |
| 47 | + for node in PostOrderDFS(code) |
| 48 | + if !(node isa LeafString) |
| 49 | + for id in setdiff(vcat(getixsv(node.eins)...), getiyv(node.eins)) |
| 50 | + push!(elimination_order, id) |
| 51 | + end |
| 52 | + end |
| 53 | + end |
| 54 | + |
| 55 | + return elimination_order |
| 56 | +end |
0 commit comments