Skip to content

Commit 63e90d8

Browse files
address comments
1 parent 0c7f2c5 commit 63e90d8

File tree

5 files changed

+106
-26
lines changed

5 files changed

+106
-26
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,9 @@ class ConversionPattern : public RewritePattern {
543543
}
544544

545545
/// Hook for derived classes to implement combined matching and rewriting.
546+
/// This overload supports only 1:1 replacements. The 1:N overload is called
547+
/// by the driver. By default, it calls this 1:1 overload or reports a fatal
548+
/// error if 1:N replacements were found.
546549
virtual LogicalResult
547550
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
548551
ConversionPatternRewriter &rewriter) const {
@@ -551,6 +554,9 @@ class ConversionPattern : public RewritePattern {
551554
rewrite(op, operands, rewriter);
552555
return success();
553556
}
557+
558+
/// Hook for derived classes to implement combined matching and rewriting.
559+
/// This overload supports 1:N replacements.
554560
virtual LogicalResult
555561
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
556562
ConversionPatternRewriter &rewriter) const {

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,11 +1152,18 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11521152
Type origType = operand.getType();
11531153
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
11541154

1155+
// Find the most recently mapped value. Unpack all temporary N:1
1156+
// materializations. Such conversions are a workaround around missing
1157+
// 1:N support in the ConversionValueMapping. (The conversion patterns
1158+
// already support 1:N replacements.)
1159+
Value repl = mapping.lookupOrDefault(operand);
1160+
SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
1161+
11551162
if (!currentTypeConverter) {
11561163
// The current pattern does not have a type converter. I.e., it does not
11571164
// distinguish between legal and illegal types. For each operand, simply
11581165
// pass through the most recently mapped value.
1159-
remapped.push_back({mapping.lookupOrDefault(operand)});
1166+
remapped.push_back(std::move(unpacked));
11601167
continue;
11611168
}
11621169

@@ -1178,12 +1185,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11781185

11791186
if (legalTypes.size() != 1) {
11801187
// TODO: This is a 1:N conversion. The conversion value mapping does not
1181-
// support such conversions yet. It stores the result of an argument
1182-
// materialization (i.e., a conversion back into a single SSA value)
1183-
// instead. Unpack such "workaround" materializations and hand the
1184-
// original replacement values to the adaptor.
1185-
Value repl = mapping.lookupOrDefault(operand);
1186-
SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
1188+
// store such materializations yet. If the types of the most recently
1189+
// mapped values do not match, build a target materialization.
11871190
if (TypeRange(unpacked) == legalTypes) {
11881191
remapped.push_back(std::move(unpacked));
11891192
continue;
@@ -1193,7 +1196,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11931196
// different legalized types.
11941197
ValueRange targetMat = buildUnresolvedMaterialization(
11951198
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1196-
/*inputs=*/repl, /*outputType=*/legalTypes,
1199+
/*inputs=*/unpacked, /*outputType=*/legalTypes,
11971200
/*originalType=*/origType, currentTypeConverter);
11981201
remapped.push_back(targetMat);
11991202
continue;
@@ -1211,7 +1214,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12111214
Value castValue = buildUnresolvedMaterialization(
12121215
MaterializationKind::Target, computeInsertPoint(newOperand),
12131216
operandLoc,
1214-
/*inputs=*/newOperand, /*outputType=*/desiredType,
1217+
/*inputs=*/unpacked, /*outputType=*/desiredType,
12151218
/*originalType=*/origType, currentTypeConverter);
12161219
mapping.map(newOperand, castValue);
12171220
newOperand = castValue;
@@ -1447,7 +1450,10 @@ ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
14471450

14481451
SmallVector<Value> result;
14491452
for (Value v : castOp.getOperands()) {
1450-
// Keep unpacking if possible.
1453+
// Keep unpacking if possible. This is needed because during block
1454+
// signature conversions and 1:N op replacements, the driver may have
1455+
// inserted two materializations back-to-back: first an argument
1456+
// materialization, then a target materialization.
14511457
llvm::append_range(result, unpackNTo1Materialization(v));
14521458
}
14531459
return result;

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,3 +463,14 @@ func.func @circular_mapping() {
463463
%0 = "test.erase_op"() : () -> (i64)
464464
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
465465
}
466+
467+
// -----
468+
469+
func.func @test_1_to_n_block_signature_conversion() {
470+
"test.duplicate_block_args"() ({
471+
^bb0(%arg0: i64):
472+
"test.repetitive_1_to_n_consumer"(%arg0) : (i64) -> ()
473+
}) {} : () -> ()
474+
"test.return"() : () -> ()
475+
}
476+

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,6 +1886,11 @@ def LegalOpC : TEST_Op<"legal_op_c">,
18861886
Arguments<(ins I32)>, Results<(outs I32)>;
18871887
def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>;
18881888

1889+
def DuplicateBlockArgsOp : TEST_Op<"duplicate_block_args", [SingleBlock]> {
1890+
let arguments = (ins UnitAttr:$is_legal);
1891+
let regions = (region SizedRegion<1>:$body);
1892+
}
1893+
18891894
// Check that the conversion infrastructure can properly undo the creation of
18901895
// operations where an operation was created before its parent, in this case,
18911896
// in the parent's builder.

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -982,9 +982,25 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
982982
TestPassthroughInvalidOp(MLIRContext *ctx)
983983
: ConversionPattern("test.invalid", 1, ctx) {}
984984
LogicalResult
985-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
985+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
986986
ConversionPatternRewriter &rewriter) const final {
987-
rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands,
987+
SmallVector<Value> flattened;
988+
for (auto it : llvm::enumerate(operands)) {
989+
ValueRange range = it.value();
990+
if (range.size() == 1) {
991+
flattened.push_back(range.front());
992+
continue;
993+
}
994+
995+
// This is a 1:N replacement. Insert a test.cast op. (That's what the
996+
// argument materialization used to do.)
997+
flattened.push_back(
998+
rewriter
999+
.create<TestCastOp>(op->getLoc(),
1000+
op->getOperand(it.index()).getType(), range)
1001+
.getResult());
1002+
}
1003+
rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
9881004
std::nullopt);
9891005
return success();
9901006
}
@@ -1010,23 +1026,13 @@ struct TestSplitReturnType : public ConversionPattern {
10101026
TestSplitReturnType(MLIRContext *ctx)
10111027
: ConversionPattern("test.return", 1, ctx) {}
10121028
LogicalResult
1013-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1029+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
10141030
ConversionPatternRewriter &rewriter) const final {
10151031
// Check for a return of F32.
10161032
if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
10171033
return failure();
1018-
1019-
// Check if the first operation is a cast operation, if it is we use the
1020-
// results directly.
1021-
auto *defOp = operands[0].getDefiningOp();
1022-
if (auto packerOp =
1023-
llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
1024-
rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
1025-
return success();
1026-
}
1027-
1028-
// Otherwise, fail to match.
1029-
return failure();
1034+
rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
1035+
return success();
10301036
}
10311037
};
10321038

