Skip to content

Commit 1244bca

Browse files
committed
[mlir][vector] Support distributing transfer op with permutation map
Differential Revision: https://reviews.llvm.org/D104263
1 parent c2e01ee commit 1244bca

File tree

2 files changed

+68
-13
lines changed

2 files changed

+68
-13
lines changed

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2842,6 +2842,20 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
28422842
return ops;
28432843
}
28442844

2845+
/// Converts TransferRead op used by ExtractMap op into a smaller dimension
2846+
/// TransferRead.
2847+
/// Example:
2848+
/// ```
2849+
/// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
2850+
/// memref<64x64x64xf32>, vector<64x4x32xf32>
2851+
/// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
2852+
/// ```
2853+
/// to:
2854+
/// ```
2855+
/// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
2856+
/// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
2857+
/// memref<64x64x64xf32>, vector<2x4x1xf32>
2858+
/// ```
28452859
struct TransferReadExtractPattern
28462860
: public OpRewritePattern<vector::TransferReadOp> {
28472861
TransferReadExtractPattern(MLIRContext *context)
@@ -2858,18 +2872,23 @@ struct TransferReadExtractPattern
28582872
return failure();
28592873

28602874
SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
2861-
AffineMap map = extract.map();
2875+
AffineMap indexMap = extract.map().compose(read.permutation_map());
28622876
unsigned idCount = 0;
28632877
ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
2864-
for (auto expr : map.getResults()) {
2878+
for (auto it :
2879+
llvm::zip(indexMap.getResults(), extract.map().getResults())) {
28652880
AffineExpr d0, d1;
28662881
bindDims(read.getContext(), d0, d1);
2867-
unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2882+
auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
2883+
if (!indexExpr)
2884+
continue;
2885+
unsigned indexPos = indexExpr.getPosition();
2886+
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
28682887
auto scale = getAffineConstantExpr(
2869-
extract.getResultType().getDimSize(pos), read.getContext());
2870-
indices[pos] =
2871-
makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
2872-
{indices[pos], extract.ids()[idCount++]});
2888+
extract.getResultType().getDimSize(vectorPos), read.getContext());
2889+
indices[indexPos] = makeComposedAffineApply(
2890+
rewriter, read.getLoc(), d0 + scale * d1,
2891+
{indices[indexPos], extract.ids()[idCount++]});
28732892
}
28742893
Value newRead = lb.create<vector::TransferReadOp>(
28752894
extract.getType(), read.source(), indices, read.permutation_map(),
@@ -2895,18 +2914,24 @@ struct TransferWriteInsertPattern
28952914
return failure();
28962915
SmallVector<Value, 4> indices(write.indices().begin(),
28972916
write.indices().end());
2898-
AffineMap map = insert.map();
2917+
AffineMap indexMap = insert.map().compose(write.permutation_map());
28992918
unsigned idCount = 0;
29002919
Location loc = write.getLoc();
2901-
for (auto expr : map.getResults()) {
2920+
for (auto it :
2921+
llvm::zip(indexMap.getResults(), insert.map().getResults())) {
29022922
AffineExpr d0, d1;
29032923
bindDims(write.getContext(), d0, d1);
2904-
unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2924+
auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
2925+
if (!indexExpr)
2926+
continue;
2927+
unsigned indexPos = indexExpr.getPosition();
2928+
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
29052929
auto scale = getAffineConstantExpr(
2906-
insert.getSourceVectorType().getDimSize(pos), write.getContext());
2907-
indices[pos] =
2930+
insert.getSourceVectorType().getDimSize(vectorPos),
2931+
write.getContext());
2932+
indices[indexPos] =
29082933
makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
2909-
{indices[pos], insert.ids()[idCount++]});
2934+
{indices[indexPos], insert.ids()[idCount++]});
29102935
}
29112936
rewriter.create<vector::TransferWriteOp>(
29122937
loc, insert.vector(), write.source(), indices, write.permutation_map(),

mlir/test/Dialect/Vector/vector-distribution.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,34 @@ func @vector_add_transfer_3d(%id0 : index, %id1 : index, %A: memref<64x64x64xf32
123123
return
124124
}
125125

126+
// -----
127+
128+
#map0 = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)>
129+
#map1 = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)>
130+
#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
131+
132+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
133+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)>
134+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)>
135+
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
126136

137+
// CHECK: func @vector_add_transfer_permutation
138+
// CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index
139+
// CHECK: %[[C0:.*]] = constant 0 : index
140+
// CHECK: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
141+
// CHECK-NEXT: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[ID2]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32>
142+
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID_0]], %[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP2]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32>
143+
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
144+
// CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
145+
// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]], %[[ID_1]], %[[C0]], %[[ID3]]] {permutation_map = #[[MAP3]]} : vector<2x4x1xf32>, memref<?x?x?x?xf32>
146+
// CHECK-NEXT: return
147+
func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?x?x?xf32>,
148+
%B: memref<?x?x?x?xf32>, %C: memref<?x?x?x?xf32>) {
149+
%c0 = constant 0 : index
150+
%cf0 = constant 0.0 : f32
151+
%a = vector.transfer_read %A[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<64x4x32xf32>
152+
%b = vector.transfer_read %B[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map1}: memref<?x?x?x?xf32>, vector<64x4x32xf32>
153+
%acc = addf %a, %b: vector<64x4x32xf32>
154+
vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
155+
return
156+
}

0 commit comments

Comments
 (0)