Skip to content

Commit 3d01f0a

Browse files
authored
[mlir][gpu] Add 'cluster_stride' attribute to gpu.subgroup_reduce (#107142)
Follow-up to 7aa22f0, adding an additional attribute needed in some applications.
1 parent 2c3da17 commit 3d01f0a

File tree

6 files changed

+144
-61
lines changed

6 files changed

+144
-61
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,10 +1200,12 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
12001200
The `subgroup_reduce` op reduces the values of lanes (work items) across a
12011201
subgroup.
12021202

1203-
The subgroup is divided into clusters of `cluster_size` contiguous lanes
1204-
each, and a reduction is done for every lane of each cluster (in parallel).
1205-
The result is equal for all lanes in a cluster. When `cluster_size` is
1206-
omitted, there is a single cluster covering the entire subgroup.
1203+
The subgroup is divided into clusters starting at lane index 0. Within each
1204+
cluster, there are `size` lanes, and the lane index advances by `stride`.
1205+
A reduction is done for each cluster in parallel: every lane in the cluster
1206+
is reduced, and the result is equal for all lanes in the cluster. If `size`
1207+
is omitted, there is a single cluster covering the entire subgroup. If
1208+
`stride` is omitted, the stride is 1 (the cluster's lanes are contiguous).
12071209

12081210
When the reduced value is of a vector type, each vector element is reduced
12091211
independently. Only 1-d vector types are allowed.
@@ -1213,7 +1215,8 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
12131215
```mlir
12141216
%1 = gpu.subgroup_reduce add %a : (f32) -> f32
12151217
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> vector<4xf16>
1216-
%3 = gpu.subgroup_reduce add %c cluster_size(4) : (f32) -> f32
1218+
%3 = gpu.subgroup_reduce add %c cluster(size = 4) : (f32) -> f32
1219+
%3 = gpu.subgroup_reduce add %c cluster(size = 4, stride = 2) : (f32) -> f32
12171220
```
12181221

12191222
If `uniform` flag is set either none or all lanes of a subgroup need to execute
@@ -1230,27 +1233,38 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
12301233
AnyIntegerOrFloatOr1DVector:$value,
12311234
GPU_AllReduceOperationAttr:$op,
12321235
UnitAttr:$uniform,
1233-
OptionalAttr<I32Attr>:$cluster_size
1236+
OptionalAttr<I32Attr>:$cluster_size,
1237+
DefaultValuedAttr<I32Attr,"1">:$cluster_stride
12341238
);
12351239
let results = (outs AnyIntegerOrFloatOr1DVector:$result);
12361240

12371241
let builders = [
12381242
OpBuilder<(ins "Value":$value,
12391243
"::mlir::gpu::AllReduceOperation":$op,
12401244
"bool":$uniform), [{
1241-
build($_builder, $_state, value, op, uniform, /*cluster_size=*/ nullptr);
1245+
build($_builder, $_state, value, op, uniform, std::nullopt);
12421246
}]>,
12431247
OpBuilder<(ins "Value":$value,
12441248
"::mlir::gpu::AllReduceOperation":$op,
12451249
"bool":$uniform,
12461250
"std::optional<uint32_t>":$cluster_size), [{
1247-
build($_builder, $_state, value, op, uniform, cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr);
1251+
build($_builder, $_state, value, op, uniform,
1252+
cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr);
1253+
}]>,
1254+
OpBuilder<(ins "Value":$value,
1255+
"::mlir::gpu::AllReduceOperation":$op,
1256+
"bool":$uniform,
1257+
"std::optional<uint32_t>":$cluster_size,
1258+
"uint32_t":$cluster_stride), [{
1259+
build($_builder, $_state, value, op, uniform,
1260+
cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr,
1261+
cluster_stride);
12481262
}]>
12491263
];
12501264

12511265
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
12521266
(`uniform` $uniform^)?
1253-
(`cluster_size` `(` $cluster_size^ `)`)?
1267+
(`cluster` `(` `size` `=` $cluster_size^ (`,` `stride` `=` $cluster_stride^)? `)`)?
12541268
attr-dict
12551269
`:` functional-type(operands, results) }];
12561270

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,14 +621,25 @@ LogicalResult gpu::SubgroupReduceOp::verify() {
621621
<< getType();
622622
}
623623

624-
if (auto clusterSize = getClusterSize()) {
624+
auto clusterSize = getClusterSize();
625+
if (clusterSize) {
625626
uint32_t size = *clusterSize;
626627
if (!llvm::isPowerOf2_32(size)) {
627628
return emitOpError() << "cluster size " << size
628629
<< " is not a power of two";
629630
}
630631
}
631632

633+
uint32_t stride = getClusterStride();
634+
if (stride != 1 && !clusterSize) {
635+
return emitOpError() << "cluster stride can only be specified if cluster "
636+
"size is specified";
637+
}
638+
if (!llvm::isPowerOf2_32(stride)) {
639+
return emitOpError() << "cluster stride " << stride
640+
<< " is not a power of two";
641+
}
642+
632643
return success();
633644
}
634645

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
5050

