Skip to content

Commit ef17f2a

Browse files
committed
show broken attempt
1 parent f6276fe commit ef17f2a

File tree

5 files changed

+111
-30
lines changed

5 files changed

+111
-30
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
295295
}];
296296
let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value);
297297
let results = (outs Polynomial_PolynomialType:$output);
298-
let assemblyFormat = "attr-dict $value";
298+
let hasCustomAssemblyFormat = 1;
299299
}
300300

301301
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {

mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
3838
}];
3939
let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial);
4040
let hasCustomAssemblyFormat = 1;
41+
let extraClassDeclaration = [{
42+
/// A parser which, upon failure to parse, does not emit errors and just returns
43+
/// a null attribute.
44+
static Attribute parse(AsmParser &parser, Type type, bool optional);
45+
}];
4146
}
4247

4348
def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
@@ -60,6 +65,11 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
6065
}];
6166
let parameters = (ins "FloatPolynomial":$polynomial);
6267
let hasCustomAssemblyFormat = 1;
68+
let extraClassDeclaration = [{
69+
/// A parser which, upon failure to parse, does not emit errors and just returns
70+
/// a null attribute.
71+
static Attribute parse(AsmParser &parser, Type type, bool optional);
72+
}];
6373
}
6474

6575
def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
@@ -81,7 +91,6 @@ def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
8191
}]>
8292
];
8393
let extraClassDeclaration = [{
84-
// used for constFoldBinaryOp
8594
using ValueType = ::mlir::Attribute;
8695
}];
8796
}
@@ -105,7 +114,6 @@ def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
105114
}]>
106115
];
107116
let extraClassDeclaration = [{
108-
// used for constFoldBinaryOp
109117
using ValueType = ::mlir::Attribute;
110118
}];
111119
}

mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@ using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>;
3838
/// a '+'.
3939
///
4040
template <typename Monomial>
41-
ParseResult
42-
parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
43-
bool &isConstantTerm, bool &shouldParseMore,
44-
ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
41+
ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
42+
llvm::StringRef &variable, bool &isConstantTerm,
43+
bool &shouldParseMore,
44+
ParseCoefficientFn<Monomial> parseAndStoreCoefficient,
45+
bool optional) {
4546
OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial);
4647

4748
isConstantTerm = false;
@@ -85,8 +86,9 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
8586
// If there's a **, then the integer exponent is required.
8687
APInt parsedExponent(apintBitWidth, 0);
8788
if (failed(parser.parseInteger(parsedExponent))) {
88-
parser.emitError(parser.getCurrentLocation(),
89-
"found invalid integer exponent");
89+
if (!optional)
90+
parser.emitError(parser.getCurrentLocation(),
91+
"found invalid integer exponent");
9092
return failure();
9193
}
9294

@@ -101,20 +103,22 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
101103
return success();
102104
}
103105

104-
template <typename PolynoimalAttrTy, typename Monomial>
106+
template <typename Monomial>
105107
LogicalResult
106108
parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
107109
llvm::StringSet<> &variables,
108-
ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
110+
ParseCoefficientFn<Monomial> parseAndStoreCoefficient,
111+
bool optional) {
109112
while (true) {
110113
Monomial parsedMonomial;
111114
llvm::StringRef parsedVariableRef;
112115
bool isConstantTerm;
113116
bool shouldParseMore;
114117
if (failed(parseMonomial<Monomial>(
115118
parser, parsedMonomial, parsedVariableRef, isConstantTerm,
116-
shouldParseMore, parseAndStoreCoefficient))) {
117-
parser.emitError(parser.getCurrentLocation(), "expected a monomial");
119+
shouldParseMore, parseAndStoreCoefficient, optional))) {
120+
if (!optional)
121+
parser.emitError(parser.getCurrentLocation(), "expected a monomial");
118122
return failure();
119123
}
120124

@@ -130,53 +134,67 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
130134
if (succeeded(parser.parseOptionalGreater())) {
131135
break;
132136
}
133-
parser.emitError(
134-
parser.getCurrentLocation(),
135-
"expected + and more monomials, or > to end polynomial attribute");
137+
if (!optional)
138+
parser.emitError(
139+
parser.getCurrentLocation(),
140+
"expected + and more monomials, or > to end polynomial attribute");
136141
return failure();
137142
}
138143

139144
if (variables.size() > 1) {
140145
std::string vars = llvm::join(variables.keys(), ", ");
141-
parser.emitError(
142-
parser.getCurrentLocation(),
143-
"polynomials must have one indeterminate, but there were multiple: " +
144-
vars);
146+
if (!optional)
147+
parser.emitError(
148+
parser.getCurrentLocation(),
149+
"polynomials must have one indeterminate, but there were multiple: " +
150+
vars);
145151
return failure();
146152
}
147153

148154
return success();
149155
}
150156

