Skip to content

Commit 2c5cc19

Browse files
[mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic
This commit moves the argument materialization logic from `legalizeConvertedArgumentTypes` to `legalizeUnresolvedMaterializations`. Before this change: - Argument materializations were created in `legalizeConvertedArgumentTypes` (which used to call `materializeLiveConversions`). After this change: - `legalizeConvertedArgumentTypes` creates a "placeholder" `unrealized_conversion_cast`. - The placeholder `unrealized_conversion_cast` is replaced with an argument materialization (using the type converter) in `legalizeUnresolvedMaterializations`. - All argument and target materializations now take place in the same location (`legalizeUnresolvedMaterializations`). This commit brings us closer towards creating all source/target/argument materializations in one central step, which can then be made optional (and delegated to the user) in the future. (There is one more source materialization step that has not been moved yet.) This commit also consolidates all `build*UnresolvedMaterialization` functions into a single `buildUnresolvedMaterialization` function. This is a re-upload of #96329.
1 parent b4f3a96 commit 2c5cc19

File tree

1 file changed

+54
-84
lines changed

1 file changed

+54
-84
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 54 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
5353
});
5454
}
5555

56+
/// Helper function that computes an insertion point where the given value is
57+
/// defined and can be used without a dominance violation.
58+
static OpBuilder::InsertPoint computeInsertPoint(Value value) {
59+
Block *insertBlock = value.getParentBlock();
60+
Block::iterator insertPt = insertBlock->begin();
61+
if (OpResult inputRes = dyn_cast<OpResult>(value))
62+
insertPt = ++inputRes.getOwner()->getIterator();
63+
return OpBuilder::InsertPoint(insertBlock, insertPt);
64+
}
65+
5666
//===----------------------------------------------------------------------===//
5767
// ConversionValueMapping
5868
//===----------------------------------------------------------------------===//
@@ -445,11 +455,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
445455
return rewrite->getKind() == Kind::BlockTypeConversion;
446456
}
447457

448-
/// Materialize any necessary conversions for converted arguments that have
449-
/// live users, using the provided `findLiveUser` to search for a user that
450-
/// survives the conversion process.
451-
LogicalResult
452-
materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
458+
Block *getOrigBlock() const { return origBlock; }
459+
460+
const TypeConverter *getConverter() const { return converter; }
453461

454462
void commit(RewriterBase &rewriter) override;
455463

@@ -830,15 +838,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
830838
/// Build an unresolved materialization operation given an output type and set
831839
/// of input operands.
832840
Value buildUnresolvedMaterialization(MaterializationKind kind,
833-
Block *insertBlock,
834-
Block::iterator insertPt, Location loc,
841+
OpBuilder::InsertPoint ip, Location loc,
835842
ValueRange inputs, Type outputType,
836843
const TypeConverter *converter);
837844

838-
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
839-
Type outputType,
840-
const TypeConverter *converter);
841-
842845
//===--------------------------------------------------------------------===//
843846
// Rewriter Notification Hooks
844847
//===--------------------------------------------------------------------===//
@@ -970,49 +973,6 @@ void BlockTypeConversionRewrite::rollback() {
970973
block->replaceAllUsesWith(origBlock);
971974
}
972975

973-
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
974-
function_ref<Operation *(Value)> findLiveUser) {
975-
// Process the remapping for each of the original arguments.
976-
for (auto it : llvm::enumerate(origBlock->getArguments())) {
977-
BlockArgument origArg = it.value();
978-
// Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
979-
OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
980-
builder.setInsertionPointToStart(block);
981-
982-
// If the type of this argument changed and the argument is still live, we
983-
// need to materialize a conversion.
984-
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
985-
continue;
986-
Operation *liveUser = findLiveUser(origArg);
987-
if (!liveUser)
988-
continue;
989-
990-
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
991-
assert(replacementValue && "replacement value not found");
992-
Value newArg;
993-
if (converter) {
994-
builder.setInsertionPointAfterValue(replacementValue);
995-
newArg = converter->materializeSourceConversion(
996-
builder, origArg.getLoc(), origArg.getType(), replacementValue);
997-
assert((!newArg || newArg.getType() == origArg.getType()) &&
998-
"materialization hook did not provide a value of the expected "
999-
"type");
1000-
}
1001-
if (!newArg) {
1002-
InFlightDiagnostic diag =
1003-
emitError(origArg.getLoc())
1004-
<< "failed to materialize conversion for block argument #"
1005-
<< it.index() << " that remained live after conversion, type was "
1006-
<< origArg.getType();
1007-
diag.attachNote(liveUser->getLoc())
1008-
<< "see existing live user here: " << *liveUser;
1009-
return failure();
1010-
}
1011-
rewriterImpl.mapping.map(origArg, newArg);
1012-
}
1013-
return success();
1014-
}
1015-
1016976
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1017977
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
1018978
if (!repl)
@@ -1185,8 +1145,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11851145
Type newOperandType = newOperand.getType();
11861146
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
11871147
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1188-
Value castValue = buildUnresolvedTargetMaterialization(
1189-
operandLoc, newOperand, desiredType, currentTypeConverter);
1148+
Value castValue = buildUnresolvedMaterialization(
1149+
MaterializationKind::Target, computeInsertPoint(newOperand),
1150+
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
1151+
currentTypeConverter);
11901152
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
11911153
newOperand = castValue;
11921154
}
@@ -1299,8 +1261,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12991261
// This block argument was dropped and no replacement value was provided.
13001262
// Materialize a replacement value "out of thin air".
13011263
Value repl = buildUnresolvedMaterialization(
1302-
MaterializationKind::Source, newBlock, newBlock->begin(),
1303-
origArg.getLoc(), /*inputs=*/ValueRange(),
1264+
MaterializationKind::Source,
1265+
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1266+
/*inputs=*/ValueRange(),
13041267
/*outputType=*/origArgType, converter);
13051268
mapping.map(origArg, repl);
13061269
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1324,8 +1287,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13241287
auto replArgs =
13251288
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
13261289
Value argMat = buildUnresolvedMaterialization(
1327-
MaterializationKind::Argument, newBlock, newBlock->begin(),
1328-
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
1290+
MaterializationKind::Argument,
1291+
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1292+
/*inputs=*/replArgs, origArgType, converter);
13291293
mapping.map(origArg, argMat);
13301294
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13311295

