Skip to content

Commit 654e8aa

Browse files
kumasentobondhugula
authored andcommitted
[MLIR] Consider AffineIfOp when getting the index set of an Op wrapped in nested loops
This diff attempts to resolve the TODO in `getOpIndexSet` (formerly known as `getInstIndexSet`), which states "Add support to handle IfInsts surronding `op`". Major changes in this diff: 1. Overload `getIndexSet`. The overloaded version considers both `AffineForOp` and `AffineIfOp`. 2. The `getInstIndexSet` is updated accordingly: its name is changed to `getOpIndexSet` and its implementation is based on a new API `getIVs` instead of `getLoopIVs`. 3. Add `addAffineIfOpDomain` to `FlatAffineConstraints`, which extracts new constraints from the integer set of `AffineIfOp` and merges it to the current constraint system. 4. Update how a `Value` is determined as dim or symbol for `ValuePositionMap` in `buildDimAndSymbolPositionMaps`. Differential Revision: https://reviews.llvm.org/D84698
1 parent d3153b5 commit 654e8aa

File tree

9 files changed

+365
-53
lines changed

9 files changed

+365
-53
lines changed

mlir/include/mlir/Analysis/AffineAnalysis.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ void getReachableAffineApplyOps(ArrayRef<Value> operands,
3434
SmallVectorImpl<Operation *> &affineApplyOps);
3535

