11
11
12
12
#include " mlir/Support/LLVM.h"
13
13
#include " mlir/Support/LogicalResult.h"
14
+ #include " llvm/ADT/APFloat.h"
14
15
#include " llvm/ADT/APInt.h"
15
16
#include " llvm/ADT/ArrayRef.h"
16
17
#include " llvm/ADT/Hashing.h"
17
- #include " llvm/ADT/SmallVector.h"
18
+ #include " llvm/ADT/SmallString.h"
19
+ #include " llvm/ADT/Twine.h"
20
+ #include " llvm/Support/raw_ostream.h"
18
21
19
22
namespace mlir {
20
23
@@ -27,98 +30,201 @@ namespace polynomial {
27
30
// / would want to specify 128-bit polynomials statically in the source code.
28
31
constexpr unsigned apintBitWidth = 64 ;
29
32
30
- // / A class representing a monomial of a single-variable polynomial with integer
31
- // / coefficients.
32
- class Monomial {
33
+ template <typename CoefficientType>
34
+ class MonomialBase {
33
35
public:
34
- Monomial (int64_t coeff, uint64_t expo)
35
- : coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
36
-
37
- Monomial (const APInt &coeff, const APInt &expo)
36
+ MonomialBase (const CoefficientType &coeff, const APInt &expo)
38
37
: coefficient(coeff), exponent(expo) {}
38
+ virtual ~MonomialBase () = 0 ;
39
39
40
- Monomial () : coefficient(apintBitWidth, 0 ), exponent(apintBitWidth, 0 ) {}
40
+ const CoefficientType &getCoefficient () const { return coefficient; }
41
+ CoefficientType &getMutableCoefficient () { return coefficient; }
42
+ const APInt &getExponent () const { return exponent; }
43
+ void setCoefficient (const CoefficientType &coeff) { coefficient = coeff; }
44
+ void setExponent (const APInt &exp) { exponent = exp; }
41
45
42
- bool operator ==(const Monomial &other) const {
46
+ bool operator ==(const MonomialBase &other) const {
43
47
return other.coefficient == coefficient && other.exponent == exponent;
44
48
}
45
- bool operator !=(const Monomial &other) const {
49
+ bool operator !=(const MonomialBase &other) const {
46
50
return other.coefficient != coefficient || other.exponent != exponent;
47
51
}
48
52
49
53
// / Monomials are ordered by exponent.
50
- bool operator <(const Monomial &other) const {
54
+ bool operator <(const MonomialBase &other) const {
51
55
return (exponent.ult (other.exponent ));
52
56
}
53
57
54
- friend ::llvm::hash_code hash_value (const Monomial &arg);
58
+ virtual bool isMonic () const = 0;
59
+ virtual void coefficientToString (llvm::SmallString<16 > &coeffString) const = 0;
55
60
56
- public:
57
- APInt coefficient ;
61
+ template < typename T>
62
+ friend ::llvm::hash_code hash_value ( const MonomialBase<T> &arg) ;
58
63
59
- // Always unsigned
64
+ protected:
65
+ CoefficientType coefficient;
60
66
APInt exponent;
61
67
};
62
68
63
- // / A single-variable polynomial with integer coefficients.
64
- // /
65
- // / Eg: x^1024 + x + 1
66
- // /
67
- // / The symbols used as the polynomial's indeterminate don't matter, so long as
68
- // / it is used consistently throughout the polynomial.
69
- class Polynomial {
69
+ // / A class representing a monomial of a single-variable polynomial with integer
70
+ // / coefficients.
71
+ class IntMonomial : public MonomialBase <APInt> {
70
72
public:
71
- Polynomial () = delete ;
73
+ IntMonomial (int64_t coeff, uint64_t expo)
74
+ : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
72
75
73
- explicit Polynomial (ArrayRef<Monomial> terms) : terms(terms){};
76
+ IntMonomial ()
77
+ : MonomialBase(APInt(apintBitWidth, 0 ), APInt(apintBitWidth, 0 )) {}
74
78
75
- // Returns a Polynomial from a list of monomials.
76
- // Fails if two monomials have the same exponent.
77
- static FailureOr<Polynomial> fromMonomials (ArrayRef<Monomial> monomials);
79
+ ~IntMonomial () = default ;
78
80
79
- // / Returns a polynomial with coefficients given by `coeffs`. The value
80
- // / coeffs[i] is converted to a monomial with exponent i.
81
- static Polynomial fromCoefficients (ArrayRef<int64_t > coeffs);
81
+ bool isMonic () const override { return coefficient == 1 ; }
82
+
83
+ void coefficientToString (llvm::SmallString<16 > &coeffString) const override {
84
+ coefficient.toStringSigned (coeffString);
85
+ }
86
+ };
87
+
88
+ // / A class representing a monomial of a single-variable polynomial with integer
89
+ // / coefficients.
90
+ class FloatMonomial : public MonomialBase <APFloat> {
91
+ public:
92
+ FloatMonomial (double coeff, uint64_t expo)
93
+ : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
94
+
95
+ FloatMonomial () : MonomialBase(APFloat((double )0 ), APInt(apintBitWidth, 0 )) {}
96
+
97
+ ~FloatMonomial () = default ;
98
+
99
+ bool isMonic () const override { return coefficient == APFloat (1.0 ); }
100
+
101
+ void coefficientToString (llvm::SmallString<16 > &coeffString) const override {
102
+ coefficient.toString (coeffString);
103
+ }
104
+ };
105
+
106
+ template <typename Monomial>
107
+ class PolynomialBase {
108
+ public:
109
+ PolynomialBase () = delete ;
110
+
111
+ explicit PolynomialBase (ArrayRef<Monomial> terms) : terms(terms){};
82
112
83
113
explicit operator bool () const { return !terms.empty (); }
84
- bool operator ==(const Polynomial &other) const {
114
+ bool operator ==(const PolynomialBase &other) const {
85
115
return other.terms == terms;
86
116
}
87
- bool operator !=(const Polynomial &other) const {
117
+ bool operator !=(const PolynomialBase &other) const {
88
118
return !(other.terms == terms);
89
119
}
90
120
91
- // Prints polynomial to 'os'.
92
- void print (raw_ostream &os) const ;
93
121
void print (raw_ostream &os, ::llvm::StringRef separator,
94
- ::llvm::StringRef exponentiation) const ;
122
+ ::llvm::StringRef exponentiation) const {
123
+ bool first = true ;
124
+ for (const Monomial &term : getTerms ()) {
125
+ if (first) {
126
+ first = false ;
127
+ } else {
128
+ os << separator;
129
+ }
130
+ std::string coeffToPrint;
131
+ if (term.isMonic () && term.getExponent ().uge (1 )) {
132
+ coeffToPrint = " " ;
133
+ } else {
134
+ llvm::SmallString<16 > coeffString;
135
+ term.coefficientToString (coeffString);
136
+ coeffToPrint = coeffString.str ();
137
+ }
138
+
139
+ if (term.getExponent () == 0 ) {
140
+ os << coeffToPrint;
141
+ } else if (term.getExponent () == 1 ) {
142
+ os << coeffToPrint << " x" ;
143
+ } else {
144
+ llvm::SmallString<16 > expString;
145
+ term.getExponent ().toStringSigned (expString);
146
+ os << coeffToPrint << " x" << exponentiation << expString;
147
+ }
148
+ }
149
+ }
150
+
151
+ // Prints polynomial to 'os'.
152
+ void print (raw_ostream &os) const { print (os, " + " , " **" ); }
153
+
95
154
void dump () const ;
96
155
97
156
// Prints polynomial so that it can be used as a valid identifier
98
- std::string toIdentifier () const ;
157
+ std::string toIdentifier () const {
158
+ std::string result;
159
+ llvm::raw_string_ostream os (result);
160
+ print (os, " _" , " " );
161
+ return os.str ();
162
+ }
99
163
100
- unsigned getDegree () const ;
164
+ unsigned getDegree () const {
165
+ return terms.back ().getExponent ().getZExtValue ();
166
+ }
101
167
102
168
ArrayRef<Monomial> getTerms () const { return terms; }
103
169
104
- friend ::llvm::hash_code hash_value (const Polynomial &arg);
170
+ template <typename T>
171
+ friend ::llvm::hash_code hash_value (const PolynomialBase<T> &arg);
105
172
106
173
private:
107
174
// The monomial terms for this polynomial.
108
175
SmallVector<Monomial> terms;
109
176
};
110
177
111
- // Make Polynomial hashable.
112
- inline ::llvm::hash_code hash_value (const Polynomial &arg) {
178
+ // / A single-variable polynomial with integer coefficients.
179
+ // /
180
+ // / Eg: x^1024 + x + 1
181
+ class IntPolynomial : public PolynomialBase <IntMonomial> {
182
+ public:
183
+ explicit IntPolynomial (ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
184
+
185
+ // Returns a Polynomial from a list of monomials.
186
+ // Fails if two monomials have the same exponent.
187
+ static FailureOr<IntPolynomial>
188
+ fromMonomials (ArrayRef<IntMonomial> monomials);
189
+
190
+ // / Returns a polynomial with coefficients given by `coeffs`. The value
191
+ // / coeffs[i] is converted to a monomial with exponent i.
192
+ static IntPolynomial fromCoefficients (ArrayRef<int64_t > coeffs);
193
+ };
194
+
195
+ // / A single-variable polynomial with double coefficients.
196
+ // /
197
+ // / Eg: 1.0 x^1024 + 3.5 x + 1e-05
198
+ class FloatPolynomial : public PolynomialBase <FloatMonomial> {
199
+ public:
200
+ explicit FloatPolynomial (ArrayRef<FloatMonomial> terms)
201
+ : PolynomialBase(terms) {}
202
+
203
+ // Returns a Polynomial from a list of monomials.
204
+ // Fails if two monomials have the same exponent.
205
+ static FailureOr<FloatPolynomial>
206
+ fromMonomials (ArrayRef<FloatMonomial> monomials);
207
+
208
+ // / Returns a polynomial with coefficients given by `coeffs`. The value
209
+ // / coeffs[i] is converted to a monomial with exponent i.
210
+ static FloatPolynomial fromCoefficients (ArrayRef<double > coeffs);
211
+ };
212
+
213
+ // Make Polynomials hashable.
214
+ template <typename T>
215
+ inline ::llvm::hash_code hash_value (const PolynomialBase<T> &arg) {
113
216
return ::llvm::hash_combine_range (arg.terms .begin (), arg.terms .end ());
114
217
}
115
218
116
- inline ::llvm::hash_code hash_value (const Monomial &arg) {
219
+ template <typename T>
220
+ inline ::llvm::hash_code hash_value (const MonomialBase<T> &arg) {
117
221
return llvm::hash_combine (::llvm::hash_value (arg.coefficient ),
118
222
::llvm::hash_value (arg.exponent));
119
223
}
120
224
121
- inline raw_ostream &operator <<(raw_ostream &os, const Polynomial &polynomial) {
225
+ template <typename T>
226
+ inline raw_ostream &operator <<(raw_ostream &os,
227
+ const PolynomialBase<T> &polynomial) {
122
228
polynomial.print (os);
123
229
return os;
124
230
}
0 commit comments