@@ -50,8 +50,6 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
50
50
51
51
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
52
52
PatternRewriter &rewriter) const override {
53
- std::optional<uint32_t > clusterSize = op.getClusterSize ();
54
-
55
53
auto vecTy = dyn_cast<VectorType>(op.getType ());
56
54
if (!vecTy || vecTy.getNumElements () < 2 )
57
55
return rewriter.notifyMatchFailure (op, " not a multi-element reduction" );
@@ -97,7 +95,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
97
95
}
98
96
99
97
Value reduce = rewriter.create <gpu::SubgroupReduceOp>(
100
- loc, extracted, op.getOp (), op.getUniform (), clusterSize);
98
+ loc, extracted, op.getOp (), op.getUniform (), op.getClusterSize (),
99
+ op.getClusterStride ());
101
100
if (numElems == 1 ) {
102
101
res = rewriter.create <vector::InsertOp>(loc, reduce, res, startIdx);
103
102
continue ;
@@ -129,8 +128,6 @@ struct ScalarizeSingleElementReduce final
129
128
130
129
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
131
130
PatternRewriter &rewriter) const override {
132
- std::optional<uint32_t > clusterSize = op.getClusterSize ();
133
-
134
131
auto vecTy = dyn_cast<VectorType>(op.getType ());
135
132
if (!vecTy || vecTy.getNumElements () != 1 )
136
133
return rewriter.notifyMatchFailure (op, " not a single-element reduction" );
@@ -140,34 +137,64 @@ struct ScalarizeSingleElementReduce final
140
137
Location loc = op.getLoc ();
141
138
Value extracted = rewriter.create <vector::ExtractOp>(loc, op.getValue (), 0 );
142
139
Value reduce = rewriter.create <gpu::SubgroupReduceOp>(
143
- loc, extracted, op.getOp (), op.getUniform (), clusterSize);
140
+ loc, extracted, op.getOp (), op.getUniform (), op.getClusterSize (),
141
+ op.getClusterStride ());
144
142
rewriter.replaceOpWithNewOp <vector::BroadcastOp>(op, vecTy, reduce);
145
143
return success ();
146
144
}
147
145
};
148
146
147
+ struct ClusterInfo {
148
+ unsigned clusterStride;
149
+ unsigned clusterSize;
150
+ unsigned subgroupSize;
151
+ };
152
+
153
+ static FailureOr<ClusterInfo>
154
+ getAndValidateClusterInfo (gpu::SubgroupReduceOp op, unsigned subgroupSize) {
155
+ assert (llvm::isPowerOf2_32 (subgroupSize));
156
+
157
+ std::optional<uint32_t > clusterSize = op.getClusterSize ();
158
+ assert (!clusterSize ||
159
+ llvm::isPowerOf2_32 (*clusterSize)); // Verifier should've caught this.
160
+ if (clusterSize && *clusterSize > subgroupSize)
161
+ return op.emitOpError ()
162
+ << " cluster size " << *clusterSize
163
+ << " is greater than subgroup size " << subgroupSize;
164
+ unsigned effectiveClusterSize = clusterSize.value_or (subgroupSize);
165
+
166
+ auto clusterStride = op.getClusterStride ();
167
+ assert (llvm::isPowerOf2_32 (clusterStride)); // Verifier should've caught this.
168
+ if (clusterStride >= subgroupSize)
169
+ return op.emitOpError ()
170
+ << " cluster stride " << clusterStride
171
+ << " is not less than subgroup size " << subgroupSize;
172
+
173
+ return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
174
+ }
175
+
149
176
// / Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
150
177
// / and `unpackFn` to convert to the native shuffle type and to the reduction
151
178
// / type, respectively. For example, with `input` of type `f16`, `packFn` could
152
179
// / build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
153
180
// / would cast it back to `f16` to perform arithmetic reduction on. Assumes that
154
181
// / the subgroup is `subgroupSize` lanes wide and divides it into clusters of
155
- // / `clusterSize` lanes, reducing all lanes in each cluster in parallel.
156
- static Value createSubgroupShuffleReduction (
157
- OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
158
- unsigned clusterSize, unsigned subgroupSize,
159
- function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
160
- assert (llvm::isPowerOf2_32 (clusterSize));
161
- assert (llvm::isPowerOf2_32 (subgroupSize));
162
- assert (clusterSize <= subgroupSize);
182
+ // / `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
183
+ // / lanes within a cluster, reducing all lanes in each cluster in parallel.
184
+ Value createSubgroupShuffleReduction (OpBuilder &builder, Location loc,
185
+ Value input, gpu::AllReduceOperation mode,
186
+ const ClusterInfo &ci,
187
+ function_ref<Value(Value)> packFn,
188
+ function_ref<Value(Value)> unpackFn) {
163
189
// Lane value always stays in the original type. We use it to perform arith
164
190
// reductions.
165
191
Value laneVal = input;
166
192
// Parallel reduction using butterfly shuffles.
167
- for (unsigned i = 1 ; i < clusterSize; i <<= 1 ) {
193
+ for (unsigned i = ci.clusterStride ; i < ci.clusterStride * ci.clusterSize ;
194
+ i <<= 1 ) {
168
195
Value shuffled = builder
169
196
.create <gpu::ShuffleOp>(loc, packFn (laneVal), i,
170
- /* width=*/ subgroupSize,
197
+ /* width=*/ ci. subgroupSize ,
171
198
/* mode=*/ gpu::ShuffleMode::XOR)
172
199
.getShuffleResult ();
173
200
laneVal = vector::makeArithReduction (builder, loc,
@@ -190,12 +217,9 @@ struct ScalarSubgroupReduceToShuffles final
190
217
191
218
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
192
219
PatternRewriter &rewriter) const override {
193
- std::optional<uint32_t > clusterSize = op.getClusterSize ();
194
- if (clusterSize && *clusterSize > subgroupSize)
195
- return op.emitOpError ()
196
- << " cluster size " << *clusterSize
197
- << " is greater than subgroup size " << subgroupSize;
198
- unsigned effectiveClusterSize = clusterSize.value_or (subgroupSize);
220
+ auto ci = getAndValidateClusterInfo (op, subgroupSize);
221
+ if (failed (ci))
222
+ return failure ();
199
223
200
224
Type valueTy = op.getType ();
201
225
unsigned elemBitwidth =
@@ -209,9 +233,8 @@ struct ScalarSubgroupReduceToShuffles final
209
233
if (elemBitwidth == shuffleBitwidth) {
210
234
auto identityFn = [](Value v) { return v; };
211
235
rewriter.replaceOp (op, createSubgroupShuffleReduction (
212
- rewriter, loc, op.getValue (), op.getOp (),
213
- effectiveClusterSize, subgroupSize, identityFn,
214
- identityFn));
236
+ rewriter, loc, op.getValue (), op.getOp (), *ci,
237
+ identityFn, identityFn));
215
238
return success ();
216
239
}
217
240
@@ -232,8 +255,7 @@ struct ScalarSubgroupReduceToShuffles final
232
255
233
256
rewriter.replaceOp (
234
257
op, createSubgroupShuffleReduction (rewriter, loc, op.getValue (),
235
- op.getOp (), effectiveClusterSize,
236
- subgroupSize, packFn, unpackFn));
258
+ op.getOp (), *ci, packFn, unpackFn));
237
259
return success ();
238
260
}
239
261
@@ -253,12 +275,9 @@ struct VectorSubgroupReduceToShuffles final
253
275
254
276
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
255
277
PatternRewriter &rewriter) const override {
256
- std::optional<uint32_t > clusterSize = op.getClusterSize ();
257
- if (clusterSize && *clusterSize > subgroupSize)
258
- return op.emitOpError ()
259
- << " cluster size " << *clusterSize
260
- << " is greater than subgroup size " << subgroupSize;
261
- unsigned effectiveClusterSize = clusterSize.value_or (subgroupSize);
278
+ auto ci = getAndValidateClusterInfo (op, subgroupSize);
279
+ if (failed (ci))
280
+ return failure ();
262
281
263
282
auto vecTy = dyn_cast<VectorType>(op.getType ());
264
283
if (!vecTy)
@@ -308,9 +327,8 @@ struct VectorSubgroupReduceToShuffles final
308
327
return rewriter.create <vector::BitCastOp>(loc, extendedVecTy, asIntVec);
309
328
};
310
329
311
- Value res = createSubgroupShuffleReduction (rewriter, loc, extendedInput,
312
- op.getOp (), effectiveClusterSize,
313
- subgroupSize, packFn, unpackFn);
330
+ Value res = createSubgroupShuffleReduction (
331
+ rewriter, loc, extendedInput, op.getOp (), *ci, packFn, unpackFn);
314
332
315
333
if (vecBitwidth < shuffleBitwidth) {
316
334
res = rewriter.create <vector::ExtractStridedSliceOp>(
0 commit comments