@@ -544,7 +544,7 @@ class FlattenContiguousRowMajorTransferReadPattern
544
544
auto loc = transferReadOp.getLoc ();
545
545
Value vector = transferReadOp.getVector ();
546
546
VectorType vectorType = cast<VectorType>(vector.getType ());
547
- Value source = transferReadOp.getSource ();
547
+ auto source = transferReadOp.getSource ();
548
548
MemRefType sourceType = dyn_cast<MemRefType>(source.getType ());
549
549
550
550
// 0. Check pre-conditions
@@ -602,26 +602,30 @@ class FlattenContiguousRowMajorTransferReadPattern
602
602
//
603
603
// For this example:
604
604
// %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
605
- // memref<1x43x2xi32>, vector<1x2xi32>
605
+ // memref<1x43x2xi32>, vector<1x2xi32>
606
606
// which would be collapsed to:
607
607
// %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
608
- // memref<1x86xi32>, vector<2xi32>
608
+ // memref<1x86xi32>, vector<2xi32>
609
609
// one would get the following offset:
610
610
// %offset = %arg0 * 43
611
+ AffineExpr offsetE, idx;
612
+ bindSymbols (rewriter.getContext (), offsetE, idx);
613
+
611
614
int64_t outputRank = transferReadOp.getIndices ().size ();
612
- Value offset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
615
+ OpFoldResult offset =
616
+ rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ();
613
617
for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
614
- Value dimIdx = rewriter.create <arith::ConstantIndexOp>(loc, i);
615
- auto sourceDimSize =
616
- rewriter.create <memref::DimOp>(loc, source, dimIdx);
617
-
618
- offset = rewriter.create <arith::AddIOp>(
619
- loc,
620
- rewriter.create <arith::MulIOp>(loc, transferReadOp.getIndices ()[i],
621
- sourceDimSize),
622
- offset);
618
+ int64_t dim = dyn_cast<ShapedType>(source.getType ()).getDimSize (i);
619
+ offset = affine::makeComposedFoldedAffineApply (
620
+ rewriter, loc, offsetE + dim * idx,
621
+ {offset, transferReadOp.getIndices ()[i]});
622
+ }
623
+ if (offset.is <Value>()) {
624
+ collapsedIndices.push_back (offset.get <Value>());
625
+ } else {
626
+ collapsedIndices.push_back (rewriter.create <arith::ConstantIndexOp>(
627
+ loc, *getConstantIntValue (offset)));
623
628
}
624
- collapsedIndices.push_back (offset);
625
629
}
626
630
627
631
// 3. Create new vector.transfer_read that reads from the collapsed memref
0 commit comments