@@ -55,13 +55,16 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
55
55
return ss;
56
56
}
57
57
58
- // inferring the relationship of two indexing map
59
- // j -> i, means j is represented as the same symbol as i
60
- // we don't allow duplicate in symbols
61
- // e.g. if 2 j corresponding to 1 i, then return failure
58
+ // infer the relation between two indexing maps
59
+ // returns target dim -> base dim, means target is the same as input
60
+ // we don't allow duplication, e.g. 2 target corresponding to 1 base
62
61
static FailureOr<DenseMap<int64_t , int64_t >>
63
62
inferIndexingMapRelation (AffineMap indexingMapBase,
64
63
AffineMap indexingMapTarget) {
64
+ // symbols are not allowed to occur
65
+ if (indexingMapBase.getNumSymbols () != 0 ||
66
+ indexingMapTarget.getNumSymbols () != 0 )
67
+ return failure ();
65
68
DenseMap<int64_t , int64_t > res;
66
69
ArrayRef<AffineExpr> resultsBase = indexingMapBase.getResults ();
67
70
ArrayRef<AffineExpr> resultsTarget = indexingMapTarget.getResults ();
@@ -70,6 +73,7 @@ inferIndexingMapRelation(AffineMap indexingMapBase,
70
73
auto base = dyn_cast<AffineDimExpr>(resultsBase[i]);
71
74
auto target = dyn_cast<AffineDimExpr>(resultsTarget[j]);
72
75
if (base && target && base.getPosition () == target.getPosition ()) {
76
+ // dim j already mapped to certain i
73
77
if (res.find (j) != res.end ())
74
78
return failure ();
75
79
res[j] = i;
@@ -91,7 +95,7 @@ inferIndexingMapRelation(AffineMap indexingMapBase,
91
95
return res;
92
96
}
93
97
94
- // given j --> i and max rank of i , return i --> j
98
+ // given target --> base and max rank of base , return base --> target
95
99
static DenseMap<int64_t , int64_t >
96
100
getReversedIndexMap (const DenseMap<int64_t , int64_t > &indexMap,
97
101
size_t maxRank) {
@@ -109,7 +113,7 @@ getReversedIndexMap(const DenseMap<int64_t, int64_t> &indexMap,
109
113
return res;
110
114
}
111
115
112
- static FailureOr< TensorLayout>
116
+ static TensorLayout
113
117
inferTargetLayout (TensorLayout layoutBase,
114
118
const DenseMap<int64_t , int64_t > &indexMap) {
115
119
SmallVector<int64_t > baseOuterAxis = layoutBase.getOuterAxis ();
@@ -177,6 +181,39 @@ getPackingAxis(int64_t numRank, bool transposed) {
177
181
return std::make_pair (outerAxisPerm, innerAxisPos);
178
182
}
179
183
184
+ // copied from mlir
185
+ static SmallVector<int64_t >
186
+ projectToInnerMostNonUnitDimsPos (ArrayRef<int64_t > dimsPos,
187
+ ArrayRef<ReassociationIndices> reassocIndices,
188
+ ArrayRef<int64_t > targetShape) {
189
+ SmallVector<int64_t > projectedDimsPos;
190
+ for (auto pos : dimsPos) {
191
+ // In the case all dims are unit, this will return the inner-most one.
192
+ int64_t projectedPos = reassocIndices[pos].back ();
193
+ for (auto i : llvm::reverse (reassocIndices[pos])) {
194
+ int64_t dim = targetShape[i];
195
+ if (dim > 1 || ShapedType::isDynamic (dim)) {
196
+ projectedPos = i;
197
+ break ;
198
+ }
199
+ }
200
+ projectedDimsPos.push_back (projectedPos);
201
+ }
202
+ return projectedDimsPos;
203
+ }
204
+
205
+ // / Check if all dims in dimsPos are divisible by the corresponding tile sizes.
206
+ static bool isDimsDivisibleByTileSizes (ArrayRef<int64_t > dimsPos,
207
+ ArrayRef<int64_t > shape,
208
+ ArrayRef<int64_t > tileSizes) {
209
+ for (auto [pos, tileSize] : llvm::zip_equal (dimsPos, tileSizes)) {
210
+ int64_t dim = shape[pos];
211
+ if (ShapedType::isDynamic (dim) || (dim % tileSize) != 0 )
212
+ return false ;
213
+ }
214
+ return true ;
215
+ }
216
+
180
217
GlobalAnalysis::GlobalAnalysis (Operation *root) {
181
218
root->walk ([&](Operation *op) {
182
219
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
@@ -198,9 +235,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
198
235
}
199
236
// ------ Get Current Op's Suggested Layout & Do Propagation ------
200
237
IRRewriter rewriter (linalgOp);
201
- // TODO: extend to packed/vnni matmul ops
202
238
if (supportedContractionNamedOpList (linalgOp)) {
203
- // get input and output rank
239
+ // infer layout for linalg contraction named ops
204
240
auto ARank = cast<ShapedType>(linalgOp.getDpsInputs ()[0 ].getType ())
205
241
.getShape ()
206
242
.size ();
@@ -242,29 +278,36 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
242
278
OperatorLayout suggestedLayout ({ALayout, BLayout}, {CLayout});
243
279
layoutCache[linalgOp] = suggestedLayout;
244
280
} else if (!mlir::linalg::isaContractionOpInterface (linalgOp) &&
281
+ !mlir::linalg::isaConvolutionOpInterface (linalgOp) &&
245
282
!supportedContractionNamedOpList (linalgOp)) {
283
+ // infer layout for non-contraction/non-convolution linalg named ops
284
+ // and linalg generic ops
246
285
SmallVector<TensorLayout> inputLayouts, outputLayouts;
247
286
size_t targetIdx = getTargetInputIdx (curInputLayouts);
248
- // TODO(yifei): wisely choose the input format basis
249
- // Let's only refer to input[0] for now
250
287
for (size_t i = 0 ; i < curInputs.size (); ++i) {
251
288
// getMatchingIndexingMap
252
289
if (i != targetIdx) {
253
- auto res = inferIndexingMapRelation (
290
+ auto indexRelation = inferIndexingMapRelation (
254
291
linalgOp.getMatchingIndexingMap (curInputs[targetIdx]),
255
292
linalgOp.getMatchingIndexingMap (curInputs[i]));
293
+ if (failed (indexRelation)) {
294
+ return WalkResult::skip ();
295
+ }
256
296
TensorLayout inputLayout =
257
- * inferTargetLayout (curInputLayouts[targetIdx], *res );
297
+ inferTargetLayout (curInputLayouts[targetIdx], *indexRelation );
258
298
inputLayouts.push_back (inputLayout);
259
299
} else {
260
300
inputLayouts.push_back (curInputLayouts[targetIdx]);
261
301
}
262
302
}
263
- auto res_out = inferIndexingMapRelation (
303
+ auto indexRelation = inferIndexingMapRelation (
264
304
linalgOp.getMatchingIndexingMap (curInputs[targetIdx]),
265
305
linalgOp.getIndexingMapMatchingResult (curResults[0 ]));
306
+ if (failed (indexRelation)) {
307
+ return WalkResult::skip ();
308
+ }
266
309
TensorLayout outputLayout =
267
- * inferTargetLayout (curInputLayouts[targetIdx], *res_out );
310
+ inferTargetLayout (curInputLayouts[targetIdx], *indexRelation );
268
311
outputLayouts.push_back (outputLayout);
269
312
OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
270
313
layoutCache[linalgOp] = suggestedLayout;
@@ -283,52 +326,44 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
283
326
OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
284
327
layoutCache[padOp] = suggestedLayout;
285
328
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
286
- auto reassociation = expandShapeOp.getReassociation ();
329
+ SmallVector<ReassociationIndices> reassocIndices =
330
+ expandShapeOp.getReassociationIndices ();
287
331
auto staticOutputShape = expandShapeOp.getStaticOutputShape ();
288
332
auto parent = expandShapeOp.getSrc ().getDefiningOp ();
289
333
auto inputShape = expandShapeOp.getSrcType ().getShape ();
290
334
TensorLayout curInputLayout =
291
335
layoutCache.find (parent) != layoutCache.end ()
292
336
? layoutCache[parent].getOutputLayout (0 )
293
337
: TensorLayout::createPlainLayout (inputShape.size ());
294
- DenseMap<int64_t , int64_t > outputInputIdxMapping, inputOutputIndexMapping;
295
- int64_t accumulationOffset = 0 ;
296
- for (int64_t i = 0 ; i < static_cast <int64_t >(reassociation.size ()); ++i) {
297
- auto subReassociation = llvm::cast<ArrayAttr>(reassociation[i]);
298
- for (int64_t j = 0 ; j < static_cast <int64_t >(subReassociation.size ());
299
- ++j) {
300
- if (staticOutputShape[accumulationOffset + j] == inputShape[i]) {
301
- outputInputIdxMapping[accumulationOffset + j] = i;
302
- inputOutputIndexMapping[i] = accumulationOffset + j;
303
- }
304
- }
305
- accumulationOffset += subReassociation.size ();
338
+ SmallVector<int64_t > innerTileSizes;
339
+ auto intTileSizes = getConstantIntValues (curInputLayout.getTileSizes ());
340
+ if (intTileSizes) {
341
+ innerTileSizes = *intTileSizes;
306
342
}
307
- auto inputOuterAxis = curInputLayout.getOuterAxis ();
308
- auto inputInnerAxis = curInputLayout.getInnerAxis ();
309
- int64_t diffDifference = staticOutputShape.size () - inputShape.size ();
310
- int64_t startIdx = 0 ;
311
- SmallVector<int64_t > outputOuterAxis, outputInnerAxis;
312
- for (int64_t i = 0 ; i < static_cast <int64_t >(staticOutputShape.size ());
313
- ++i) {
314
- if (outputInputIdxMapping.find (i) != outputInputIdxMapping.end ()) {
315
- outputOuterAxis.push_back (inputOuterAxis[outputInputIdxMapping[i]] +
316
- diffDifference);
317
- } else {
318
- outputOuterAxis.push_back (startIdx++);
319
- }
343
+ ArrayRef<int64_t > innerDimsPos = curInputLayout.getInnerAxis ();
344
+ ArrayRef<int64_t > outerDimsPerm = curInputLayout.getOuterAxis ();
345
+ SmallVector<int64_t > projectedInnerDimsPos =
346
+ projectToInnerMostNonUnitDimsPos (curInputLayout.getInnerAxis (),
347
+ reassocIndices, staticOutputShape);
348
+
349
+ if (!isDimsDivisibleByTileSizes (projectedInnerDimsPos, staticOutputShape,
350
+ innerTileSizes)) {
351
+ return WalkResult::skip ();
320
352
}
321
- for (int64_t i = 0 ; i < static_cast <int64_t >(inputInnerAxis.size ());
322
- ++i) {
323
- outputInnerAxis.push_back (inputOutputIndexMapping[inputInnerAxis[i]]);
353
+ SmallVector<int64_t > newOuterDimsPerm;
354
+ for (auto outerPos : outerDimsPerm) {
355
+ newOuterDimsPerm.insert (newOuterDimsPerm.end (),
356
+ reassocIndices[outerPos].begin (),
357
+ reassocIndices[outerPos].end ());
324
358
}
325
- TensorLayout outputLayout (outputOuterAxis, outputInnerAxis ,
359
+ TensorLayout outputLayout (newOuterDimsPerm, projectedInnerDimsPos ,
326
360
curInputLayout.getTileSizes ());
327
361
SmallVector<TensorLayout> inputLayouts{curInputLayout},
328
362
outputLayouts{outputLayout};
329
363
OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
330
364
layoutCache[expandShapeOp] = suggestedLayout;
331
365
}
366
+ return WalkResult::advance ();
332
367
});
333
368
}
334
369
0 commit comments