151157
Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
158+
return IntPolynomialAttr::parse(parser, type, /*optional=*/false);
159+
}
160+
161+
Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type,
162+
bool optional) {
152163
if (failed(parser.parseLess()))
153164
return {};
154165

155166
llvm::SmallVector<IntMonomial> monomials;
156167
llvm::StringSet<> variables;
157168

158-
if (failed(parsePolynomialAttr<IntPolynomialAttr, IntMonomial>(
169+
if (failed(parsePolynomialAttr<IntMonomial>(
159170
parser, monomials, variables,
160171
[&](IntMonomial &monomial) -> OptionalParseResult {
161172
APInt parsedCoeff(apintBitWidth, 1);
162173
OptionalParseResult result =
163174
parser.parseOptionalInteger(parsedCoeff);
164175
monomial.setCoefficient(parsedCoeff);
165176
return result;
166-
}))) {
177+
},
178+
optional))) {
167179
return {};
168180
}
169181

170182
auto result = IntPolynomial::fromMonomials(monomials);
171183
if (failed(result)) {
172-
parser.emitError(parser.getCurrentLocation())
173-
<< "parsed polynomial must have unique exponents among monomials";
184+
if (!optional)
185+
parser.emitError(parser.getCurrentLocation())
186+
<< "parsed polynomial must have unique exponents among monomials";
174187
return {};
175188
}
176189
return IntPolynomialAttr::get(parser.getContext(), result.value());
177190
}
178191

179192
Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
193+
return FloatPolynomialAttr::parse(parser, type, /*optional=*/false);
194+
}
195+
196+
Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type,
197+
bool optional) {
180198
if (failed(parser.parseLess()))
181199
return {};
182200

@@ -191,8 +209,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
191209
return OptionalParseResult(result);
192210
};
193211

194-
if (failed(parsePolynomialAttr<FloatPolynomialAttr, FloatMonomial>(
195-
parser, monomials, variables, parseAndStoreCoefficient))) {
212+
if (failed(parsePolynomialAttr<FloatMonomial>(
213+
parser, monomials, variables, parseAndStoreCoefficient, optional))) {
196214
return {};
197215
}
198216

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,60 @@ LogicalResult INTTOp::verify() {
186186
return verifyNTTOp(this->getOperation(), ring, tensorType);
187187
}
188188

189+
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
190+
// Using the built-in parser.parseAttribute requires the full
191+
// #polynomial.typed_int_polynomial syntax, which is excessive.
192+
// Instead we manually parse the components.
193+
Type type;
194+
parser.parseOptionalAttribute();
195+
196+
IntPolynomialAttr intPolyAttr;
197+
parser.parseOptionalAttribute(intPolyAttr);
198+
if (intPolyAttr) {
199+
if (parser.parseColon() || parser.parseType(type))
200+
return failure();
201+
202+
result.addAttribute("value",
203+
TypedIntPolynomialAttr::get(type, intPolyAttr));
204+
result.addTypes(type);
205+
return success();
206+
}
207+
208+
Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr, /*optional=*/true);
209+
if (floatPolyAttr) {
210+
if (parser.parseColon() || parser.parseType(type))
211+
return failure();
212+
result.addAttribute("value",
213+
TypedFloatPolynomialAttr::get(type, intPolyAttr));
214+
result.addTypes(type);
215+
return success();
216+
}
217+
218+
// In the worst case, still accept the verbose versions.
219+
TypedIntPolynomialAttr typedIntPolyAttr;
220+
ParseResult res = parser.parseAttribute<TypedIntPolynomialAttr>(
221+
typedIntPolyAttr, "value", result.attributes);
222+
if (succeeded(res)) {
223+
result.addTypes(typedIntPolyAttr.getType());
224+
return success();
225+
}
226+
227+
TypedFloatPolynomialAttr typedFloatPolyAttr;
228+
res = parser.parseAttribute<TypedFloatPolynomialAttr>(
229+
typedFloatPolyAttr, "value", result.attributes);
230+
if (succeeded(res)) {
231+
result.addTypes(typedFloatPolyAttr.getType());
232+
return success();
233+
}
234+
235+
return failure();
236+
}
237+
238+
void ConstantOp::print(OpAsmPrinter &p) {
239+
p << " ";
240+
p.printAttribute(getValue());
241+
}
242+
189243
LogicalResult ConstantOp::inferReturnTypes(
190244
MLIRContext *context, std::optional<mlir::Location> location,
191245
ConstantOp::Adaptor adaptor,
@@ -196,6 +250,7 @@ LogicalResult ConstantOp::inferReturnTypes(
196250
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
197251
inferredReturnTypes.push_back(floatPoly.getType());
198252
} else {
253+
assert(false && "unexpected attribute type");
199254
return failure();
200255
}
201256
return success();

mlir/test/Dialect/Polynomial/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ module {
7474

7575
func.func @test_monic_monomial_mul() {
7676
%five = arith.constant 5 : index
77-
%0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
77+
%0 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
7878
%1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
7979
return
8080
}
8181

8282
func.func @test_constant() {
83-
%0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
84-
%1 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
85-
%2 = polynomial.constant #polynomial.float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
83+
%0 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
84+
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
85+
%2 = polynomial.constant <1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
8686
return
8787
}
8888

0 commit comments

Comments
 (0)