Skip to content

Commit d9c848d

Browse files
[mlir][AMDGPU] "Added support for 64-bit operands in
ROCDL::DPPUpdateOp operation."
1 parent 5275a29 commit d9c848d

File tree

4 files changed

+27
-22
lines changed

4 files changed

+27
-22
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
609609
builder.getInt32(op.getRowMask()),
610610
builder.getInt32(op.getBankMask()),
611611
builder.getInt1(op.getBoundCtrl())
612-
};
612+
};
613613
$res = createIntrinsicCall(builder,
614614
llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
615615
}];

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
845845
Value old = adaptor.getOld();
846846
Type srcType = src.getType();
847847
Type oldType = old.getType();
848-
auto llvmI32Type = typeConverter->convertType(rewriter.getI32Type());
848+
auto llvmType =
849+
(srcType.getIntOrFloatBitWidth() <= 32 ? rewriter.getI32Type()
850+
: rewriter.getI64Type());
849851
auto llvmSrcIntType = typeConverter->convertType(
850852
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
851853

@@ -863,7 +865,7 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
863865
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
864866
operand = rewriter.create<LLVM::InsertElementOp>(
865867
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
866-
operand = rewriter.create<LLVM::BitcastOp>(loc, llvmI32Type, operand);
868+
operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
867869
}
868870
return operand;
869871
};
@@ -951,7 +953,7 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
951953

952954
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
953955
auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
954-
loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
956+
loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
955957

956958
Value result = dppMovOp.getRes();
957959
if (srcType.getIntOrFloatBitWidth() < 32) {

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ LogicalResult MFMAOp::verify() {
331331
//===----------------------------------------------------------------------===//
332332
LogicalResult DPPOp::verify() {
333333
Type srcType = getSrc().getType();
334-
if (srcType.getIntOrFloatBitWidth() > 32) {
335-
return emitOpError("integer and floating point types larger than 32 bits "
334+
if (srcType.getIntOrFloatBitWidth() > 64) {
335+
return emitOpError("integer and floating point types larger than 64 bits "
336336
"are not supported");
337337
}
338338

mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ func.func @quad_dpp(%arg0: i32, %arg1: i32) -> i32 {
1818
return %0 : i32
1919
}
2020

21-
func.func @quad_perm_dpp(%arg0: i32, %arg1: i32) -> i32 {
22-
// CHECK-LABEL: func @quad_perm_dpp
23-
// CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i32
24-
// CHECK: return %0 : i32
25-
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i32
26-
return %0 : i32
27-
}
28-
2921
func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
3022
// CHECK-LABEL: func @wave_shr_dpp
3123
// CHECK: rocdl.update.dpp %arg0, %arg1 with 312, 10, 1, true : i32
@@ -34,14 +26,6 @@ func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
3426
return %0 : i32
3527
}
3628

37-
func.func @row_bcast_dpp(%arg0: i32, %arg1: i32) -> i32 {
38-
// CHECK-LABEL: func @row_bcast_dpp
39-
// CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : i32
40-
// CHECK: return %0 : i32
41-
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : i32
42-
return %0 : i32
43-
}
44-
4529
func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
4630
// CHECK-LABEL: func @row_bcast_dpp_f32
4731
// CHECK: llvm.bitcast %arg1 : f32 to i32
@@ -146,3 +130,22 @@ func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
146130
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f16
147131
return %0 : f16
148132
}
133+
134+
func.func @quad_perm_dpp(%arg0: i64, %arg1: i64) -> i64 {
135+
// CHECK-LABEL: func @quad_perm_dpp
136+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i64
137+
// CHECK: return %0 : i64
138+
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i64
139+
return %0 : i64
140+
}
141+
142+
func.func @row_bcast_dpp(%arg0: f64, %arg1: f64) -> f64 {
143+
// CHECK-LABEL: func @row_bcast_dpp
144+
// CHECK: llvm.bitcast %arg1 : f64 to i64
145+
// CHECK: llvm.bitcast %arg0 : f64 to i64
146+
// CHECK: rocdl.update.dpp %1, %0 with 323, 4, 1, false : i64
147+
// CHECK: llvm.bitcast %2 : i64 to f64
148+
// CHECK: return %3 : f64
149+
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : f64
150+
return %0 : f64
151+
}

0 commit comments

Comments
 (0)