Skip to content

Commit 74470e3

Browse files
committed
use int/float keywords
1 parent ef17f2a commit 74470e3

File tree

2 files changed

+44
-29
lines changed

2 files changed

+44
-29
lines changed

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -189,45 +189,46 @@ LogicalResult INTTOp::verify() {
189189
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
190190
// Using the built-in parser.parseAttribute requires the full
191191
// #polynomial.typed_int_polynomial syntax, which is excessive.
192-
// Instead we manually parse the components.
192+
// Instead we parse a keyword int to signal it's an integer polynomial
193193
Type type;
194-
parser.parseOptionalAttribute();
195-
196-
IntPolynomialAttr intPolyAttr;
197-
parser.parseOptionalAttribute(intPolyAttr);
198-
if (intPolyAttr) {
199-
if (parser.parseColon() || parser.parseType(type))
200-
return failure();
201-
202-
result.addAttribute("value",
203-
TypedIntPolynomialAttr::get(type, intPolyAttr));
204-
result.addTypes(type);
205-
return success();
194+
if (succeeded(parser.parseOptionalKeyword("float"))) {
195+
Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr);
196+
if (floatPolyAttr) {
197+
if (parser.parseColon() || parser.parseType(type))
198+
return failure();
199+
result.addAttribute("value",
200+
TypedFloatPolynomialAttr::get(type, floatPolyAttr));
201+
result.addTypes(type);
202+
return success();
203+
}
206204
}
207205

208-
Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr, /*optional=*/true);
209-
if (floatPolyAttr) {
210-
if (parser.parseColon() || parser.parseType(type))
211-
return failure();
212-
result.addAttribute("value",
213-
TypedFloatPolynomialAttr::get(type, intPolyAttr));
214-
result.addTypes(type);
215-
return success();
206+
if (succeeded(parser.parseOptionalKeyword("int"))) {
207+
Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr);
208+
if (intPolyAttr) {
209+
if (parser.parseColon() || parser.parseType(type))
210+
return failure();
211+
212+
result.addAttribute("value",
213+
TypedIntPolynomialAttr::get(type, intPolyAttr));
214+
result.addTypes(type);
215+
return success();
216+
}
216217
}
217218

218219
// In the worst case, still accept the verbose versions.
219220
TypedIntPolynomialAttr typedIntPolyAttr;
220-
ParseResult res = parser.parseAttribute<TypedIntPolynomialAttr>(
221+
OptionalParseResult res = parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
221222
typedIntPolyAttr, "value", result.attributes);
222-
if (succeeded(res)) {
223+
if (res.has_value() && succeeded(res.value())) {
223224
result.addTypes(typedIntPolyAttr.getType());
224225
return success();
225226
}
226227

227228
TypedFloatPolynomialAttr typedFloatPolyAttr;
228229
res = parser.parseAttribute<TypedFloatPolynomialAttr>(
229230
typedFloatPolyAttr, "value", result.attributes);
230-
if (succeeded(res)) {
231+
if (res.has_value() && succeeded(res.value())) {
231232
result.addTypes(typedFloatPolyAttr.getType());
232233
return success();
233234
}
@@ -237,7 +238,17 @@ ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
237238

238239
void ConstantOp::print(OpAsmPrinter &p) {
239240
p << " ";
240-
p.printAttribute(getValue());
241+
if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) {
242+
p << "int";
243+
intPoly.getValue().print(p);
244+
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
245+
p << "float";
246+
floatPoly.getValue().print(p);
247+
} else {
248+
assert(false && "unexpected attribute type");
249+
}
250+
p << " : ";
251+
p.printType(getOutput().getType());
241252
}
242253

243254
LogicalResult ConstantOp::inferReturnTypes(

mlir/test/Dialect/Polynomial/ops.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,19 @@ module {
7474

7575
func.func @test_monic_monomial_mul() {
7676
%five = arith.constant 5 : index
77-
%0 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
77+
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
7878
%1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
7979
return
8080
}
8181

8282
func.func @test_constant() {
83-
%0 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
84-
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
85-
%2 = polynomial.constant <1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
83+
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
84+
%1 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
85+
%2 = polynomial.constant float<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
86+
87+
// Test verbose fallbacks
88+
%verb0 = polynomial.constant #polynomial.typed_int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
89+
%verb2 = polynomial.constant #polynomial.typed_float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
8690
return
8791
}
8892

0 commit comments

Comments
 (0)