@@ -210,13 +210,21 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
210
210
struct ScalarSubgroupReduceToShuffles final
211
211
: OpRewritePattern<gpu::SubgroupReduceOp> {
212
212
ScalarSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
213
- unsigned shuffleBitwidth,
213
+ unsigned shuffleBitwidth, bool matchClustered,
214
214
PatternBenefit benefit)
215
215
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
216
- shuffleBitwidth (shuffleBitwidth) {}
216
+ shuffleBitwidth (shuffleBitwidth), matchClustered(matchClustered) {}
217
217
218
218
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
219
219
PatternRewriter &rewriter) const override {
220
+ if (op.getClusterSize ().has_value () != matchClustered) {
221
+ return rewriter.notifyMatchFailure (
222
+ op, llvm::formatv (" op is {0}clustered but pattern is configured to "
223
+ " only match {1}clustered ops" ,
224
+ matchClustered ? " non-" : " " ,
225
+ matchClustered ? " " : " non-" ));
226
+ }
227
+
220
228
auto ci = getAndValidateClusterInfo (op, subgroupSize);
221
229
if (failed (ci))
222
230
return failure ();
@@ -262,19 +270,28 @@ struct ScalarSubgroupReduceToShuffles final
262
270
private:
263
271
unsigned subgroupSize = 0 ;
264
272
unsigned shuffleBitwidth = 0 ;
273
+ bool matchClustered = false ;
265
274
};
266
275
267
276
// / Lowers vector gpu subgroup reductions to a series of shuffles.
268
277
struct VectorSubgroupReduceToShuffles final
269
278
: OpRewritePattern<gpu::SubgroupReduceOp> {
270
279
VectorSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
271
- unsigned shuffleBitwidth,
280
+ unsigned shuffleBitwidth, bool matchClustered,
272
281
PatternBenefit benefit)
273
282
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
274
- shuffleBitwidth (shuffleBitwidth) {}
283
+ shuffleBitwidth (shuffleBitwidth), matchClustered(matchClustered) {}
275
284
276
285
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
277
286
PatternRewriter &rewriter) const override {
287
+ if (op.getClusterSize ().has_value () != matchClustered) {
288
+ return rewriter.notifyMatchFailure (
289
+ op, llvm::formatv (" op is {0}clustered but pattern is configured to "
290
+ " only match {1}clustered ops" ,
291
+ matchClustered ? " non-" : " " ,
292
+ matchClustered ? " " : " non-" ));
293
+ }
294
+
278
295
auto ci = getAndValidateClusterInfo (op, subgroupSize);
279
296
if (failed (ci))
280
297
return failure ();
@@ -343,6 +360,7 @@ struct VectorSubgroupReduceToShuffles final
343
360
private:
344
361
unsigned subgroupSize = 0 ;
345
362
unsigned shuffleBitwidth = 0 ;
363
+ bool matchClustered = false ;
346
364
};
347
365
} // namespace
348
366
@@ -358,5 +376,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
358
376
RewritePatternSet &patterns, unsigned subgroupSize,
359
377
unsigned shuffleBitwidth, PatternBenefit benefit) {
360
378
patterns.add <ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
361
- patterns.getContext (), subgroupSize, shuffleBitwidth, benefit);
379
+ patterns.getContext (), subgroupSize, shuffleBitwidth,
380
+ /* matchClustered=*/ false , benefit);
381
+ }
382
+
383
+ void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns (
384
+ RewritePatternSet &patterns, unsigned subgroupSize,
385
+ unsigned shuffleBitwidth, PatternBenefit benefit) {
386
+ patterns.add <ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
387
+ patterns.getContext (), subgroupSize, shuffleBitwidth,
388
+ /* matchClustered=*/ true , benefit);
362
389
}
0 commit comments