3636
/// Builds a system of constraints with dimensional identifiers corresponding to
37-
/// the loop IVs of the forOps appearing in that order. Bounds of the loop are
38-
/// used to add appropriate inequalities. Any symbols founds in the bound
39-
/// operands are added as symbols in the system. Returns failure for the yet
40-
/// unimplemented cases.
37+
/// the loop IVs of the forOps and AffineIfOp's operands appearing in
38+
/// that order. Bounds of the loop are used to add appropriate inequalities.
39+
/// Constraints from the index sets of AffineIfOp are also added. Any symbols
40+
/// founds in the bound operands are added as symbols in the system. Returns
41+
/// failure for the yet unimplemented cases. `ops` accepts both AffineForOp and
42+
/// AffineIfOp.
4143
// TODO: handle non-unit strides.
42-
LogicalResult getIndexSet(MutableArrayRef<AffineForOp> forOps,
44+
LogicalResult getIndexSet(MutableArrayRef<Operation *> ops,
4345
FlatAffineConstraints *domain);
4446

4547
/// Encapsulates a memref load or store access information.

mlir/include/mlir/Analysis/AffineStructures.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace mlir {
2121

2222
class AffineCondition;
2323
class AffineForOp;
24+
class AffineIfOp;
2425
class AffineMap;
2526
class AffineValueMap;
2627
class IntegerSet;
@@ -215,6 +216,15 @@ class FlatAffineConstraints {
215216
// TODO: add support for non-unit strides.
216217
LogicalResult addAffineForOpDomain(AffineForOp forOp);
217218

219+
/// Adds constraints imposed by the `affine.if` operation. These constraints
220+
/// are collected from the IntegerSet attached to the given `affine.if`
221+
/// instance argument (`ifOp`). It is asserted that:
222+
/// 1) The IntegerSet of the given `affine.if` instance should not contain
223+
/// semi-affine expressions,
224+
/// 2) The columns of the constraint system created from `ifOp` should match
225+
/// the columns in the current one regarding numbers and values.
226+
void addAffineIfOpDomain(AffineIfOp ifOp);
227+
218228
/// Adds a lower or an upper bound for the identifier at the specified
219229
/// position with constraints being drawn from the specified bound map and
220230
/// operands. If `eq` is true, add a single equality equal to the bound map's

mlir/include/mlir/Analysis/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ class Value;
3939
// TODO: handle 'affine.if' ops.
4040
void getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);
4141

42+
/// Populates 'ops' with IVs of the loops surrounding `op`, along with
43+
/// `affine.if` operations interleaved between these loops, ordered from the
44+
/// outermost `affine.for` or `affine.if` operation to the innermost one.
45+
void getEnclosingAffineForAndIfOps(Operation &op,
46+
SmallVectorImpl<Operation *> *ops);
47+
4248
/// Returns the nesting depth of this operation, i.e., the number of loops
4349
/// surrounding this operation.
4450
unsigned getNestingDepth(Operation *op);

mlir/lib/Analysis/AffineAnalysis.cpp

Lines changed: 141 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -85,33 +85,42 @@ void mlir::getReachableAffineApplyOps(
8585
// FlatAffineConstraints. (For eg., by using iv - lb % step = 0 and/or by
8686
// introducing a method in FlatAffineConstraints setExprStride(ArrayRef<int64_t>
8787
// expr, int64_t stride)
88-
LogicalResult mlir::getIndexSet(MutableArrayRef<AffineForOp> forOps,
88+
LogicalResult mlir::getIndexSet(MutableArrayRef<Operation *> ops,
8989
FlatAffineConstraints *domain) {
9090
SmallVector<Value, 4> indices;
91+
SmallVector<AffineForOp, 8> forOps;
92+
93+
for (Operation *op : ops) {
94+
assert((isa<AffineForOp, AffineIfOp>(op)) &&
95+
"ops should have either AffineForOp or AffineIfOp");
96+
if (AffineForOp forOp = dyn_cast<AffineForOp>(op))
97+
forOps.push_back(forOp);
98+
}
9199
extractForInductionVars(forOps, &indices);
92100
// Reset while associated Values in 'indices' to the domain.
93101
domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
94-
for (auto forOp : forOps) {
102+
for (Operation *op : ops) {
95103
// Add constraints from forOp's bounds.
96-
if (failed(domain->addAffineForOpDomain(forOp)))
97-
return failure();
104+
if (AffineForOp forOp = dyn_cast<AffineForOp>(op)) {
105+
if (failed(domain->addAffineForOpDomain(forOp)))
106+
return failure();
107+
} else if (AffineIfOp ifOp = dyn_cast<AffineIfOp>(op)) {
108+
domain->addAffineIfOpDomain(ifOp);
109+
}
98110
}
99111
return success();
100112
}
101113

102-
// Computes the iteration domain for 'opInst' and populates 'indexSet', which
103-
// encapsulates the constraints involving loops surrounding 'opInst' and
104-
// potentially involving any Function symbols. The dimensional identifiers in
105-
// 'indexSet' correspond to the loops surrounding 'op' from outermost to
106-
// innermost.
107-
// TODO: Add support to handle IfInsts surrounding 'op'.
108-
static LogicalResult getInstIndexSet(Operation *op,
109-
FlatAffineConstraints *indexSet) {
110-
// TODO: Extend this to gather enclosing IfInsts and consider
111-
// factoring it out into a utility function.
112-
SmallVector<AffineForOp, 4> loops;
113-
getLoopIVs(*op, &loops);
114-
return getIndexSet(loops, indexSet);
114+
/// Computes the iteration domain for 'op' and populates 'indexSet', which
115+
/// encapsulates the constraints involving loops surrounding 'op' and
116+
/// potentially involving any Function symbols. The dimensional identifiers in
117+
/// 'indexSet' correspond to the loops surrounding 'op' from outermost to
118+
/// innermost.
119+
static LogicalResult getOpIndexSet(Operation *op,
120+
FlatAffineConstraints *indexSet) {
121+
SmallVector<Operation *, 4> ops;
122+
getEnclosingAffineForAndIfOps(*op, &ops);
123+
return getIndexSet(ops, indexSet);
115124
}
116125

117126
namespace {
@@ -209,32 +218,83 @@ static void buildDimAndSymbolPositionMaps(
209218
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
210219
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
211220
FlatAffineConstraints *dependenceConstraints) {
212-
auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc) {
221+
222+
// IsDimState is a tri-state boolean. It is used to distinguish three
223+
// different cases of the values passed to updateValuePosMap.
224+
// - When it is TRUE, we are certain that all values are dim values.
225+
// - When it is FALSE, we are certain that all values are symbol values.
226+
// - When it is UNKNOWN, we need to further check whether the value is from a
227+
// loop IV to determine its type (dim or symbol).
228+
229+
// We need this enumeration because sometimes we cannot determine whether a
230+
// Value is a symbol or a dim by the information from the Value itself. If a
231+
// Value appears in an affine map of a loop, we can determine whether it is a
232+
// dim or not by the function `isForInductionVar`. But when a Value is in the
233+
// affine set of an if-statement, there is no way to identify its category
234+
// (dim/symbol) by itself. Fortunately, the Values to be inserted into
235+
// `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such
236+
// information of Value category: `srcDomain` and `dstDomain` organize Values
237+
// by their category, such that the position of each Value stored in
238+
// `srcDomain` and `dstDomain` marks which category that a Value belongs to.
239+
// Therefore, we can separate Values into dim and symbol groups before passing
240+
// them to the function `updateValuePosMap`. Specifically, when passing the
241+
// dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE.
242+
// However, Values from the operands of `srcAccessMap` and `dstAccessMap` are
243+
// not explicitly categorized into dim or symbol, and we have to rely on
244+
// `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in
245+
// this case.
246+
enum IsDimState { TRUE, FALSE, UNKNOWN };
247+
248+
// This function places each given Value (in `values`) under a respective
249+
// category in `valuePosMap`. Specifically, the placement rules are:
250+
// 1) If `isDim` is FALSE, then every value in `values` are inserted into
251+
// `valuePosMap` as symbols.
252+
// 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an
253+
// induction variable of a for-loop, we treat it as symbol as well.
254+
// 3) For other cases, we decide whether to add a value to the `src` or the
255+
// `dst` section of the dim category simply by the boolean value `isSrc`.
256+
auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc,
257+
IsDimState isDim) {
213258
for (unsigned i = 0, e = values.size(); i < e; ++i) {
214259
auto value = values[i];
215-
if (!isForInductionVar(values[i])) {
216-
assert(isValidSymbol(values[i]) &&
260+
if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) {
261+
assert(isValidSymbol(value) &&
217262
"access operand has to be either a loop IV or a symbol");
218263
valuePosMap->addSymbolValue(value);
219-
} else if (isSrc) {
220-
valuePosMap->addSrcValue(value);
221264
} else {
222-
valuePosMap->addDstValue(value);
265+
if (isSrc)
266+
valuePosMap->addSrcValue(value);
267+
else
268+
valuePosMap->addDstValue(value);
223269
}
224270
}
225271
};
226272

227-
SmallVector<Value, 4> srcValues, destValues;
228-
srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues);
229-
dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues);
230-
// Update value position map with identifiers from src iteration domain.
231-
updateValuePosMap(srcValues, /*isSrc=*/true);
232-
// Update value position map with identifiers from dst iteration domain.
233-
updateValuePosMap(destValues, /*isSrc=*/false);
273+
// Collect values from the src and dst domains. For each domain, we separate
274+
// the collected values into dim and symbol parts.
275+
SmallVector<Value, 4> srcDimValues, dstDimValues, srcSymbolValues,
276+
dstSymbolValues;
277+
srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcDimValues);
278+
dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstDimValues);
279+
srcDomain.getIdValues(srcDomain.getNumDimIds(),
280+
srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
281+
dstDomain.getIdValues(dstDomain.getNumDimIds(),
282+
dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
283+
284+
// Update value position map with dim values from src iteration domain.
285+
updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE);
286+
// Update value position map with dim values from dst iteration domain.
287+
updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE);
288+
// Update value position map with symbols from src iteration domain.
289+
updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE);
290+
// Update value position map with symbols from dst iteration domain.
291+
updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE);
234292
// Update value position map with identifiers from src access function.
235-
updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true);
293+
updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true,
294+
/*isDim=*/UNKNOWN);
236295
// Update value position map with identifiers from dst access function.
237-
updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
296+
updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false,
297+
/*isDim=*/UNKNOWN);
238298
}
239299

