@@ -61,6 +61,21 @@ void mlir::getEnclosingAffineForAndIfOps(Operation &op,
61
61
std::reverse (ops->begin (), ops->end ());
62
62
}
63
63
64
+ // Populates 'cst' with FlatAffineConstraints which represent original domain of
65
+ // the loop bounds that define 'ivs'.
66
+ LogicalResult
67
+ ComputationSliceState::getSourceAsConstraints (FlatAffineConstraints &cst) {
68
+ assert (!ivs.empty () && " Cannot have a slice without its IVs" );
69
+ cst.reset (/* numDims=*/ ivs.size (), /* numSymbols=*/ 0 , /* numLocals=*/ 0 , ivs);
70
+ for (Value iv : ivs) {
71
+ AffineForOp loop = getForInductionVarOwner (iv);
72
+ assert (loop && " Expected affine for" );
73
+ if (failed (cst.addAffineForOpDomain (loop)))
74
+ return failure ();
75
+ }
76
+ return success ();
77
+ }
78
+
64
79
// Populates 'cst' with FlatAffineConstraints which represent slice bounds.
65
80
LogicalResult
66
81
ComputationSliceState::getAsConstraints (FlatAffineConstraints *cst) {
@@ -75,9 +90,10 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
75
90
values.append (lbOperands[0 ].begin (), lbOperands[0 ].end ());
76
91
cst->reset (numDims, numSymbols, 0 , values);
77
92
78
- // Add loop bound constraints for values which are loop IVs and equality
79
- // constraints for symbols which are constants.
80
- for (const auto &value : values) {
93
+ // Add loop bound constraints for values which are loop IVs of the destination
94
+ // of fusion and equality constraints for symbols which are constants.
95
+ for (unsigned i = numDims, end = values.size (); i < end; ++i) {
96
+ Value value = values[i];
81
97
assert (cst->containsId (value) && " value expected to be present" );
82
98
if (isValidSymbol (value)) {
83
99
// Check if the symbol is a constant.
@@ -196,6 +212,76 @@ Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
196
212
return true ;
197
213
}
198
214
215
+ // / Returns true if it is deterministically verified that the original iteration
216
+ // / space of the slice is contained within the new iteration space that is
217
+ // / created after fusing 'this' slice into its destination.
218
+ Optional<bool > ComputationSliceState::isSliceValid () {
219
+ // Fast check to determine if the slice is valid. If the following conditions
220
+ // are verified to be true, slice is declared valid by the fast check:
221
+ // 1. Each slice loop is a single iteration loop bound in terms of a single
222
+ // destination loop IV.
223
+ // 2. Loop bounds of the destination loop IV (from above) and those of the
224
+ // source loop IV are exactly the same.
225
+ // If the fast check is inconclusive or false, we proceed with a more
226
+ // expensive analysis.
227
+ // TODO: Store the result of the fast check, as it might be used again in
228
+ // `canRemoveSrcNodeAfterFusion`.
229
+ Optional<bool > isValidFastCheck = isSliceMaximalFastCheck ();
230
+ if (isValidFastCheck.hasValue () && isValidFastCheck.getValue ())
231
+ return true ;
232
+
233
+ // Create constraints for the source loop nest using which slice is computed.
234
+ FlatAffineConstraints srcConstraints;
235
+ // TODO: Store the source's domain to avoid computation at each depth.
236
+ if (failed (getSourceAsConstraints (srcConstraints))) {
237
+ LLVM_DEBUG (llvm::dbgs () << " Unable to compute source's domain\n " );
238
+ return llvm::None;
239
+ }
240
+ // As the set difference utility currently cannot handle symbols in its
241
+ // operands, validity of the slice cannot be determined.
242
+ if (srcConstraints.getNumSymbolIds () > 0 ) {
243
+ LLVM_DEBUG (llvm::dbgs () << " Cannot handle symbols in source domain\n " );
244
+ return llvm::None;
245
+ }
246
+ // TODO: Handle local ids in the source domains while using the 'projectOut'
247
+ // utility below. Currently, aligning is not done assuming that there will be
248
+ // no local ids in the source domain.
249
+ if (srcConstraints.getNumLocalIds () != 0 ) {
250
+ LLVM_DEBUG (llvm::dbgs () << " Cannot handle locals in source domain\n " );
251
+ return llvm::None;
252
+ }
253
+
254
+ // Create constraints for the slice loop nest that would be created if the
255
+ // fusion succeeds.
256
+ FlatAffineConstraints sliceConstraints;
257
+ if (failed (getAsConstraints (&sliceConstraints))) {
258
+ LLVM_DEBUG (llvm::dbgs () << " Unable to compute slice's domain\n " );
259
+ return llvm::None;
260
+ }
261
+
262
+ // Projecting out every dimension other than the 'ivs' to express slice's
263
+ // domain completely in terms of source's IVs.
264
+ sliceConstraints.projectOut (ivs.size (),
265
+ sliceConstraints.getNumIds () - ivs.size ());
266
+
267
+ LLVM_DEBUG (llvm::dbgs () << " Domain of the source of the slice:\n " );
268
+ LLVM_DEBUG (srcConstraints.dump ());
269
+ LLVM_DEBUG (llvm::dbgs () << " Domain of the slice if this fusion succeeds "
270
+ " (expressed in terms of its source's IVs):\n " );
271
+ LLVM_DEBUG (sliceConstraints.dump ());
272
+
273
+ // TODO: Store 'srcSet' to avoid recalculating for each depth.
274
+ PresburgerSet srcSet (srcConstraints);
275
+ PresburgerSet sliceSet (sliceConstraints);
276
+ PresburgerSet diffSet = sliceSet.subtract (srcSet);
277
+
278
+ if (!diffSet.isIntegerEmpty ()) {
279
+ LLVM_DEBUG (llvm::dbgs () << " Incorrect slice\n " );
280
+ return false ;
281
+ }
282
+ return true ;
283
+ }
284
+
199
285
// / Returns true if the computation slice encloses all the iterations of the
200
286
// / sliced loop nest. Returns false if it does not. Returns llvm::None if it
201
287
// / cannot determine if the slice is maximal or not.
@@ -715,14 +801,14 @@ unsigned mlir::getInnermostCommonLoopDepth(
715
801
}
716
802
717
803
// / Computes in 'sliceUnion' the union of all slice bounds computed at
718
- // / 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
719
- // / Returns 'Success' if union was computed, 'failure' otherwise.
720
- LogicalResult mlir::computeSliceUnion (ArrayRef<Operation *> opsA,
721
- ArrayRef<Operation *> opsB,
722
- unsigned loopDepth ,
723
- unsigned numCommonLoops,
724
- bool isBackwardSlice,
725
- ComputationSliceState *sliceUnion) {
804
+ // / 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
805
+ // / then verifies if it is valid. Returns 'SliceComputationResult::Success' if
806
+ // / union was computed correctly, an appropriate failure otherwise.
807
+ SliceComputationResult
808
+ mlir::computeSliceUnion (ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB ,
809
+ unsigned loopDepth, unsigned numCommonLoops,
810
+ bool isBackwardSlice,
811
+ ComputationSliceState *sliceUnion) {
726
812
// Compute the union of slice bounds between all pairs in 'opsA' and
727
813
// 'opsB' in 'sliceUnionCst'.
728
814
FlatAffineConstraints sliceUnionCst;
@@ -738,7 +824,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
738
824
if ((!isBackwardSlice && loopDepth > getNestingDepth (opsA[i])) ||
739
825
(isBackwardSlice && loopDepth > getNestingDepth (opsB[j]))) {
740
826
LLVM_DEBUG (llvm::dbgs () << " Invalid loop depth\n " );
741
- return failure () ;
827
+ return SliceComputationResult::GenericFailure ;
742
828
}
743
829
744
830
bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst ) &&
@@ -751,7 +837,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
751
837
/* allowRAR=*/ readReadAccesses);
752
838
if (result.value == DependenceResult::Failure) {
753
839
LLVM_DEBUG (llvm::dbgs () << " Dependence check failed\n " );
754
- return failure () ;
840
+ return SliceComputationResult::GenericFailure ;
755
841
}
756
842
if (result.value == DependenceResult::NoDependence)
757
843
continue ;
@@ -768,7 +854,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
768
854
if (failed (tmpSliceState.getAsConstraints (&sliceUnionCst))) {
769
855
LLVM_DEBUG (llvm::dbgs ()
770
856
<< " Unable to compute slice bound constraints\n " );
771
- return failure () ;
857
+ return SliceComputationResult::GenericFailure ;
772
858
}
773
859
assert (sliceUnionCst.getNumDimAndSymbolIds () > 0 );
774
860
continue ;
@@ -779,7 +865,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
779
865
if (failed (tmpSliceState.getAsConstraints (&tmpSliceCst))) {
780
866
LLVM_DEBUG (llvm::dbgs ()
781
867
<< " Unable to compute slice bound constraints\n " );
782
- return failure () ;
868
+ return SliceComputationResult::GenericFailure ;
783
869
}
784
870
785
871
// Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
@@ -802,24 +888,24 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
802
888
// to unionBoundingBox below expects constraints for each Loop IV, even
803
889
// if they are the unsliced full loop bounds added here.
804
890
if (failed (addMissingLoopIVBounds (sliceUnionIVs, &sliceUnionCst)))
805
- return failure () ;
891
+ return SliceComputationResult::GenericFailure ;
806
892
if (failed (addMissingLoopIVBounds (tmpSliceIVs, &tmpSliceCst)))
807
- return failure () ;
893
+ return SliceComputationResult::GenericFailure ;
808
894
}
809
895
// Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
810
896
if (sliceUnionCst.getNumLocalIds () > 0 ||
811
897
tmpSliceCst.getNumLocalIds () > 0 ||
812
898
failed (sliceUnionCst.unionBoundingBox (tmpSliceCst))) {
813
899
LLVM_DEBUG (llvm::dbgs ()
814
900
<< " Unable to compute union bounding box of slice bounds\n " );
815
- return failure () ;
901
+ return SliceComputationResult::GenericFailure ;
816
902
}
817
903
}
818
904
}
819
905
820
906
// Empty union.
821
907
if (sliceUnionCst.getNumDimAndSymbolIds () == 0 )
822
- return failure () ;
908
+ return SliceComputationResult::GenericFailure ;
823
909
824
910
// Gather loops surrounding ops from loop nest where slice will be inserted.
825
911
SmallVector<Operation *, 4 > ops;
@@ -831,7 +917,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
831
917
getInnermostCommonLoopDepth (ops, &surroundingLoops);
832
918
if (loopDepth > innermostCommonLoopDepth) {
833
919
LLVM_DEBUG (llvm::dbgs () << " Exceeds max loop depth\n " );
834
- return failure () ;
920
+ return SliceComputationResult::GenericFailure ;
835
921
}
836
922
837
923
// Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
@@ -868,7 +954,18 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
868
954
// canonicalization.
869
955
sliceUnion->lbOperands .resize (numSliceLoopIVs, sliceBoundOperands);
870
956
sliceUnion->ubOperands .resize (numSliceLoopIVs, sliceBoundOperands);
871
- return success ();
957
+
958
+ // Check if the slice computed is valid. Return success only if it is verified
959
+ // that the slice is valid, otherwise return appropriate failure status.
960
+ Optional<bool > isSliceValid = sliceUnion->isSliceValid ();
961
+ if (!isSliceValid.hasValue ()) {
962
+ LLVM_DEBUG (llvm::dbgs () << " Cannot determine if the slice is valid\n " );
963
+ return SliceComputationResult::GenericFailure;
964
+ }
965
+ if (!isSliceValid.getValue ())
966
+ return SliceComputationResult::IncorrectSliceFailure;
967
+
968
+ return SliceComputationResult::Success;
872
969
}
873
970
874
971
const char *const kSliceFusionBarrierAttrName = " slice_fusion_barrier" ;
0 commit comments