Skip to content

Commit 8063622

Browse files
authored
[mlir][vector] Allow vector distribution with multiple written elements (#75122)
Add a configuration option to allow vector distribution with multiple elements written by a single lane. This is so that we can perform vector multi-reduction with multiple results per workgroup.
1 parent 42e4967 commit 8063622

File tree

4 files changed

+123
-17
lines changed

4 files changed

+123
-17
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
4343
using DistributionMapFn = std::function<AffineMap(Value)>;
4444

4545
/// Distribute transfer_write ops based on the affine map returned by
46-
/// `distributionMapFn`.
46+
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
47+
/// will not be distributed (it should be less than the warp size).
48+
///
4749
/// Example:
4850
/// ```
4951
/// %0 = vector.warp_execute_on_lane_0(%id){
@@ -67,7 +69,7 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
6769
/// distribute, meaning writes should propagate first.
6870
void populateDistributeTransferWriteOpPatterns(
6971
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
70-
PatternBenefit benefit = 2);
72+
unsigned maxNumElementsToExtract, PatternBenefit benefit = 2);
7173

7274
/// Move scalar operations with no dependency on the warp op outside of the
7375
/// region.

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Interfaces/SideEffectInterfaces.h"
1717
#include "mlir/Transforms/RegionUtils.h"
1818
#include "llvm/ADT/SetVector.h"
19+
#include "llvm/Support/FormatVariadic.h"
1920
#include <numeric>
2021
#include <utility>
2122

@@ -458,7 +459,9 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
458459
}
459460

460461
/// Distribute transfer_write ops based on the affine map returned by
461-
/// `distributionMapFn`.
462+
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
463+
/// will not be distributed (it should be less than the warp size).
464+
///
462465
/// Example:
463466
/// ```
464467
/// %0 = vector.warp_execute_on_lane_0(%id){
@@ -476,9 +479,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
476479
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
477480
struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
478481
WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
479-
PatternBenefit b = 1)
482+
unsigned maxNumElementsToExtract, PatternBenefit b = 1)
480483
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
481-
distributionMapFn(std::move(fn)) {}
484+
distributionMapFn(std::move(fn)),
485+
maxNumElementsToExtract(maxNumElementsToExtract) {}
482486

483487
/// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
484488
/// are multiples of the distribution ratio are supported at the moment.
@@ -553,10 +557,13 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
553557
Location loc = writeOp.getLoc();
554558
VectorType vecType = writeOp.getVectorType();
555559

556-
// Only sink out vector of 1 element for now to not serialize large vector
557-
// store. This can later be controlled by user.
558-
if (vecType.getNumElements() != 1)
559-
return failure();
560+
if (vecType.getNumElements() > maxNumElementsToExtract) {
561+
return rewriter.notifyMatchFailure(
562+
warpOp,
563+
llvm::formatv(
564+
"writes more elements ({0}) than allowed to extract ({1})",
565+
vecType.getNumElements(), maxNumElementsToExtract));
566+
}
560567

561568
// Do not process warp ops that contain only TransferWriteOps.
562569
if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
@@ -616,6 +623,7 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
616623

617624
private:
618625
DistributionMapFn distributionMapFn;
626+
unsigned maxNumElementsToExtract = 1;
619627
};
620628

621629
/// Sink out elementwise op feeding into a warp op yield.
@@ -1833,9 +1841,9 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
18331841

18341842
void mlir::vector::populateDistributeTransferWriteOpPatterns(
18351843
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1836-
PatternBenefit benefit) {
1844+
unsigned maxNumElementsToExtract, PatternBenefit benefit) {
18371845
patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1838-
benefit);
1846+
maxNumElementsToExtract, benefit);
18391847
}
18401848

18411849
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
1-
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
2-
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
3-
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
4-
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
5-
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
1+
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
2+
// RUN: --test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
3+
4+
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
5+
// RUN: --test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
6+
7+
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
8+
// RUN: --test-vector-warp-distribute="hoist-uniform distribute-transfer-write max-transfer-write-elements=4" \
9+
// RUN: | FileCheck --check-prefixes=CHECK-D %s
10+
11+
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
12+
// RUN: --test-vector-warp-distribute=propagate-distribution --canonicalize \
13+
// RUN: | FileCheck --check-prefixes=CHECK-PROP %s
14+
15+
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
16+
// RUN: --test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
17+
// RUN: --canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
618

