Skip to content

Commit 84bf381

Browse files
committed
Merge branch 'feature/fused-ops' into jrickert.bump_integration
2 parents e9a12bb + 1134785 commit 84bf381

File tree

9 files changed

+195
-44
lines changed

9 files changed

+195
-44
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
# grouping. Source groupings form a DAG.
2424
# SOURCES: List of specific source files relative to ROOT_DIR to include.
2525
# SOURCES_GLOB: List of glob patterns relative to ROOT_DIR to include.
26+
27+
if (POLICY CMP0175)
28+
cmake_policy(SET CMP0175 OLD)
29+
endif()
30+
2631
function(declare_mlir_python_sources name)
2732
cmake_parse_arguments(ARG
2833
""

mlir/include/mlir/Dialect/PDL/IR/Builtins.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ enum class UnaryOpKind {
4343
LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
4444
PDLResultList &results,
4545
ArrayRef<PDLValue> args);
46-
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
47-
Attribute element);
46+
LogicalResult addElemToArrayAttr(PatternRewriter &rewriter,
47+
PDLResultList &results,
48+
ArrayRef<PDLValue> args);
4849
LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results,
4950
llvm::ArrayRef<PDLValue> args);
5051
LogicalResult div(PatternRewriter &rewriter, PDLResultList &results,

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
8484
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
8585
ConversionPatternRewriter &rewriter) const override {
8686

87-
if (!isa<FloatType>(adaptor.getRhs().getType())) {
87+
if (!emitc::isFloatOrOpaqueType(adaptor.getRhs().getType())) {
8888
return rewriter.notifyMatchFailure(op.getLoc(),
8989
"cmpf currently only supported on "
9090
"floats, not tensors/vectors thereof");

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,19 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
3838
return success();
3939
}
4040

41-
mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
42-
mlir::Attribute attr,
43-
mlir::Attribute element) {
44-
assert(isa<ArrayAttr>(attr));
45-
auto values = cast<ArrayAttr>(attr).getValue().vec();
46-
values.push_back(element);
47-
return rewriter.getArrayAttr(values);
41+
LogicalResult addElemToArrayAttr(PatternRewriter &rewriter,
42+
PDLResultList &results,
43+
ArrayRef<PDLValue> args) {
44+
45+
assert(args.size() == 2 &&
46+
"Expected two arguments, one ArrayAttr and one Attr");
47+
auto arrayAttr = cast<ArrayAttr>(args[0].cast<Attribute>());
48+
auto attrElement = args[1].cast<Attribute>();
49+
llvm::SmallVector<Attribute> values(arrayAttr.getValue());
50+
values.push_back(attrElement);
51+
52+
results.push_back(rewriter.getArrayAttr(values));
53+
return success();
4854
}
4955

5056
template <UnaryOpKind T>
@@ -344,11 +350,15 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
344350
// See Parser::defineBuiltins()
345351
pdlPattern.registerRewriteFunction(
346352
"__builtin_addEntryToDictionaryAttr_rewrite", addEntryToDictionaryAttr);
347-
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
348-
addElemToArrayAttr);
349353
pdlPattern.registerConstraintFunction(
350354
"__builtin_addEntryToDictionaryAttr_constraint",
351355
addEntryToDictionaryAttr);
356+
357+
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttrRewriter",
358+
addElemToArrayAttr);
359+
pdlPattern.registerConstraintFunction(
360+
"__builtin_addElemToArrayAttrConstraint", addElemToArrayAttr);
361+
352362
pdlPattern.registerRewriteFunction("__builtin_mulRewrite", mul);
353363
pdlPattern.registerRewriteFunction("__builtin_divRewrite", div);
354364
pdlPattern.registerRewriteFunction("__builtin_modRewrite", mod);
@@ -357,22 +367,14 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
357367
pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2);
358368
pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2);
359369
pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs);
360-
pdlPattern.registerConstraintFunction("__builtin_mulConstraint",
361-
mul);
362-
pdlPattern.registerConstraintFunction("__builtin_divConstraint",
363-
div);
364-
pdlPattern.registerConstraintFunction("__builtin_modConstraint",
365-
mod);
366-
pdlPattern.registerConstraintFunction("__builtin_addConstraint",
367-
add);
368-
pdlPattern.registerConstraintFunction("__builtin_subConstraint",
369-
sub);
370-
pdlPattern.registerConstraintFunction("__builtin_log2Constraint",
371-
log2);
372-
pdlPattern.registerConstraintFunction("__builtin_exp2Constraint",
373-
exp2);
374-
pdlPattern.registerConstraintFunction("__builtin_absConstraint",
375-
abs);
370+
pdlPattern.registerConstraintFunction("__builtin_mulConstraint", mul);
371+
pdlPattern.registerConstraintFunction("__builtin_divConstraint", div);
372+
pdlPattern.registerConstraintFunction("__builtin_modConstraint", mod);
373+
pdlPattern.registerConstraintFunction("__builtin_addConstraint", add);
374+
pdlPattern.registerConstraintFunction("__builtin_subConstraint", sub);
375+
pdlPattern.registerConstraintFunction("__builtin_log2Constraint", log2);
376+
pdlPattern.registerConstraintFunction("__builtin_exp2Constraint", exp2);
377+
pdlPattern.registerConstraintFunction("__builtin_absConstraint", abs);
376378
pdlPattern.registerConstraintFunction("__builtin_equals", equals);
377379
}
378380
} // namespace mlir::pdl

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,8 @@ class Parser {
625625
struct {
626626
ast::UserRewriteDecl *addEntryToDictionaryAttr_Rewrite;
627627
ast::UserConstraintDecl *addEntryToDictionaryAttr_Constraint;
628-
ast::UserRewriteDecl *addElemToArrayAttr;
628+
ast::UserRewriteDecl *addElemToArrayAttrRewrite;
629+
ast::UserConstraintDecl *addElemToArrayAttrConstraint;
629630
ast::UserRewriteDecl *mulRewrite;
630631
ast::UserRewriteDecl *divRewrite;
631632
ast::UserRewriteDecl *modRewrite;
@@ -691,9 +692,13 @@ void Parser::declareBuiltins() {
691692
"__builtin_addEntryToDictionaryAttr_constraint",
692693
{"attr", "attrName", "attrEntry"},
693694
/*returnsAttr=*/true);
694-
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
695-
"__builtin_addElemToArrayAttr", {"attr", "element"},
695+
builtins.addElemToArrayAttrRewrite = declareBuiltin<ast::UserRewriteDecl>(
696+
"__builtin_addElemToArrayAttrRewriter", {"attr", "element"},
696697
/*returnsAttr=*/true);
698+
builtins.addElemToArrayAttrConstraint =
699+
declareBuiltin<ast::UserConstraintDecl>(
700+
"__builtin_addElemToArrayAttrConstraint", {"attr", "element"},
701+
/*returnsAttr=*/true);
697702
builtins.mulRewrite = declareBuiltin<ast::UserRewriteDecl>(
698703
"__builtin_mulRewrite", {"lhs", "rhs"}, true);
699704
builtins.divRewrite = declareBuiltin<ast::UserRewriteDecl>(
@@ -2323,27 +2328,35 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
23232328

23242329
consumeToken(Token::l_square);
23252330

2331+
ast::Decl *builtinFunction = builtins.addElemToArrayAttrRewrite;
23262332
if (parserContext != ParserContext::Rewrite)
2327-
return emitError(
2328-
"Parsing of array attributes as constraint not supported!");
2333+
builtinFunction = builtins.addElemToArrayAttrConstraint;
23292334

2330-
FailureOr<ast::Expr *> arrayAttr = ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]");
2335+
FailureOr<ast::Expr *> arrayAttr =
2336+
ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]");
23312337
if (failed(arrayAttr))
23322338
return failure();
23332339

2340+
// No values inside the array
2341+
if (consumeIf(Token::r_square)) {
2342+
return arrayAttr;
2343+
}
2344+
23342345
do {
23352346
FailureOr<ast::Expr *> attr = parseExpr();
23362347
if (failed(attr))
23372348
return failure();
23382349

23392350
SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttr, *attr};
2340-
auto elemToArrayCall = createBuiltinCall(
2341-
curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs);
2351+
2352+
auto elemToArrayCall =
2353+
createBuiltinCall(curToken.getLoc(), builtinFunction, arrayAttrArgs);
23422354
if (failed(elemToArrayCall))
23432355
return failure();
23442356

23452357
// Uses the new array for the next element.
23462358
arrayAttr = elemToArrayCall;
2359+
23472360
} while (consumeIf(Token::comma));
23482361

