@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
48
48
return success ();
49
49
}
50
50
51
+ // If the `linalgOp` represents a transpose, return the permutation vector for
52
+ // the transpose. Otherwise, return failure.
53
+ static FailureOr<SmallVector<int64_t >>
54
+ getTransposeOpPermutation (linalg::LinalgOp linalgOp) {
55
+ if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation ()))
56
+ return SmallVector<int64_t >(transposeOp.getPermutation ());
57
+ if (linalgOp.getNumParallelLoops () != linalgOp.getNumLoops ())
58
+ return failure ();
59
+
60
+ if (linalgOp.getNumDpsInputs () != 1 || linalgOp.getNumDpsInits () != 1 )
61
+ return failure ();
62
+ auto mapRange = linalgOp.getIndexingMapsArray ();
63
+ if (!mapRange.front ().isPermutation () || !mapRange.back ().isPermutation () ||
64
+ mapRange.front () == mapRange.back ()) {
65
+ return failure ();
66
+ }
67
+ if (!llvm::hasSingleElement (linalgOp.getBlock ()->getOperations ()))
68
+ return failure ();
69
+ AffineMap outMap = mapRange.back ();
70
+ AffineMap inMap = mapRange.front ();
71
+ // To get the permutation, look at each output index and find which
72
+ // dimension in the input we're reading from for that index.
73
+ return llvm::map_to_vector (outMap.getResults (),
74
+ [&](AffineExpr expr) -> int64_t {
75
+ return *inMap.getResultPosition (expr);
76
+ });
77
+ }
78
+
51
79
// / Packing one-dimensional tensor can be expressed as an expand shape op.
52
80
struct SimplifyPackToExpandShape : public OpRewritePattern <PackOp> {
53
81
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -246,14 +274,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
246
274
247
275
for (unsigned int i = 0 ; i < rank; ++i) {
248
276
int64_t remappedPosition = permutation[i];
249
-
250
- if (!inVec.empty ()) {
251
- if (remappedPosition >= rank) {
252
- return false ;
253
- }
277
+ if (remappedPosition >= rank)
278
+ return false ;
279
+ if (!inVec.empty ())
254
280
remappedPosition = inVec[remappedPosition];
255
- }
256
-
257
281
resVec.push_back (remappedPosition);
258
282
}
259
283
@@ -263,20 +287,25 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
263
287
// / Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
264
288
// / semantics.
265
289
struct FoldProducerPackWithConsumerLinalgTransposeOp
266
- : public OpRewritePattern <linalg::TransposeOp > {
267
- using OpRewritePattern <linalg::TransposeOp >::OpRewritePattern ;
290
+ : public OpInterfaceRewritePattern <linalg::LinalgOp > {
291
+ using OpInterfaceRewritePattern <linalg::LinalgOp >::OpInterfaceRewritePattern ;
268
292
269
- LogicalResult matchAndRewrite (linalg::TransposeOp transposeOp ,
293
+ LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp ,
270
294
PatternRewriter &rewriter) const override {
271
- auto packOp = transposeOp. getOperand (0 ).getDefiningOp <PackOp>();
295
+ auto packOp = linalgOp-> getOperand (0 ).getDefiningOp <PackOp>();
272
296
273
297
if (!packOp)
274
298
return failure ();
275
299
300
+ FailureOr<SmallVector<int64_t >> maybePerm =
301
+ getTransposeOpPermutation (linalgOp);
302
+ if (failed (maybePerm))
303
+ return failure ();
304
+
276
305
auto innerDimsPos = packOp.getInnerDimsPos ();
277
306
auto mixedInnerTiles = packOp.getMixedTiles ();
278
307
auto outerDimsPerm = packOp.getOuterDimsPerm ();
279
- auto transposePerm = transposeOp. getPermutation ();
308
+ auto transposePerm = maybePerm. value ();
280
309
SmallVector<int64_t > newOuterDimsPermVec;
281
310
SmallVector<int64_t > newInnerDimsPosVec;
282
311
SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -285,7 +314,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
285
314
if (!checkAndPermute (transposePerm, outerDimsPerm, newOuterDimsPermVec,
286
315
srcRank))
287
316
return rewriter.notifyMatchFailure (
288
- transposeOp ,
317
+ linalgOp ,
289
318
" Cannot fold in tensor.pack if a tile dimension was transposed "
290
319
" with a non-tile dimension in linalg.transpose." );
291
320
@@ -297,11 +326,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
297
326
}
298
327
299
328
Value output = packOp.createDestinationTensor (
300
- rewriter, transposeOp .getLoc (), packOp.getSource (),
301
- newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
329
+ rewriter, linalgOp .getLoc (), packOp.getSource (), newMixedInnerTilesVec ,
330
+ newInnerDimsPosVec, newOuterDimsPermVec);
302
331
303
332
rewriter.replaceOpWithNewOp <PackOp>(
304
- transposeOp , packOp.getSource (), output, newInnerDimsPosVec,
333
+ linalgOp , packOp.getSource (), output, newInnerDimsPosVec,
305
334
newMixedInnerTilesVec, packOp.getPaddingValue (), newOuterDimsPermVec);
306
335
307
336
return success ();
@@ -316,12 +345,16 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
316
345
317
346
LogicalResult matchAndRewrite (PackOp packOp,
318
347
PatternRewriter &rewriter) const override {
319
- auto transposeOp = packOp.getSource ().getDefiningOp <linalg::TransposeOp>();
348
+ auto linalgOp = packOp.getSource ().getDefiningOp <linalg::LinalgOp>();
349
+ if (!linalgOp)
350
+ return failure ();
320
351
321
- if (!transposeOp)
352
+ FailureOr<SmallVector<int64_t >> maybePerm =
353
+ getTransposeOpPermutation (linalgOp);
354
+ if (failed (maybePerm))
322
355
return failure ();
323
356
324
- auto transposePermutation = transposeOp. getPermutation ();
357
+ auto transposePermutation = maybePerm. value ();
325
358
auto outerDimsPerm = packOp.getOuterDimsPerm ();
326
359
auto innerDimsPos = packOp.getInnerDimsPos ();
327
360
SmallVector<int64_t > newInnerDimsPosVec;
@@ -337,11 +370,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
337
370
newInnerDimsPosVec.push_back (transposePermutation[dim]);
338
371
339
372
Value output = packOp.createDestinationTensor (
340
- rewriter, packOp.getLoc (), transposeOp. getOperand (0 ),
373
+ rewriter, packOp.getLoc (), linalgOp-> getOperand (0 ),
341
374
packOp.getMixedTiles (), newInnerDimsPosVec, newOuterDimsPermVec);
342
375
343
376
rewriter.replaceOpWithNewOp <PackOp>(
344
- packOp, transposeOp. getOperand (0 ), output, newInnerDimsPosVec,
377
+ packOp, linalgOp-> getOperand (0 ), output, newInnerDimsPosVec,
345
378
packOp.getMixedTiles (), packOp.getPaddingValue (), newOuterDimsPermVec);
346
379
347
380
return success ();
@@ -351,34 +384,38 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
351
384
// / Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
352
385
// / transpose semantics.
353
386
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
354
- : public OpRewritePattern <linalg::TransposeOp > {
355
- using OpRewritePattern <linalg::TransposeOp >::OpRewritePattern ;
387
+ : public OpInterfaceRewritePattern <linalg::LinalgOp > {
388
+ using OpInterfaceRewritePattern <linalg::LinalgOp >::OpInterfaceRewritePattern ;
356
389
357
- LogicalResult matchAndRewrite (linalg::TransposeOp transposeOp ,
390
+ LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp ,
358
391
PatternRewriter &rewriter) const override {
359
- auto unPackOp = transposeOp. getOperand (0 ).getDefiningOp <UnPackOp>();
392
+ auto unPackOp = linalgOp-> getOperand (0 ).getDefiningOp <UnPackOp>();
360
393
361
394
if (!unPackOp)
362
395
return failure ();
363
396
364
- auto transposePermutation = transposeOp.getPermutation ();
397
+ FailureOr<SmallVector<int64_t >> maybePerm =
398
+ getTransposeOpPermutation (linalgOp);
399
+ if (failed (maybePerm))
400
+ return failure ();
401
+
365
402
auto outerDimsPerm = unPackOp.getOuterDimsPerm ();
366
403
auto innerDimsPos = unPackOp.getInnerDimsPos ();
367
404
SmallVector<int64_t > newInnerDimsPosVec;
368
405
SmallVector<int64_t > newOuterDimsPermVec =
369
- llvm::to_vector (transposePermutation);
370
-
371
- if (!outerDimsPerm.empty ())
372
- applyPermutationToVector (newOuterDimsPermVec, outerDimsPerm);
406
+ invertPermutationVector (maybePerm.value ());
373
407
374
408
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
375
409
// permutation rank won't necessarily be equal in all cases.
376
410
for (auto dim : innerDimsPos)
377
- newInnerDimsPosVec.push_back (transposePermutation[dim]);
411
+ newInnerDimsPosVec.push_back (newOuterDimsPermVec[dim]);
412
+
413
+ if (!outerDimsPerm.empty ())
414
+ applyPermutationToVector (newOuterDimsPermVec, outerDimsPerm);
378
415
379
416
// Reuse the destination of the transpose op.
380
417
rewriter.replaceOpWithNewOp <UnPackOp>(
381
- transposeOp , unPackOp.getSource (), transposeOp .getDpsInits ()[0 ],
418
+ linalgOp , unPackOp.getSource (), linalgOp .getDpsInits ()[0 ],
382
419
newInnerDimsPosVec, unPackOp.getMixedTiles (), newOuterDimsPermVec);
383
420
384
421
return success ();
@@ -393,13 +430,17 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
393
430
394
431
LogicalResult matchAndRewrite (UnPackOp unPackOp,
395
432
PatternRewriter &rewriter) const override {
396
- auto transposeOp =
397
- unPackOp.getSource ().getDefiningOp <linalg::TransposeOp>();
433
+ auto linalgOp = unPackOp.getSource ().getDefiningOp <linalg::LinalgOp>();
434
+ if (!linalgOp)
435
+ return failure ();
398
436
399
- if (!transposeOp)
437
+ FailureOr<SmallVector<int64_t >> maybePerm =
438
+ getTransposeOpPermutation (linalgOp);
439
+ if (failed (maybePerm))
400
440
return failure ();
401
441
402
- auto transposePermutation = transposeOp.getPermutation ();
442
+ SmallVector<int64_t > inverseTransposePerm =
443
+ invertPermutationVector (maybePerm.value ());
403
444
auto outerDimsPerm = unPackOp.getOuterDimsPerm ();
404
445
auto innerDimsPos = unPackOp.getInnerDimsPos ();
405
446
int64_t destRank = unPackOp.getSourceRank () - innerDimsPos.size ();
@@ -408,26 +449,26 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
408
449
SmallVector<int64_t > newInnerDimsPosVec;
409
450
SmallVector<OpFoldResult> newMixedInnerTilesVec;
410
451
411
- if (!checkAndPermute (transposePermutation , outerDimsPerm,
452
+ if (!checkAndPermute (inverseTransposePerm , outerDimsPerm,
412
453
newOuterDimsPermVec, destRank))
413
454
return rewriter.notifyMatchFailure (
414
455
unPackOp,
415
456
" Cannot fold in tensor.unpack if a tile dimension was transposed "
416
457
" with a non-tile dimension in linalg.transpose." );
417
458
418
459
// Process transpose operation for tiled inner dimensions
419
- for (unsigned int i = destRank; i < transposePermutation .size (); ++i) {
420
- int64_t remappedPosition = transposePermutation [i] - destRank;
460
+ for (unsigned int i = destRank; i < inverseTransposePerm .size (); ++i) {
461
+ int64_t remappedPosition = inverseTransposePerm [i] - destRank;
421
462
newMixedInnerTilesVec.push_back (mixedInnerTilesVec[remappedPosition]);
422
463
newInnerDimsPosVec.push_back (innerDimsPos[remappedPosition]);
423
464
}
424
465
425
466
Value output = unPackOp.createDestinationTensor (
426
- rewriter, unPackOp.getLoc (), transposeOp. getOperand (0 ),
467
+ rewriter, unPackOp.getLoc (), linalgOp-> getOperand (0 ),
427
468
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
428
469
429
470
rewriter.replaceOpWithNewOp <UnPackOp>(
430
- unPackOp, transposeOp. getOperand (0 ), output, newInnerDimsPosVec,
471
+ unPackOp, linalgOp-> getOperand (0 ), output, newInnerDimsPosVec,
431
472
newMixedInnerTilesVec, newOuterDimsPermVec);
432
473
433
474
return success ();
0 commit comments