Skip to content

Commit dc53715

Browse files
Vinayaka Bandishtibondhugula
authored andcommitted
[MLIR][Affine] Add utility to check if the slice is valid
Fixes a bug in affine fusion pipeline where an incorrect slice is computed. After the slice computation is done, original domain of the the source is compared with the new domain that will result if the fusion succeeds. If the new domain must be a subset of the original domain for the slice to be valid. If the slice computed is incorrect, fusion based on such a slice is avoided. Relevant test cases are added/edited. Fixes https://bugs.llvm.org/show_bug.cgi?id=49203 Differential Revision: https://reviews.llvm.org/D98239
1 parent b468f0e commit dc53715

File tree

8 files changed

+281
-61
lines changed

8 files changed

+281
-61
lines changed

mlir/include/mlir/Analysis/Utils.h

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ unsigned getNestingDepth(Operation *op);
5454
void getSequentialLoops(AffineForOp forOp,
5555
llvm::SmallDenseSet<Value, 8> *sequentialLoops);
5656

57+
/// Enumerates different result statuses of slice computation by
58+
/// `computeSliceUnion`
59+
// TODO: Identify and add different kinds of failures during slice computation.
60+
struct SliceComputationResult {
61+
enum ResultEnum {
62+
Success,
63+
IncorrectSliceFailure, // Slice is computed, but it is incorrect.
64+
GenericFailure, // Unable to compute src loop computation slice.
65+
} value;
66+
SliceComputationResult(ResultEnum v) : value(v) {}
67+
};
68+
5769
/// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their
5870
/// associated operands for a set of loops within a loop nest (typically the
5971
/// set of loops surrounding a store operation). Loop bound AffineMaps which
@@ -80,6 +92,12 @@ struct ComputationSliceState {
8092
// Returns failure if we cannot add loop bounds because of unsupported cases.
8193
LogicalResult getAsConstraints(FlatAffineConstraints *cst);
8294

95+
/// Adds to 'cst' constraints which represent the original loop bounds on
96+
/// 'ivs' in 'this'. This corresponds to the original domain of the loop nest
97+
/// from which the slice is being computed. Returns failure if we cannot add
98+
/// loop bounds because of unsupported cases.
99+
LogicalResult getSourceAsConstraints(FlatAffineConstraints &cst);
100+
83101
// Clears all bounds and operands in slice state.
84102
void clearBounds();
85103

@@ -93,6 +111,22 @@ struct ComputationSliceState {
93111
// information hasn't changed.
94112
Optional<bool> isMaximal() const;
95113

114+
/// Checks the validity of the slice computed. This is done using the
115+
/// following steps:
116+
/// 1. Get the new domain of the slice that would be created if fusion
117+
/// succeeds. This domain gets constructed with source loop IVS and
118+
/// destination loop IVS as dimensions.
119+
/// 2. Project out the dimensions of the destination loop from the domain
120+
/// above calculated in step(1) to express it purely in terms of the source
121+
/// loop IVs.
122+
/// 3. Calculate a set difference between the iterations of the new domain and
123+
/// the original domain of the source loop.
124+
/// If this difference is empty, the slice is declared to be valid. Otherwise,
125+
/// return false as it implies that the effective fusion results in at least
126+
/// one iteration of the slice that was not originally in the source's domain.
127+
/// If the validity cannot be determined, returns llvm:None.
128+
Optional<bool> isSliceValid();
129+
96130
void dump() const;
97131

98132
private:
@@ -151,21 +185,21 @@ void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp,
151185
ComputationSliceState *sliceState);
152186

153187
/// Computes in 'sliceUnion' the union of all slice bounds computed at
154-
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
155-
/// The parameter 'numCommonLoops' is the number of loops common to the
156-
/// operations in 'opsA' and 'opsB'.
157-
/// If 'isBackwardSlice' is true, computes slice bounds for loop nest
158-
/// surrounding ops in 'opsA', as a function of IVs and symbols of loop nest
159-
/// surrounding ops in 'opsB' at 'loopDepth'.
160-
/// If 'isBackwardSlice' is false, computes slice bounds for loop nest
161-
/// surrounding ops in 'opsB', as a function of IVs and symbols of loop nest
162-
/// surrounding ops in 'opsA' at 'loopDepth'.
163-
/// Returns 'success' if union was computed, 'failure' otherwise.
188+
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
189+
/// then verifies if it is valid. The parameter 'numCommonLoops' is the number
190+
/// of loops common to the operations in 'opsA' and 'opsB'. If 'isBackwardSlice'
191+
/// is true, computes slice bounds for loop nest surrounding ops in 'opsA', as a
192+
/// function of IVs and symbols of loop nest surrounding ops in 'opsB' at
193+
/// 'loopDepth'. If 'isBackwardSlice' is false, computes slice bounds for loop
194+
/// nest surrounding ops in 'opsB', as a function of IVs and symbols of loop
195+
/// nest surrounding ops in 'opsA' at 'loopDepth'. Returns
196+
/// 'SliceComputationResult::Success' if union was computed correctly, an
197+
/// appropriate 'failure' otherwise.
164198
// TODO: Change this API to take 'forOpA'/'forOpB'.
165-
LogicalResult computeSliceUnion(ArrayRef<Operation *> opsA,
166-
ArrayRef<Operation *> opsB, unsigned loopDepth,
167-
unsigned numCommonLoops, bool isBackwardSlice,
168-
ComputationSliceState *sliceUnion);
199+
SliceComputationResult
200+
computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
201+
unsigned loopDepth, unsigned numCommonLoops,
202+
bool isBackwardSlice, ComputationSliceState *sliceUnion);
169203