240300
// Sets up dependence constraints columns appropriately, in the format:
@@ -270,24 +330,33 @@ static void initDependenceConstraints(
270330
dependenceConstraints->setIdValues(
271331
srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
272332

273-
// Set values for the symbolic identifier dimensions.
274-
auto setSymbolIds = [&](ArrayRef<Value> values) {
333+
// Set values for the symbolic identifier dimensions. `isSymbolDetermined`
334+
// indicates whether we are certain that the `values` passed in are all
335+
// symbols. If `isSymbolDetermined` is true, then we treat every Value in
336+
// `values` as a symbol; otherwise, we let the function `isForInductionVar` to
337+
// distinguish whether a Value in `values` is a symbol or not.
338+
auto setSymbolIds = [&](ArrayRef<Value> values,
339+
bool isSymbolDetermined = true) {
275340
for (auto value : values) {
276-
if (!isForInductionVar(value)) {
341+
if (isSymbolDetermined || !isForInductionVar(value)) {
277342
assert(isValidSymbol(value) && "expected symbol");
278343
dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
279344
}
280345
}
281346
};
282347

283-
setSymbolIds(srcAccessMap.getOperands());
284-
setSymbolIds(dstAccessMap.getOperands());
348+
// We are uncertain about whether all operands in `srcAccessMap` and
349+
// `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false.
350+
setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false);
351+
setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false);
285352

286353
SmallVector<Value, 8> srcSymbolValues, dstSymbolValues;
287354
srcDomain.getIdValues(srcDomain.getNumDimIds(),
288355
srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
289356
dstDomain.getIdValues(dstDomain.getNumDimIds(),
290357
dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
358+
// Since we only take symbol Values out of `srcDomain` and `dstDomain`,
359+
// `isSymbolDetermined` is kept to its default value: true.
291360
setSymbolIds(srcSymbolValues);
292361
setSymbolIds(dstSymbolValues);
293362

@@ -530,22 +599,50 @@ getNumCommonLoops(const FlatAffineConstraints &srcDomain,
530599
return numCommonLoops;
531600
}
532601

533-
// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
602+
/// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
534603
static Block *getCommonBlock(const MemRefAccess &srcAccess,
535604
const MemRefAccess &dstAccess,
536605
const FlatAffineConstraints &srcDomain,
537606
unsigned numCommonLoops) {
607+
// Get the chain of ancestor blocks to the given `MemRefAccess` instance. The
608+
// search terminates when either an op with the `AffineScope` trait or
609+
// `endBlock` is reached.
610+
auto getChainOfAncestorBlocks = [&](const MemRefAccess &access,
611+
SmallVector<Block *, 4> &ancestorBlocks,
612+
Block *endBlock = nullptr) {
613+
Block *currBlock = access.opInst->getBlock();
614+
// Loop terminates when the currBlock is nullptr or equals to the endBlock,
615+
// or its parent operation holds an affine scope.
616+
while (currBlock && currBlock != endBlock &&
617+
!currBlock->getParentOp()->hasTrait<OpTrait::AffineScope>()) {
618+
ancestorBlocks.push_back(currBlock);
619+
currBlock = currBlock->getParentOp()->getBlock();
620+
}
621+
};
622+
538623
if (numCommonLoops == 0) {
539-
auto *block = srcAccess.opInst->getBlock();
624+
Block *block = srcAccess.opInst->getBlock();
540625
while (!llvm::isa<FuncOp>(block->getParentOp())) {
541626
block = block->getParentOp()->getBlock();
542627
}
543628
return block;
544629
}
545-
auto commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
546-
auto forOp = getForInductionVarOwner(commonForValue);
630+
Value commonForIV = srcDomain.getIdValue(numCommonLoops - 1);
631+
AffineForOp forOp = getForInductionVarOwner(commonForIV);
547632
assert(forOp && "commonForValue was not an induction variable");
548-
return forOp.getBody();
633+
634+
// Find the closest common block including those in AffineIf.
635+
SmallVector<Block *, 4> srcAncestorBlocks, dstAncestorBlocks;
636+
getChainOfAncestorBlocks(srcAccess, srcAncestorBlocks, forOp.getBody());
637+
getChainOfAncestorBlocks(dstAccess, dstAncestorBlocks, forOp.getBody());
638+
639+
Block *commonBlock = forOp.getBody();
640+
for (int i = srcAncestorBlocks.size() - 1, j = dstAncestorBlocks.size() - 1;
641+
i >= 0 && j >= 0 && srcAncestorBlocks[i] == dstAncestorBlocks[j];
642+
i--, j--)
643+
commonBlock = srcAncestorBlocks[i];
644+
645+
return commonBlock;
549646
}
550647

551648
// Returns true if the ancestor operation of 'srcAccess' appears before the
@@ -788,12 +885,12 @@ DependenceResult mlir::checkMemrefAccessDependence(
788885

789886
// Get iteration domain for the 'srcAccess' operation.
790887
FlatAffineConstraints srcDomain;
791-
if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain)))
888+
if (failed(getOpIndexSet(srcAccess.opInst, &srcDomain)))
792889
return DependenceResult::Failure;
793890

794891
// Get iteration domain for 'dstAccess' operation.
795892
FlatAffineConstraints dstDomain;
796-
if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain)))
893+
if (failed(getOpIndexSet(dstAccess.opInst, &dstDomain)))
797894
return DependenceResult::Failure;
798895

