@@ -982,9 +982,25 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
982
982
TestPassthroughInvalidOp (MLIRContext *ctx)
983
983
: ConversionPattern(" test.invalid" , 1 , ctx) {}
984
984
LogicalResult
985
- matchAndRewrite (Operation *op, ArrayRef<Value > operands,
985
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange > operands,
986
986
ConversionPatternRewriter &rewriter) const final {
987
- rewriter.replaceOpWithNewOp <TestValidOp>(op, std::nullopt, operands,
987
+ SmallVector<Value> flattened;
988
+ for (auto it : llvm::enumerate (operands)) {
989
+ ValueRange range = it.value ();
990
+ if (range.size () == 1 ) {
991
+ flattened.push_back (range.front ());
992
+ continue ;
993
+ }
994
+
995
+ // This is a 1:N replacement. Insert a test.cast op. (That's what the
996
+ // argument materialization used to do.)
997
+ flattened.push_back (
998
+ rewriter
999
+ .create <TestCastOp>(op->getLoc (),
1000
+ op->getOperand (it.index ()).getType (), range)
1001
+ .getResult ());
1002
+ }
1003
+ rewriter.replaceOpWithNewOp <TestValidOp>(op, std::nullopt, flattened,
988
1004
std::nullopt);
989
1005
return success ();
990
1006
}
@@ -1010,23 +1026,13 @@ struct TestSplitReturnType : public ConversionPattern {
1010
1026
TestSplitReturnType (MLIRContext *ctx)
1011
1027
: ConversionPattern(" test.return" , 1 , ctx) {}
1012
1028
LogicalResult
1013
- matchAndRewrite (Operation *op, ArrayRef<Value > operands,
1029
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange > operands,
1014
1030
ConversionPatternRewriter &rewriter) const final {
1015
1031
// Check for a return of F32.
1016
1032
if (op->getNumOperands () != 1 || !op->getOperand (0 ).getType ().isF32 ())
1017
1033
return failure ();
1018
-
1019
- // Check if the first operation is a cast operation, if it is we use the
1020
- // results directly.
1021
- auto *defOp = operands[0 ].getDefiningOp ();
1022
- if (auto packerOp =
1023
- llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
1024
- rewriter.replaceOpWithNewOp <TestReturnOp>(op, packerOp.getOperands ());
1025
- return success ();
1026
- }
1027
-
1028
- // Otherwise, fail to match.
1029
- return failure ();
1034
+ rewriter.replaceOpWithNewOp <TestReturnOp>(op, operands[0 ]);
1035
+ return success ();
1030
1036
}
1031
1037
};
1032
1038
@@ -1181,6 +1187,47 @@ class TestEraseOp : public ConversionPattern {
1181
1187
}
1182
1188
};
1183
1189
1190
+ // / This pattern matches a test.duplicate_block_args op and duplicates all
1191
+ // / block arguments.
1192
+ class TestDuplicateBlockArgs
1193
+ : public OpConversionPattern<DuplicateBlockArgsOp> {
1194
+ using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
1195
+
1196
+ LogicalResult
1197
+ matchAndRewrite (DuplicateBlockArgsOp op, OpAdaptor adaptor,
1198
+ ConversionPatternRewriter &rewriter) const override {
1199
+ if (op.getIsLegal ())
1200
+ return failure ();
1201
+ rewriter.startOpModification (op);
1202
+ Block *body = &op.getBody ().front ();
1203
+ TypeConverter::SignatureConversion result (body->getNumArguments ());
1204
+ for (auto it : llvm::enumerate (body->getArgumentTypes ()))
1205
+ result.addInputs (it.index (), {it.value (), it.value ()});
1206
+ rewriter.applySignatureConversion (body, result, getTypeConverter ());
1207
+ op.setIsLegal (true );
1208
+ rewriter.finalizeOpModification (op);
1209
+ return success ();
1210
+ }
1211
+ };
1212
+
1213
+ // / This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1214
+ // / op. The pattern supports 1:N replacements and forwards the replacement
1215
+ // / values of the single operand as test.valid operands.
1216
+ class TestRepetitive1ToNConsumer : public ConversionPattern {
1217
+ public:
1218
+ TestRepetitive1ToNConsumer (MLIRContext *ctx)
1219
+ : ConversionPattern(" test.repetitive_1_to_n_consumer" , 1 , ctx) {}
1220
+ LogicalResult
1221
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
1222
+ ConversionPatternRewriter &rewriter) const final {
1223
+ // A single operand is expected.
1224
+ if (op->getNumOperands () != 1 )
1225
+ return failure ();
1226
+ rewriter.replaceOpWithNewOp <TestValidOp>(op, operands.front ());
1227
+ return success ();
1228
+ }
1229
+ };
1230
+
1184
1231
} // namespace
1185
1232
1186
1233
namespace {
@@ -1258,9 +1305,11 @@ struct TestLegalizePatternDriver
1258
1305
TestUpdateConsumerType, TestNonRootReplacement,
1259
1306
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1260
1307
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1261
- TestUndoPropertiesModification, TestEraseOp>(&getContext ());
1308
+ TestUndoPropertiesModification, TestEraseOp,
1309
+ TestRepetitive1ToNConsumer>(&getContext ());
1262
1310
patterns.add <TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1263
1311
&getContext (), converter);
1312
+ patterns.add <TestDuplicateBlockArgs>(converter, &getContext ());
1264
1313
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern (patterns,
1265
1314
converter);
1266
1315
mlir::populateCallOpTypeConversionPattern (patterns, converter);
@@ -1312,6 +1361,9 @@ struct TestLegalizePatternDriver
1312
1361
target.addDynamicallyLegalOp <TestOpInPlaceSelfFold>(
1313
1362
[](TestOpInPlaceSelfFold op) { return op.getFolded (); });
1314
1363
1364
+ target.addDynamicallyLegalOp <DuplicateBlockArgsOp>(
1365
+ [](DuplicateBlockArgsOp op) { return op.getIsLegal (); });
1366
+
1315
1367
// Handle a partial conversion.
1316
1368
if (mode == ConversionMode::Partial) {
1317
1369
DenseSet<Operation *> unlegalizedOps;
0 commit comments