@@ -85,33 +85,42 @@ void mlir::getReachableAffineApplyOps(
85
85
// FlatAffineConstraints. (For eg., by using iv - lb % step = 0 and/or by
86
86
// introducing a method in FlatAffineConstraints setExprStride(ArrayRef<int64_t>
87
87
// expr, int64_t stride)
88
- LogicalResult mlir::getIndexSet (MutableArrayRef<AffineForOp> forOps ,
88
+ LogicalResult mlir::getIndexSet (MutableArrayRef<Operation *> ops ,
89
89
FlatAffineConstraints *domain) {
90
90
SmallVector<Value, 4 > indices;
91
+ SmallVector<AffineForOp, 8 > forOps;
92
+
93
+ for (Operation *op : ops) {
94
+ assert ((isa<AffineForOp, AffineIfOp>(op)) &&
95
+ " ops should have either AffineForOp or AffineIfOp" );
96
+ if (AffineForOp forOp = dyn_cast<AffineForOp>(op))
97
+ forOps.push_back (forOp);
98
+ }
91
99
extractForInductionVars (forOps, &indices);
92
100
// Reset while associated Values in 'indices' to the domain.
93
101
domain->reset (forOps.size (), /* numSymbols=*/ 0 , /* numLocals=*/ 0 , indices);
94
- for (auto forOp : forOps ) {
102
+ for (Operation *op : ops ) {
95
103
// Add constraints from forOp's bounds.
96
- if (failed (domain->addAffineForOpDomain (forOp)))
97
- return failure ();
104
+ if (AffineForOp forOp = dyn_cast<AffineForOp>(op)) {
105
+ if (failed (domain->addAffineForOpDomain (forOp)))
106
+ return failure ();
107
+ } else if (AffineIfOp ifOp = dyn_cast<AffineIfOp>(op)) {
108
+ domain->addAffineIfOpDomain (ifOp);
109
+ }
98
110
}
99
111
return success ();
100
112
}
101
113
102
- // Computes the iteration domain for 'opInst' and populates 'indexSet', which
103
- // encapsulates the constraints involving loops surrounding 'opInst' and
104
- // potentially involving any Function symbols. The dimensional identifiers in
105
- // 'indexSet' correspond to the loops surrounding 'op' from outermost to
106
- // innermost.
107
- // TODO: Add support to handle IfInsts surrounding 'op'.
108
- static LogicalResult getInstIndexSet (Operation *op,
109
- FlatAffineConstraints *indexSet) {
110
- // TODO: Extend this to gather enclosing IfInsts and consider
111
- // factoring it out into a utility function.
112
- SmallVector<AffineForOp, 4 > loops;
113
- getLoopIVs (*op, &loops);
114
- return getIndexSet (loops, indexSet);
114
+ // / Computes the iteration domain for 'op' and populates 'indexSet', which
115
+ // / encapsulates the constraints involving loops surrounding 'op' and
116
+ // / potentially involving any Function symbols. The dimensional identifiers in
117
+ // / 'indexSet' correspond to the loops surrounding 'op' from outermost to
118
+ // / innermost.
119
+ static LogicalResult getOpIndexSet (Operation *op,
120
+ FlatAffineConstraints *indexSet) {
121
+ SmallVector<Operation *, 4 > ops;
122
+ getEnclosingAffineForAndIfOps (*op, &ops);
123
+ return getIndexSet (ops, indexSet);
115
124
}
116
125
117
126
namespace {
@@ -209,32 +218,83 @@ static void buildDimAndSymbolPositionMaps(
209
218
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
210
219
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
211
220
FlatAffineConstraints *dependenceConstraints) {
212
- auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc) {
221
+
222
+ // IsDimState is a tri-state boolean. It is used to distinguish three
223
+ // different cases of the values passed to updateValuePosMap.
224
+ // - When it is TRUE, we are certain that all values are dim values.
225
+ // - When it is FALSE, we are certain that all values are symbol values.
226
+ // - When it is UNKNOWN, we need to further check whether the value is from a
227
+ // loop IV to determine its type (dim or symbol).
228
+
229
+ // We need this enumeration because sometimes we cannot determine whether a
230
+ // Value is a symbol or a dim by the information from the Value itself. If a
231
+ // Value appears in an affine map of a loop, we can determine whether it is a
232
+ // dim or not by the function `isForInductionVar`. But when a Value is in the
233
+ // affine set of an if-statement, there is no way to identify its category
234
+ // (dim/symbol) by itself. Fortunately, the Values to be inserted into
235
+ // `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such
236
+ // information of Value category: `srcDomain` and `dstDomain` organize Values
237
+ // by their category, such that the position of each Value stored in
238
+ // `srcDomain` and `dstDomain` marks which category that a Value belongs to.
239
+ // Therefore, we can separate Values into dim and symbol groups before passing
240
+ // them to the function `updateValuePosMap`. Specifically, when passing the
241
+ // dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE.
242
+ // However, Values from the operands of `srcAccessMap` and `dstAccessMap` are
243
+ // not explicitly categorized into dim or symbol, and we have to rely on
244
+ // `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in
245
+ // this case.
246
+ enum IsDimState { TRUE , FALSE , UNKNOWN };
247
+
248
+ // This function places each given Value (in `values`) under a respective
249
+ // category in `valuePosMap`. Specifically, the placement rules are:
250
+ // 1) If `isDim` is FALSE, then every value in `values` are inserted into
251
+ // `valuePosMap` as symbols.
252
+ // 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an
253
+ // induction variable of a for-loop, we treat it as symbol as well.
254
+ // 3) For other cases, we decide whether to add a value to the `src` or the
255
+ // `dst` section of the dim category simply by the boolean value `isSrc`.
256
+ auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc,
257
+ IsDimState isDim) {
213
258
for (unsigned i = 0 , e = values.size (); i < e; ++i) {
214
259
auto value = values[i];
215
- if (!isForInductionVar (values[i] )) {
216
- assert (isValidSymbol (values[i] ) &&
260
+ if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar (value) )) {
261
+ assert (isValidSymbol (value ) &&
217
262
" access operand has to be either a loop IV or a symbol" );
218
263
valuePosMap->addSymbolValue (value);
219
- } else if (isSrc) {
220
- valuePosMap->addSrcValue (value);
221
264
} else {
222
- valuePosMap->addDstValue (value);
265
+ if (isSrc)
266
+ valuePosMap->addSrcValue (value);
267
+ else
268
+ valuePosMap->addDstValue (value);
223
269
}
224
270
}
225
271
};
226
272
227
- SmallVector<Value, 4 > srcValues, destValues;
228
- srcDomain.getIdValues (0 , srcDomain.getNumDimAndSymbolIds (), &srcValues);
229
- dstDomain.getIdValues (0 , dstDomain.getNumDimAndSymbolIds (), &destValues);
230
- // Update value position map with identifiers from src iteration domain.
231
- updateValuePosMap (srcValues, /* isSrc=*/ true );
232
- // Update value position map with identifiers from dst iteration domain.
233
- updateValuePosMap (destValues, /* isSrc=*/ false );
273
+ // Collect values from the src and dst domains. For each domain, we separate
274
+ // the collected values into dim and symbol parts.
275
+ SmallVector<Value, 4 > srcDimValues, dstDimValues, srcSymbolValues,
276
+ dstSymbolValues;
277
+ srcDomain.getIdValues (0 , srcDomain.getNumDimIds (), &srcDimValues);
278
+ dstDomain.getIdValues (0 , dstDomain.getNumDimIds (), &dstDimValues);
279
+ srcDomain.getIdValues (srcDomain.getNumDimIds (),
280
+ srcDomain.getNumDimAndSymbolIds (), &srcSymbolValues);
281
+ dstDomain.getIdValues (dstDomain.getNumDimIds (),
282
+ dstDomain.getNumDimAndSymbolIds (), &dstSymbolValues);
283
+
284
+ // Update value position map with dim values from src iteration domain.
285
+ updateValuePosMap (srcDimValues, /* isSrc=*/ true , /* isDim=*/ TRUE );
286
+ // Update value position map with dim values from dst iteration domain.
287
+ updateValuePosMap (dstDimValues, /* isSrc=*/ false , /* isDim=*/ TRUE );
288
+ // Update value position map with symbols from src iteration domain.
289
+ updateValuePosMap (srcSymbolValues, /* isSrc=*/ true , /* isDim=*/ FALSE );
290
+ // Update value position map with symbols from dst iteration domain.
291
+ updateValuePosMap (dstSymbolValues, /* isSrc=*/ false , /* isDim=*/ FALSE );
234
292
// Update value position map with identifiers from src access function.
235
- updateValuePosMap (srcAccessMap.getOperands (), /* isSrc=*/ true );
293
+ updateValuePosMap (srcAccessMap.getOperands (), /* isSrc=*/ true ,
294
+ /* isDim=*/ UNKNOWN);
236
295
// Update value position map with identifiers from dst access function.
237
- updateValuePosMap (dstAccessMap.getOperands (), /* isSrc=*/ false );
296
+ updateValuePosMap (dstAccessMap.getOperands (), /* isSrc=*/ false ,
297
+ /* isDim=*/ UNKNOWN);
238
298
}
239
299
240
300
// Sets up dependence constraints columns appropriately, in the format:
@@ -270,24 +330,33 @@ static void initDependenceConstraints(
270
330
dependenceConstraints->setIdValues (
271
331
srcLoopIVs.size (), srcLoopIVs.size () + dstLoopIVs.size (), dstLoopIVs);
272
332
273
- // Set values for the symbolic identifier dimensions.
274
- auto setSymbolIds = [&](ArrayRef<Value> values) {
333
+ // Set values for the symbolic identifier dimensions. `isSymbolDetermined`
334
+ // indicates whether we are certain that the `values` passed in are all
335
+ // symbols. If `isSymbolDetermined` is true, then we treat every Value in
336
+ // `values` as a symbol; otherwise, we let the function `isForInductionVar` to
337
+ // distinguish whether a Value in `values` is a symbol or not.
338
+ auto setSymbolIds = [&](ArrayRef<Value> values,
339
+ bool isSymbolDetermined = true ) {
275
340
for (auto value : values) {
276
- if (!isForInductionVar (value)) {
341
+ if (isSymbolDetermined || !isForInductionVar (value)) {
277
342
assert (isValidSymbol (value) && " expected symbol" );
278
343
dependenceConstraints->setIdValue (valuePosMap.getSymPos (value), value);
279
344
}
280
345
}
281
346
};
282
347
283
- setSymbolIds (srcAccessMap.getOperands ());
284
- setSymbolIds (dstAccessMap.getOperands ());
348
+ // We are uncertain about whether all operands in `srcAccessMap` and
349
+ // `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false.
350
+ setSymbolIds (srcAccessMap.getOperands (), /* isSymbolDetermined=*/ false );
351
+ setSymbolIds (dstAccessMap.getOperands (), /* isSymbolDetermined=*/ false );
285
352
286
353
SmallVector<Value, 8 > srcSymbolValues, dstSymbolValues;
287
354
srcDomain.getIdValues (srcDomain.getNumDimIds (),
288
355
srcDomain.getNumDimAndSymbolIds (), &srcSymbolValues);
289
356
dstDomain.getIdValues (dstDomain.getNumDimIds (),
290
357
dstDomain.getNumDimAndSymbolIds (), &dstSymbolValues);
358
+ // Since we only take symbol Values out of `srcDomain` and `dstDomain`,
359
+ // `isSymbolDetermined` is kept to its default value: true.
291
360
setSymbolIds (srcSymbolValues);
292
361
setSymbolIds (dstSymbolValues);
293
362
@@ -530,22 +599,50 @@ getNumCommonLoops(const FlatAffineConstraints &srcDomain,
530
599
return numCommonLoops;
531
600
}
532
601
533
- // Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
602
+ // / Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
534
603
static Block *getCommonBlock (const MemRefAccess &srcAccess,
535
604
const MemRefAccess &dstAccess,
536
605
const FlatAffineConstraints &srcDomain,
537
606
unsigned numCommonLoops) {
607
+ // Get the chain of ancestor blocks to the given `MemRefAccess` instance. The
608
+ // search terminates when either an op with the `AffineScope` trait or
609
+ // `endBlock` is reached.
610
+ auto getChainOfAncestorBlocks = [&](const MemRefAccess &access,
611
+ SmallVector<Block *, 4 > &ancestorBlocks,
612
+ Block *endBlock = nullptr ) {
613
+ Block *currBlock = access.opInst ->getBlock ();
614
+ // Loop terminates when the currBlock is nullptr or equals to the endBlock,
615
+ // or its parent operation holds an affine scope.
616
+ while (currBlock && currBlock != endBlock &&
617
+ !currBlock->getParentOp ()->hasTrait <OpTrait::AffineScope>()) {
618
+ ancestorBlocks.push_back (currBlock);
619
+ currBlock = currBlock->getParentOp ()->getBlock ();
620
+ }
621
+ };
622
+
538
623
if (numCommonLoops == 0 ) {
539
- auto *block = srcAccess.opInst ->getBlock ();
624
+ Block *block = srcAccess.opInst ->getBlock ();
540
625
while (!llvm::isa<FuncOp>(block->getParentOp ())) {
541
626
block = block->getParentOp ()->getBlock ();
542
627
}
543
628
return block;
544
629
}
545
- auto commonForValue = srcDomain.getIdValue (numCommonLoops - 1 );
546
- auto forOp = getForInductionVarOwner (commonForValue );
630
+ Value commonForIV = srcDomain.getIdValue (numCommonLoops - 1 );
631
+ AffineForOp forOp = getForInductionVarOwner (commonForIV );
547
632
assert (forOp && " commonForValue was not an induction variable" );
548
- return forOp.getBody ();
633
+
634
+ // Find the closest common block including those in AffineIf.
635
+ SmallVector<Block *, 4 > srcAncestorBlocks, dstAncestorBlocks;
636
+ getChainOfAncestorBlocks (srcAccess, srcAncestorBlocks, forOp.getBody ());
637
+ getChainOfAncestorBlocks (dstAccess, dstAncestorBlocks, forOp.getBody ());
638
+
639
+ Block *commonBlock = forOp.getBody ();
640
+ for (int i = srcAncestorBlocks.size () - 1 , j = dstAncestorBlocks.size () - 1 ;
641
+ i >= 0 && j >= 0 && srcAncestorBlocks[i] == dstAncestorBlocks[j];
642
+ i--, j--)
643
+ commonBlock = srcAncestorBlocks[i];
644
+
645
+ return commonBlock;
549
646
}
550
647
551
648
// Returns true if the ancestor operation of 'srcAccess' appears before the
@@ -788,12 +885,12 @@ DependenceResult mlir::checkMemrefAccessDependence(
788
885
789
886
// Get iteration domain for the 'srcAccess' operation.
790
887
FlatAffineConstraints srcDomain;
791
- if (failed (getInstIndexSet (srcAccess.opInst , &srcDomain)))
888
+ if (failed (getOpIndexSet (srcAccess.opInst , &srcDomain)))
792
889
return DependenceResult::Failure;
793
890
794
891
// Get iteration domain for 'dstAccess' operation.
795
892
FlatAffineConstraints dstDomain;
796
- if (failed (getInstIndexSet (dstAccess.opInst , &dstDomain)))
893
+ if (failed (getOpIndexSet (dstAccess.opInst , &dstDomain)))
797
894
return DependenceResult::Failure;
798
895
799
896
// Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
@@ -814,7 +911,6 @@ DependenceResult mlir::checkMemrefAccessDependence(
814
911
buildDimAndSymbolPositionMaps (srcDomain, dstDomain, srcAccessMap,
815
912
dstAccessMap, &valuePosMap,
816
913
dependenceConstraints);
817
-
818
914
initDependenceConstraints (srcDomain, dstDomain, srcAccessMap, dstAccessMap,
819
915
valuePosMap, dependenceConstraints);
820
916
0 commit comments