12
12
13
13
#include " mlir/Dialect/Arith/IR/Arith.h"
14
14
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
15
+ #include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
16
+ #include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15
17
#include " mlir/Dialect/GPU/Transforms/Passes.h"
16
18
#include " mlir/Dialect/GPU/Utils/GPUUtils.h"
17
19
#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -362,6 +364,106 @@ struct VectorSubgroupReduceToShuffles final
362
364
unsigned shuffleBitwidth = 0 ;
363
365
bool matchClustered = false ;
364
366
};
367
+
368
+ Value createSubgroupDPPReduction (OpBuilder &b, Location loc, Value input,
369
+ gpu::AllReduceOperation mode,
370
+ const ClusterInfo &ci) {
371
+ Value result = input;
372
+ if (ci.clusterSize >= 2 ) {
373
+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 1 );
374
+ Value dppResult =
375
+ b.create <amdgpu::DPPOp>(loc, result.getType (), result, result,
376
+ amdgpu::DPPPerm::row_shl, permArg);
377
+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
378
+ result, dppResult);
379
+ }
380
+
381
+ if (ci.clusterSize >= 4 ) {
382
+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 2 );
383
+ Value dppResult =
384
+ b.create <amdgpu::DPPOp>(loc, result.getType (), result, result,
385
+ amdgpu::DPPPerm::row_shl, permArg);
386
+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
387
+ result, dppResult);
388
+ }
389
+
390
+ if (ci.clusterSize >= 8 ) {
391
+ Value dppResult = b.create <amdgpu::DPPOp>(
392
+ loc, result.getType (), result, result, amdgpu::DPPPerm::row_half_mirror,
393
+ b.getUnitAttr ());
394
+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
395
+ result, dppResult);
396
+ }
397
+
398
+ if (ci.clusterSize >= 16 ) {
399
+ Value dppResult =
400
+ b.create <amdgpu::DPPOp>(loc, result.getType (), result, result,
401
+ amdgpu::DPPPerm::row_mirror, b.getUnitAttr ());
402
+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
403
+ result, dppResult);
404
+ }
405
+
406
+ const int allRows = 0xf ;
407
+ const int allBanks = 0xf ;
408
+ auto int32Type = IntegerType::get (b.getContext (), 32 );
409
+ if (ci.clusterSize >= 32 ) {
410
+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 15 );
411
+ Value dppResult = b.create <amdgpu::DPPOp>(
412
+ loc, result.getType (), result, result, amdgpu::DPPPerm::row_bcast_15,
413
+ b.getUnitAttr (), 0xa , allBanks, false );
414
+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
415
+ result, dppResult);
416
+ if (ci.subgroupSize == 32 ) {
417
+ Value lane01 = b.create <LLVM::ConstantOp>(loc, int32Type, 1 );
418
+ result =
419
+ b.create <ROCDL::ReadlaneOp>(loc, input.getType (), result, lane01);
420
+ }
421
+ }
422
+
423
+ if (ci.clusterSize == 64 ) {
424
+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 31 );
425
+ Value dppResult = b.create <amdgpu::DPPOp>(
426
+ loc, result.getType (), result, result, amdgpu::DPPPerm::row_bcast_31,
427
+ b.getUnitAttr (), allRows, allBanks, false );
428
+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
429
+ result, dppResult);
430
+ Value lane63 = b.create <LLVM::ConstantOp>(loc, int32Type, 63 );
431
+ result = b.create <ROCDL::ReadlaneOp>(loc, input.getType (), result, lane63);
432
+ }
433
+
434
+ assert (result.getType () == input.getType ());
435
+ return result;
436
+ }
437
+
438
+ struct ScalarSubgroupReduceToDPP final
439
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
440
+ ScalarSubgroupReduceToDPP (MLIRContext *ctx, unsigned subgroupSize,
441
+ bool matchClustered, PatternBenefit benefit)
442
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
443
+ matchClustered (matchClustered) {}
444
+
445
+ LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
446
+ PatternRewriter &rewriter) const override {
447
+ if (op.getClusterSize ().has_value () != matchClustered) {
448
+ return rewriter.notifyMatchFailure (
449
+ op, llvm::formatv (" op is {0}clustered but pattern is configured to "
450
+ " only match {1}clustered ops" ,
451
+ matchClustered ? " non-" : " " ,
452
+ matchClustered ? " " : " non-" ));
453
+ }
454
+ auto ci = getAndValidateClusterInfo (op, subgroupSize);
455
+ if (failed (ci))
456
+ return failure ();
457
+ Location loc = op.getLoc ();
458
+ rewriter.replaceOp (op, createSubgroupDPPReduction (
459
+ rewriter, loc, op.getValue (), op.getOp (), *ci));
460
+ return success ();
461
+ }
462
+
463
+ private:
464
+ unsigned subgroupSize = 0 ;
465
+ bool matchClustered = false ;
466
+ };
365
467
} // namespace
366
468
367
469
void mlir::populateGpuBreakDownSubgroupReducePatterns (
@@ -372,6 +474,13 @@ void mlir::populateGpuBreakDownSubgroupReducePatterns(
372
474
patterns.add <ScalarizeSingleElementReduce>(patterns.getContext (), benefit);
373
475
}
374
476
477
+ void mlir::populateGpuLowerSubgroupReduceToDPPPatterns (
478
+ RewritePatternSet &patterns, unsigned subgroupSize,
479
+ PatternBenefit benefit) {
480
+ patterns.add <ScalarSubgroupReduceToDPP>(patterns.getContext (), subgroupSize,
481
+ /* matchClustered=*/ true , benefit);
482
+ }
483
+
375
484
void mlir::populateGpuLowerSubgroupReduceToShufflePatterns (
376
485
RewritePatternSet &patterns, unsigned subgroupSize,
377
486
unsigned shuffleBitwidth, PatternBenefit benefit) {
0 commit comments