@@ -432,44 +432,50 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
432
432
/* bound_ctrl=*/ false );
433
433
res = vector::makeArithReduction (
434
434
rewriter, loc, gpu::convertReductionKind (mode), res, dpp);
435
- if (ci.subgroupSize == 32 ) {
436
- Value lane0 = rewriter.create <arith::ConstantOp>(
437
- loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (0 ));
438
- res =
439
- rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
440
- }
441
435
} else {
442
436
return rewriter.notifyMatchFailure (
443
437
op, " Subgroup reduce lowering to DPP not currently supported for "
444
438
" this device." );
445
439
}
440
+ if (ci.subgroupSize == 32 ) {
441
+ Value lane31 = rewriter.create <arith::ConstantOp>(
442
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (31 ));
443
+ res = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane31);
444
+ }
446
445
}
447
446
if (ci.clusterSize >= 64 ) {
448
447
if (chipset.majorVersion <= 9 ) {
449
448
// Broadcast 31st lane value to rows 2 and 3.
450
- // Use row mask to avoid polluting rows 0 and 1.
451
449
dpp = rewriter.create <amdgpu::DPPOp>(
452
450
loc, res.getType (), res, res, amdgpu::DPPPerm::row_bcast_31,
453
- rewriter.getUnitAttr (), 0xc , allBanks,
454
- /* bound_ctrl*/ false );
451
+ rewriter.getUnitAttr (), 0xf , allBanks,
452
+ /* bound_ctrl*/ true );
453
+ res = vector::makeArithReduction (
454
+ rewriter, loc, gpu::convertReductionKind (mode), dpp, res);
455
+ // Obtain reduction from last rows, the previous rows are polluted.
456
+ Value lane63 = rewriter.create <arith::ConstantOp>(
457
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (63 ));
458
+ res = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane63);
455
459
456
460
} else if (chipset.majorVersion <= 12 ) {
457
461
// Assume reduction across 32 lanes has been done.
458
462
// Perform final reduction manually by summing values in lane 0 and
459
463
// lane 32.
460
- Value lane0 = rewriter.create <arith::ConstantOp>(
461
- loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (0 ));
462
- Value lane32 = rewriter.create <arith::ConstantOp>(
463
- loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (32 ));
464
- dpp = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane32);
465
- res = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
464
+ Value lane31 = rewriter.create <arith::ConstantOp>(
465
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (31 ));
466
+ Value lane63 = rewriter.create <arith::ConstantOp>(
467
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (63 ));
468
+ lane31 =
469
+ rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane31);
470
+ lane63 =
471
+ rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane63);
472
+ res = vector::makeArithReduction (
473
+ rewriter, loc, gpu::convertReductionKind (mode), lane31, lane63);
466
474
} else {
467
475
return rewriter.notifyMatchFailure (
468
476
op, " Subgroup reduce lowering to DPP not currently supported for "
469
477
" this device." );
470
478
}
471
- res = vector::makeArithReduction (rewriter, loc,
472
- gpu::convertReductionKind (mode), res, dpp);
473
479
}
474
480
assert (res.getType () == input.getType ());
475
481
return res;
0 commit comments