Skip to content

Commit f6276fe

Browse files
committed
add typed variants for polynomial.constant op
1 parent 29b4231 commit f6276fe

File tree

4 files changed

+77
-13
lines changed

4 files changed

+77
-13
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,14 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
272272
let hasVerifier = 1;
273273
}
274274

275-
def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
276-
Polynomial_FloatPolynomialAttr,
277-
Polynomial_IntPolynomialAttr
275+
def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
276+
Polynomial_TypedFloatPolynomialAttr,
277+
Polynomial_TypedIntPolynomialAttr
278278
]>;
279279

280280
// Not deriving from Polynomial_Op due to need for custom assembly format
281-
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
281+
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
282+
[Pure, InferTypeOpAdaptor]> {
282283
let summary = "Define a constant polynomial via an attribute.";
283284
let description = [{
284285
Example:
@@ -292,9 +293,9 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
292293
%0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
293294
```
294295
}];
295-
let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
296+
let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value);
296297
let results = (outs Polynomial_PolynomialType:$output);
297-
let assemblyFormat = "attr-dict `:` type($output)";
298+
let assemblyFormat = "attr-dict $value";
298299
}
299300

300301
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {

mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
1818
}
1919

2020
def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
21-
let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
21+
let summary = "an attribute containing a single-variable polynomial with integer coefficients";
2222
let description = [{
2323
A polynomial attribute represents a single-variable polynomial with integer
2424
coefficients, which is used to define the modulus of a `RingAttr`, as well
@@ -41,7 +41,7 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
4141
}
4242

4343
def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
44-
let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
44+
let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients";
4545
let description = [{
4646
A polynomial attribute represents a single-variable polynomial with double
4747
precision floating point coefficients.
@@ -62,8 +62,56 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
6262
let hasCustomAssemblyFormat = 1;
6363
}
6464

65+
def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
66+
"TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
67+
let summary = "a typed int_polynomial";
68+
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value);
69+
let assemblyFormat = "$value `:` $type";
70+
let builders = [
71+
AttrBuilderWithInferredContext<(ins "Type":$type,
72+
"const IntPolynomial &":$value), [{
73+
return $_get(
74+
type.getContext(),
75+
type,
76+
IntPolynomialAttr::get(type.getContext(), value));
77+
}]>,
78+
AttrBuilderWithInferredContext<(ins "Type":$type,
79+
"const Attribute &":$value), [{
80+
return $_get(type.getContext(), type, ::llvm::cast<IntPolynomialAttr>(value));
81+
}]>
82+
];
83+
let extraClassDeclaration = [{
84+
// used for constFoldBinaryOp
85+
using ValueType = ::mlir::Attribute;
86+
}];
87+
}
88+
89+
def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
90+
"TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
91+
let summary = "a typed float_polynomial";
92+
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value);
93+
let assemblyFormat = "$value `:` $type";
94+
let builders = [
95+
AttrBuilderWithInferredContext<(ins "Type":$type,
96+
"const FloatPolynomial &":$value), [{
97+
return $_get(
98+
type.getContext(),
99+
type,
100+
FloatPolynomialAttr::get(type.getContext(), value));
101+
}]>,
102+
AttrBuilderWithInferredContext<(ins "Type":$type,
103+
"const Attribute &":$value), [{
104+
return $_get(type.getContext(), type, ::llvm::cast<FloatPolynomialAttr>(value));
105+
}]>
106+
];
107+
let extraClassDeclaration = [{
108+
// used for constFoldBinaryOp
109+
using ValueType = ::mlir::Attribute;
110+
}];
111+
}
112+
65113
def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
66-
let summary = "An attribute specifying a polynomial ring.";
114+
let summary = "an attribute specifying a polynomial ring";
67115
let description = [{
68116
A ring describes the domain in which polynomial arithmetic occurs. The ring
69117
attribute in `polynomial` represents the more specific case of polynomials

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,21 @@ LogicalResult INTTOp::verify() {
186186
return verifyNTTOp(this->getOperation(), ring, tensorType);
187187
}
188188

189+
LogicalResult ConstantOp::inferReturnTypes(
190+
MLIRContext *context, std::optional<mlir::Location> location,
191+
ConstantOp::Adaptor adaptor,
192+
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
193+
Attribute operand = adaptor.getValue();
194+
if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
195+
inferredReturnTypes.push_back(intPoly.getType());
196+
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
197+
inferredReturnTypes.push_back(floatPoly.getType());
198+
} else {
199+
return failure();
200+
}
201+
return success();
202+
}
203+
189204
//===----------------------------------------------------------------------===//
190205
// TableGen'd canonicalization patterns
191206
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Polynomial/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ module {
7474

7575
func.func @test_monic_monomial_mul() {
7676
%five = arith.constant 5 : index
77-
%0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
77+
%0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
7878
%1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
7979
return
8080
}
8181

8282
func.func @test_constant() {
83-
%0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
84-
%1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<ring=#ring1>
85-
%2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<ring=#ring2>
83+
%0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
84+
%1 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
85+
%2 = polynomial.constant #polynomial.float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
8686
return
8787
}
8888

0 commit comments

Comments
 (0)