799896
// Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
@@ -814,7 +911,6 @@ DependenceResult mlir::checkMemrefAccessDependence(
814911
buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
815912
dstAccessMap, &valuePosMap,
816913
dependenceConstraints);
817-
818914
initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
819915
valuePosMap, dependenceConstraints);
820916

mlir/lib/Analysis/AffineStructures.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,18 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
723723
/*eq=*/false, /*lower=*/false);
724724
}
725725

726+
void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
727+
// Create the base constraints from the integer set attached to ifOp.
728+
FlatAffineConstraints cst(ifOp.getIntegerSet());
729+
730+
// Bind ids in the constraints to ifOp operands.
731+
SmallVector<Value, 4> operands = ifOp.getOperands();
732+
cst.setIdValues(0, cst.getNumDimAndSymbolIds(), operands);
733+
734+
// Merge the constraints from ifOp to the current domain.
735+
mergeAndAlignIdsWithOther(0, &cst);
736+
}
737+
726738
// Searches for a constraint with a non-zero coefficient at 'colIdx' in
727739
// equality (isEq=true) or inequality (isEq=false) constraints.
728740
// Returns true and sets row found in search in 'rowIdx'.

mlir/lib/Analysis/Utils.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
4444
std::reverse(loops->begin(), loops->end());
4545
}
4646

