Skip to content

Commit 2bdb364

Browse files
author
Jerry Wu
committed
Add tests
1 parent 29f0357 commit 2bdb364

File tree

4 files changed

+102
-0
lines changed

4 files changed

+102
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt %s -test-scf-while-zero-trip-check -split-input-file | FileCheck %s
2+
3+
func.func @replace_scf_while_with_zero_trip_check(%bound : i32) -> i32 {
4+
%cst0 = arith.constant 0 : i32
5+
%cst5 = arith.constant 5 : i32
6+
%res:2 = scf.while (%iter = %cst0) : (i32) -> (i32, i32) {
7+
%cond = arith.cmpi slt, %iter, %bound : i32
8+
%inv = arith.addi %bound, %cst5 : i32
9+
scf.condition(%cond) %iter, %inv : i32, i32
10+
} do {
11+
^bb0(%arg1: i32, %arg2: i32):
12+
%next = arith.addi %arg1, %arg2 : i32
13+
scf.yield %next : i32
14+
}
15+
return %res#0 : i32
16+
}
17+
18+
// CHECK-LABEL: func.func @replace_scf_while_with_zero_trip_check(
19+
// CHECK-SAME: %[[ARG0:.*]]: i32) -> i32 {
20+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
21+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
22+
// CHECK-DAG: %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[ARG0]] : i32
23+
// CHECK-DAG: %[[PRE_INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
24+
// CHECK: %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
25+
// CHECK: %[[WHILE:.*]]:2 = scf.while (
26+
// CHECK-SAME: %[[ARG1:.*]] = %[[C0]], %[[ARG2:.*]] = %[[PRE_INV]]
27+
// CHECK-SAME: ) : (i32, i32) -> (i32, i32) {
28+
// CHECK: %[[NEXT:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
29+
// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[NEXT]], %[[ARG0]] : i32
30+
// CHECK: %[[INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
31+
// CHECK: scf.condition(%[[COND]]) %[[NEXT]], %[[INV]] : i32, i32
32+
// CHECK: } do {
33+
// CHECK: ^bb0(%[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
34+
// CHECK: scf.yield %[[ARG3]], %[[ARG4]] : i32, i32
35+
// CHECK: }
36+
// CHECK: scf.yield %[[WHILE]]#0, %[[WHILE]]#1 : i32, i32
37+
// CHECK: } else {
38+
// CHECK: scf.yield %[[C0]], %[[PRE_INV]] : i32, i32
39+
// CHECK: }
40+
// CHECK: return %[[IF]]#0 : i32

mlir/test/lib/Dialect/SCF/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
add_mlir_library(MLIRSCFTestPasses
33
TestLoopParametricTiling.cpp
44
TestLoopUnrolling.cpp
5+
TestLoopZeroTripCheck.cpp
56
TestSCFUtils.cpp
67
TestWhileOpBuilder.cpp
78

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===- TestLoopZeroTripCheck.cpp -- Pass to test replaceWithZeroTripCheck -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the passes to test replaceWithZeroTripCheck for SCF
10+
// dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Pass/Pass.h"
18+
19+
using namespace mlir;
20+
21+
namespace {
22+
23+
struct TestSCFWhileZeroTripCheckPass
24+
: public PassWrapper<TestSCFWhileZeroTripCheckPass,
25+
OperationPass<func::FuncOp>> {
26+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFWhileZeroTripCheckPass)
27+
28+
StringRef getArgument() const final {
29+
return "test-scf-while-zero-trip-check";
30+
}
31+
StringRef getDescription() const final {
32+
return "test replaceWithZeroTripCheck of scf.while";
33+
}
34+
explicit TestSCFWhileZeroTripCheckPass() = default;
35+
TestSCFWhileZeroTripCheckPass(const TestSCFWhileZeroTripCheckPass &pass)
36+
: PassWrapper(pass) {}
37+
38+
void runOnOperation() override {
39+
func::FuncOp func = getOperation();
40+
MLIRContext *context = &getContext();
41+
IRRewriter rewriter(context);
42+
func.walk([&](scf::WhileOp op) {
43+
auto result = op.replaceWithZeroTripCheck(rewriter);
44+
if (failed(result)) {
45+
signalPassFailure();
46+
}
47+
});
48+
}
49+
};
50+
51+
} // namespace
52+
53+
namespace mlir {
54+
namespace test {
55+
void registerTestLoopZeroTripCheckPass() {
56+
PassRegistration<TestSCFWhileZeroTripCheckPass>();
57+
}
58+
} // namespace test
59+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ void registerTestLoopFusion();
110110
void registerTestCFGLoopInfoPass();
111111
void registerTestLoopMappingPass();
112112
void registerTestLoopUnrollingPass();
113+
void registerTestLoopZeroTripCheckPass();
113114
void registerTestLowerToLLVM();
114115
void registerTestLowerToNVVM();
115116
void registerTestMakeIsolatedFromAbovePass();
@@ -234,6 +235,7 @@ void registerTestPasses() {
234235
mlir::test::registerTestCFGLoopInfoPass();
235236
mlir::test::registerTestLoopMappingPass();
236237
mlir::test::registerTestLoopUnrollingPass();
238+
mlir::test::registerTestLoopZeroTripCheckPass();
237239
mlir::test::registerTestLowerToLLVM();
238240
mlir::test::registerTestMakeIsolatedFromAbovePass();
239241
mlir::test::registerTestMatchReductionPass();

0 commit comments

Comments
 (0)