23492362
if (failed(
@@ -2415,7 +2428,8 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
24152428
consumeToken(Token::l_brace);
24162429
SMRange loc = curToken.getLoc();
24172430

2418-
FailureOr<ast::Expr *> dictAttrCall = ast::AttributeExpr::create(ctx, loc, "{}");
2431+
FailureOr<ast::Expr *> dictAttrCall =
2432+
ast::AttributeExpr::create(ctx, loc, "{}");
24192433
if (failed(dictAttrCall))
24202434
return failure();
24212435

mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ Pattern RewriteMultipleEntriesDictionary {
218218
// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr"
219219
// CHECK: %[[VAL_5:.*]] = attribute = "test1"
220220
// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]
221-
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]]
221+
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_2]], %[[VAL_6]]
222222
// CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]}
223223
// CHECK: replace %[[VAL_1]] with %[[VAL_8]]
224224
Pattern RewriteOneDictionaryArrayAttr {
@@ -229,6 +229,43 @@ Pattern RewriteOneDictionaryArrayAttr {
229229
};
230230
}
231231

232+
// -----
233+
234+
// CHECK-LABEL: pdl.pattern @ConstraintWithArrayAttr
235+
// CHECK: %[[VAL_0:.*]] = attribute = "test1"
236+
// CHECK: %[[VAL_1:.*]] = attribute = "test2"
237+
// CHECK: %[[VAL_2:.*]] = attribute = []
238+
// CHECK: %[[VAL_3:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_2]], %[[VAL_0]]
239+
// CHECK: %[[VAL_4:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_3]], %[[VAL_1]]
240+
// CHECK: %[[VAL_5:.*]] = operation "test.op"
241+
// CHECK: rewrite %[[VAL_5]] {
242+
// CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_array" = %[[VAL_4]]}
243+
// CHECK: replace %[[VAL_5]] with %[[VAL_6]]
244+
245+
Pattern ConstraintWithArrayAttr {
246+
let attr1 = attr<"\"test1\"">;
247+
let attr2 = attr<"\"test2\"">;
248+
let array = [attr1, attr2];
249+
let root = op<test.op> -> ();
250+
rewrite root with {
251+
let newRoot = op<test.success>() { some_array = array} -> ();
252+
replace root with newRoot;
253+
};
254+
}
255+
256+
// -----
257+
258+
// CHECK-LABEL: pdl.pattern @ConstraintNotMatchingArrayAttrInAttrType
259+
// CHECK-NOT: apply_native_constraint "__builtin_addElemToArrayAttrConstraint"
260+
261+
262+
Constraint I64Value(value: Value);
263+
Pattern ConstraintNotMatchingArrayAttrInAttrType {
264+
let root = op<my_dialect.foo>(arg: Value, arg2: Value, arg3: [Value, I64Value], arg);
265+
replace root with arg;
266+
}
267+
268+
232269
// -----
233270

