@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
724
724
}
725
725
};
726
726
727
- // / A rewrite to turn unit dim transpose-like vector.shape_casts into
728
- // / vector.transposes. The shape_cast has to be from an illegal vector type to a
729
- // / legal one (as defined by isLegalVectorType).
730
- // /
731
- // / The reasoning for this is if we've got to this pass and we still have
732
- // / shape_casts of illegal types, then they likely will not cancel out. Turning
733
- // / them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
734
- // / eliminate them.
735
- // /
736
- // / Example:
737
- // /
738
- // / BEFORE:
739
- // / ```mlir
740
- // / %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
741
- // / ```
742
- // /
743
- // / AFTER:
744
- // / ```mlir
745
- // / %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
746
- // / ```
747
- struct ConvertIllegalShapeCastOpsToTransposes
748
- : public OpRewritePattern<vector::ShapeCastOp> {
749
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
750
-
751
- LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
752
- PatternRewriter &rewriter) const override {
753
- auto sourceType = shapeCastOp.getSourceVectorType ();
754
- auto resultType = shapeCastOp.getResultVectorType ();
755
- if (isLegalVectorType (sourceType) || !isLegalVectorType (resultType))
756
- return rewriter.notifyMatchFailure (shapeCastOp,
757
- kMatchFailureNotIllegalToLegal );
758
-
759
- // Note: If we know that `sourceType` is an illegal vector type (and 2D)
760
- // then dim 0 is scalable and dim 1 is fixed.
761
- if (sourceType.getRank () != 2 || sourceType.getDimSize (1 ) != 1 )
762
- return rewriter.notifyMatchFailure (
763
- shapeCastOp, " expected source to be a 2D scalable vector with a "
764
- " trailing unit dim" );
765
-
766
- auto loc = shapeCastOp.getLoc ();
767
- auto transpose = rewriter.create <vector::TransposeOp>(
768
- loc, shapeCastOp.getSource (), ArrayRef<int64_t >{1 , 0 });
769
-
770
- if (resultType.getRank () == 1 )
771
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(shapeCastOp, resultType,
772
- transpose);
773
- else
774
- rewriter.replaceOp (shapeCastOp, transpose);
775
-
776
- return success ();
777
- }
778
- };
779
-
780
727
// / Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
781
728
// / the ZA state. This workaround rewrite to support these transposes when ZA is
782
729
// / available.
@@ -920,6 +867,116 @@ struct LowerIllegalTransposeStoreViaZA
920
867
}
921
868
};
922
869
870
+ // / Lower `vector.transfer_read` of a scalable column to `scf::for`
871
+ // /
872
+ // / Lowers a "read" of a scalable column from a MemRef for which there is no
873
+ // / hardware pperation that we could use to a loop over the rows to read and
874
+ // / loads one element at a time.
875
+ // /
876
+ // / BEFORE:
877
+ // / ```
878
+ // / %res = vector.transfer_read %mem[%a, %b] (...)
879
+ // / : memref<?x?xf32>, vector<[4]x1xf32>
880
+ // / ```
881
+ // /
882
+ // / AFTER:
883
+ // / ```
884
+ // / %cst = arith.constant (...) : vector<[4]xf32>
885
+ // / %vscale = vector.vscale
886
+ // / %c4_vscale = arith.muli %vscale, %c4 : index
887
+ // / %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
888
+ // / -> (vector<[4]xf32>) {
889
+ // /
890
+ // / %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
891
+ // / %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
892
+ // / scf.yield %vec : vector<[4]xf32>
893
+ // / }
894
+ // / %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
895
+ // / ```
896
+ // /
897
+ // / TODO: This transformation isn't specific to SME - move it to the SVE
898
+ // / dialect.
899
+ // / TODO: Check the in_bounds attribute and generate vector.maskedload if
900
+ // / required.
901
+ struct LowerColumnTransferReadToLoops
902
+ : public OpRewritePattern<vector::TransferReadOp> {
903
+ using OpRewritePattern::OpRewritePattern;
904
+
905
+ LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
906
+ PatternRewriter &rewriter) const override {
907
+ // NOTE: This is a fairly low-level transformation, so we shouldn't be
908
+ // adding support for Tensors without good rationale.
909
+ if (readOp.hasPureTensorSemantics ())
910
+ return rewriter.notifyMatchFailure (
911
+ readOp, " Tensor semantics are unsupported (either bufferize or "
912
+ " extend this pattern)" );
913
+
914
+ auto resType = readOp.getVectorType ();
915
+
916
+ if (resType.getRank () != 2 )
917
+ return rewriter.notifyMatchFailure (readOp,
918
+ " Only 2D vectors are supported!" );
919
+
920
+ if (resType.getShape ()[1 ] != 1 )
921
+ return rewriter.notifyMatchFailure (
922
+ readOp, " The trailing output dim is != 1 (not supported ATM)" );
923
+
924
+ if (!resType.getScalableDims ()[0 ] || resType.getScalableDims ()[1 ])
925
+ return rewriter.notifyMatchFailure (
926
+ readOp, " Expected the leading dim to be scalable and the trailing "
927
+ " dim to be fixed." );
928
+
929
+ // Create new result type - similar to the original vector with the
930
+ // trailing unit dim collapsed.
931
+ int64_t numRows = resType.getShape ()[0 ];
932
+ VectorType newResType = VectorType::get (numRows, resType.getElementType (),
933
+ /* scalableDims=*/ {true });
934
+
935
+ // Create a loop over all rows and load one element at a time.
936
+ auto loc = readOp.getLoc ();
937
+ auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
938
+ auto createVscaleMultiple =
939
+ vector::makeVscaleConstantBuilder (rewriter, loc);
940
+ auto upperBound = createVscaleMultiple (numRows);
941
+ auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
942
+ Value init = rewriter.create <arith::ConstantOp>(
943
+ loc, newResType, DenseElementsAttr::get (newResType, 0 .0f ));
944
+
945
+ scf::ForOp loadLoop;
946
+ {
947
+ OpBuilder::InsertionGuard g (rewriter);
948
+ loadLoop = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step,
949
+ ValueRange{init});
950
+ rewriter.setInsertionPointToStart (loadLoop.getBody ());
951
+
952
+ auto tileSliceIndex = loadLoop.getInductionVar ();
953
+
954
+ auto idx0 = rewriter.create <arith::AddIOp>(loc, tileSliceIndex,
955
+ readOp.getIndices ()[0 ]);
956
+ auto idx1 = readOp.getIndices ()[1 ];
957
+
958
+ Value scalar = rewriter.create <memref::LoadOp>(
959
+ loc, readOp.getBase (), SmallVector<Value>({idx0, idx1}));
960
+
961
+ Operation *updateInit = rewriter.create <vector::InsertOp>(
962
+ loc, scalar, loadLoop.getRegionIterArg (0 ), tileSliceIndex);
963
+
964
+ rewriter.create <scf::YieldOp>(loc, updateInit->getResult (0 ));
965
+ }
966
+
967
+ // The read operation has been "legalized", but since the original result
968
+ // type was a 2D vector, we need to cast before returning the result. This
969
+ // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
970
+ // no-op).
971
+ auto sc = rewriter.create <vector::ShapeCastOp>(
972
+ loc, readOp.getResult ().getType (), loadLoop.getResult (0 ));
973
+
974
+ rewriter.replaceOp (readOp, sc);
975
+
976
+ return success ();
977
+ }
978
+ };
979
+
923
980
struct VectorLegalizationPass
924
981
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
925
982
void runOnOperation () override {
@@ -941,10 +998,10 @@ struct VectorLegalizationPass
941
998
942
999
// Apply preprocessing patterns.
943
1000
RewritePatternSet rewritePatterns (context);
944
- rewritePatterns. add <FoldExtractFromVectorOfSMELikeCreateMasks,
945
- LiftIllegalVectorTransposeToMemory ,
946
- ConvertIllegalShapeCastOpsToTransposes ,
947
- LowerIllegalTransposeStoreViaZA>(context);
1001
+ rewritePatterns
1002
+ . add <FoldExtractFromVectorOfSMELikeCreateMasks ,
1003
+ LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory ,
1004
+ LowerIllegalTransposeStoreViaZA>(context);
948
1005
if (failed (
949
1006
applyPatternsGreedily (getOperation (), std::move (rewritePatterns))))
950
1007
return signalPassFailure ();
0 commit comments