Skip to content

Commit 7f312f6

Browse files
committed
[mlir] Avoid folding in OpBuilder::tryFold when types change
This was missed when tightening fold restrictions in https://reviews.llvm.org/D95991. Differential Revision: https://reviews.llvm.org/D113138
1 parent a55c4ec commit 7f312f6

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

mlir/lib/IR/Builders.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,11 @@ Operation *OpBuilder::createOperation(const OperationState &state) {
392392
/// Note: This function does not erase the operation on a successful fold.
393393
LogicalResult OpBuilder::tryFold(Operation *op,
394394
SmallVectorImpl<Value> &results) {
395-
results.reserve(op->getNumResults());
395+
ResultRange opResults = op->getResults();
396+
397+
results.reserve(opResults.size());
396398
auto cleanupFailure = [&] {
397-
results.assign(op->result_begin(), op->result_end());
399+
results.assign(opResults.begin(), opResults.end());
398400
return failure();
399401
};
400402

@@ -405,7 +407,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
405407
// Check to see if any operands to the operation is constant and whether
406408
// the operation knows how to constant fold itself.
407409
SmallVector<Attribute, 4> constOperands(op->getNumOperands());
408-
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
410+
for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
409411
matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
410412

411413
// Try to fold the operation.
@@ -419,9 +421,14 @@ LogicalResult OpBuilder::tryFold(Operation *op,
419421

420422
// Populate the results with the folded results.
421423
Dialect *dialect = op->getDialect();
422-
for (auto &it : llvm::enumerate(foldResults)) {
424+
for (auto it : llvm::zip(foldResults, opResults.getTypes())) {
425+
Type expectedType = std::get<1>(it);
426+
423427
// Normal values get pushed back directly.
424-
if (auto value = it.value().dyn_cast<Value>()) {
428+
if (auto value = std::get<0>(it).dyn_cast<Value>()) {
429+
if (value.getType() != expectedType)
430+
return cleanupFailure();
431+
425432
results.push_back(value);
426433
continue;
427434
}
@@ -431,9 +438,9 @@ LogicalResult OpBuilder::tryFold(Operation *op,
431438
return cleanupFailure();
432439

433440
// Ask the dialect to materialize a constant operation for this value.
434-
Attribute attr = it.value().get<Attribute>();
435-
auto *constOp = dialect->materializeConstant(
436-
cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
441+
Attribute attr = std::get<0>(it).get<Attribute>();
442+
auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
443+
op->getLoc());
437444
if (!constOp) {
438445
// Erase any generated constants.
439446
for (Operation *cst : generatedConstants)

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,13 @@ builtin.module {
307307
}
308308

309309
}
310+
311+
// -----
312+
313+
// The "passthrough_fold" folder will naively return its operand, but we don't
314+
// want to fold here because of the type mismatch.
315+
func @typemismatch(%arg: f32) -> i32 {
316+
// expected-remark@+1 {{op 'test.passthrough_fold' is not legalizable}}
317+
%0 = "test.passthrough_fold"(%arg) : (f32) -> (i32)
318+
"test.return"(%0) : (i32) -> ()
319+
}

0 commit comments

Comments
 (0)