170204
/// Creates a clone of the computation contained in the loop nest surrounding
171205
/// 'srcOpInst', slices the iteration space of src loop based on slice bounds

mlir/include/mlir/Transforms/LoopFusionUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct FusionResult {
3535
FailBlockDependence, // Fusion would violate another dependence in block.
3636
FailFusionDependence, // Fusion would reverse dependences between loops.
3737
FailComputationSlice, // Unable to compute src loop computation slice.
38+
FailIncorrectSlice, // Slice is computed, but it is incorrect.
3839
} value;
3940
FusionResult(ResultEnum v) : value(v) {}
4041
};

mlir/lib/Analysis/AffineStructures.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,13 +2128,22 @@ LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
21282128
continue;
21292129
}
21302130

2131-
if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
2132-
/*lower=*/true)))
2133-
return failure();
2134-
2135-
if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
2136-
/*lower=*/false)))
2137-
return failure();
2131+
// If lower or upper bound maps are null or provide no results, it implies
2132+
// that the source loop was not at all sliced, and the entire loop will be a
2133+
// part of the slice.
2134+
if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
2135+
ubMap.getNumResults() != 0) {
2136+
if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
2137+
/*lower=*/true)))
2138+
return failure();
2139+
if (failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
2140+
/*lower=*/false)))
2141+
return failure();
2142+
} else {
2143+
auto loop = getForInductionVarOwner(values[i]);
2144+
if (failed(this->addAffineForOpDomain(loop)))
2145+
return failure();
2146+
}
21382147
}
21392148
return success();
21402149
}

mlir/lib/Analysis/Utils.cpp

Lines changed: 118 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ void mlir::getEnclosingAffineForAndIfOps(Operation &op,
6161
std::reverse(ops->begin(), ops->end());
6262
}
6363

64+
// Populates 'cst' with FlatAffineConstraints which represent original domain of
65+
// the loop bounds that define 'ivs'.
66+
LogicalResult
67+
ComputationSliceState::getSourceAsConstraints(FlatAffineConstraints &cst) {
68+
assert(!ivs.empty() && "Cannot have a slice without its IVs");
69+
cst.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs);
70+
for (Value iv : ivs) {
71+
AffineForOp loop = getForInductionVarOwner(iv);
72+
assert(loop && "Expected affine for");
73+
if (failed(cst.addAffineForOpDomain(loop)))
74+
return failure();
75+
}
76+
return success();
77+
}
78+
6479
// Populates 'cst' with FlatAffineConstraints which represent slice bounds.
6580
LogicalResult
6681
ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
@@ -75,9 +90,10 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
7590
values.append(lbOperands[0].begin(), lbOperands[0].end());
7691
cst->reset(numDims, numSymbols, 0, values);
7792

