|
10 | 10 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
11 | 11 | #include "mlir/IR/Diagnostics.h"
|
12 | 12 | #include "mlir/IR/MLIRContext.h"
|
| 13 | +#include "mlir/IR/OwningOpRef.h" |
13 | 14 | #include "gtest/gtest.h"
|
14 | 15 |
|
15 | 16 | using namespace mlir;
|
@@ -55,35 +56,50 @@ class SCFLoopLikeTest : public ::testing::Test {
|
55 | 56 | };
|
56 | 57 |
|
57 | 58 | TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
|
58 |
| - Value lb = b.create<arith::ConstantIndexOp>(loc, 0); |
59 |
| - Value ub = b.create<arith::ConstantIndexOp>(loc, 10); |
60 |
| - Value step = b.create<arith::ConstantIndexOp>(loc, 2); |
| 59 | + OwningOpRef<arith::ConstantIndexOp> lb = |
| 60 | + b.create<arith::ConstantIndexOp>(loc, 0); |
| 61 | + OwningOpRef<arith::ConstantIndexOp> ub = |
| 62 | + b.create<arith::ConstantIndexOp>(loc, 10); |
| 63 | + OwningOpRef<arith::ConstantIndexOp> step = |
| 64 | + b.create<arith::ConstantIndexOp>(loc, 2); |
61 | 65 |
|
62 |
| - auto forOp = b.create<scf::ForOp>(loc, lb, ub, step); |
63 |
| - checkUnidimensional(forOp); |
| 66 | + OwningOpRef<scf::ForOp> forOp = |
| 67 | + b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get()); |
| 68 | + checkUnidimensional(forOp.get()); |
64 | 69 |
|
65 |
| - auto forallOp = b.create<scf::ForallOp>( |
66 |
| - loc, ArrayRef<OpFoldResult>(lb), ArrayRef<OpFoldResult>(ub), |
67 |
| - ArrayRef<OpFoldResult>(step), ValueRange(), std::nullopt); |
68 |
| - checkUnidimensional(forallOp); |
| 70 | + OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>( |
| 71 | + loc, ArrayRef<OpFoldResult>(static_cast<Value>(lb.get())), |
| 72 | + ArrayRef<OpFoldResult>(static_cast<Value>(ub.get())), |
| 73 | + ArrayRef<OpFoldResult>(static_cast<Value>(step.get())), ValueRange(), |
| 74 | + std::nullopt); |
| 75 | + checkUnidimensional(forallOp.get()); |
69 | 76 |
|
70 |
| - auto parallelOp = b.create<scf::ParallelOp>( |
71 |
| - loc, ValueRange(lb), ValueRange(ub), ValueRange(step), ValueRange()); |
72 |
| - checkUnidimensional(parallelOp); |
| 77 | + OwningOpRef<scf::ParallelOp> parallelOp = |
| 78 | + b.create<scf::ParallelOp>(loc, ValueRange(lb.get()), ValueRange(ub.get()), |
| 79 | + ValueRange(step.get()), ValueRange()); |
| 80 | + checkUnidimensional(parallelOp.get()); |
73 | 81 | }
|
74 | 82 |
|
75 | 83 | TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
|
76 |
| - Value lb = b.create<arith::ConstantIndexOp>(loc, 0); |
77 |
| - Value ub = b.create<arith::ConstantIndexOp>(loc, 10); |
78 |
| - Value step = b.create<arith::ConstantIndexOp>(loc, 2); |
| 84 | + OwningOpRef<arith::ConstantIndexOp> lb = |
| 85 | + b.create<arith::ConstantIndexOp>(loc, 0); |
| 86 | + OwningOpRef<arith::ConstantIndexOp> ub = |
| 87 | + b.create<arith::ConstantIndexOp>(loc, 10); |
| 88 | + OwningOpRef<arith::ConstantIndexOp> step = |
| 89 | + b.create<arith::ConstantIndexOp>(loc, 2); |
| 90 | + auto lbValue = static_cast<Value>(lb.get()); |
| 91 | + auto ubValue = static_cast<Value>(ub.get()); |
| 92 | + auto stepValue = static_cast<Value>(step.get()); |
79 | 93 |
|
80 |
| - auto forallOp = b.create<scf::ForallOp>( |
81 |
| - loc, ArrayRef<OpFoldResult>({lb, lb}), ArrayRef<OpFoldResult>({ub, ub}), |
82 |
| - ArrayRef<OpFoldResult>({step, step}), ValueRange(), std::nullopt); |
83 |
| - checkMultidimensional(forallOp); |
| 94 | + OwningOpRef<scf::ForallOp> forallOp = |
| 95 | + b.create<scf::ForallOp>(loc, ArrayRef<OpFoldResult>({lbValue, lbValue}), |
| 96 | + ArrayRef<OpFoldResult>({ubValue, ubValue}), |
| 97 | + ArrayRef<OpFoldResult>({stepValue, stepValue}), |
| 98 | + ValueRange(), std::nullopt); |
| 99 | + checkMultidimensional(forallOp.get()); |
84 | 100 |
|
85 |
| - auto parallelOp = |
86 |
| - b.create<scf::ParallelOp>(loc, ValueRange({lb, lb}), ValueRange({ub, ub}), |
87 |
| - ValueRange({step, step}), ValueRange()); |
88 |
| - checkMultidimensional(parallelOp); |
| 101 | + OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>( |
| 102 | + loc, ValueRange({lbValue, lbValue}), ValueRange({ubValue, ubValue}), |
| 103 | + ValueRange({stepValue, stepValue}), ValueRange()); |
| 104 | + checkMultidimensional(parallelOp.get()); |
89 | 105 | }
|
0 commit comments