Skip to content

Commit d069031

Browse files
[mlir][AMDGPU] "Added support for 64-bit operands in
ROCDL::DPPUpdateOp operation."
1 parent ce8792a commit d069031

File tree

4 files changed

+60
-61
lines changed

4 files changed

+60
-61
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
609609
builder.getInt32(op.getRowMask()),
610610
builder.getInt32(op.getBankMask()),
611611
builder.getInt1(op.getBoundCtrl())
612-
};
612+
};
613613
$res = createIntrinsicCall(builder,
614614
llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
615615
}];

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -845,25 +845,34 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
845845
Value old = adaptor.getOld();
846846
Type srcType = src.getType();
847847
Type oldType = old.getType();
848-
auto llvmI32Type = typeConverter->convertType(rewriter.getI32Type());
848+
Type llvmType = nullptr;
849+
if (srcType.getIntOrFloatBitWidth() < 32) {
850+
llvmType = rewriter.getI32Type();
851+
} else if (isa<FloatType>(srcType)) {
852+
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
853+
? rewriter.getF32Type()
854+
: rewriter.getF64Type();
855+
} else if (isa<IntegerType>(srcType)) {
856+
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
857+
? rewriter.getI32Type()
858+
: rewriter.getI64Type();
859+
}
849860
auto llvmSrcIntType = typeConverter->convertType(
850861
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
851862

852-
// If the source type is less or equal to i32 or f32, use bitcast to convert
853-
// it to i32.
863+
// If the source type is less of 32, use bitcast to convert it to i32.
854864
auto convertOperand = [&](Value operand, Type operandType) {
855-
if (llvm::isa<FloatType>(operandType)) {
856-
operand =
857-
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
858-
}
859-
860-
if (operandType.getIntOrFloatBitWidth() < 32) {
865+
if (operandType.getIntOrFloatBitWidth() <= 16) {
866+
if (llvm::isa<FloatType>(operandType)) {
867+
operand =
868+
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
869+
}
861870
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
862871
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
863872
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
864873
operand = rewriter.create<LLVM::InsertElementOp>(
865874
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
866-
operand = rewriter.create<LLVM::BitcastOp>(loc, llvmI32Type, operand);
875+
operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
867876
}
868877
return operand;
869878
};
@@ -951,15 +960,14 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
951960

952961
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
953962
auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
954-
loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
963+
loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
955964

956965
Value result = dppMovOp.getRes();
957966
if (srcType.getIntOrFloatBitWidth() < 32) {
958967
result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
959-
}
960-
961-
if (!llvm::isa<IntegerType>(srcType)) {
962-
result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
968+
if (!llvm::isa<IntegerType>(srcType)) {
969+
result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
970+
}
963971
}
964972

965973
// We are replacing the AMDGPU_DPPOp instruction with the new

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ LogicalResult MFMAOp::verify() {
331331
//===----------------------------------------------------------------------===//
332332
LogicalResult DPPOp::verify() {
333333
Type srcType = getSrc().getType();
334-
if (srcType.getIntOrFloatBitWidth() > 32) {
335-
return emitOpError("integer and floating point types larger than 32 bits "
334+
if (srcType.getIntOrFloatBitWidth() > 64) {
335+
return emitOpError("integer and floating point types larger than 64 bits "
336336
"are not supported");
337337
}
338338

mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ func.func @quad_dpp(%arg0: i32, %arg1: i32) -> i32 {
1818
return %0 : i32
1919
}
2020

21-
func.func @quad_perm_dpp(%arg0: i32, %arg1: i32) -> i32 {
22-
// CHECK-LABEL: func @quad_perm_dpp
23-
// CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i32
24-
// CHECK: return %0 : i32
25-
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i32
26-
return %0 : i32
27-
}
28-
2921
func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
3022
// CHECK-LABEL: func @wave_shr_dpp
3123
// CHECK: rocdl.update.dpp %arg0, %arg1 with 312, 10, 1, true : i32
@@ -34,25 +26,6 @@ func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
3426
return %0 : i32
3527
}
3628

37-
func.func @row_bcast_dpp(%arg0: i32, %arg1: i32) -> i32 {
38-
// CHECK-LABEL: func @row_bcast_dpp
39-
// CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : i32
40-
// CHECK: return %0 : i32
41-
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : i32
42-
return %0 : i32
43-
}
44-
45-
func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
46-
// CHECK-LABEL: func @row_bcast_dpp_f32
47-
// CHECK: llvm.bitcast %arg1 : f32 to i32
48-
// CHECK: llvm.bitcast %arg0 : f32 to i32
49-
// CHECK: rocdl.update.dpp %1, %0 with 322, 15, 15, true : i32
50-
// CHECK: llvm.bitcast %2 : i32 to f32
51-
// CHECK: return %3 : f32
52-
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
53-
return %0 : f32
54-
}
55-
5629
func.func @row_half_mirror_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
5730
// CHECK-LABEL: func @row_half_mirror_update_dpp
5831
// CHECK: rocdl.update.dpp %arg0, %arg1 with 321, 15, 1, false : i32
@@ -69,17 +42,46 @@ func.func @wave_rol_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
6942
return %0 : i32
7043
}
7144

45+
func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
46+
// CHECK-LABEL: func @row_bcast_dpp_f32
47+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 322, 15, 15, true : f32
48+
// CHECK: return %0 : f32
49+
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
50+
return %0 : f32
51+
}
52+
7253
func.func @test_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
7354
// CHECK-LABEL: func @test_dpp_f32
74-
// CHECK: llvm.bitcast %arg1 : f32 to i32
75-
// CHECK: llvm.bitcast %arg0 : f32 to i32
76-
// CHECK: rocdl.update.dpp %1, %0 with 320, 1, 4, true : i32
77-
// CHECK: llvm.bitcast %2 : i32 to f32
78-
// CHECK: return %3 : f32
55+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 320, 1, 4, true : f32
56+
// CHECK: return %0 : f32
7957
%0 = amdgpu.dpp %arg0 %arg1 row_mirror { row_mask = 0x1 : i32, bank_mask = 0x4 : i32, bound_ctrl = true } : f32
8058
return %0 : f32
8159
}
8260

61+
func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
62+
// CHECK-LABEL: func @quad_perm_update_dpp_f32
63+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 1, false : f32
64+
// CHECK: return %0 : f32
65+
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
66+
return %0 : f32
67+
}
68+
69+
func.func @quad_perm_dpp(%arg0: i64, %arg1: i64) -> i64 {
70+
// CHECK-LABEL: func @quad_perm_dpp
71+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i64
72+
// CHECK: return %0 : i64
73+
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i64
74+
return %0 : i64
75+
}
76+
77+
func.func @row_bcast_dpp(%arg0: f64, %arg1: f64) -> f64 {
78+
// CHECK-LABEL: func @row_bcast_dpp
79+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : f64
80+
// CHECK: return %0 : f64
81+
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : f64
82+
return %0 : f64
83+
}
84+
8385
func.func @test_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
8486
// CHECK-LABEL: func @test_dpp_f16
8587
// CHECK: llvm.bitcast %arg1 : f16 to i16
@@ -117,17 +119,6 @@ func.func @row_shl_dpp_i16(%arg0: i16, %arg1: i16) -> i16 {
117119
return %0 : i16
118120
}
119121

120-
func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
121-
// CHECK-LABEL: func @quad_perm_update_dpp_f32
122-
// CHECK: llvm.bitcast %arg1 : f32 to i32
123-
// CHECK: llvm.bitcast %arg0 : f32 to i32
124-
// CHECK: rocdl.update.dpp %1, %0 with 88, 15, 1, false : i32
125-
// CHECK: llvm.bitcast %2 : i32 to f32
126-
// CHECK: return %3 : f32
127-
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
128-
return %0 : f32
129-
}
130-
131122
func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
132123
// CHECK-LABEL: func @row_bcast_update_dpp_f16
133124
// CHECK: llvm.bitcast %arg1 : f16 to i16

0 commit comments

Comments
 (0)