Skip to content

Commit 13e55b4

Browse files
Fixing permlanex16 intrinsic failure
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 25eed29 commit 13e55b4

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,13 +670,13 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
670670

671671
// PermLaneX16 intrinsic operation
672672
def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
673-
[AllTypesMatch<["res", "old", "src0", "src1", "src2"]>], 1, 0, 0,
673+
[AllTypesMatch<["res", "old", "src0"]>, AllTypesMatch<["src1", "src2"]>], 1, 0, 0,
674674
[4, 5], ["fi", "boundControl"]>,
675675
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src0, LLVM_Type:$src1, LLVM_Type:$src2,
676676
I1Attr:$fi, I1Attr:$boundControl)> {
677677
let results = (outs LLVM_Type:$res);
678678
let assemblyFormat = [{
679-
attr-dict $old `,` $src0 `,` $src1 `,` $src2 `,` $fi `,` $boundControl `:` type($src0)
679+
attr-dict $old `,` $src0 `,` $src1 `,` $src2 `,` $fi `,` $boundControl `:` type($src0) `,` type($src1)
680680
}];
681681
let description = [{
682682
Performs a `permlanex16` operation with the given operands, applying the

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,14 +405,9 @@ Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
405405

406406
const int allRows = 0xf;
407407
const int allBanks = 0xf;
408-
auto uint32Type = b.getIntegerType(32, false);
409408
if (ci.clusterSize >= 32) {
410-
// auto permArg = b.getI32IntegerAttr(15);
411-
// Value dppResult = b.create<amdgpu::DPPOp>(
412-
// loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15,
413-
// b.getUnitAttr(), 0xa, allBanks, false);
414409
auto uIntMax = llvm::APInt::getMaxValue(32u);
415-
Value uIntMaxConst = b.create<LLVM::ConstantOp>(loc, uint32Type, uIntMax);
410+
Value uIntMaxConst = b.create<LLVM::ConstantOp>(loc, b.getI32Type(), uIntMax);
416411
Value dppResult = b.create<ROCDL::PermlaneX16Op>(loc, input.getType(), result, result, uIntMaxConst, uIntMaxConst, true, false);
417412
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
418413
result, dppResult);

0 commit comments

Comments
 (0)