@@ -2842,6 +2842,20 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
2842
2842
return ops;
2843
2843
}
2844
2844
2845
+ // / Converts TransferRead op used by ExtractMap op into a smaller dimension
2846
+ // / TransferRead.
2847
+ // / Example:
2848
+ // / ```
2849
+ // / %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
2850
+ // / memref<64x64x64xf32>, vector<64x4x32xf32>
2851
+ // / %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
2852
+ // / ```
2853
+ // / to:
2854
+ // / ```
2855
+ // / %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
2856
+ // / %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
2857
+ // / memref<64x64x64xf32>, vector<2x4x1xf32>
2858
+ // / ```
2845
2859
struct TransferReadExtractPattern
2846
2860
: public OpRewritePattern<vector::TransferReadOp> {
2847
2861
TransferReadExtractPattern (MLIRContext *context)
@@ -2858,18 +2872,23 @@ struct TransferReadExtractPattern
2858
2872
return failure ();
2859
2873
2860
2874
SmallVector<Value, 4 > indices (read.indices ().begin (), read.indices ().end ());
2861
- AffineMap map = extract.map ();
2875
+ AffineMap indexMap = extract.map (). compose (read. permutation_map () );
2862
2876
unsigned idCount = 0 ;
2863
2877
ImplicitLocOpBuilder lb (read.getLoc (), rewriter);
2864
- for (auto expr : map.getResults ()) {
2878
+ for (auto it :
2879
+ llvm::zip (indexMap.getResults (), extract.map ().getResults ())) {
2865
2880
AffineExpr d0, d1;
2866
2881
bindDims (read.getContext (), d0, d1);
2867
- unsigned pos = expr.cast <AffineDimExpr>().getPosition ();
2882
+ auto indexExpr = std::get<0 >(it).dyn_cast <AffineDimExpr>();
2883
+ if (!indexExpr)
2884
+ continue ;
2885
+ unsigned indexPos = indexExpr.getPosition ();
2886
+ unsigned vectorPos = std::get<1 >(it).cast <AffineDimExpr>().getPosition ();
2868
2887
auto scale = getAffineConstantExpr (
2869
- extract.getResultType ().getDimSize (pos ), read.getContext ());
2870
- indices[pos ] =
2871
- makeComposedAffineApply ( rewriter, read.getLoc (), d0 + scale * d1,
2872
- {indices[pos ], extract.ids ()[idCount++]});
2888
+ extract.getResultType ().getDimSize (vectorPos ), read.getContext ());
2889
+ indices[indexPos ] = makeComposedAffineApply (
2890
+ rewriter, read.getLoc (), d0 + scale * d1,
2891
+ {indices[indexPos ], extract.ids ()[idCount++]});
2873
2892
}
2874
2893
Value newRead = lb.create <vector::TransferReadOp>(
2875
2894
extract.getType (), read.source (), indices, read.permutation_map (),
@@ -2895,18 +2914,24 @@ struct TransferWriteInsertPattern
2895
2914
return failure ();
2896
2915
SmallVector<Value, 4 > indices (write.indices ().begin (),
2897
2916
write.indices ().end ());
2898
- AffineMap map = insert.map ();
2917
+ AffineMap indexMap = insert.map (). compose (write. permutation_map () );
2899
2918
unsigned idCount = 0 ;
2900
2919
Location loc = write.getLoc ();
2901
- for (auto expr : map.getResults ()) {
2920
+ for (auto it :
2921
+ llvm::zip (indexMap.getResults (), insert.map ().getResults ())) {
2902
2922
AffineExpr d0, d1;
2903
2923
bindDims (write.getContext (), d0, d1);
2904
- unsigned pos = expr.cast <AffineDimExpr>().getPosition ();
2924
+ auto indexExpr = std::get<0 >(it).dyn_cast <AffineDimExpr>();
2925
+ if (!indexExpr)
2926
+ continue ;
2927
+ unsigned indexPos = indexExpr.getPosition ();
2928
+ unsigned vectorPos = std::get<1 >(it).cast <AffineDimExpr>().getPosition ();
2905
2929
auto scale = getAffineConstantExpr (
2906
- insert.getSourceVectorType ().getDimSize (pos), write.getContext ());
2907
- indices[pos] =
2930
+ insert.getSourceVectorType ().getDimSize (vectorPos),
2931
+ write.getContext ());
2932
+ indices[indexPos] =
2908
2933
makeComposedAffineApply (rewriter, loc, d0 + scale * d1,
2909
- {indices[pos ], insert.ids ()[idCount++]});
2934
+ {indices[indexPos ], insert.ids ()[idCount++]});
2910
2935
}
2911
2936
rewriter.create <vector::TransferWriteOp>(
2912
2937
loc, insert.vector (), write.source (), indices, write.permutation_map (),
0 commit comments