Skip to content

Commit b716bf8

Browse files
[mlir][scf] Fix builder of WhileOp with region builder arguments.
The overload of WhileOp::build with arguments for builder functions for the regions of the op was broken: It did not compute correctly the types (and locations) of the region arguments, which lead to failed assertions when the result types were different from the operand types. Specifically, it used the result types (and operand locations) for *both* regions, instead of the operand types (and locations) for the 'before' region and the result types (and loecations) for the 'after' region. Reviewed By: Mogball, mehdi_amini Differential Revision: https://reviews.llvm.org/D142952
1 parent 3599cbd commit b716bf8

File tree

5 files changed

+132
-5
lines changed

5 files changed

+132
-5
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2764,19 +2764,24 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
27642764

27652765
OpBuilder::InsertionGuard guard(odsBuilder);
27662766

2767-
SmallVector<Location, 4> blockArgLocs;
2767+
// Build before region.
2768+
SmallVector<Location, 4> beforeArgLocs;
2769+
beforeArgLocs.reserve(operands.size());
27682770
for (Value operand : operands) {
2769-
blockArgLocs.push_back(operand.getLoc());
2771+
beforeArgLocs.push_back(operand.getLoc());
27702772
}
27712773

27722774
Region *beforeRegion = odsState.addRegion();
2773-
Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
2774-
resultTypes, blockArgLocs);
2775+
Block *beforeBlock = odsBuilder.createBlock(
2776+
beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs);
27752777
beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
27762778

2779+
// Build after region.
2780+
SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
2781+
27772782
Region *afterRegion = odsState.addRegion();
27782783
Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
2779-
resultTypes, blockArgLocs);
2784+
resultTypes, afterArgLocs);
27802785
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
27812786
}
27822787

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-opt %s -test-scf-while-op-builder | FileCheck %s
2+
3+
// CHECK-LABEL: @testMatchingTypes
4+
func.func @testMatchingTypes(%arg0 : i32) {
5+
%0 = scf.while (%arg1 = %arg0) : (i32) -> (i32) {
6+
%c10 = arith.constant 10 : i32
7+
%1 = arith.cmpi slt, %arg1, %c10 : i32
8+
scf.condition(%1) %arg1 : i32
9+
} do {
10+
^bb0(%arg1: i32):
11+
scf.yield %arg1 : i32
12+
}
13+
// Expect the same loop twice (the dummy added by the test pass and the
14+
// original one).
15+
// CHECK: %[[V0:.*]] = scf.while (%[[arg1:.*]] = %[[arg0:.*]]) : (i32) -> i32 {
16+
// CHECK: %[[V1:.*]] = scf.while (%[[arg2:.*]] = %[[arg0]]) : (i32) -> i32 {
17+
return
18+
}
19+
20+
// CHECK-LABEL: @testNonMatchingTypes
21+
func.func @testNonMatchingTypes(%arg0 : i32) {
22+
%c1 = arith.constant 1 : i32
23+
%c10 = arith.constant 10 : i32
24+
%0:2 = scf.while (%arg1 = %arg0) : (i32) -> (i32, i32) {
25+
%1 = arith.cmpi slt, %arg1, %c10 : i32
26+
scf.condition(%1) %arg1, %c1 : i32, i32
27+
} do {
28+
^bb0(%arg1: i32, %arg2: i32):
29+
%1 = arith.addi %arg1, %arg2 : i32
30+
scf.yield %1 : i32
31+
}
32+
// Expect the same loop twice (the dummy added by the test pass and the
33+
// original one).
34+
// CHECK: %[[V0:.*]] = scf.while (%[[arg1:.*]] = %[[arg0:.*]]) : (i32) -> (i32, i32) {
35+
// CHECK: %[[V1:.*]] = scf.while (%[[arg2:.*]] = %[[arg0]]) : (i32) -> (i32, i32) {
36+
return
37+
}

