@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
48
48
return success ();
49
49
}
50
50
51
+ // If the `linalgOp` represents a transpose, return the permutation vector for
52
+ // the transpose. Otherwise, return failure.
53
+ static FailureOr<SmallVector<int64_t >>
54
+ getTransposeOpPermutation (linalg::LinalgOp linalgOp) {
55
+ if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation ()))
56
+ return SmallVector<int64_t >(transposeOp.getPermutation ());
57
+ if (linalgOp.getNumParallelLoops () != linalgOp.getNumLoops ())
58
+ return failure ();
59
+
60
+ if (linalgOp.getNumDpsInputs () != 1 || linalgOp.getNumDpsInits () != 1 )
61
+ return failure ();
62
+ auto mapRange = linalgOp.getIndexingMapsArray ();
63
+ if (mapRange.size () != 2 || !mapRange.front ().isPermutation () ||
64
+ !mapRange.back ().isPermutation () || mapRange.front () == mapRange.back ()) {
65
+ return failure ();
66
+ }
67
+ if (!llvm::hasSingleElement (linalgOp.getBlock ()->getOperations ()))
68
+ return failure ();
69
+ AffineMap outMap = mapRange.back ();
70
+ AffineMap inMap = mapRange.front ();
71
+ // To get the permutation, look at each output index and find which
72
+ // dimension in the input we're reading from for that index.
73
+ return llvm::map_to_vector (outMap.getResults (),
74
+ [&](AffineExpr expr) -> int64_t {
75
+ return *inMap.getResultPosition (expr);
76
+ });
77
+ }
78
+
51
79
// / Packing one-dimensional tensor can be expressed as an expand shape op.
52
80
struct SimplifyPackToExpandShape : public OpRewritePattern <PackOp> {
53
81
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -246,14 +274,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
246
274
247
275
for (unsigned int i = 0 ; i < rank; ++i) {
248
276
int64_t remappedPosition = permutation[i];
249
-
250
- if (!inVec.empty ()) {
251
- if (remappedPosition >= rank) {
252
- return false ;
253
- }
277
+ if (remappedPosition >= rank)
278
+ return false ;
279
+ if (!inVec.empty ())
254
280
remappedPosition = inVec[remappedPosition];
255
- }
256
-
257
281
resVec.push_back (remappedPosition);
258
282
}
259
283
@@ -263,20 +287,26 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
263
287
// / Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
264
288
// / semantics.
265
289
struct FoldProducerPackWithConsumerLinalgTransposeOp
266
- : public OpRewritePattern <linalg::TransposeOp > {
267
- using OpRewritePattern <linalg::TransposeOp >::OpRewritePattern ;
290
+ : public OpInterfaceRewritePattern <linalg::LinalgOp > {
291
+ using OpInterfaceRewritePattern <linalg::LinalgOp >::OpInterfaceRewritePattern ;
268
292
269
- LogicalResult matchAndRewrite (linalg::TransposeOp transposeOp ,
293
+ LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp ,
270
294
PatternRewriter &rewriter) const override {
271
- auto packOp = transposeOp. getOperand (0 ).getDefiningOp <PackOp>();
295
+ auto packOp = linalgOp-> getOperand (0 ).getDefiningOp <PackOp>();
272
296
273
297
if (!packOp)
274
298
return failure ();
275
299
300
+ FailureOr<SmallVector<int64_t >> maybePerm =
301
+ getTransposeOpPermutation (linalgOp);
302
+ if (failed (maybePerm)) {
303
+ return failure ();
304
+ }
305
+
276
306
auto innerDimsPos = packOp.getInnerDimsPos ();
277
307
auto mixedInnerTiles = packOp.getMixedTiles ();
278
308
auto outerDimsPerm = packOp.getOuterDimsPerm ();
279
- auto transposePerm = transposeOp. getPermutation ();
309
+ auto transposePerm = maybePerm. value ();
280
310
SmallVector<int64_t > newOuterDimsPermVec;
281
311
SmallVector<int64_t > newInnerDimsPosVec;
282
312
SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -285,7 +315,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
285
315
if (!checkAndPermute (transposePerm, outerDimsPerm, newOuterDimsPermVec,
286
316
srcRank))
287
317
return rewriter.notifyMatchFailure (
288
- transposeOp ,
318
+ linalgOp ,
289
319
" Cannot fold in tensor.pack if a tile dimension was transposed "
290
320
" with a non-tile dimension in linalg.transpose." );
291
321
@@ -297,11 +327,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
297
327
}
298
328
299
329
Value output = packOp.createDestinationTensor (
300
- rewriter, transposeOp .getLoc (), packOp.getSource (),
301
- newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
330
+ rewriter, linalgOp .getLoc (), packOp.getSource (), newMixedInnerTilesVec ,
331
+ newInnerDimsPosVec, newOuterDimsPermVec);
302
332
303
333
rewriter.replaceOpWithNewOp <PackOp>(
304
- transposeOp , packOp.getSource (), output, newInnerDimsPosVec,
334
+ linalgOp , packOp.getSource (), output, newInnerDimsPosVec,
305
335
newMixedInnerTilesVec, packOp.getPaddingValue (), newOuterDimsPermVec);
306
336
307
337
return success ();
@@ -316,12 +346,17 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
316
346
317
347
LogicalResult matchAndRewrite (PackOp packOp,
318
348
PatternRewriter &rewriter) const override {
319
- auto transposeOp = packOp.getSource ().getDefiningOp <linalg::TransposeOp>();
349
+ auto linalgOp = packOp.getSource ().getDefiningOp <linalg::LinalgOp>();
350
+ if (!linalgOp)
351
+ return failure ();
320
352
321
- if (!transposeOp)
353
+ FailureOr<SmallVector<int64_t >> maybePerm =
354
+ getTransposeOpPermutation (linalgOp);
355
+ if (failed (maybePerm)) {
322
356
return failure ();
357
+ }
323
358
324
- auto transposePermutation = transposeOp. getPermutation ();
359
+ auto transposePermutation = maybePerm. value ();
325
360
auto outerDimsPerm = packOp.getOuterDimsPerm ();
326
361
auto innerDimsPos = packOp.getInnerDimsPos ();
327
362
SmallVector<int64_t > newInnerDimsPosVec;
@@ -337,11 +372,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
337
372
newInnerDimsPosVec.push_back (transposePermutation[dim]);
338
373
339
374
Value output = packOp.createDestinationTensor (
340
- rewriter, packOp.getLoc (), transposeOp. getOperand (0 ),
375
+ rewriter, packOp.getLoc (), linalgOp-> getOperand (0 ),
341
376
packOp.getMixedTiles (), newInnerDimsPosVec, newOuterDimsPermVec);
342
377
343
378
rewriter.replaceOpWithNewOp <PackOp>(
344
- packOp, transposeOp. getOperand (0 ), output, newInnerDimsPosVec,
379
+ packOp, linalgOp-> getOperand (0 ), output, newInnerDimsPosVec,
345
380
packOp.getMixedTiles (), packOp.getPaddingValue (), newOuterDimsPermVec);
346
381
347
382
return success ();
@@ -351,34 +386,41 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
351
386
// / Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
352
387
// / transpose semantics.
353
388
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
354
- : public OpRewritePattern <linalg::TransposeOp > {
355
- using OpRewritePattern <linalg::TransposeOp >::OpRewritePattern ;
389
+ : public OpInterfaceRewritePattern <linalg::LinalgOp > {
390
+ using OpInterfaceRewritePattern <linalg::LinalgOp >::OpInterfaceRewritePattern ;
356
391
357
- LogicalResult matchAndRewrite (linalg::TransposeOp transposeOp ,
392
+ LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp ,
358
393
PatternRewriter &rewriter) const override {
359
- auto unPackOp = transposeOp. getOperand (0 ).getDefiningOp <UnPackOp>();
394
+ auto unPackOp = linalgOp-> getOperand (0 ).getDefiningOp <UnPackOp>();
360
395
361
396
if (!unPackOp)
362
397
return failure ();
363
398
364
- auto transposePermutation = transposeOp.getPermutation ();
399
+ FailureOr<SmallVector<int64_t >> maybePerm =
400
+ getTransposeOpPermutation (linalgOp);
401
+ if (failed (maybePerm)) {
402
+ return failure ();
403
+ }
404
+
405
+ auto transposePermutation = maybePerm.value ();
406
+ SmallVector<int64_t > inverseTransposePerm =
407
+ invertPermutationVector (transposePermutation);
365
408
auto outerDimsPerm = unPackOp.getOuterDimsPerm ();
366
409
auto innerDimsPos = unPackOp.getInnerDimsPos ();
367
410
SmallVector<int64_t > newInnerDimsPosVec;
368
- SmallVector<int64_t > newOuterDimsPermVec =
369
- llvm::to_vector (transposePermutation);
411
+ SmallVector<int64_t > newOuterDimsPermVec = inverseTransposePerm;
370
412
371
413
if (!outerDimsPerm.empty ())
372
414
applyPermutationToVector (newOuterDimsPermVec, outerDimsPerm);
373
415
374
416
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
375
417
// permutation rank won't necessarily be equal in all cases.
376
418
for (auto dim : innerDimsPos)
377
- newInnerDimsPosVec.push_back (transposePermutation [dim]);
419
+ newInnerDimsPosVec.push_back (inverseTransposePerm [dim]);
378
420
379
421
// Reuse the destination of the transpose op.
380
422
rewriter.replaceOpWithNewOp <UnPackOp>(
381
- transposeOp , unPackOp.getSource (), transposeOp .getDpsInits ()[0 ],
423
+ linalgOp , unPackOp.getSource (), linalgOp .getDpsInits ()[0 ],
382
424
newInnerDimsPosVec, unPackOp.getMixedTiles (), newOuterDimsPermVec);
383
425
384
426
return success ();
@@ -393,13 +435,19 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
393
435
394
436
LogicalResult matchAndRewrite (UnPackOp unPackOp,
395
437
PatternRewriter &rewriter) const override {
396
- auto transposeOp =
397
- unPackOp.getSource ().getDefiningOp <linalg::TransposeOp>();
438
+ auto linalgOp = unPackOp.getSource ().getDefiningOp <linalg::LinalgOp>();
439
+ if (!linalgOp)
440
+ return failure ();
398
441
399
- if (!transposeOp)
442
+ FailureOr<SmallVector<int64_t >> maybePerm =
443
+ getTransposeOpPermutation (linalgOp);
444
+ if (failed (maybePerm)) {
400
445
return failure ();
446
+ }
401
447
402
- auto transposePermutation = transposeOp.getPermutation ();
448
+ auto transposePermutation = maybePerm.value ();
449
+ SmallVector<int64_t > inverseTransposePerm =
450
+ invertPermutationVector (transposePermutation);
403
451
auto outerDimsPerm = unPackOp.getOuterDimsPerm ();
404
452
auto innerDimsPos = unPackOp.getInnerDimsPos ();
405
453
int64_t destRank = unPackOp.getSourceRank () - innerDimsPos.size ();
@@ -408,26 +456,26 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
408
456
SmallVector<int64_t > newInnerDimsPosVec;
409
457
SmallVector<OpFoldResult> newMixedInnerTilesVec;
410
458
411
- if (!checkAndPermute (transposePermutation , outerDimsPerm,
459
+ if (!checkAndPermute (inverseTransposePerm , outerDimsPerm,
412
460
newOuterDimsPermVec, destRank))
413
461
return rewriter.notifyMatchFailure (
414
462
unPackOp,
415
463
" Cannot fold in tensor.unpack if a tile dimension was transposed "
416
464
" with a non-tile dimension in linalg.transpose." );
417
465
418
466
// Process transpose operation for tiled inner dimensions
419
- for (unsigned int i = destRank; i < transposePermutation .size (); ++i) {
420
- int64_t remappedPosition = transposePermutation [i] - destRank;
467
+ for (unsigned int i = destRank; i < inverseTransposePerm .size (); ++i) {
468
+ int64_t remappedPosition = inverseTransposePerm [i] - destRank;
421
469
newMixedInnerTilesVec.push_back (mixedInnerTilesVec[remappedPosition]);
422
470
newInnerDimsPosVec.push_back (innerDimsPos[remappedPosition]);
423
471
}
424
472
425
473
Value output = unPackOp.createDestinationTensor (
426
- rewriter, unPackOp.getLoc (), transposeOp. getOperand (0 ),
474
+ rewriter, unPackOp.getLoc (), linalgOp-> getOperand (0 ),
427
475
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
428
476
429
477
rewriter.replaceOpWithNewOp <UnPackOp>(
430
- unPackOp, transposeOp. getOperand (0 ), output, newInnerDimsPosVec,
478
+ unPackOp, linalgOp-> getOperand (0 ), output, newInnerDimsPosVec,
431
479
newMixedInnerTilesVec, newOuterDimsPermVec);
432
480
433
481
return success ();
0 commit comments