Skip to content

Commit bafbe39

Browse files
authored
[NVPTX] Add support for atomic add for bf16 type (#89586)
atom.add.noftz.bf16 is supported since SM 9.0 and PTX 7.8
1 parent 81d3045 commit bafbe39

File tree

3 files changed

+164
-1
lines changed

3 files changed

+164
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6125,6 +6125,9 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
61256125
if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
61266126
STI.getPTXVersion() >= 63)
61276127
return AtomicExpansionKind::None;
6128+
if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
6129+
STI.getPTXVersion() >= 78)
6130+
return AtomicExpansionKind::None;
61286131
if (Ty->isFloatTy())
61296132
return AtomicExpansionKind::None;
61306133
if (Ty->isDoubleTy() && STI.hasAtomAddF64())

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1545,7 +1545,7 @@ multiclass F_ATOMIC_2_imp<ValueType ptrT, NVPTXRegClass ptrclass,
15451545
def imm : NVPTXInst<(outs regclass:$dst), (ins ptrclass:$addr, IMMType:$b),
15461546
!strconcat("atom", SpaceStr, OpcStr, TypeStr, " \t$dst, [$addr], $b;", ""),
15471547
[(set (regT regclass:$dst), (IntOp (ptrT ptrclass:$addr), IMM:$b))]>,
1548-
Requires<!if(!eq(TypeStr, ".f16"), [Predicate<"false">], Pred)>;
1548+
Requires<!if(!or(!eq(TypeStr, ".f16"), !eq(TypeStr, ".bf16")), [Predicate<"false">], Pred)>;
15491549
}
15501550
multiclass F_ATOMIC_2<ValueType regT, NVPTXRegClass regclass, string SpaceStr, string TypeStr,
15511551
string OpcStr, PatFrag IntOp, Operand IMMType, SDNode IMM,
@@ -1662,6 +1662,13 @@ defm INT_PTX_ATOM_ADD_S_F16 : F_ATOMIC_2<f16, Int16Regs, ".shared", ".f16", ".ad
16621662
defm INT_PTX_ATOM_ADD_GEN_F16 : F_ATOMIC_2<f16, Int16Regs, "", ".f16", ".add.noftz",
16631663
atomic_load_add_gen, f16imm, fpimm, [hasSM<70>, hasPTX<63>]>;
16641664

