Skip to content

Commit ddaf040

Browse files
[mlir][Transforms][NFC] Make signature conversion more efficient (#83922)
During block signature conversion, a new block is inserted and ops are moved from the old block to the new block. This commit changes the implementation such that ops are moved in bulk (`splice`) instead of one-by-one; that's what `splitBlock` is doing. This also makes it possible to pass the new block argument types directly to `createBlock` instead of using `addArgument` (which bypasses the rewriter). This doesn't change anything from a technical point of view (there is no rewriter API for adding arguments at the moment), but the implementation reads a bit nicer.
1 parent fb582b6 commit ddaf040

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,22 +1281,17 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12811281
ConversionPatternRewriter &rewriter, Block *block,
12821282
const TypeConverter *converter,
12831283
TypeConverter::SignatureConversion &signatureConversion) {
1284-
MLIRContext *ctx = rewriter.getContext();
1284+
OpBuilder::InsertionGuard g(rewriter);
12851285

12861286
// If no arguments are being changed or added, there is nothing to do.
12871287
unsigned origArgCount = block->getNumArguments();
12881288
auto convertedTypes = signatureConversion.getConvertedTypes();
12891289
if (llvm::equal(block->getArgumentTypes(), convertedTypes))
12901290
return block;
12911291

1292-
// Split the block at the beginning to get a new block to use for the updated
1293-
// signature.
1294-
Block *newBlock = rewriter.splitBlock(block, block->begin());
1295-
block->replaceAllUsesWith(newBlock);
1296-
1297-
// Map all new arguments to the location of the argument they originate from.
1292+
// Compute the locations of all block arguments in the new block.
12981293
SmallVector<Location> newLocs(convertedTypes.size(),
1299-
Builder(ctx).getUnknownLoc());
1294+
rewriter.getUnknownLoc());
13001295
for (unsigned i = 0; i < origArgCount; ++i) {
13011296
auto inputMap = signatureConversion.getInputMapping(i);
13021297
if (!inputMap || inputMap->replacementValue)
@@ -1306,9 +1301,16 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13061301
newLocs[inputMap->inputNo + j] = origLoc;
13071302
}
13081303

1309-
SmallVector<Value, 4> newArgRange(
1310-
newBlock->addArguments(convertedTypes, newLocs));
1311-
ArrayRef<Value> newArgs(newArgRange);
1304+
// Insert a new block with the converted block argument types and move all ops
1305+
// from the old block to the new block.
1306+
Block *newBlock =
1307+
rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1308+
convertedTypes, newLocs);
1309+
appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1310+
newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1311+
1312+
// Replace all uses of the old block with the new block.
1313+
block->replaceAllUsesWith(newBlock);
13121314

13131315
// Remap each of the original arguments as determined by the signature
13141316
// conversion.
@@ -1333,7 +1335,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13331335
}
13341336

13351337
// Otherwise, this is a 1->1+ mapping.
1336-
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
1338+
auto replArgs =
1339+
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
13371340
Value newArg;
13381341

13391342
// If this is a 1->1 mapping and the types of new and replacement arguments

0 commit comments

Comments
 (0)