78-
// Add loop bound constraints for values which are loop IVs and equality
79-
// constraints for symbols which are constants.
80-
for (const auto &value : values) {
93+
// Add loop bound constraints for values which are loop IVs of the destination
94+
// of fusion and equality constraints for symbols which are constants.
95+
for (unsigned i = numDims, end = values.size(); i < end; ++i) {
96+
Value value = values[i];
8197
assert(cst->containsId(value) && "value expected to be present");
8298
if (isValidSymbol(value)) {
8399
// Check if the symbol is a constant.
@@ -196,6 +212,76 @@ Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
196212
return true;
197213
}
198214

215+
/// Returns true if it is deterministically verified that the original iteration
216+
/// space of the slice is contained within the new iteration space that is
217+
/// created after fusing 'this' slice into its destination.
218+
Optional<bool> ComputationSliceState::isSliceValid() {
219+
// Fast check to determine if the slice is valid. If the following conditions
220+
// are verified to be true, slice is declared valid by the fast check:
221+
// 1. Each slice loop is a single iteration loop bound in terms of a single
222+
// destination loop IV.
223+
// 2. Loop bounds of the destination loop IV (from above) and those of the
224+
// source loop IV are exactly the same.
225+
// If the fast check is inconclusive or false, we proceed with a more
226+
// expensive analysis.
227+
// TODO: Store the result of the fast check, as it might be used again in
228+
// `canRemoveSrcNodeAfterFusion`.
229+
Optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
230+
if (isValidFastCheck.hasValue() && isValidFastCheck.getValue())
231+
return true;
232+
233+
// Create constraints for the source loop nest using which slice is computed.
234+
FlatAffineConstraints srcConstraints;
235+
// TODO: Store the source's domain to avoid computation at each depth.
236+
if (failed(getSourceAsConstraints(srcConstraints))) {
237+
LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
238+
return llvm::None;
239+
}
240+
// As the set difference utility currently cannot handle symbols in its
241+
// operands, validity of the slice cannot be determined.
242+
if (srcConstraints.getNumSymbolIds() > 0) {
243+
LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
244+
return llvm::None;
245+
}
246+
// TODO: Handle local ids in the source domains while using the 'projectOut'
247+
// utility below. Currently, aligning is not done assuming that there will be
248+
// no local ids in the source domain.
249+
if (srcConstraints.getNumLocalIds() != 0) {
250+
LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
251+
return llvm::None;
252+
}
253+
254+
// Create constraints for the slice loop nest that would be created if the
255+
// fusion succeeds.
256+
FlatAffineConstraints sliceConstraints;
257+
if (failed(getAsConstraints(&sliceConstraints))) {
258+
LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
259+
return llvm::None;
260+
}
261+
262+
// Projecting out every dimension other than the 'ivs' to express slice's
263+
// domain completely in terms of source's IVs.
264+
sliceConstraints.projectOut(ivs.size(),
265+
sliceConstraints.getNumIds() - ivs.size());
266+
267+
LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
268+
LLVM_DEBUG(srcConstraints.dump());
269+
LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
270+
"(expressed in terms of its source's IVs):\n");
271+
LLVM_DEBUG(sliceConstraints.dump());
272+
273+
// TODO: Store 'srcSet' to avoid recalculating for each depth.
274+
PresburgerSet srcSet(srcConstraints);
275+
PresburgerSet sliceSet(sliceConstraints);
276+
PresburgerSet diffSet = sliceSet.subtract(srcSet);
277+
278+
if (!diffSet.isIntegerEmpty()) {
279+
LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
280+
return false;
281+
}
282+
return true;
283+
}
284+
199285
/// Returns true if the computation slice encloses all the iterations of the
200286
/// sliced loop nest. Returns false if it does not. Returns llvm::None if it
201287
/// cannot determine if the slice is maximal or not.
@@ -715,14 +801,14 @@ unsigned mlir::getInnermostCommonLoopDepth(
715801
}
716802

717803
/// Computes in 'sliceUnion' the union of all slice bounds computed at
718-
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
719-
/// Returns 'Success' if union was computed, 'failure' otherwise.
720-
LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
721-
ArrayRef<Operation *> opsB,
722-
unsigned loopDepth,
723-
unsigned numCommonLoops,
724-
bool isBackwardSlice,
725-
ComputationSliceState *sliceUnion) {
804+
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
805+
/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
806+
/// union was computed correctly, an appropriate failure otherwise.
807+
SliceComputationResult
808+
mlir::computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
809+
unsigned loopDepth, unsigned numCommonLoops,
810+
bool isBackwardSlice,
811+
ComputationSliceState *sliceUnion) {
726812
// Compute the union of slice bounds between all pairs in 'opsA' and
727813
// 'opsB' in 'sliceUnionCst'.
728814
FlatAffineConstraints sliceUnionCst;
@@ -738,7 +824,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
738824
if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) ||
739825
(isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) {
740826
LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
741-
return failure();
827+
return SliceComputationResult::GenericFailure;
742828
}
743829

