@@ -210,13 +210,24 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
210
210
struct ScalarSubgroupReduceToShuffles final
211
211
: OpRewritePattern<gpu::SubgroupReduceOp> {
212
212
ScalarSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
213
- unsigned shuffleBitwidth,
213
+ unsigned shuffleBitwidth, bool matchClustered,
214
214
PatternBenefit benefit)
215
215
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
216
- shuffleBitwidth (shuffleBitwidth) {}
216
+ shuffleBitwidth (shuffleBitwidth), matchClustered(matchClustered) {}
217
217
218
218
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
219
219
PatternRewriter &rewriter) const override {
220
+ if (op.getClusterSize ().has_value () != matchClustered) {
221
+ if (matchClustered)
222
+ return rewriter.notifyMatchFailure (
223
+ op, " op is non-clustered but pattern is configured to only match "
224
+ " clustered ops" );
225
+ else
226
+ return rewriter.notifyMatchFailure (
227
+ op, " op is clustered but pattern is configured to only match "
228
+ " non-clustered ops" );
229
+ }
230
+
220
231
auto ci = getAndValidateClusterInfo (op, subgroupSize);
221
232
if (failed (ci))
222
233
return failure ();
@@ -262,19 +273,31 @@ struct ScalarSubgroupReduceToShuffles final
262
273
private:
263
274
unsigned subgroupSize = 0 ;
264
275
unsigned shuffleBitwidth = 0 ;
276
+ bool matchClustered;
265
277
};
266
278
267
279
// / Lowers vector gpu subgroup reductions to a series of shuffles.
268
280
struct VectorSubgroupReduceToShuffles final
269
281
: OpRewritePattern<gpu::SubgroupReduceOp> {
270
282
VectorSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
271
- unsigned shuffleBitwidth,
283
+ unsigned shuffleBitwidth, bool matchClustered,
272
284
PatternBenefit benefit)
273
285
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
274
- shuffleBitwidth (shuffleBitwidth) {}
286
+ shuffleBitwidth (shuffleBitwidth), matchClustered(matchClustered) {}
275
287
276
288
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
277
289
PatternRewriter &rewriter) const override {
290
+ if (op.getClusterSize ().has_value () != matchClustered) {
291
+ if (matchClustered)
292
+ return rewriter.notifyMatchFailure (
293
+ op, " op is non-clustered but pattern is configured to only match "
294
+ " clustered ops" );
295
+ else
296
+ return rewriter.notifyMatchFailure (
297
+ op, " op is clustered but pattern is configured to only match "
298
+ " non-clustered ops" );
299
+ }
300
+
278
301
auto ci = getAndValidateClusterInfo (op, subgroupSize);
279
302
if (failed (ci))
280
303
return failure ();
@@ -343,6 +366,7 @@ struct VectorSubgroupReduceToShuffles final
343
366
private:
344
367
unsigned subgroupSize = 0 ;
345
368
unsigned shuffleBitwidth = 0 ;
369
+ bool matchClustered;
346
370
};
347
371
} // namespace
348
372
@@ -358,5 +382,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
358
382
RewritePatternSet &patterns, unsigned subgroupSize,
359
383
unsigned shuffleBitwidth, PatternBenefit benefit) {
360
384
patterns.add <ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
361
- patterns.getContext (), subgroupSize, shuffleBitwidth, benefit);
385
+ patterns.getContext (), subgroupSize, shuffleBitwidth,
386
+ /* matchClustered=*/ false , benefit);
387
+ }
388
+
389
+ void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns (
390
+ RewritePatternSet &patterns, unsigned subgroupSize,
391
+ unsigned shuffleBitwidth, PatternBenefit benefit) {
392
+ patterns.add <ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
393
+ patterns.getContext (), subgroupSize, shuffleBitwidth,
394
+ /* matchClustered=*/ true , benefit);
362
395
}
0 commit comments