Skip to content

[mlir] Require folders to produce Values of same type #75887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,13 @@ void fir::BoxAddrOp::build(mlir::OpBuilder &builder,
mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) {
if (auto *v = getVal().getDefiningOp()) {
if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) {
if (!box.getSlice()) // Fold only if not sliced
// Fold only if not sliced
if (!box.getSlice() && box.getMemref().getType() == getType())
return box.getMemref();
}
if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v))
return box.getMemref();
if (box.getMemref().getType() == getType())
return box.getMemref();
}
return {};
}
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
setOperand(src);
return getResult();
}

// trunci(zexti(a)) -> a
// trunci(sexti(a)) -> a
return src;
if (srcType == dstType)
return src;
}

// trunci(trunci(a)) -> trunci(a))
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,8 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
if (!inputTy.hasRank()) \
return {}; \
if (inputTy != getType()) \
return {}; \
if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
return getInput(); \
return {}; \
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,9 +1602,10 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
: 0;
};

// If splat or broadcast from a scalar, just return the source scalar.
unsigned broadcastSrcRank = getRank(source.getType());
if (broadcastSrcRank == 0)
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
return source;

unsigned extractResultRank = getRank(extractOp.getType());
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,14 +486,11 @@ LogicalResult OpBuilder::tryFold(Operation *op,

// Populate the results with the folded results.
Dialect *dialect = op->getDialect();
for (auto it : llvm::zip(foldResults, opResults.getTypes())) {
for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
Type expectedType = std::get<1>(it);

// Normal values get pushed back directly.
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
if (value.getType() != expectedType)
return cleanupFailure();

results.push_back(value);
continue;
}
Expand Down
26 changes: 24 additions & 2 deletions mlir/lib/IR/Operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,13 +606,30 @@ void Operation::setSuccessor(Block *block, unsigned index) {
getBlockOperands()[index].set(block);
}

#ifndef NDEBUG
/// Assert that the folded results (in case of values) have the same type as
/// the results of the given op.
static void checkFoldResultTypes(Operation *op,
SmallVectorImpl<OpFoldResult> &results) {
if (!results.empty())
for (auto [ofr, opResult] : llvm::zip_equal(results, op->getResults()))
if (auto value = ofr.dyn_cast<Value>())
assert(value.getType() == opResult.getType() &&
"folder produced value of incorrect type");
}
#endif // NDEBUG

/// Attempt to fold this operation using the Op's registered foldHook.
LogicalResult Operation::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
if (succeeded(name.foldHook(this, operands, results)))
if (succeeded(name.foldHook(this, operands, results))) {
#ifndef NDEBUG
checkFoldResultTypes(this, results);
#endif // NDEBUG
return success();
}

// Otherwise, fall back on the dialect hook to handle it.
Dialect *dialect = getDialect();
Expand All @@ -623,7 +640,12 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
if (!interface)
return failure();

return interface->fold(this, operands, results);
LogicalResult status = interface->fold(this, operands, results);
#ifndef NDEBUG
if (succeeded(status))
checkFoldResultTypes(this, results);
#endif // NDEBUG
return status;
}

LogicalResult Operation::fold(SmallVectorImpl<OpFoldResult> &results) {
Expand Down
4 changes: 0 additions & 4 deletions mlir/lib/Transforms/Utils/FoldUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,6 @@ OperationFolder::processFoldResults(Operation *op,

// Check if the result was an SSA value.
if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
if (repl.getType() != op->getResult(i).getType()) {
results.clear();
return failure();
}
results.emplace_back(repl);
continue;
}
Expand Down
13 changes: 0 additions & 13 deletions mlir/test/Transforms/test-canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,6 @@ func.func @test_commutative_multi_cst(%arg0: i32, %arg1: i32) -> (i32, i32) {
return %y, %z: i32, i32
}

// CHECK-LABEL: func @typemismatch

func.func @typemismatch() -> i32 {
%c42 = arith.constant 42.0 : f32

// The "passthrough_fold" folder will naively return its operand, but we don't
// want to fold here because of the type mismatch.

// CHECK: "test.passthrough_fold"
%0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
return %0 : i32
}

// CHECK-LABEL: test_dialect_canonicalizer
func.func @test_dialect_canonicalizer() -> (i32) {
%0 = "test.dialect_canonicalizable"() : () -> (i32)
Expand Down
10 changes: 0 additions & 10 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -310,16 +310,6 @@ builtin.module {

// -----

// The "passthrough_fold" folder will naively return its operand, but we don't
// want to fold here because of the type mismatch.
func.func @typemismatch(%arg: f32) -> i32 {
// expected-remark@+1 {{op 'test.passthrough_fold' is not legalizable}}
%0 = "test.passthrough_fold"(%arg) : (f32) -> (i32)
"test.return"(%0) : (i32) -> ()
}

// -----

// expected-remark @below {{applyPartialConversion failed}}
module {
func.func private @callee(%0 : f32) -> f32
Expand Down
4 changes: 0 additions & 4 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,6 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
return {};
}

OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) {
return getOperand();
}

OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
int64_t sum = 0;
if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
Expand Down
7 changes: 0 additions & 7 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1363,13 +1363,6 @@ def TestOpFoldWithFoldAdaptor
let hasFolder = 1;
}

// An op that always fold itself.
def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
let arguments = (ins AnyType:$op);
let results = (outs AnyType);
let hasFolder = 1;
}

def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
let arguments = (ins);
let results = (outs I32);
Expand Down