234271
// CHECK-LABEL: pdl.pattern @RewriteMultiplyElementsArrayAttr
@@ -240,8 +277,8 @@ Pattern RewriteOneDictionaryArrayAttr {
240277
// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr"
241278
// CHECK: %[[VAL_6:.*]] = attribute = "test1"
242279
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
243-
// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]]
244-
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]]
280+
// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_3]], %[[VAL_7]]
281+
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_8]], %[[VAL_2]]
245282
// CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]}
246283
// CHECK: replace %[[VAL_1]] with %[[VAL_10]]
247284
Pattern RewriteMultiplyElementsArrayAttr {

mlir/test/mlir-pdll/Parser/expr-failure.pdll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,22 @@ Pattern {
134134

135135
// -----
136136

137+
Pattern ConstraintArrayAttrWithAttrAndValue {
138+
let root = op<test.op>(arg: Value) -> ();
139+
let attr1 = attr<"\"test1\"">;
140+
let array = [attr1, arg];
141+
// CHECK: unable to convert expression of type `Value` to the expected type of `Attr`
142+
let root = op<test.op> -> ();
143+
rewrite root with {
144+
let newRoot = op<test.success>() { some_array = array} -> ();
145+
replace root with newRoot;
146+
};
147+
}
148+
149+
// -----
150+
151+
152+
137153
//===----------------------------------------------------------------------===//
138154
// Range Expr
139155
//===----------------------------------------------------------------------===//

mlir/test/mlir-pdll/Parser/expr.pdll

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Pattern {
3434

3535
// CHECK-LABEL: Module
3636
// CHECK: |-NamedAttributeDecl {{.*}} Name<some_array>
37-
// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttr> ResultType<Attr>
37+
// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttrRewriter> ResultType<Attr>
3838
// CHECK: `Arguments`
3939
// CHECK: CallExpr {{.*}} Type<Attr>
4040
// CHECK: AttributeExpr {{.*}} Value<"[]">
@@ -87,6 +87,77 @@ Constraint getPopulatedDict() -> Attr {
8787
return dictionary;
8888
}
8989

90+
91+
92+
// -----
93+
94+
// CHECK-LABEL: Module
95+
// CHECK:LetStmt {{.*}}
96+
//CHECK-NEXT:`-VariableDecl {{.*}} Name<array> Type<Attr>
97+
//CHECK-NEXT: `-AttributeExpr {{.*}} Value<"[]">
98+
//CHECK-NEXT:ReturnStmt {{.*}}
99+
100+
Constraint getEmtpyArray() -> Attr {
101+
let array = [];
102+
return array;
103+
}
104+
105+
// -----
106+
107+
// CHECK-LABEL: Module
108+
// CHECK:LetStmt {{.*}}
109+
//CHECK-NEXT:`-VariableDecl {{.*}} Name<array> Type<Attr>
110+
//CHECK-NEXT: `-CallExpr {{.*}} Type<Attr>
111+
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
112+
//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType<Attr>
113+
// CHECK: `Arguments`
114+
//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]">
115+
//CHECK-NEXT: `-AttributeExpr {{.*}} Value<""attr1"">
116+
//CHECK-NEXT:ReturnStmt {{.*}}
117+
118+
Constraint getPopulateArray() -> Attr {
119+
let array = ["attr1"];
120+
return array;
121+
}
122+
123+
124+
// -----
125+
126+
127+
// CHECK-LABEL: Module
128+
// CHECK:LetStmt {{.*}}
129+
//CHECK-NEXT:`-VariableDecl {{.*}} Name<array> Type<Attr>
130+
//CHECK-NEXT: `-CallExpr {{.*}} Type<Attr>
131+
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
132+
//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType<Attr>
133+
// CHECK-DAG: `Arguments`
134+
//CHECK-NEXT: |-CallExpr {{.*}} Type<Attr>
135+
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
136+
//CHECK-NEXT: | `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType<Attr>
137+
// CHECK-DAG: `Arguments`
138+
//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]">
139+
//CHECK-NEXT: `-CallExpr {{.*}} Type<Attr>
140+
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
141+
//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<getA> ResultType<Attr>
142+
// CHECK: `-CallExpr {{.*}} Type<Attr>
143+
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
144+
//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<getB> ResultType<Attr>
145+
// CHECK-DAG: -ReturnStmt {{.*}}
146+
147+
Constraint getA() -> Attr {
148+
return "A";
149+
}
150+
151+
Constraint getB() -> Attr {
152+
return "B";
153+
}
154+
155+
Constraint getPopulateArrayFromOtherConstraints() -> Attr {
156+
let array = [getA(), getB()];
157+
return array;
158+
}
159+
160+
90161
// -----
91162

92163
//===----------------------------------------------------------------------===//

mlir/unittests/Dialect/PDL/BuiltinTest.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,17 @@ TEST_F(BuiltinTest, addEntryToDictionaryAttr) {
6666
}
6767

6868
TEST_F(BuiltinTest, addElemToArrayAttr) {
69+
TestPDLResultList results(1);
70+
6971
auto dict = rewriter.getDictionaryAttr(
7072
rewriter.getNamedAttr("key", rewriter.getStringAttr("value")));
7173
rewriter.getArrayAttr({});
7274

7375
auto arrAttr = rewriter.getArrayAttr({});
76+
EXPECT_TRUE(succeeded(
77+
builtin::addElemToArrayAttr(rewriter, results, {arrAttr, dict})));
7478
mlir::Attribute updatedArrAttr =
75-
builtin::addElemToArrayAttr(rewriter, arrAttr, dict);
79+
results.getResults().front().cast<Attribute>();
7680

7781
auto dictInsideArrAttr =
7882
cast<DictionaryAttr>(*cast<ArrayAttr>(updatedArrAttr).begin());
@@ -617,7 +621,7 @@ TEST_F(BuiltinTest, log2) {
617621
cast<FloatAttr>(result.cast<Attribute>()).getValue().convertToFloat(),
618622
2.0);
619623
}
620-
624+
621625
auto threeF16 = rewriter.getF16FloatAttr(3.0);
622626

623627
// check correctness
@@ -626,7 +630,8 @@ TEST_F(BuiltinTest, log2) {
626630
EXPECT_TRUE(builtin::log2(rewriter, results, {threeF16}).succeeded());
627631

628632
PDLValue result = results.getResults()[0];
629-
float resultVal = cast<FloatAttr>(result.cast<Attribute>()).getValue().convertToFloat();
633+
float resultVal =
634+
cast<FloatAttr>(result.cast<Attribute>()).getValue().convertToFloat();
630635
EXPECT_TRUE(resultVal > 1.58 && resultVal < 1.59);
631636
}
632637
}

0 commit comments

Comments
 (0)