@@ -48,34 +48,6 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
48
48
return success ();
49
49
}
50
50
51
- // If the `linalgOp` represents a transpose, return the permutation vector for
52
- // the transpose. Otherwise, return failure.
53
- static FailureOr<SmallVector<int64_t >>
54
- getTransposeOpPermutation (linalg::LinalgOp linalgOp) {
55
- if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation ()))
56
- return SmallVector<int64_t >(transposeOp.getPermutation ());
57
- if (linalgOp.getNumParallelLoops () != linalgOp.getNumLoops ())
58
- return failure ();
59
-
60
- if (linalgOp.getNumDpsInputs () != 1 || linalgOp.getNumDpsInits () != 1 )
61
- return failure ();
62
- auto mapRange = linalgOp.getIndexingMapsArray ();
63
- if (!mapRange.front ().isPermutation () || !mapRange.back ().isPermutation () ||
64
- mapRange.front () == mapRange.back ()) {
65
- return failure ();
66
- }
67
- if (!llvm::hasSingleElement (linalgOp.getBlock ()->getOperations ()))
68
- return failure ();
69
- AffineMap outMap = mapRange.back ();
70
- AffineMap inMap = mapRange.front ();
71
- // To get the permutation, look at each output index and find which
72
- // dimension in the input we're reading from for that index.
73
- return llvm::map_to_vector (outMap.getResults (),
74
- [&](AffineExpr expr) -> int64_t {
75
- return *inMap.getResultPosition (expr);
76
- });
77
- }
78
-
79
51
// / Packing one-dimensional tensor can be expressed as an expand shape op.
80
52
struct SimplifyPackToExpandShape : public OpRewritePattern <PackOp> {
81
53
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -274,10 +246,14 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
274
246
275
247
for (unsigned int i = 0 ; i < rank; ++i) {
276
248
int64_t remappedPosition = permutation[i];
277
- if (remappedPosition >= rank)
278
- return false ;
279
- if (!inVec.empty ())
249
+
250
+ if (!inVec.empty ()) {
251
+ if (remappedPosition >= rank) {
252
+ return false ;
253
+ }
280
254
remappedPosition = inVec[remappedPosition];
255
+ }
256
+
281
257
resVec.push_back (remappedPosition);
282
258
}
283
259
@@ -287,25 +263,20 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
287
263
// / Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
288
264
// / semantics.
289
265
struct FoldProducerPackWithConsumerLinalgTransposeOp
290
- : public OpInterfaceRewritePattern <linalg::LinalgOp > {
291
- using OpInterfaceRewritePattern <linalg::LinalgOp >::OpInterfaceRewritePattern ;
266
+ : public OpRewritePattern <linalg::TransposeOp > {
267
+ using OpRewritePattern <linalg::TransposeOp >::OpRewritePattern ;
292
268
293
- LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp ,
269
+ LogicalResult matchAndRewrite (linalg::TransposeOp transposeOp ,
294
270
PatternRewriter &rewriter) const override {
295
- auto packOp = linalgOp-> getOperand (0 ).getDefiningOp <PackOp>();
271
+ auto packOp = transposeOp. getOperand (0 ).getDefiningOp <PackOp>();
296
272
297
273
if (!packOp)
298
274
return failure ();
299
275
300
- FailureOr<SmallVector<int64_t >> maybePerm =
301
- getTransposeOpPermutation (linalgOp);
302
- if (failed (maybePerm))
303
- return failure ();
304
-
305
276
auto innerDimsPos = packOp.getInnerDimsPos ();
306
277
auto mixedInnerTiles = packOp.getMixedTiles ();
307
278
auto outerDimsPerm = packOp.getOuterDimsPerm ();
308
- auto transposePerm = maybePerm. value ();
279
+ auto transposePerm = transposeOp. getPermutation ();
309
280
SmallVector<int64_t > newOuterDimsPermVec;
310
281
SmallVector<int64_t > newInnerDimsPosVec;
311
282
SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -314,7 +285,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
314
285
if (!checkAndPermute (transposePerm, outerDimsPerm, newOuterDimsPermVec,
315
286
srcRank))
316
287
return rewriter.notifyMatchFailure (
317
- linalgOp ,
288
+ transposeOp ,
318
289
" Cannot fold in tensor.pack if a tile dimension was transposed "
319
290
" with a non-tile dimension in linalg.transpose." );
320
291
@@ -326,11 +297,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
326
297
}
327
298
328
299
Value output = packOp.createDestinationTensor (
329
- rewriter, linalgOp .getLoc (), packOp.getSource (), newMixedInnerTilesVec ,
330
- newInnerDimsPosVec, newOuterDimsPermVec);
300
+ rewriter, transposeOp .getLoc (), packOp.getSource (),
301
+ newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
331
302
332
303
rewriter.replaceOpWithNewOp <PackOp>(
333
- linalgOp , packOp.getSource (), output, newInnerDimsPosVec,
304
+ transposeOp , packOp.getSource (), output, newInnerDimsPosVec,
334
305
newMixedInnerTilesVec, packOp.getPaddingValue (), newOuterDimsPermVec);
335
306
336
307
return success ();
@@ -345,16 +316,12 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
345
316
346
317
LogicalResult matchAndRewrite (PackOp packOp,
347
318
PatternRewriter &rewriter) const override {
348
- auto linalgOp = packOp.getSource ().getDefiningOp <linalg::LinalgOp>();
349
- if (!linalgOp)
350
- return failure ();
319
+ auto transposeOp = packOp.getSource ().getDefiningOp <linalg::TransposeOp>();
351
320
352
- FailureOr<SmallVector<int64_t >> maybePerm =
353
- getTransposeOpPermutation (linalgOp);
354
- if (failed (maybePerm))
321
+ if (!transposeOp)
355
322
return failure ();
356
323
357
- auto transposePermutation = maybePerm. value ();
324
+ auto transposePermutation = transposeOp. getPermutation ();
358
325
auto outerDimsPerm = packOp.getOuterDimsPerm ();
359
326
auto innerDimsPos = packOp.getInnerDimsPos ();
360
327
SmallVector<int64_t > newInnerDimsPosVec;
@@ -370,11 +337,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
370
337
newInnerDimsPosVec.push_back (transposePermutation[dim]);
371
338
372
339
Value output = packOp.createDestinationTensor (
373
- rewriter, packOp.getLoc (), linalgOp-> getOperand (0 ),
340
+ rewriter, packOp.getLoc (), transposeOp. getOperand (0 ),
374
341
packOp.getMixedTiles (), newInnerDimsPosVec, newOuterDimsPermVec);
375
342
376
343
rewriter.replaceOpWithNewOp <PackOp>(
377
- packOp, linalgOp-> getOperand (0 ), output, newInnerDimsPosVec,
344
+ packOp, transposeOp. getOperand (0 ), output, newInnerDimsPosVec,
378
345
packOp.getMixedTiles (), packOp.getPaddingValue (), newOuterDimsPermVec);
379
346
380
347
return success ();
@@ -384,38 +351,34 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
384
351
// / Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
385
352
// / transpose semantics.
386
353
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
387
- : public OpInterfaceRewritePattern <linalg::LinalgOp > {
388
- using OpInterfaceRewritePattern <linalg::LinalgOp >::OpInterfaceRewritePattern ;
354
+ : public OpRewritePattern <linalg::TransposeOp > {
355
+ using OpRewritePattern <linalg::TransposeOp >::OpRewritePattern ;
389
356
390
- LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp ,
357
+ LogicalResult matchAndRewrite (linalg::TransposeOp transposeOp ,
391
358
PatternRewriter &rewriter) const override {
392
- auto unPackOp = linalgOp-> getOperand (0 ).getDefiningOp <UnPackOp>();
359
+ auto unPackOp = transposeOp. getOperand (0 ).getDefiningOp <UnPackOp>();
393
360
394
361
if (!unPackOp)
395
362
return failure ();
396
363
397
- FailureOr<SmallVector<int64_t >> maybePerm =
398
- getTransposeOpPermutation (linalgOp);
399
- if (failed (maybePerm))
400
- return failure ();
401
-
364
+ auto transposePermutation = transposeOp.getPermutation ();
402
365
auto outerDimsPerm = unPackOp.getOuterDimsPerm ();
403
366
auto innerDimsPos = unPackOp.getInnerDimsPos ();
404
367
SmallVector<int64_t > newInnerDimsPosVec;
405
368
SmallVector<int64_t > newOuterDimsPermVec =
406
- invertPermutationVector (maybePerm.value ());
369
+ llvm::to_vector (transposePermutation);
370
+
371
+ if (!outerDimsPerm.empty ())
372
+ applyPermutationToVector (newOuterDimsPermVec, outerDimsPerm);
407
373
408
374
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
409
375
// permutation rank won't necessarily be equal in all cases.
410
376
for (auto dim : innerDimsPos)
411
- newInnerDimsPosVec.push_back (newOuterDimsPermVec[dim]);
412
-
413
- if (!outerDimsPerm.empty ())
414
- applyPermutationToVector (newOuterDimsPermVec, outerDimsPerm);
377
+ newInnerDimsPosVec.push_back (transposePermutation[dim]);
415
378
416
379
// Reuse the destination of the transpose op.
417
380
rewriter.replaceOpWithNewOp <UnPackOp>(
418
- linalgOp , unPackOp.getSource (), linalgOp .getDpsInits ()[0 ],
381
+ transposeOp , unPackOp.getSource (), transposeOp .getDpsInits ()[0 ],
419
382
newInnerDimsPosVec, unPackOp.getMixedTiles (), newOuterDimsPermVec);
420
383
421
384
return success ();
@@ -430,17 +393,13 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
430
393
431
394
LogicalResult matchAndRewrite (UnPackOp unPackOp,
432
395
PatternRewriter &rewriter) const override {
433
- auto linalgOp = unPackOp.getSource ().getDefiningOp <linalg::LinalgOp>();
434
- if (!linalgOp)
435
- return failure ();
396
+ auto transposeOp =
397
+ unPackOp.getSource ().getDefiningOp <linalg::TransposeOp>();
436
398
437
- FailureOr<SmallVector<int64_t >> maybePerm =
438
- getTransposeOpPermutation (linalgOp);
439
- if (failed (maybePerm))
399
+ if (!transposeOp)
440
400
return failure ();
441
401
442
- SmallVector<int64_t > inverseTransposePerm =
443
- invertPermutationVector (maybePerm.value ());
402
+ auto transposePermutation = transposeOp.getPermutation ();
444
403
auto outerDimsPerm = unPackOp.getOuterDimsPerm ();
445
404
auto innerDimsPos = unPackOp.getInnerDimsPos ();
446
405
int64_t destRank = unPackOp.getSourceRank () - innerDimsPos.size ();
@@ -449,26 +408,26 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
449
408
SmallVector<int64_t > newInnerDimsPosVec;
450
409
SmallVector<OpFoldResult> newMixedInnerTilesVec;
451
410
452
- if (!checkAndPermute (inverseTransposePerm , outerDimsPerm,
411
+ if (!checkAndPermute (transposePermutation , outerDimsPerm,
453
412
newOuterDimsPermVec, destRank))
454
413
return rewriter.notifyMatchFailure (
455
414
unPackOp,
456
415
" Cannot fold in tensor.unpack if a tile dimension was transposed "
457
416
" with a non-tile dimension in linalg.transpose." );
458
417
459
418
// Process transpose operation for tiled inner dimensions
460
- for (unsigned int i = destRank; i < inverseTransposePerm .size (); ++i) {
461
- int64_t remappedPosition = inverseTransposePerm [i] - destRank;
419
+ for (unsigned int i = destRank; i < transposePermutation .size (); ++i) {
420
+ int64_t remappedPosition = transposePermutation [i] - destRank;
462
421
newMixedInnerTilesVec.push_back (mixedInnerTilesVec[remappedPosition]);
463
422
newInnerDimsPosVec.push_back (innerDimsPos[remappedPosition]);
464
423
}
465
424
466
425
Value output = unPackOp.createDestinationTensor (
467
- rewriter, unPackOp.getLoc (), linalgOp-> getOperand (0 ),
426
+ rewriter, unPackOp.getLoc (), transposeOp. getOperand (0 ),
468
427
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
469
428
470
429
rewriter.replaceOpWithNewOp <UnPackOp>(
471
- unPackOp, linalgOp-> getOperand (0 ), output, newInnerDimsPosVec,
430
+ unPackOp, transposeOp. getOperand (0 ), output, newInnerDimsPosVec,
472
431
newMixedInnerTilesVec, newOuterDimsPermVec);
473
432
474
433
return success ();
0 commit comments