@@ -1339,7 +1303,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13391303
if (converter)
13401304
legalOutputType = converter->convertType(origArgType);
13411305
if (legalOutputType && legalOutputType != origArgType) {
1342-
Value targetMat = buildUnresolvedTargetMaterialization(
1306+
Value targetMat = buildUnresolvedMaterialization(
1307+
MaterializationKind::Target, computeInsertPoint(argMat),
13431308
origArg.getLoc(), argMat, legalOutputType, converter);
13441309
mapping.map(argMat, targetMat);
13451310
}
@@ -1362,33 +1327,20 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13621327
/// Build an unresolved materialization operation given an output type and set
13631328
/// of input operands.
13641329
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1365-
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1366-
Location loc, ValueRange inputs, Type outputType,
1367-
const TypeConverter *converter) {
1330+
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1331+
ValueRange inputs, Type outputType, const TypeConverter *converter) {
13681332
// Avoid materializing an unnecessary cast.
13691333
if (inputs.size() == 1 && inputs.front().getType() == outputType)
13701334
return inputs.front();
13711335

13721336
// Create an unresolved materialization. We use a new OpBuilder to avoid
13731337
// tracking the materialization like we do for other operations.
1374-
OpBuilder builder(insertBlock, insertPt);
1338+
OpBuilder builder(ip.getBlock(), ip.getPoint());
13751339
auto convertOp =
13761340
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
13771341
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13781342
return convertOp.getResult(0);
13791343
}
1380-
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
1381-
Location loc, Value input, Type outputType,
1382-
const TypeConverter *converter) {
1383-
Block *insertBlock = input.getParentBlock();
1384-
Block::iterator insertPt = insertBlock->begin();
1385-
if (OpResult inputRes = dyn_cast<OpResult>(input))
1386-
insertPt = ++inputRes.getOwner()->getIterator();
1387-
1388-
return buildUnresolvedMaterialization(MaterializationKind::Target,
1389-
insertBlock, insertPt, loc, input,
1390-
outputType, converter);
1391-
}
13921344

13931345
//===----------------------------------------------------------------------===//
13941346
// Rewriter Notification Hooks
@@ -2502,9 +2454,9 @@ LogicalResult
25022454
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
25032455
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
25042456
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2505-
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2506-
inverseMapping)) ||
2507-
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2457+
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
2458+
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2459+
inverseMapping)))
25082460
return failure();
25092461

25102462
// Process requested operation replacements.
@@ -2560,10 +2512,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25602512
++i) {
25612513
auto &rewrite = rewriterImpl.rewrites[i];
25622514
if (auto *blockTypeConversionRewrite =
2563-
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2564-
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2565-
findLiveUser)))
2566-
return failure();
2515+
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
2516+
// Process the remapping for each of the original arguments.
2517+
for (Value origArg :
2518+
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
2519+
// If the type of this argument changed and the argument is still live,
2520+
// we need to materialize a conversion.
2521+
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
2522+
continue;
2523+
Operation *liveUser = findLiveUser(origArg);
2524+
if (!liveUser)
2525+
continue;
2526+
2527+
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
2528+
assert(replacementValue && "replacement value not found");
2529+
Value repl = rewriterImpl.buildUnresolvedMaterialization(
2530+
MaterializationKind::Source, computeInsertPoint(replacementValue),
2531+
origArg.getLoc(), /*inputs=*/replacementValue,
2532+
/*outputType=*/origArg.getType(),
2533+
blockTypeConversionRewrite->getConverter());
2534+
rewriterImpl.mapping.map(origArg, repl);
2535+
}
2536+
}
25672537
}
25682538
return success();
25692539
}

0 commit comments

Comments
 (0)