744830
bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
@@ -751,7 +837,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
751837
/*allowRAR=*/readReadAccesses);
752838
if (result.value == DependenceResult::Failure) {
753839
LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
754-
return failure();
840+
return SliceComputationResult::GenericFailure;
755841
}
756842
if (result.value == DependenceResult::NoDependence)
757843
continue;
@@ -768,7 +854,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
768854
if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
769855
LLVM_DEBUG(llvm::dbgs()
770856
<< "Unable to compute slice bound constraints\n");
771-
return failure();
857+
return SliceComputationResult::GenericFailure;
772858
}
773859
assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
774860
continue;
@@ -779,7 +865,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
779865
if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
780866
LLVM_DEBUG(llvm::dbgs()
781867
<< "Unable to compute slice bound constraints\n");
782-
return failure();
868+
return SliceComputationResult::GenericFailure;
783869
}
784870

785871
// Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
@@ -802,24 +888,24 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
802888
// to unionBoundingBox below expects constraints for each Loop IV, even
803889
// if they are the unsliced full loop bounds added here.
804890
if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
805-
return failure();
891+
return SliceComputationResult::GenericFailure;
806892
if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
807-
return failure();
893+
return SliceComputationResult::GenericFailure;
808894
}
809895
// Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
810896
if (sliceUnionCst.getNumLocalIds() > 0 ||
811897
tmpSliceCst.getNumLocalIds() > 0 ||
812898
failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
813899
LLVM_DEBUG(llvm::dbgs()
814900
<< "Unable to compute union bounding box of slice bounds\n");
815-
return failure();
901+
return SliceComputationResult::GenericFailure;
816902
}
817903
}
818904
}
819905

820906
// Empty union.
821907
if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
822-
return failure();
908+
return SliceComputationResult::GenericFailure;
823909

824910
// Gather loops surrounding ops from loop nest where slice will be inserted.
825911
SmallVector<Operation *, 4> ops;
@@ -831,7 +917,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
831917
getInnermostCommonLoopDepth(ops, &surroundingLoops);
832918
if (loopDepth > innermostCommonLoopDepth) {
833919
LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
834-
return failure();
920+
return SliceComputationResult::GenericFailure;
835921
}
836922

837923
// Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
@@ -868,7 +954,18 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
868954
// canonicalization.
869955
sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
870956
sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
871-
return success();
957+
958+
// Check if the slice computed is valid. Return success only if it is verified
959+
// that the slice is valid, otherwise return appropriate failure status.
960+
Optional<bool> isSliceValid = sliceUnion->isSliceValid();
961+
if (!isSliceValid.hasValue()) {
962+
LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
963+
return SliceComputationResult::GenericFailure;
964+
}
965+
if (!isSliceValid.getValue())
966+
return SliceComputationResult::IncorrectSliceFailure;
967+
968+
return SliceComputationResult::Success;
872969
}
873970

874971
const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";

0 commit comments

Comments
 (0)