1665+
defm INT_PTX_ATOM_ADD_G_BF16 : F_ATOMIC_2<bf16, Int16Regs, ".global", ".bf16", ".add.noftz",
1666+
atomic_load_add_g, bf16imm, fpimm, [hasSM<90>, hasPTX<78>]>;
1667+
defm INT_PTX_ATOM_ADD_S_BF16 : F_ATOMIC_2<bf16, Int16Regs, ".shared", ".bf16", ".add.noftz",
1668+
atomic_load_add_s, bf16imm, fpimm, [hasSM<90>, hasPTX<78>]>;
1669+
defm INT_PTX_ATOM_ADD_GEN_BF16 : F_ATOMIC_2<bf16, Int16Regs, "", ".bf16", ".add.noftz",
1670+
atomic_load_add_gen, bf16imm, fpimm, [hasSM<90>, hasPTX<78>]>;
1671+
16651672
defm INT_PTX_ATOM_ADD_G_F32 : F_ATOMIC_2<f32, Float32Regs, ".global", ".f32", ".add",
16661673
atomic_load_add_g, f32imm, fpimm>;
16671674
defm INT_PTX_ATOM_ADD_S_F32 : F_ATOMIC_2<f32, Float32Regs, ".shared", ".f32", ".add",
@@ -2174,6 +2181,8 @@ multiclass ATOM2_add_impl<string OpStr> {
21742181
defm _s32 : ATOM2S_impl<OpStr, "i", "s32", i32, Int32Regs, i32imm, imm, i32, []>;
21752182
defm _u32 : ATOM2S_impl<OpStr, "i", "u32", i32, Int32Regs, i32imm, imm, i32, []>;
21762183
defm _u64 : ATOM2S_impl<OpStr, "i", "u64", i64, Int64Regs, i64imm, imm, i64, []>;
2184+
defm _bf16 : ATOM2S_impl<OpStr, "f", "bf16", bf16, Int16Regs, bf16imm, fpimm, bf16,
2185+
[hasSM<90>, hasPTX<78>]>;
21772186
defm _f16 : ATOM2S_impl<OpStr, "f", "f16", f16, Int16Regs, f16imm, fpimm, f16,
21782187
[hasSM<70>, hasPTX<63>]>;
21792188
defm _f32 : ATOM2S_impl<OpStr, "f", "f32", f32, Float32Regs, f32imm, fpimm, f32,
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc < %s -march=nvptx -mcpu=sm_90 -mattr=+ptx78 | FileCheck %s --check-prefixes=CHECK
3+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx78 | FileCheck %s --check-prefixes=CHECK64
4+
; RUN: llc < %s -march=nvptx -mcpu=sm_86 -mattr=+ptx71 | FileCheck %s --check-prefixes=CHECKPTX71
5+
; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -march=nvptx -mcpu=sm_90 -mattr=+ptx78 | %ptxas-verify -arch=sm_90 %}
6+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx78 | %ptxas-verify -arch=sm_90 %}
7+
; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -march=nvptx -mcpu=sm_86 -mattr=+ptx71 | %ptxas-verify -arch=sm_86 %}
8+
9+
target triple = "nvptx64-nvidia-cuda"
10+
11+
define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat %val) {
12+
; CHECK-LABEL: test(
13+
; CHECK: {
14+
; CHECK-NEXT: .reg .b16 %rs<7>;
15+
; CHECK-NEXT: .reg .b32 %r<4>;
16+
; CHECK-EMPTY:
17+
; CHECK-NEXT: // %bb.0:
18+
; CHECK-NEXT: ld.param.u32 %r1, [test_param_0];
19+
; CHECK-NEXT: ld.param.b16 %rs1, [test_param_3];
20+
; CHECK-NEXT: atom.add.noftz.bf16 %rs2, [%r1], %rs1;
21+
; CHECK-NEXT: ld.param.u32 %r2, [test_param_1];
22+
; CHECK-NEXT: mov.b16 %rs3, 0x3F80;
23+
; CHECK-NEXT: atom.add.noftz.bf16 %rs4, [%r1], %rs3;
24+
; CHECK-NEXT: ld.param.u32 %r3, [test_param_2];
25+
; CHECK-NEXT: atom.global.add.noftz.bf16 %rs5, [%r2], %rs1;
26+
; CHECK-NEXT: atom.shared.add.noftz.bf16 %rs6, [%r3], %rs1;
27+
; CHECK-NEXT: ret;
28+
;
29+
; CHECK64-LABEL: test(
30+
; CHECK64: {
31+
; CHECK64-NEXT: .reg .b16 %rs<7>;
32+
; CHECK64-NEXT: .reg .b64 %rd<4>;
33+
; CHECK64-EMPTY:
34+
; CHECK64-NEXT: // %bb.0:
35+
; CHECK64-NEXT: ld.param.u64 %rd1, [test_param_0];
36+
; CHECK64-NEXT: ld.param.b16 %rs1, [test_param_3];
37+
; CHECK64-NEXT: atom.add.noftz.bf16 %rs2, [%rd1], %rs1;
38+
; CHECK64-NEXT: ld.param.u64 %rd2, [test_param_1];
39+
; CHECK64-NEXT: mov.b16 %rs3, 0x3F80;
40+
; CHECK64-NEXT: atom.add.noftz.bf16 %rs4, [%rd1], %rs3;
41+
; CHECK64-NEXT: ld.param.u64 %rd3, [test_param_2];
42+
; CHECK64-NEXT: atom.global.add.noftz.bf16 %rs5, [%rd2], %rs1;
43+
; CHECK64-NEXT: atom.shared.add.noftz.bf16 %rs6, [%rd3], %rs1;
44+
; CHECK64-NEXT: ret;
45+
;
46+
; CHECKPTX71-LABEL: test(
47+
; CHECKPTX71: {
48+
; CHECKPTX71-NEXT: .reg .pred %p<5>;
49+
; CHECKPTX71-NEXT: .reg .b16 %rs<18>;
50+
; CHECKPTX71-NEXT: .reg .b32 %r<58>;
51+
; CHECKPTX71-NEXT: .reg .f32 %f<12>;
52+
; CHECKPTX71-EMPTY:
53+
; CHECKPTX71-NEXT: // %bb.0:
54+
; CHECKPTX71-NEXT: ld.param.b16 %rs1, [test_param_3];
55+
; CHECKPTX71-NEXT: ld.param.u32 %r23, [test_param_2];
56+
; CHECKPTX71-NEXT: ld.param.u32 %r22, [test_param_1];
57+
; CHECKPTX71-NEXT: ld.param.u32 %r24, [test_param_0];
58+
; CHECKPTX71-NEXT: and.b32 %r1, %r24, -4;
59+
; CHECKPTX71-NEXT: and.b32 %r25, %r24, 3;
60+
; CHECKPTX71-NEXT: shl.b32 %r2, %r25, 3;
61+
; CHECKPTX71-NEXT: mov.b32 %r26, 65535;
62+
; CHECKPTX71-NEXT: shl.b32 %r27, %r26, %r2;
63+
; CHECKPTX71-NEXT: not.b32 %r3, %r27;
64+
; CHECKPTX71-NEXT: ld.u32 %r54, [%r1];
65+
; CHECKPTX71-NEXT: cvt.f32.bf16 %f2, %rs1;
66+
; CHECKPTX71-NEXT: $L__BB0_1: // %atomicrmw.start
67+
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
68+
; CHECKPTX71-NEXT: shr.u32 %r28, %r54, %r2;
69+
; CHECKPTX71-NEXT: cvt.u16.u32 %rs2, %r28;
70+
; CHECKPTX71-NEXT: cvt.f32.bf16 %f1, %rs2;
71+
; CHECKPTX71-NEXT: add.rn.f32 %f3, %f1, %f2;
72+
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs4, %f3;
73+
; CHECKPTX71-NEXT: cvt.u32.u16 %r29, %rs4;
74+
; CHECKPTX71-NEXT: shl.b32 %r30, %r29, %r2;
75+
; CHECKPTX71-NEXT: and.b32 %r31, %r54, %r3;
76+
; CHECKPTX71-NEXT: or.b32 %r32, %r31, %r30;
77+
; CHECKPTX71-NEXT: atom.cas.b32 %r6, [%r1], %r54, %r32;
78+
; CHECKPTX71-NEXT: setp.ne.s32 %p1, %r6, %r54;
79+
; CHECKPTX71-NEXT: mov.u32 %r54, %r6;
80+
; CHECKPTX71-NEXT: @%p1 bra $L__BB0_1;
81+
; CHECKPTX71-NEXT: // %bb.2: // %atomicrmw.end
82+
; CHECKPTX71-NEXT: ld.u32 %r55, [%r1];
83+
; CHECKPTX71-NEXT: $L__BB0_3: // %atomicrmw.start9
84+
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
85+
; CHECKPTX71-NEXT: shr.u32 %r33, %r55, %r2;
86+
; CHECKPTX71-NEXT: cvt.u16.u32 %rs6, %r33;
87+
; CHECKPTX71-NEXT: cvt.f32.bf16 %f4, %rs6;
88+
; CHECKPTX71-NEXT: add.rn.f32 %f5, %f4, 0f3F800000;
89+
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs8, %f5;
90+
; CHECKPTX71-NEXT: cvt.u32.u16 %r34, %rs8;
91+
; CHECKPTX71-NEXT: shl.b32 %r35, %r34, %r2;
92+
; CHECKPTX71-NEXT: and.b32 %r36, %r55, %r3;
93+
; CHECKPTX71-NEXT: or.b32 %r37, %r36, %r35;
94+
; CHECKPTX71-NEXT: atom.cas.b32 %r9, [%r1], %r55, %r37;
95+
; CHECKPTX71-NEXT: setp.ne.s32 %p2, %r9, %r55;
96+
; CHECKPTX71-NEXT: mov.u32 %r55, %r9;
97+
; CHECKPTX71-NEXT: @%p2 bra $L__BB0_3;
98+
; CHECKPTX71-NEXT: // %bb.4: // %atomicrmw.end8
99+
; CHECKPTX71-NEXT: and.b32 %r10, %r22, -4;
100+
; CHECKPTX71-NEXT: shl.b32 %r38, %r22, 3;
101+
; CHECKPTX71-NEXT: and.b32 %r11, %r38, 24;
102+
; CHECKPTX71-NEXT: shl.b32 %r40, %r26, %r11;
103+
; CHECKPTX71-NEXT: not.b32 %r12, %r40;
104+
; CHECKPTX71-NEXT: ld.global.u32 %r56, [%r10];
105+
; CHECKPTX71-NEXT: $L__BB0_5: // %atomicrmw.start27
106+
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
107+
; CHECKPTX71-NEXT: shr.u32 %r41, %r56, %r11;
108+
; CHECKPTX71-NEXT: cvt.u16.u32 %rs10, %r41;
109+
; CHECKPTX71-NEXT: cvt.f32.bf16 %f6, %rs10;
110+
; CHECKPTX71-NEXT: add.rn.f32 %f8, %f6, %f2;
111+
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs12, %f8;
112+
; CHECKPTX71-NEXT: cvt.u32.u16 %r42, %rs12;
113+
; CHECKPTX71-NEXT: shl.b32 %r43, %r42, %r11;
114+
; CHECKPTX71-NEXT: and.b32 %r44, %r56, %r12;
115+
; CHECKPTX71-NEXT: or.b32 %r45, %r44, %r43;
116+
; CHECKPTX71-NEXT: atom.global.cas.b32 %r15, [%r10], %r56, %r45;
117+
; CHECKPTX71-NEXT: setp.ne.s32 %p3, %r15, %r56;
118+
; CHECKPTX71-NEXT: mov.u32 %r56, %r15;
119+
; CHECKPTX71-NEXT: @%p3 bra $L__BB0_5;
120+
; CHECKPTX71-NEXT: // %bb.6: // %atomicrmw.end26
121+
; CHECKPTX71-NEXT: and.b32 %r16, %r23, -4;
122+
; CHECKPTX71-NEXT: shl.b32 %r46, %r23, 3;
123+
; CHECKPTX71-NEXT: and.b32 %r17, %r46, 24;
124+
; CHECKPTX71-NEXT: shl.b32 %r48, %r26, %r17;
125+
; CHECKPTX71-NEXT: not.b32 %r18, %r48;
126+
; CHECKPTX71-NEXT: ld.shared.u32 %r57, [%r16];
127+
; CHECKPTX71-NEXT: $L__BB0_7: // %atomicrmw.start45
128+
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
129+
; CHECKPTX71-NEXT: shr.u32 %r49, %r57, %r17;
130+
; CHECKPTX71-NEXT: cvt.u16.u32 %rs14, %r49;
131+
; CHECKPTX71-NEXT: cvt.f32.bf16 %f9, %rs14;
132+
; CHECKPTX71-NEXT: add.rn.f32 %f11, %f9, %f2;
133+
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs16, %f11;
134+
; CHECKPTX71-NEXT: cvt.u32.u16 %r50, %rs16;
135+
; CHECKPTX71-NEXT: shl.b32 %r51, %r50, %r17;
136+
; CHECKPTX71-NEXT: and.b32 %r52, %r57, %r18;
137+
; CHECKPTX71-NEXT: or.b32 %r53, %r52, %r51;
138+
; CHECKPTX71-NEXT: atom.shared.cas.b32 %r21, [%r16], %r57, %r53;
139+
; CHECKPTX71-NEXT: setp.ne.s32 %p4, %r21, %r57;
140+
; CHECKPTX71-NEXT: mov.u32 %r57, %r21;
141+
; CHECKPTX71-NEXT: @%p4 bra $L__BB0_7;
142+
; CHECKPTX71-NEXT: // %bb.8: // %atomicrmw.end44
143+
; CHECKPTX71-NEXT: ret;
144+
%r1 = atomicrmw fadd ptr %dp0, bfloat %val seq_cst
145+
%r2 = atomicrmw fadd ptr %dp0, bfloat 1.0 seq_cst
146+
%r3 = atomicrmw fadd ptr addrspace(1) %dp1, bfloat %val seq_cst
147+
%r4 = atomicrmw fadd ptr addrspace(3) %dp3, bfloat %val seq_cst
148+
ret void
149+
}
150+
151+
attributes #1 = { argmemonly nounwind }

0 commit comments

Comments
 (0)