Skip to content

Commit 64f67f8

Browse files
dhernandez0kuhar
andauthored
[mlir][AMDGPU] Enable emulating vector buffer_atomic_fadd for bf16 on gfx942 (#129029)
- Change to make sure architectures < gfx950 emulate bf16 buffer_atomic_fadd - Add tests for bf16 buffer_atomic_fadd and architectures: gfx12, gfx942 and gfx950 --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent b36bf47 commit 64f67f8

File tree

2 files changed

+72
-14
lines changed

2 files changed

+72
-14
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,15 @@ void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
189189
} else {
190190
target.addIllegalOp<RawBufferAtomicFmaxOp>();
191191
}
192+
// TODO(https://github.com/llvm/llvm-project/issues/129206): Refactor
193+
// this to avoid hardcoding ISA version: gfx950 has bf16 atomics.
194+
if (chipset < Chipset(9, 5, 0)) {
195+
target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
196+
[](RawBufferAtomicFaddOp op) -> bool {
197+
Type elemType = getElementTypeOrSelf(op.getValue().getType());
198+
return !isa<BFloat16Type>(elemType);
199+
});
200+
}
192201
}
193202
patterns.add<
194203
RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,

mlir/test/Dialect/AMDGPU/amdgpu-emulate-atomics.mlir

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx90a %s | FileCheck %s --check-prefixes=CHECK,GFX9
1+
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx90a %s | FileCheck %s --check-prefixes=CHECK,GFX90A
22
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1030 %s | FileCheck %s --check-prefixes=CHECK,GFX10
33
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1100 %s | FileCheck %s --check-prefixes=CHECK,GFX11
4+
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1200 %s | FileCheck %s --check-prefixes=CHECK,GFX12
5+
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx942 %s | FileCheck %s --check-prefixes=CHECK,GFX942
6+
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx950 %s | FileCheck %s --check-prefixes=CHECK,GFX950
47

58
// -----
69