5151
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
5252
PatternRewriter &rewriter) const override {
53-
std::optional<uint32_t> clusterSize = op.getClusterSize();
54-
5553
auto vecTy = dyn_cast<VectorType>(op.getType());
5654
if (!vecTy || vecTy.getNumElements() < 2)
5755
return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
@@ -97,7 +95,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
9795
}
9896

9997
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
100-
loc, extracted, op.getOp(), op.getUniform(), clusterSize);
98+
loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
99+
op.getClusterStride());
101100
if (numElems == 1) {
102101
res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
103102
continue;
@@ -129,8 +128,6 @@ struct ScalarizeSingleElementReduce final
129128

130129
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
131130
PatternRewriter &rewriter) const override {
132-
std::optional<uint32_t> clusterSize = op.getClusterSize();
133-
134131
auto vecTy = dyn_cast<VectorType>(op.getType());
135132
if (!vecTy || vecTy.getNumElements() != 1)
136133
return rewriter.notifyMatchFailure(op, "not a single-element reduction");
@@ -140,34 +137,64 @@ struct ScalarizeSingleElementReduce final
140137
Location loc = op.getLoc();
141138
Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
142139
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
143-
loc, extracted, op.getOp(), op.getUniform(), clusterSize);
140+
loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
141+
op.getClusterStride());
144142
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
145143
return success();
146144
}
147145
};
148146

147+
struct ClusterInfo {
148+
unsigned clusterStride;
149+
unsigned clusterSize;
150+
unsigned subgroupSize;
151+
};
152+
153+
static FailureOr<ClusterInfo>
154+
getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
155+
assert(llvm::isPowerOf2_32(subgroupSize));
156+
157+
std::optional<uint32_t> clusterSize = op.getClusterSize();
158+
assert(!clusterSize ||
159+
llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
160+
if (clusterSize && *clusterSize > subgroupSize)
161+
return op.emitOpError()
162+
<< "cluster size " << *clusterSize
163+
<< " is greater than subgroup size " << subgroupSize;
164+
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
165+
166+
auto clusterStride = op.getClusterStride();
167+
assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
168+
if (clusterStride >= subgroupSize)
169+
return op.emitOpError()
170+
<< "cluster stride " << clusterStride
171+
<< " is not less than subgroup size " << subgroupSize;
172+
173+
return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
174+
}
175+
149176
/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
150177
/// and `unpackFn` to convert to the native shuffle type and to the reduction
151178
/// type, respectively. For example, with `input` of type `f16`, `packFn` could
152179
/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
153180
/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
154181
/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
155-
/// `clusterSize` lanes, reducing all lanes in each cluster in parallel.
156-
static Value createSubgroupShuffleReduction(
157-
OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
158-
unsigned clusterSize, unsigned subgroupSize,
159-
function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
160-
assert(llvm::isPowerOf2_32(clusterSize));
161-
assert(llvm::isPowerOf2_32(subgroupSize));
162-
assert(clusterSize <= subgroupSize);
182+
/// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
183+
/// lanes within a cluster, reducing all lanes in each cluster in parallel.
184+
Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
185+
Value input, gpu::AllReduceOperation mode,
186+
const ClusterInfo &ci,
187+
function_ref<Value(Value)> packFn,
188+
function_ref<Value(Value)> unpackFn) {
163189
// Lane value always stays in the original type. We use it to perform arith
164190
// reductions.
165191
Value laneVal = input;
166192
// Parallel reduction using butterfly shuffles.
167-
for (unsigned i = 1; i < clusterSize; i <<= 1) {
193+
for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
194+
i <<= 1) {
168195
Value shuffled = builder
169196
.create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
170-
/*width=*/subgroupSize,
197+
/*width=*/ci.subgroupSize,
171198
/*mode=*/gpu::ShuffleMode::XOR)
172199
.getShuffleResult();
173200
laneVal = vector::makeArithReduction(builder, loc,
@@ -190,12 +217,9 @@ struct ScalarSubgroupReduceToShuffles final
190217

191218
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
192219
PatternRewriter &rewriter) const override {
193-
std::optional<uint32_t> clusterSize = op.getClusterSize();
194-
if (clusterSize && *clusterSize > subgroupSize)
195-
return op.emitOpError()
196-
<< "cluster size " << *clusterSize
197-
<< " is greater than subgroup size " << subgroupSize;
198-
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
220+
auto ci = getAndValidateClusterInfo(op, subgroupSize);
221+
if (failed(ci))
222+
return failure();
199223

200224
Type valueTy = op.getType();
201225
unsigned elemBitwidth =
@@ -209,9 +233,8 @@ struct ScalarSubgroupReduceToShuffles final
209233
if (elemBitwidth == shuffleBitwidth) {
210234
auto identityFn = [](Value v) { return v; };
211235
rewriter.replaceOp(op, createSubgroupShuffleReduction(
212-
rewriter, loc, op.getValue(), op.getOp(),
213-
effectiveClusterSize, subgroupSize, identityFn,
214-
identityFn));
236+
rewriter, loc, op.getValue(), op.getOp(), *ci,
237+
identityFn, identityFn));
215238
return success();
216239
}
217240