719
// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
820
// CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)>
@@ -134,6 +146,84 @@ func.func @warp_extract(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : ind
134146

135147
// -----
136148

149+
// Check that we can distribute writes of the maximum allowed number of elements.
150+
151+
// CHECK-D-LABEL: func @warp_extract_4_elems(
152+
// CHECK-D: %[[WARPOP:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4x1xf32>)
153+
// CHECK-D: "test.dummy_op"
154+
// CHECK-D: "test.dummy_op"
155+
// CHECK-D: vector.yield %{{.*}}, %{{.*}} : vector<4xf32>, vector<4x1xf32>
156+
// CHECK-D: }
157+
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
158+
// CHECK-D: vector.transfer_write %[[WARPOP]]#1, %{{.*}}[%{{.*}}] {{.*}} : vector<4x1xf32>
159+
// CHECK-D: }
160+
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
161+
// CHECK-D: vector.transfer_write %[[WARPOP]]#0, %{{.*}}[%{{.*}}] {{.*}} : vector<4xf32>
162+
// CHECK-D: }
163+
164+
func.func @warp_extract_4_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
165+
vector.warp_execute_on_lane_0(%laneid)[32] {
166+
%c0 = arith.constant 0 : index
167+
%v = "test.dummy_op"() : () -> (vector<4xf32>)
168+
%v1 = "test.dummy_op"() : () -> (vector<4x1xf32>)
169+
vector.transfer_write %v1, %arg1[%c0, %c0] : vector<4x1xf32>, memref<1024x1024xf32>
170+
vector.transfer_write %v, %arg1[%c0, %c0] : vector<4xf32>, memref<1024x1024xf32>
171+
}
172+
return
173+
}
174+
175+
// -----
176+
177+
// Check that we do not distribute writes larger than the maximum allowed
178+
// number of elements.
179+
180+
// CHECK-D-LABEL: func @warp_extract_5_elems(
181+
// CHECK-D: arith.constant 0 : index
182+
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
183+
// CHECK-D: %[[V:.+]] = "test.dummy_op"
184+
// CHECK-D: %[[V1:.+]] = "test.dummy_op"
185+
// CHECK-D: vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<5x1xf32>
186+
// CHECK-D: vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<5xf32>
187+
// CHECK-D: }
188+
189+
func.func @warp_extract_5_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
190+
vector.warp_execute_on_lane_0(%laneid)[32] {
191+
%c0 = arith.constant 0 : index
192+
%v = "test.dummy_op"() : () -> (vector<5xf32>)
193+
%v1 = "test.dummy_op"() : () -> (vector<5x1xf32>)
194+
vector.transfer_write %v1, %arg1[%c0, %c0] : vector<5x1xf32>, memref<1024x1024xf32>
195+
vector.transfer_write %v, %arg1[%c0, %c0] : vector<5xf32>, memref<1024x1024xf32>
196+
}
197+
return
198+
}
199+
200+
// -----
201+
202+
// Check that we do not distribute writes larger than the maximum allowed
203+
// number of elements, or multiples of the maximum number of elements.
204+
205+
// CHECK-D-LABEL: func @warp_extract_8_elems(
206+
// CHECK-D: arith.constant 0 : index
207+
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
208+
// CHECK-D: %[[V:.+]] = "test.dummy_op"
209+
// CHECK-D: %[[V1:.+]] = "test.dummy_op"
210+
// CHECK-D: vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<8x1xf32>
211+
// CHECK-D: vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<8xf32>
212+
// CHECK-D: }
213+
214+
func.func @warp_extract_8_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
215+
vector.warp_execute_on_lane_0(%laneid)[32] {
216+
%c0 = arith.constant 0 : index
217+
%v = "test.dummy_op"() : () -> (vector<8xf32>)
218+
%v1 = "test.dummy_op"() : () -> (vector<8x1xf32>)
219+
vector.transfer_write %v1, %arg1[%c0, %c0] : vector<8x1xf32>, memref<1024x1024xf32>
220+
vector.transfer_write %v, %arg1[%c0, %c0] : vector<8xf32>, memref<1024x1024xf32>
221+
}
222+
return
223+
}
224+
225+
// -----
226+
137227
// CHECK-PROP-LABEL: func @warp_dead_result(
138228
func.func @warp_dead_result(%laneid: index) -> (vector<1xf32>) {
139229
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,11 @@ struct TestVectorDistribution
568568
llvm::cl::desc("Test distribution of transfer write"),
569569
llvm::cl::init(false)};
570570

571+
Option<unsigned> maxTransferWriteElements{
572+
*this, "max-transfer-write-elements",
573+
llvm::cl::desc("Maximum number of transfer write elements to distribute"),
574+
llvm::cl::init(1)};
575+
571576
Option<bool> hoistUniform{*this, "hoist-uniform",
572577
llvm::cl::desc("Test hoist uniform"),
573578
llvm::cl::init(false)};
@@ -624,7 +629,8 @@ struct TestVectorDistribution
624629
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
625630
} else if (distributeTransferWriteOps) {
626631
RewritePatternSet patterns(ctx);
627-
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
632+
populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
633+
maxTransferWriteElements);
628634
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
629635
} else if (propagateDistribution) {
630636
RewritePatternSet patterns(ctx);

0 commit comments

Comments
 (0)