mlir/test/lib/Dialect/SCF/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_library(MLIRSCFTestPasses
33
TestLoopParametricTiling.cpp
44
TestLoopUnrolling.cpp
55
TestSCFUtils.cpp
6+
TestWhileOpBuilder.cpp
67

78
EXCLUDE_FROM_LIBMLIR
89

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===- TestWhileOpBuilder.cpp - Pass to test WhileOp::build ---------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to test some builder functions of WhileOp. It
10+
// tests the regression explained in https://reviews.llvm.org/D142952, where
11+
// a WhileOp::build overload crashed when fed with operands of different types
12+
// than the result types.
13+
//
14+
// To test the build function, the pass copies each WhileOp found in the body
15+
// of a FuncOp and adds an additional WhileOp with the same operands and result
16+
// types (but dummy computations) using the builder in question.
17+
//
18+
//===----------------------------------------------------------------------===//
19+
20+
#include "mlir/Dialect/Arith/IR/Arith.h"
21+
#include "mlir/Dialect/Func/IR/FuncOps.h"
22+
#include "mlir/Dialect/SCF/IR/SCF.h"
23+
#include "mlir/IR/BuiltinOps.h"
24+
#include "mlir/IR/ImplicitLocOpBuilder.h"
25+
#include "mlir/Pass/Pass.h"
26+
27+
using namespace mlir;
28+
using namespace mlir::arith;
29+
using namespace mlir::scf;
30+
31+
namespace {
32+
struct TestSCFWhileOpBuilderPass
33+
: public PassWrapper<TestSCFWhileOpBuilderPass,
34+
OperationPass<func::FuncOp>> {
35+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFWhileOpBuilderPass)
36+
37+
StringRef getArgument() const final { return "test-scf-while-op-builder"; }
38+
StringRef getDescription() const final {
39+
return "test build functions of scf.while";
40+
}
41+
explicit TestSCFWhileOpBuilderPass() = default;
42+
TestSCFWhileOpBuilderPass(const TestSCFWhileOpBuilderPass &pass) = default;
43+
44+
void runOnOperation() override {
45+
func::FuncOp func = getOperation();
46+
func.walk([&](WhileOp whileOp) {
47+
Location loc = whileOp->getLoc();
48+
ImplicitLocOpBuilder builder(loc, whileOp);
49+
50+
// Create a WhileOp with the same operands and result types.
51+
TypeRange resultTypes = whileOp->getResultTypes();
52+
ValueRange operands = whileOp->getOperands();
53+
builder.create<WhileOp>(
54+
loc, resultTypes, operands, /*beforeBuilder=*/
55+
[&](OpBuilder &b, Location loc, ValueRange args) {
56+
// Just cast the before args into the right types for condition.
57+
ImplicitLocOpBuilder builder(loc, b);
58+
auto castOp =
59+
builder.create<UnrealizedConversionCastOp>(resultTypes, args);
60+
auto cmp = builder.create<ConstantIntOp>(/*value=*/1, /*width=*/1);
61+
builder.create<ConditionOp>(cmp, castOp->getResults());
62+
},
63+
/*afterBuilder=*/
64+
[&](OpBuilder &b, Location loc, ValueRange args) {
65+
// Just cast the after args into the right types for yield.
66+
ImplicitLocOpBuilder builder(loc, b);
67+
auto castOp = builder.create<UnrealizedConversionCastOp>(
68+
operands.getTypes(), args);
69+
builder.create<YieldOp>(castOp->getResults());
70+
});
71+
});
72+
}
73+
};
74+
} // namespace
75+
76+
namespace mlir {
77+
namespace test {
78+
void registerTestSCFWhileOpBuilderPass() {
79+
PassRegistration<TestSCFWhileOpBuilderPass>();
80+
}
81+
} // namespace test
82+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ void registerTestPDLLPasses();
113113
void registerTestPreparationPassWithAllowedMemrefResults();
114114
void registerTestRecursiveTypesPass();
115115
void registerTestSCFUtilsPass();
116+
void registerTestSCFWhileOpBuilderPass();
116117
void registerTestShapeMappingPass();
117118
void registerTestSliceAnalysisPass();
118119
void registerTestTensorCopyInsertionPass();
@@ -220,6 +221,7 @@ void registerTestPasses() {
220221
mlir::test::registerTestPDLLPasses();
221222
mlir::test::registerTestRecursiveTypesPass();
222223
mlir::test::registerTestSCFUtilsPass();
224+
mlir::test::registerTestSCFWhileOpBuilderPass();
223225
mlir::test::registerTestShapeMappingPass();
224226
mlir::test::registerTestSliceAnalysisPass();
225227
mlir::test::registerTestTensorCopyInsertionPass();

0 commit comments

Comments
 (0)