47+
/// Populates 'ops' with IVs of the loops surrounding `op`, along with
48+
/// `affine.if` operations interleaved between these loops, ordered from the
49+
/// outermost `affine.for` operation to the innermost one.
50+
void mlir::getEnclosingAffineForAndIfOps(Operation &op,
51+
SmallVectorImpl<Operation *> *ops) {
52+
ops->clear();
53+
Operation *currOp = op.getParentOp();
54+
55+
// Traverse up the hierarchy collecting all `affine.for` and `affine.if`
56+
// operations.
57+
while (currOp && (isa<AffineIfOp, AffineForOp>(currOp))) {
58+
ops->push_back(currOp);
59+
currOp = currOp->getParentOp();
60+
}
61+
std::reverse(ops->begin(), ops->end());
62+
}
63+
4764
// Populates 'cst' with FlatAffineConstraints which represent slice bounds.
4865
LogicalResult
4966
ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {

mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,11 @@ mlir::tilePerfectlyNested(MutableArrayRef<AffineForOp> input,
289289
extractForInductionVars(input, &origLoopIVs);
290290

291291
FlatAffineConstraints cst;
292-
getIndexSet(input, &cst);
292+
SmallVector<Operation *, 8> ops;
293+
ops.reserve(input.size());
294+
for (AffineForOp forOp : input)
295+
ops.push_back(forOp);
296+
getIndexSet(ops, &cst);
293297
if (!cst.isHyperRectangular(0, width)) {
294298
rootAffineForOp.emitError("tiled code generation unimplemented for the "
295299
"non-hyperrectangular case");

0 commit comments

Comments
 (0)