Skip to content

Commit aa60a3e

Browse files
krzysz00giuseros
andauthored
[mlir][AMDGPU] Support vector<2xf16> inputs to buffer atomic fadd (#108286)
Extend the lowering of atomic.fadd to support the v2f16 variant avaliable on some AMDGPU chips. Re-lands #108238 (and addresses review comments from there) Co-authored-by: Giuseppe Rossini <[email protected]>
1 parent 7910812 commit aa60a3e

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def AMDGPU_RawBufferAtomicCmpswapOp :
254254
def AMDGPU_RawBufferAtomicFaddOp :
255255
AMDGPU_Op<"raw_buffer_atomic_fadd", [AllElementTypesMatch<["value", "memref"]>,
256256
AttrSizedOperandSegments]>,
257-
Arguments<(ins F32:$value,
257+
Arguments<(ins AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16]>]>:$value,
258258
Arg<AnyMemRef, "buffer to operate on", [MemRead, MemWrite]>:$memref,
259259
Variadic<I32>:$indices,
260260
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
@@ -405,7 +405,7 @@ def AMDGPU_RawBufferAtomicUminOp :
405405

406406
def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
407407
"The possible permutations for a DPP operation",
408-
[
408+
[
409409
I32EnumAttrCase<"quad_perm", 0>,
410410
I32EnumAttrCase<"row_shl", 1>,
411411
I32EnumAttrCase<"row_shr", 2>,
@@ -419,7 +419,7 @@ def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
419419
I32EnumAttrCase<"row_bcast_15", 10>,
420420
I32EnumAttrCase<"row_bcast_31", 11>
421421
]> {
422-
let genSpecializedAttr = 0;
422+
let genSpecializedAttr = 0;
423423
let cppNamespace = "::mlir::amdgpu";
424424
}
425425

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
115115
rewriter.getIntegerType(floatType.getWidth()));
116116
}
117117
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
118+
uint32_t vecLen = dataVector.getNumElements();
118119
uint32_t elemBits = dataVector.getElementTypeBitWidth();
119-
uint32_t totalBits = elemBits * dataVector.getNumElements();
120+
uint32_t totalBits = elemBits * vecLen;
121+
bool usePackedFp16 =
122+
isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
120123
if (totalBits > maxVectorOpWidth)
121124
return gpuOp.emitOpError(
122125
"Total width of loads or stores must be no more than " +
123126
Twine(maxVectorOpWidth) + " bits, but we call for " +
124127
Twine(totalBits) +
125128
" bits. This should've been caught in validation");
126-
if (elemBits < 32) {
129+
if (!usePackedFp16 && elemBits < 32) {
127130
if (totalBits > 32) {
128131
if (totalBits % 32 != 0)
129132
return gpuOp.emitOpError("Load or store of more than 32-bits that "

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,17 @@ func.func @gpu_gcn_raw_buffer_atomic_fadd_f32(%value: f32, %buf: memref<64xf32>,
151151
func.return
152152
}
153153

154+
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_v2f16
155+
func.func @gpu_gcn_raw_buffer_atomic_fadd_v2f16(%value: vector<2xf16>, %buf: memref<64xf16>, %idx: i32) {
156+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i32)
157+
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
158+
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
159+
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
160+
// CHECK: rocdl.raw.ptr.buffer.atomic.fadd %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : vector<2xf16>
161+
amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %value -> %buf[%idx] : vector<2xf16> -> memref<64xf16>, i32
162+
func.return
163+
}
164+
154165
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fmax_f32
155166
func.func @gpu_gcn_raw_buffer_atomic_fmax_f32(%value: f32, %buf: memref<64xf32>, %idx: i32) {
156167
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)

0 commit comments

Comments
 (0)