Skip to content

Commit 3da843b

Browse files
Max191Max Dawkins
andauthored
[mlir] Add ValueBoundsOpInterfaceImpl for scf.forall (#118817)
Adds a ValueBoundsOpInterface implementation for scf.forall ops. The implementation supports bounding for both induction variables, results, and block args of the forall op. Induction variables are given upper and lower bounds based on the lower and upper loop bounds, and dimensions of the results and init block arguments are constrained to be equal to the matching dims of the shared_outs operand. Signed-off-by: Max Dawkins <[email protected]> Co-authored-by: Max Dawkins <[email protected]>
1 parent cc46d0b commit 3da843b

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,47 @@ struct ForOpInterface
9595
}
9696
};
9797

98+
struct ForallOpInterface
99+
: public ValueBoundsOpInterface::ExternalModel<ForallOpInterface,
100+
ForallOp> {
101+
102+
void populateBoundsForIndexValue(Operation *op, Value value,
103+
ValueBoundsConstraintSet &cstr) const {
104+
auto forallOp = cast<ForallOp>(op);
105+
106+
// Index values should be induction variables, since the semantics of
107+
// tensor::ParallelInsertSliceOp requires forall outputs to be ranked
108+
// tensors.
109+
auto blockArg = cast<BlockArgument>(value);
110+
assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() &&
111+
"expected index value to be an induction var");
112+
int64_t idx = blockArg.getArgNumber();
113+
// TODO: Take into account step size.
114+
AffineExpr lb = cstr.getExpr(forallOp.getMixedLowerBound()[idx]);
115+
AffineExpr ub = cstr.getExpr(forallOp.getMixedUpperBound()[idx]);
116+
cstr.bound(value) >= lb;
117+
cstr.bound(value) < ub;
118+
}
119+
120+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
121+
ValueBoundsConstraintSet &cstr) const {
122+
auto forallOp = cast<ForallOp>(op);
123+
124+
// `value` is an iter_arg or an OpResult.
125+
int64_t iterArgIdx;
126+
if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
127+
iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size();
128+
} else {
129+
iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
130+
}
131+
132+
// The forall results and output arguments have the same sizes as the output
133+
// operands.
134+
Value outputOperand = forallOp.getOutputs()[iterArgIdx];
135+
cstr.bound(value)[dim] == cstr.getExpr(outputOperand, dim);
136+
}
137+
};
138+
98139
struct IfOpInterface
99140
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
100141

@@ -161,6 +202,7 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
161202
DialectRegistry &registry) {
162203
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
163204
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
205+
scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx);
164206
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
165207
});
166208
}

mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,42 @@ func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: in
107107

108108
// -----
109109

110+
// CHECK-LABEL: func @scf_forall(
111+
// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
112+
// CHECK: "test.some_use"(%[[a]], %[[b]])
113+
func.func @scf_forall(%a: index, %b: index, %c: index) {
114+
scf.forall (%iv) = (%a) to (%b) step (%c) {
115+
%0 = "test.reify_bound"(%iv) {type = "LB"} : (index) -> (index)
116+
%1 = "test.reify_bound"(%iv) {type = "UB"} : (index) -> (index)
117+
"test.some_use"(%0, %1) : (index, index) -> ()
118+
}
119+
return
120+
}
121+
122+
// -----
123+
124+
// CHECK-LABEL: func @scf_forall_tensor_result(
125+
// CHECK-SAME: %[[size:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
126+
// CHECK: "test.some_use"(%[[size]])
127+
// CHECK: "test.some_use"(%[[size]])
128+
func.func @scf_forall_tensor_result(%size: index, %a: index, %b: index, %c: index) {
129+
%cst = arith.constant 5.0 : f32
130+
%empty = tensor.empty(%size) : tensor<?xf32>
131+
%0 = scf.forall (%iv) = (%a) to (%b) step (%c) shared_outs(%arg = %empty) -> tensor<?xf32> {
132+
%filled = linalg.fill ins(%cst : f32) outs(%arg : tensor<?xf32>) -> tensor<?xf32>
133+
%1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
134+
"test.some_use"(%1) : (index) -> ()
135+
scf.forall.in_parallel {
136+
tensor.parallel_insert_slice %filled into %arg[0][%size][1] : tensor<?xf32> into tensor<?xf32>
137+
}
138+
}
139+
%2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
140+
"test.some_use"(%2) : (index) -> ()
141+
return
142+
}
143+
144+
// -----
145+
110146
// CHECK-LABEL: func @scf_if_constant(
111147
func.func @scf_if_constant(%c : i1) {
112148
// CHECK: arith.constant 4 : index

0 commit comments

Comments
 (0)