@@ -369,7 +369,7 @@ struct Strategy<TransferReadOp> {
369
369
// / Retrieve the indices of the current StoreOp that stores into the buffer.
370
370
static void getBufferIndices (TransferReadOp xferOp,
371
371
SmallVector<Value, 8 > &indices) {
372
- memref::StoreOp storeOp = getStoreOp (xferOp);
372
+ auto storeOp = getStoreOp (xferOp);
373
373
auto prevIndices = memref::StoreOpAdaptor (storeOp).getIndices ();
374
374
indices.append (prevIndices.begin (), prevIndices.end ());
375
375
}
@@ -591,8 +591,8 @@ struct PrepareTransferReadConversion
591
591
if (checkPrepareXferOp (xferOp, options).failed ())
592
592
return failure ();
593
593
594
- BufferAllocs buffers = allocBuffers (rewriter, xferOp);
595
- Operation *newXfer = rewriter.clone (*xferOp.getOperation ());
594
+ auto buffers = allocBuffers (rewriter, xferOp);
595
+ auto *newXfer = rewriter.clone (*xferOp.getOperation ());
596
596
newXfer->setAttr (kPassLabel , rewriter.getUnitAttr ());
597
597
if (xferOp.getMask ()) {
598
598
dyn_cast<TransferReadOp>(newXfer).getMaskMutable ().assign (
@@ -885,7 +885,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
885
885
// If the xferOp has a mask: Find and cast mask buffer.
886
886
Value castedMaskBuffer;
887
887
if (xferOp.getMask ()) {
888
- Value maskBuffer = getMaskBuffer (xferOp);
888
+ auto maskBuffer = getMaskBuffer (xferOp);
889
+ auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType ());
889
890
if (xferOp.isBroadcastDim (0 ) || xferOp.getMaskType ().getRank () == 1 ) {
890
891
// Do not unpack a dimension of the mask, if:
891
892
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -896,8 +897,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
896
897
} else {
897
898
// It's safe to assume the mask buffer can be unpacked if the data
898
899
// buffer was unpacked.
899
- auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType ());
900
- MemRefType castedMaskType = *unpackOneDim (maskBufferType);
900
+ auto castedMaskType = *unpackOneDim (maskBufferType);
901
901
castedMaskBuffer =
902
902
locB.create <vector::TypeCastOp>(castedMaskType, maskBuffer);
903
903
}
@@ -938,18 +938,11 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
938
938
b.setInsertionPoint (newXfer); // Insert load before newXfer.
939
939
940
940
SmallVector<Value, 8 > loadIndices;
941
- if ( auto memrefType =
942
- castedMaskBuffer. getType (). dyn_cast <MemRefType>()) {
943
- // If castedMaskBuffer is a memref, then one dim was
944
- // unpacked; see above.
941
+ Strategy<OpTy>:: getBufferIndices (xferOp, loadIndices);
942
+ // In case of broadcast: Use same indices to load from memref
943
+ // as before.
944
+ if (!xferOp. isBroadcastDim ( 0 ))
945
945
loadIndices.push_back (iv);
946
- } else {
947
- Strategy<OpTy>::getBufferIndices (xferOp, loadIndices);
948
- // In case of broadcast: Use same indices to load from
949
- // memref as before.
950
- if (!xferOp.isBroadcastDim (0 ))
951
- loadIndices.push_back (iv);
952
- }
953
946
954
947
auto mask = b.create <memref::LoadOp>(loc, castedMaskBuffer,
955
948
loadIndices);
0 commit comments