@@ -232,8 +255,7 @@ struct ScalarSubgroupReduceToShuffles final
232255

233256
rewriter.replaceOp(
234257
op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
235-
op.getOp(), effectiveClusterSize,
236-
subgroupSize, packFn, unpackFn));
258+
op.getOp(), *ci, packFn, unpackFn));
237259
return success();
238260
}
239261

@@ -253,12 +275,9 @@ struct VectorSubgroupReduceToShuffles final
253275

254276
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
255277
PatternRewriter &rewriter) const override {
256-
std::optional<uint32_t> clusterSize = op.getClusterSize();
257-
if (clusterSize && *clusterSize > subgroupSize)
258-
return op.emitOpError()
259-
<< "cluster size " << *clusterSize
260-
<< " is greater than subgroup size " << subgroupSize;
261-
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
278+
auto ci = getAndValidateClusterInfo(op, subgroupSize);
279+
if (failed(ci))
280+
return failure();
262281

263282
auto vecTy = dyn_cast<VectorType>(op.getType());
264283
if (!vecTy)
@@ -308,9 +327,8 @@ struct VectorSubgroupReduceToShuffles final
308327
return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
309328
};
310329

311-
Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput,
312-
op.getOp(), effectiveClusterSize,
313-
subgroupSize, packFn, unpackFn);
330+
Value res = createSubgroupShuffleReduction(
331+
rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
314332

315333
if (vecBitwidth < shuffleBitwidth) {
316334
res = rewriter.create<vector::ExtractStridedSliceOp>(

mlir/test/Dialect/GPU/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ func.func @subgroup_reduce_cluster_size_1() {
255255
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2)
256256
threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) {
257257
%1 = "test.test2"() : () -> i32
258-
%2 = gpu.subgroup_reduce add %1 cluster_size(1) : (i32) -> (i32)
258+
%2 = gpu.subgroup_reduce add %1 cluster(size=1) : (i32) -> (i32)
259259
"test.test3"(%2) : (i32) -> ()
260260
gpu.terminator
261261
}

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,18 +335,35 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
335335

336336
func.func @subgroup_reduce_zero_cluster_size(%arg0 : vector<4xf32>) {
337337
// expected-error@+1 {{cluster size 0 is not a power of two}}
338-
%res = gpu.subgroup_reduce add %arg0 cluster_size(0) : (vector<4xf32>) -> vector<4xf32>
338+
%res = gpu.subgroup_reduce add %arg0 cluster(size = 0) : (vector<4xf32>) -> vector<4xf32>
339339
return
340340
}
341341

342342
// -----
343343

344344
func.func @subgroup_reduce_npot_cluster_size(%arg0 : vector<4xf32>) {
345345
// expected-error@+1 {{cluster size 3 is not a power of two}}
346-
%res = gpu.subgroup_reduce add %arg0 cluster_size(3) : (vector<4xf32>) -> vector<4xf32>
346+
%res = gpu.subgroup_reduce add %arg0 cluster(size = 3) : (vector<4xf32>) -> vector<4xf32>
347347
return
348348
}
349349

350+
// -----
351+
352+
func.func @subgroup_reduce_zero_cluster_stride(%arg0 : vector<4xf32>) {
353+
// expected-error@+1 {{cluster stride 0 is not a power of two}}
354+
%res = gpu.subgroup_reduce add %arg0 cluster(size = 4, stride = 0) : (vector<4xf32>) -> vector<4xf32>
355+
return
356+
}
357+
358+
// -----
359+
360+
func.func @subgroup_reduce_cluster_stride_without_size(%arg0 : vector<4xf32>) {
361+
// expected-error@+1 {{cluster stride can only be specified if cluster size is specified}}
362+
%res = gpu.subgroup_reduce add %arg0 { cluster_stride = 2 : i32 } : (vector<4xf32>) -> vector<4xf32>
363+
return
364+
}
365+
366+
350367
// -----
351368

352369
func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {

0 commit comments

Comments
 (0)