Skip to content

Commit 547f345

Browse files
committed
refactor and support Float polynomials
1 parent 7e59223 commit 547f345

File tree

4 files changed

+307
-137
lines changed

4 files changed

+307
-137
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h

Lines changed: 149 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111

1212
#include "mlir/Support/LLVM.h"
1313
#include "mlir/Support/LogicalResult.h"
14+
#include "llvm/ADT/APFloat.h"
1415
#include "llvm/ADT/APInt.h"
1516
#include "llvm/ADT/ArrayRef.h"
1617
#include "llvm/ADT/Hashing.h"
17-
#include "llvm/ADT/SmallVector.h"
18+
#include "llvm/ADT/SmallString.h"
19+
#include "llvm/ADT/Twine.h"
20+
#include "llvm/Support/raw_ostream.h"
1821

1922
namespace mlir {
2023

@@ -27,98 +30,201 @@ namespace polynomial {
2730
/// would want to specify 128-bit polynomials statically in the source code.
2831
constexpr unsigned apintBitWidth = 64;
2932

30-
/// A class representing a monomial of a single-variable polynomial with integer
31-
/// coefficients.
32-
class Monomial {
33+
template <typename CoefficientType>
34+
class MonomialBase {
3335
public:
34-
Monomial(int64_t coeff, uint64_t expo)
35-
: coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
36-
37-
Monomial(const APInt &coeff, const APInt &expo)
36+
MonomialBase(const CoefficientType &coeff, const APInt &expo)
3837
: coefficient(coeff), exponent(expo) {}
38+
virtual ~MonomialBase() = 0;
3939

40-
Monomial() : coefficient(apintBitWidth, 0), exponent(apintBitWidth, 0) {}
40+
const CoefficientType &getCoefficient() const { return coefficient; }
41+
CoefficientType &getMutableCoefficient() { return coefficient; }
42+
const APInt &getExponent() const { return exponent; }
43+
void setCoefficient(const CoefficientType &coeff) { coefficient = coeff; }
44+
void setExponent(const APInt &exp) { exponent = exp; }
4145

42-
bool operator==(const Monomial &other) const {
46+
bool operator==(const MonomialBase &other) const {
4347
return other.coefficient == coefficient && other.exponent == exponent;
4448
}
45-
bool operator!=(const Monomial &other) const {
49+
bool operator!=(const MonomialBase &other) const {
4650
return other.coefficient != coefficient || other.exponent != exponent;
4751
}
4852

4953
/// Monomials are ordered by exponent.
50-
bool operator<(const Monomial &other) const {
54+
bool operator<(const MonomialBase &other) const {
5155
return (exponent.ult(other.exponent));
5256
}
5357

54-
friend ::llvm::hash_code hash_value(const Monomial &arg);
58+
virtual bool isMonic() const = 0;
59+
virtual void coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
5560

56-
public:
57-
APInt coefficient;
61+
template <typename T>
62+
friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
5863

59-
// Always unsigned
64+
protected:
65+
CoefficientType coefficient;
6066
APInt exponent;
6167
};
6268

63-
/// A single-variable polynomial with integer coefficients.
64-
///
65-
/// Eg: x^1024 + x + 1
66-
///
67-
/// The symbols used as the polynomial's indeterminate don't matter, so long as
68-
/// it is used consistently throughout the polynomial.
69-
class Polynomial {
69+
/// A class representing a monomial of a single-variable polynomial with integer
70+
/// coefficients.
71+
class IntMonomial : public MonomialBase<APInt> {
7072
public:
71-
Polynomial() = delete;
73+
IntMonomial(int64_t coeff, uint64_t expo)
74+
: MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
7275

73-
explicit Polynomial(ArrayRef<Monomial> terms) : terms(terms){};
76+
IntMonomial()
77+
: MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
7478

75-
// Returns a Polynomial from a list of monomials.
76-
// Fails if two monomials have the same exponent.
77-
static FailureOr<Polynomial> fromMonomials(ArrayRef<Monomial> monomials);
79+
~IntMonomial() = default;
7880

79-
/// Returns a polynomial with coefficients given by `coeffs`. The value
80-
/// coeffs[i] is converted to a monomial with exponent i.
81-
static Polynomial fromCoefficients(ArrayRef<int64_t> coeffs);
81+
bool isMonic() const override { return coefficient == 1; }
82+
83+
void coefficientToString(llvm::SmallString<16> &coeffString) const override {
84+
coefficient.toStringSigned(coeffString);
85+
}
86+
};
87+
88+
/// A class representing a monomial of a single-variable polynomial with integer
89+
/// coefficients.
90+
class FloatMonomial : public MonomialBase<APFloat> {
91+
public:
92+
FloatMonomial(double coeff, uint64_t expo)
93+
: MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
94+
95+
FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
96+
97+
~FloatMonomial() = default;
98+
99+
bool isMonic() const override { return coefficient == APFloat(1.0); }
100+
101+
void coefficientToString(llvm::SmallString<16> &coeffString) const override {
102+
coefficient.toString(coeffString);
103+
}
104+
};
105+
106+
template <typename Monomial>
107+
class PolynomialBase {
108+
public:
109+
PolynomialBase() = delete;
110+
111+
explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
82112

83113
explicit operator bool() const { return !terms.empty(); }
84-
bool operator==(const Polynomial &other) const {
114+
bool operator==(const PolynomialBase &other) const {
85115
return other.terms == terms;
86116
}
87-
bool operator!=(const Polynomial &other) const {
117+
bool operator!=(const PolynomialBase &other) const {
88118
return !(other.terms == terms);
89119
}
90120

91-
// Prints polynomial to 'os'.
92-
void print(raw_ostream &os) const;
93121
void print(raw_ostream &os, ::llvm::StringRef separator,
94-
::llvm::StringRef exponentiation) const;
122+
::llvm::StringRef exponentiation) const {
123+
bool first = true;
124+
for (const Monomial &term : getTerms()) {
125+
if (first) {
126+
first = false;
127+
} else {
128+
os << separator;
129+
}
130+
std::string coeffToPrint;
131+
if (term.isMonic() && term.getExponent().uge(1)) {
132+
coeffToPrint = "";
133+
} else {
134+
llvm::SmallString<16> coeffString;
135+
term.coefficientToString(coeffString);
136+
coeffToPrint = coeffString.str();
137+
}
138+
139+
if (term.getExponent() == 0) {
140+
os << coeffToPrint;
141+
} else if (term.getExponent() == 1) {
142+
os << coeffToPrint << "x";
143+
} else {
144+
llvm::SmallString<16> expString;
145+
term.getExponent().toStringSigned(expString);
146+
os << coeffToPrint << "x" << exponentiation << expString;
147+
}
148+
}
149+
}
150+
151+
// Prints polynomial to 'os'.
152+
void print(raw_ostream &os) const { print(os, " + ", "**"); }
153+
95154
void dump() const;
96155

97156
// Prints polynomial so that it can be used as a valid identifier
98-
std::string toIdentifier() const;
157+
std::string toIdentifier() const {
158+
std::string result;
159+
llvm::raw_string_ostream os(result);
160+
print(os, "_", "");
161+
return os.str();
162+
}
99163

100-
unsigned getDegree() const;
164+
unsigned getDegree() const {
165+
return terms.back().getExponent().getZExtValue();
166+
}
101167

102168
ArrayRef<Monomial> getTerms() const { return terms; }
103169

104-
friend ::llvm::hash_code hash_value(const Polynomial &arg);
170+
template <typename T>
171+
friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
105172

106173
private:
107174
// The monomial terms for this polynomial.
108175
SmallVector<Monomial> terms;
109176
};
110177

111-
// Make Polynomial hashable.
112-
inline ::llvm::hash_code hash_value(const Polynomial &arg) {
178+
/// A single-variable polynomial with integer coefficients.
179+
///
180+
/// Eg: x^1024 + x + 1
181+
class IntPolynomial : public PolynomialBase<IntMonomial> {
182+
public:
183+
explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
184+
185+
// Returns a Polynomial from a list of monomials.
186+
// Fails if two monomials have the same exponent.
187+
static FailureOr<IntPolynomial>
188+
fromMonomials(ArrayRef<IntMonomial> monomials);
189+
190+
/// Returns a polynomial with coefficients given by `coeffs`. The value
191+
/// coeffs[i] is converted to a monomial with exponent i.
192+
static IntPolynomial fromCoefficients(ArrayRef<int64_t> coeffs);
193+
};
194+
195+
/// A single-variable polynomial with double coefficients.
196+
///
197+
/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
198+
class FloatPolynomial : public PolynomialBase<FloatMonomial> {
199+
public:
200+
explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
201+
: PolynomialBase(terms) {}
202+
203+
// Returns a Polynomial from a list of monomials.
204+
// Fails if two monomials have the same exponent.
205+
static FailureOr<FloatPolynomial>
206+
fromMonomials(ArrayRef<FloatMonomial> monomials);
207+
208+
/// Returns a polynomial with coefficients given by `coeffs`. The value
209+
/// coeffs[i] is converted to a monomial with exponent i.
210+
static FloatPolynomial fromCoefficients(ArrayRef<double> coeffs);
211+
};
212+
213+
// Make Polynomials hashable.
214+
template <typename T>
215+
inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
113216
return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
114217
}
115218

116-
inline ::llvm::hash_code hash_value(const Monomial &arg) {
219+
template <typename T>
220+
inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
117221
return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
118222
::llvm::hash_value(arg.exponent));
119223
}
120224

121-
inline raw_ostream &operator<<(raw_ostream &os, const Polynomial &polynomial) {
225+
template <typename T>
226+
inline raw_ostream &operator<<(raw_ostream &os,
227+
const PolynomialBase<T> &polynomial) {
122228
polynomial.print(os);
123229
return os;
124230
}

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
6060
let mnemonic = attrMnemonic;
6161
}
6262

63-
def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
64-
let summary = "An attribute containing a single-variable polynomial.";
63+
def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
64+
let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
6565
let description = [{
66-
A polynomial attribute represents a single-variable polynomial, which
67-
is used to define the modulus of a `RingAttr`, as well as to define constants
68-
and perform constant folding for `polynomial` ops.
66+
A polynomial attribute represents a single-variable polynomial with integer
67+
coefficients, which is used to define the modulus of a `RingAttr`, as well
68+
as to define constants and perform constant folding for `polynomial` ops.
6969

7070
The polynomial must be expressed as a list of monomial terms, with addition
7171
or subtraction between them. The choice of variable name is arbitrary, but
@@ -76,10 +76,32 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
7676
Example:
7777

7878
```mlir
79-
#poly = #polynomial.polynomial<x**1024 + 1>
79+
#poly = #polynomial.int_polynomial<x**1024 + 1>
8080
```
8181
}];
82-
let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
82+
let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial);
83+
let hasCustomAssemblyFormat = 1;
84+
}
85+
86+
def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
87+
let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
88+
let description = [{
89+
A polynomial attribute represents a single-variable polynomial with double
90+
precision floating point coefficients.
91+
92+
The polynomial must be expressed as a list of monomial terms, with addition
93+
or subtraction between them. The choice of variable name is arbitrary, but
94+
must be consistent across all the monomials used to define a single
95+
attribute. The order of monomial terms is arbitrary, each monomial degree
96+
must occur at most once.
97+
98+
Example:
99+
100+
```mlir
101+
#poly = #polynomial.float_polynomial<0.5 x**7 + 1.5>
102+
```
103+
}];
104+
let parameters = (ins "FloatPolynomial":$polynomial);
83105
let hasCustomAssemblyFormat = 1;
84106
}
85107

@@ -123,15 +145,15 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
123145
let parameters = (ins
124146
"Type": $coefficientType,
125147
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
126-
OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
148+
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
127149
OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
128150
);
129151

130152
let builders = [
131153
AttrBuilder<
132154
(ins "::mlir::Type":$coefficientTy,
133155
"::mlir::IntegerAttr":$coefficientModulusAttr,
134-
"::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
156+
"::mlir::polynomial::IntPolynomialAttr":$polynomialModulusAttr), [{
135157
return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
136158
}]>
137159
];
@@ -405,10 +427,14 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
405427
let arguments = (ins Polynomial_PolynomialType:$input);
406428
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
407429
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
408-
409430
let hasVerifier = 1;
410431
}
411432

433+
def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
434+
Polynomial_FloatPolynomialAttr,
435+
Polynomial_IntPolynomialAttr
436+
]>;
437+
412438
def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
413439
let summary = "Define a constant polynomial via an attribute.";
414440
let description = [{
@@ -420,9 +446,9 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
420446
%0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
421447
```
422448
}];
423-
let arguments = (ins Polynomial_PolynomialAttr:$input);
449+
let arguments = (ins Polynomial_AnyPolynomialAttr:$input);
424450
let results = (outs Polynomial_PolynomialType:$output);
425-
let assemblyFormat = "$input attr-dict `:` type($output)";
451+
let assemblyFormat = "operands attr-dict `:` type($output)";
426452
}
427453

428454
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {

0 commit comments

Comments
 (0)