@@ -10,16 +13,37 @@ func.func @atomic_fmax(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
1013
// CHECK: gpu.printf "Begin\0A"
1114
// GFX10: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
1215
// GFX11: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
13-
// GFX9: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
14-
// GFX9: cf.br [[loop:\^.+]]([[ld]] : f32)
15-
// GFX9: [[loop]]([[arg:%.+]]: f32):
16-
// GFX9: [[operated:%.+]] = arith.maximumf [[val]], [[arg]]
17-
// GFX9: [[atomicRes:%.+]] = amdgpu.raw_buffer_atomic_cmpswap {foo, indexOffset = 4 : i32} [[operated]], [[arg]] -> [[buffer]][[[idx]]]
18-
// GFX9: [[argCast:%.+]] = arith.bitcast [[arg]] : f32 to i32
19-
// GFX9: [[resCast:%.+]] = arith.bitcast [[atomicRes]] : f32 to i32
20-
// GFX9: [[test:%.+]] = arith.cmpi eq, [[resCast]], [[argCast]]
21-
// GFX9: cf.cond_br [[test]], [[post:\^.+]], [[loop]]([[atomicRes]] : f32)
22-
// GFX9: [[post]]:
16+
// GFX12: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
17+
// GFX90A: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
18+
// GFX90A: cf.br [[loop:\^.+]]([[ld]] : f32)
19+
// GFX90A: [[loop]]([[arg:%.+]]: f32):
20+
// GFX90A: [[operated:%.+]] = arith.maximumf [[val]], [[arg]]
21+
// GFX90A: [[atomicRes:%.+]] = amdgpu.raw_buffer_atomic_cmpswap {foo, indexOffset = 4 : i32} [[operated]], [[arg]] -> [[buffer]][[[idx]]]
22+
// GFX90A: [[argCast:%.+]] = arith.bitcast [[arg]] : f32 to i32
23+
// GFX90A: [[resCast:%.+]] = arith.bitcast [[atomicRes]] : f32 to i32
24+
// GFX90A: [[test:%.+]] = arith.cmpi eq, [[resCast]], [[argCast]]
25+
// GFX90A: cf.cond_br [[test]], [[post:\^.+]], [[loop]]([[atomicRes]] : f32)
26+
// GFX90A: [[post]]:
27+
// GFX942: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
28+
// GFX942: cf.br [[loop:\^.+]]([[ld]] : f32)
29+
// GFX942: [[loop]]([[arg:%.+]]: f32):
30+
// GFX942: [[operated:%.+]] = arith.maximumf [[val]], [[arg]]
31+
// GFX942: [[atomicRes:%.+]] = amdgpu.raw_buffer_atomic_cmpswap {foo, indexOffset = 4 : i32} [[operated]], [[arg]] -> [[buffer]][[[idx]]]
32+
// GFX942: [[argCast:%.+]] = arith.bitcast [[arg]] : f32 to i32
33+
// GFX942: [[resCast:%.+]] = arith.bitcast [[atomicRes]] : f32 to i32
34+
// GFX942: [[test:%.+]] = arith.cmpi eq, [[resCast]], [[argCast]]
35+
// GFX942: cf.cond_br [[test]], [[post:\^.+]], [[loop]]([[atomicRes]] : f32)
36+
// GFX942: [[post]]:
37+
// GFX950: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
38+
// GFX950: cf.br [[loop:\^.+]]([[ld]] : f32)
39+
// GFX950: [[loop]]([[arg:%.+]]: f32):
40+
// GFX950: [[operated:%.+]] = arith.maximumf [[val]], [[arg]]
41+
// GFX950: [[atomicRes:%.+]] = amdgpu.raw_buffer_atomic_cmpswap {foo, indexOffset = 4 : i32} [[operated]], [[arg]] -> [[buffer]][[[idx]]]
42+
// GFX950: [[argCast:%.+]] = arith.bitcast [[arg]] : f32 to i32
43+
// GFX950: [[resCast:%.+]] = arith.bitcast [[atomicRes]] : f32 to i32
44+
// GFX950: [[test:%.+]] = arith.cmpi eq, [[resCast]], [[argCast]]
45+
// GFX950: cf.cond_br [[test]], [[post:\^.+]], [[loop]]([[atomicRes]] : f32)
46+
// GFX950: [[post]]:
2347
// CHECK-NEXT: gpu.printf "End\0A"
2448
gpu.printf "Begin\n"
2549
amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} %val -> %buffer[%idx] : f32 -> memref<?xf32>, i32
@@ -33,9 +57,12 @@ func.func @atomic_fmax_f64(%val: f64, %buffer: memref<?xf64>, %idx: i32) {
3357
// CHECK: func @atomic_fmax_f64
3458
// CHECK-SAME: ([[val:%.+]]: f64, [[buffer:%.+]]: memref<?xf64>, [[idx:%.+]]: i32)
3559
// CHECK: gpu.printf "Begin\0A"
36-
// GFX9: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
60+
// GFX90A: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
3761
// GFX10: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
3862
// GFX11: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
63+
// GFX12: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
64+
// GFX942: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
65+
// GFX950: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
3966
// CHECK-NEXT: gpu.printf "End\0A"
4067
gpu.printf "Begin\n"
4168
amdgpu.raw_buffer_atomic_fmax %val -> %buffer[%idx] : f64 -> memref<?xf64>, i32
@@ -47,17 +74,20 @@ func.func @atomic_fmax_f64(%val: f64, %buffer: memref<?xf64>, %idx: i32) {
4774

4875
func.func @atomic_fadd(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
4976
// CHECK: func @atomic_fadd
50-
// GFX9: amdgpu.raw_buffer_atomic_fadd
77+
// GFX90A: amdgpu.raw_buffer_atomic_fadd
5178
// GFX10: amdgpu.raw_buffer_load
5279
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
5380
// GFX11: amdgpu.raw_buffer_atomic_fadd
81+
// GFX12: amdgpu.raw_buffer_atomic_fadd
82+
// GFX942: amdgpu.raw_buffer_atomic_fadd
83+
// GFX950: amdgpu.raw_buffer_atomic_fadd
5484
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : f32 -> memref<?xf32>, i32
5585
func.return
5686
}
5787

5888
// CHECK: func @atomic_fadd_v2f16
5989
func.func @atomic_fadd_v2f16(%val: vector<2xf16>, %buffer: memref<?xf16>, %idx: i32) {
60-
// GFX9: amdgpu.raw_buffer_atomic_fadd
90+
// GFX90A: amdgpu.raw_buffer_atomic_fadd
6191
// GFX10: amdgpu.raw_buffer_load
6292
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
6393
// Note: the atomic operation itself will be done over i32, and then we use bitcasts
@@ -69,6 +99,25 @@ func.func @atomic_fadd_v2f16(%val: vector<2xf16>, %buffer: memref<?xf16>, %idx:
6999
// GFX11: %[[vecCastOld:.+]] = vector.bitcast %[[old]] : vector<2xf16> to vector<1xi32>
70100
// GFX11: %[[scalarOld:.+]] = vector.extract %[[vecCastOld]][0]
71101
// GFX11: arith.cmpi eq, %[[scalarOld]], %[[scalarExpected]]
102+
// GFX942: amdgpu.raw_buffer_atomic_fadd
103+
// GFX12: amdgpu.raw_buffer_atomic_fadd
104+
// GFX950: amdgpu.raw_buffer_atomic_fadd
72105
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : vector<2xf16> -> memref<?xf16>, i32
73106
func.return
74107
}
108+
109+
// CHECK: func @atomic_fadd_v2bf16
110+
func.func @atomic_fadd_v2bf16(%val: vector<2xbf16>, %buffer: memref<?xbf16>, %idx: i32) {
111+
// GFX90A: amdgpu.raw_buffer_load
112+
// GFX90A: amdgpu.raw_buffer_atomic_cmpswap
113+
// GFX10: amdgpu.raw_buffer_load
114+
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
115+
// GFX11: amdgpu.raw_buffer_load
116+
// GFX11: amdgpu.raw_buffer_atomic_cmpswap
117+
// GFX942: amdgpu.raw_buffer_load
118+
// GFX942: amdgpu.raw_buffer_atomic_cmpswap
119+
// GFX12: amdgpu.raw_buffer_atomic_fadd
120+
// GFX950: amdgpu.raw_buffer_atomic_fadd
121+
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : vector<2xbf16> -> memref<?xbf16>, i32
122+
func.return
123+
}

0 commit comments

Comments
 (0)