7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Analysis/Presburger/Barvinok.h"
10
+ #include " llvm/ADT/Sequence.h"
10
11
11
12
using namespace mlir ;
12
13
using namespace presburger ;
@@ -24,7 +25,7 @@ ConeV mlir::presburger::detail::getDual(ConeH cone) {
24
25
// is represented as a row [a1, ..., an, b]
25
26
// and that b = 0.
26
27
27
- for (unsigned i = 0 ; i < numIneq; ++i ) {
28
+ for (auto i : llvm::seq< int >( 0 , numIneq) ) {
28
29
assert (cone.atIneq (i, numVar) == 0 &&
29
30
" H-representation of cone is not centred at the origin!" );
30
31
for (unsigned j = 0 ; j < numVar; ++j) {
@@ -63,3 +64,83 @@ MPInt mlir::presburger::detail::getIndex(ConeV cone) {
63
64
64
65
return cone.determinant ();
65
66
}
67
+
68
+ // / Compute the generating function for a unimodular cone.
69
+ // / This consists of a single term of the form
70
+ // / sign * x^num / prod_j (1 - x^den_j)
71
+ // /
72
+ // / sign is either +1 or -1.
73
+ // / den_j is defined as the set of generators of the cone.
74
+ // / num is computed by expressing the vertex as a weighted
75
+ // / sum of the generators, and then taking the floor of the
76
+ // / coefficients.
77
+ GeneratingFunction mlir::presburger::detail::unimodularConeGeneratingFunction (
78
+ ParamPoint vertex, int sign, ConeH cone) {
79
+ // Consider a cone with H-representation [0 -1].
80
+ // [-1 -2]
81
+ // Let the vertex be given by the matrix [ 2 2 0], with 2 params.
82
+ // [-1 -1/2 1]
83
+
84
+ // `cone` must be unimodular.
85
+ assert (getIndex (getDual (cone)) == 1 && " input cone is not unimodular!" );
86
+
87
+ unsigned numVar = cone.getNumVars ();
88
+ unsigned numIneq = cone.getNumInequalities ();
89
+
90
+ // Thus its ray matrix, U, is the inverse of the
91
+ // transpose of its inequality matrix, `cone`.
92
+ // The last column of the inequality matrix is null,
93
+ // so we remove it to obtain a square matrix.
94
+ FracMatrix transp = FracMatrix (cone.getInequalities ()).transpose ();
95
+ transp.removeRow (numVar);
96
+
97
+ FracMatrix generators (numVar, numIneq);
98
+ transp.determinant (/* inverse=*/ &generators); // This is the U-matrix.
99
+ // Thus the generators are given by U = [2 -1].
100
+ // [-1 0]
101
+
102
+ // The powers in the denominator of the generating
103
+ // function are given by the generators of the cone,
104
+ // i.e., the rows of the matrix U.
105
+ std::vector<Point> denominator (numIneq);
106
+ ArrayRef<Fraction> row;
107
+ for (auto i : llvm::seq<int >(0 , numVar)) {
108
+ row = generators.getRow (i);
109
+ denominator[i] = Point (row);
110
+ }
111
+
112
+ // The vertex is v \in Z^{d x (n+1)}
113
+ // We need to find affine functions of parameters λ_i(p)
114
+ // such that v = Σ λ_i(p)*u_i,
115
+ // where u_i are the rows of U (generators)
116
+ // The λ_i are given by the columns of Λ = v^T U^{-1}, and
117
+ // we have transp = U^{-1}.
118
+ // Then the exponent in the numerator will be
119
+ // Σ -floor(-λ_i(p))*u_i.
120
+ // Thus we store the (exponent of the) numerator as the affine function -Λ,
121
+ // since the generators u_i are already stored as the exponent of the
122
+ // denominator. Note that the outer -1 will have to be accounted for, as it is
123
+ // not stored. See end for an example.
124
+
125
+ unsigned numColumns = vertex.getNumColumns ();
126
+ unsigned numRows = vertex.getNumRows ();
127
+ ParamPoint numerator (numColumns, numRows);
128
+ SmallVector<Fraction> ithCol (numRows);
129
+ for (auto i : llvm::seq<int >(0 , numColumns)) {
130
+ for (auto j : llvm::seq<int >(0 , numRows))
131
+ ithCol[j] = vertex (j, i);
132
+ numerator.setRow (i, transp.preMultiplyWithRow (ithCol));
133
+ numerator.negateRow (i);
134
+ }
135
+ // Therefore Λ will be given by [ 1 0 ] and the negation of this will be
136
+ // [ 1/2 -1 ]
137
+ // [ -1 -2 ]
138
+ // stored as the numerator.
139
+ // Algebraically, the numerator exponent is
140
+ // [ -2 ⌊ - N - M/2 + 1 ⌋ + 1 ⌊ 0 + M + 2 ⌋ ] -> first COLUMN of U is [2, -1]
141
+ // [ 1 ⌊ - N - M/2 + 1 ⌋ + 0 ⌊ 0 + M + 2 ⌋ ] -> second COLUMN of U is [-1, 0]
142
+
143
+ return GeneratingFunction (numColumns - 1 , SmallVector<int >(1 , sign),
144
+ std::vector ({numerator}),
145
+ std::vector ({denominator}));
146
+ }
0 commit comments