@@ -1181,6 +1187,47 @@ class TestEraseOp : public ConversionPattern {
11811187
}
11821188
};
11831189

1190+
/// This pattern matches a test.duplicate_block_args op and duplicates all
1191+
/// block arguments.
1192+
class TestDuplicateBlockArgs
1193+
: public OpConversionPattern<DuplicateBlockArgsOp> {
1194+
using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
1195+
1196+
LogicalResult
1197+
matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
1198+
ConversionPatternRewriter &rewriter) const override {
1199+
if (op.getIsLegal())
1200+
return failure();
1201+
rewriter.startOpModification(op);
1202+
Block *body = &op.getBody().front();
1203+
TypeConverter::SignatureConversion result(body->getNumArguments());
1204+
for (auto it : llvm::enumerate(body->getArgumentTypes()))
1205+
result.addInputs(it.index(), {it.value(), it.value()});
1206+
rewriter.applySignatureConversion(body, result, getTypeConverter());
1207+
op.setIsLegal(true);
1208+
rewriter.finalizeOpModification(op);
1209+
return success();
1210+
}
1211+
};
1212+
1213+
/// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1214+
/// op. The pattern supports 1:N replacements and forwards the replacement
1215+
/// values of the single operand as test.valid operands.
1216+
class TestRepetitive1ToNConsumer : public ConversionPattern {
1217+
public:
1218+
TestRepetitive1ToNConsumer(MLIRContext *ctx)
1219+
: ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
1220+
LogicalResult
1221+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1222+
ConversionPatternRewriter &rewriter) const final {
1223+
// A single operand is expected.
1224+
if (op->getNumOperands() != 1)
1225+
return failure();
1226+
rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
1227+
return success();
1228+
}
1229+
};
1230+
11841231
} // namespace
11851232

11861233
namespace {
@@ -1258,9 +1305,11 @@ struct TestLegalizePatternDriver
12581305
TestUpdateConsumerType, TestNonRootReplacement,
12591306
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
12601307
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1261-
TestUndoPropertiesModification, TestEraseOp>(&getContext());
1308+
TestUndoPropertiesModification, TestEraseOp,
1309+
TestRepetitive1ToNConsumer>(&getContext());
12621310
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
12631311
&getContext(), converter);
1312+
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
12641313
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
12651314
converter);
12661315
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1312,6 +1361,9 @@ struct TestLegalizePatternDriver
13121361
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
13131362
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
13141363

1364+
target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
1365+
[](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
1366+
13151367
// Handle a partial conversion.
13161368
if (mode == ConversionMode::Partial) {
13171369
DenseSet<Operation *> unlegalizedOps;

0 commit comments

Comments
 (0)