@@ -867,6 +867,116 @@ struct LowerIllegalTransposeStoreViaZA
867
867
}
868
868
};
869
869
870
+ // / Lower `vector.transfer_read` of a scalable column to `scf::for`
871
+ // /
872
+ // / Lowers a "read" of a scalable column from a MemRef for which there is no
873
+ // / hardware pperation that we could use to a loop over the rows to read and
874
+ // / loads one element at a time.
875
+ // /
876
+ // / BEFORE:
877
+ // / ```
878
+ // / %res = vector.transfer_read %mem[%a, %b] (...)
879
+ // / : memref<?x?xf32>, vector<[4]x1xf32>
880
+ // / ```
881
+ // /
882
+ // / AFTER:
883
+ // / ```
884
+ // / %cst = arith.constant (...) : vector<[4]xf32>
885
+ // / %vscale = vector.vscale
886
+ // / %c4_vscale = arith.muli %vscale, %c4 : index
887
+ // / %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
888
+ // / -> (vector<[4]xf32>) {
889
+ // /
890
+ // / %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
891
+ // / %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
892
+ // / scf.yield %vec : vector<[4]xf32>
893
+ // / }
894
+ // / %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
895
+ // / ```
896
+ // /
897
+ // / TODO: This transformation isn't specific to SME - move it to the SVE
898
+ // / dialect.
899
+ // / TODO: Check the in_bounds attribute and generate vector.maskedload if
900
+ // / required.
901
+ struct LowerColumnTransferReadToLoops
902
+ : public OpRewritePattern<vector::TransferReadOp> {
903
+ using OpRewritePattern::OpRewritePattern;
904
+
905
+ LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
906
+ PatternRewriter &rewriter) const override {
907
+ // NOTE: This is a fairly low-level transformation, so we shouldn't be
908
+ // adding support for Tensors without good rationale.
909
+ if (readOp.hasPureTensorSemantics ())
910
+ return rewriter.notifyMatchFailure (
911
+ readOp, " Tensor semantics are unsupported (either bufferize or "
912
+ " extend this pattern)" );
913
+
914
+ auto resType = readOp.getVectorType ();
915
+
916
+ if (resType.getRank () != 2 )
917
+ return rewriter.notifyMatchFailure (readOp,
918
+ " Only 2D vectors are supported!" );
919
+
920
+ if (resType.getShape ()[1 ] != 1 )
921
+ return rewriter.notifyMatchFailure (
922
+ readOp, " The trailing output dim is != 1 (not supported ATM)" );
923
+
924
+ if (!resType.getScalableDims ()[0 ] || resType.getScalableDims ()[1 ])
925
+ return rewriter.notifyMatchFailure (
926
+ readOp, " Expected the leading dim to be scalable and the trailing "
927
+ " dim to be fixed." );
928
+
929
+ // Create new result type - similar to the original vector with the
930
+ // trailing unit dim collapsed.
931
+ int64_t numRows = resType.getShape ()[0 ];
932
+ VectorType newResType = VectorType::get (numRows, resType.getElementType (),
933
+ /* scalableDims=*/ {true });
934
+
935
+ // Create a loop over all rows and load one element at a time.
936
+ auto loc = readOp.getLoc ();
937
+ auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
938
+ auto createVscaleMultiple =
939
+ vector::makeVscaleConstantBuilder (rewriter, loc);
940
+ auto upperBound = createVscaleMultiple (numRows);
941
+ auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
942
+ Value init = rewriter.create <arith::ConstantOp>(
943
+ loc, newResType, DenseElementsAttr::get (newResType, 0 .0f ));
944
+
945
+ scf::ForOp loadLoop;
946
+ {
947
+ OpBuilder::InsertionGuard g (rewriter);
948
+ loadLoop = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step,
949
+ ValueRange{init});
950
+ rewriter.setInsertionPointToStart (loadLoop.getBody ());
951
+
952
+ auto tileSliceIndex = loadLoop.getInductionVar ();
953
+
954
+ auto idx0 = rewriter.create <arith::AddIOp>(loc, tileSliceIndex,
955
+ readOp.getIndices ()[0 ]);
956
+ auto idx1 = readOp.getIndices ()[1 ];
957
+
958
+ Value scalar = rewriter.create <memref::LoadOp>(
959
+ loc, readOp.getBase (), SmallVector<Value>({idx0, idx1}));
960
+
961
+ Operation *updateInit = rewriter.create <vector::InsertOp>(
962
+ loc, scalar, loadLoop.getRegionIterArg (0 ), tileSliceIndex);
963
+
964
+ rewriter.create <scf::YieldOp>(loc, updateInit->getResult (0 ));
965
+ }
966
+
967
+ // The read operation has been "legalized", but since the original result
968
+ // type was a 2D vector, we need to cast before returning the result. This
969
+ // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
970
+ // no-op).
971
+ auto sc = rewriter.create <vector::ShapeCastOp>(
972
+ loc, readOp.getResult ().getType (), loadLoop.getResult (0 ));
973
+
974
+ rewriter.replaceOp (readOp, sc);
975
+
976
+ return success ();
977
+ }
978
+ };
979
+
870
980
struct VectorLegalizationPass
871
981
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
872
982
void runOnOperation () override {
@@ -888,9 +998,10 @@ struct VectorLegalizationPass
888
998
889
999
// Apply preprocessing patterns.
890
1000
RewritePatternSet rewritePatterns (context);
891
- rewritePatterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
892
- LiftIllegalVectorTransposeToMemory,
893
- LowerIllegalTransposeStoreViaZA>(context);
1001
+ rewritePatterns
1002
+ .add <FoldExtractFromVectorOfSMELikeCreateMasks,
1003
+ LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
1004
+ LowerIllegalTransposeStoreViaZA>(context);
894
1005
if (failed (
895
1006
applyPatternsGreedily (getOperation (), std::move (rewritePatterns))))
896
1007
return signalPassFailure ();
0 commit comments