Skip to content

Commit 40fffba

Browse files
authored
[X86][AVX10.2] Fix wrong predicates for BF16 feature (#113800)
Since AVX10.2, we need to enable 128/256-bit vector by default and check for 512 feature for 512-bit vector.
1 parent 1fe8e78 commit 40fffba

File tree

2 files changed

+41
-22
lines changed

2 files changed

+41
-22
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2406,7 +2406,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24062406
addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
24072407
}
24082408

2409-
if (!Subtarget.useSoftFloat() && Subtarget.hasBF16()) {
2409+
if (!Subtarget.useSoftFloat() && Subtarget.hasBF16() &&
2410+
Subtarget.useAVX512Regs()) {
24102411
addRegisterClass(MVT::v32bf16, &X86::VR512RegClass);
24112412
setF16Action(MVT::v32bf16, Expand);
24122413
for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
@@ -2419,27 +2420,23 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24192420
}
24202421

24212422
if (!Subtarget.useSoftFloat() && Subtarget.hasAVX10_2()) {
2422-
addRegisterClass(MVT::v8bf16, &X86::VR128XRegClass);
2423-
addRegisterClass(MVT::v16bf16, &X86::VR256XRegClass);
2424-
addRegisterClass(MVT::v32bf16, &X86::VR512RegClass);
2425-
2426-
setOperationAction(ISD::FADD, MVT::v32bf16, Legal);
2427-
setOperationAction(ISD::FSUB, MVT::v32bf16, Legal);
2428-
setOperationAction(ISD::FMUL, MVT::v32bf16, Legal);
2429-
setOperationAction(ISD::FDIV, MVT::v32bf16, Legal);
2430-
setOperationAction(ISD::FSQRT, MVT::v32bf16, Legal);
2431-
setOperationAction(ISD::FMA, MVT::v32bf16, Legal);
2432-
setOperationAction(ISD::SETCC, MVT::v32bf16, Custom);
2433-
if (Subtarget.hasVLX()) {
2434-
for (auto VT : {MVT::v8bf16, MVT::v16bf16}) {
2435-
setOperationAction(ISD::FADD, VT, Legal);
2436-
setOperationAction(ISD::FSUB, VT, Legal);
2437-
setOperationAction(ISD::FMUL, VT, Legal);
2438-
setOperationAction(ISD::FDIV, VT, Legal);
2439-
setOperationAction(ISD::FSQRT, VT, Legal);
2440-
setOperationAction(ISD::FMA, VT, Legal);
2441-
setOperationAction(ISD::SETCC, VT, Custom);
2442-
}
2423+
for (auto VT : {MVT::v8bf16, MVT::v16bf16}) {
2424+
setOperationAction(ISD::FADD, VT, Legal);
2425+
setOperationAction(ISD::FSUB, VT, Legal);
2426+
setOperationAction(ISD::FMUL, VT, Legal);
2427+
setOperationAction(ISD::FDIV, VT, Legal);
2428+
setOperationAction(ISD::FSQRT, VT, Legal);
2429+
setOperationAction(ISD::FMA, VT, Legal);
2430+
setOperationAction(ISD::SETCC, VT, Custom);
2431+
}
2432+
if (Subtarget.hasAVX10_2_512()) {
2433+
setOperationAction(ISD::FADD, MVT::v32bf16, Legal);
2434+
setOperationAction(ISD::FSUB, MVT::v32bf16, Legal);
2435+
setOperationAction(ISD::FMUL, MVT::v32bf16, Legal);
2436+
setOperationAction(ISD::FDIV, MVT::v32bf16, Legal);
2437+
setOperationAction(ISD::FSQRT, MVT::v32bf16, Legal);
2438+
setOperationAction(ISD::FMA, MVT::v32bf16, Legal);
2439+
setOperationAction(ISD::SETCC, MVT::v32bf16, Custom);
24432440
}
24442441
}
24452442

llvm/test/CodeGen/X86/avx10_2bf16-arith.ll

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,3 +1166,25 @@ entry:
11661166
%2 = select <8 x i1> %1, <8 x bfloat> %0, <8 x bfloat> zeroinitializer
11671167
ret <8 x bfloat> %2
11681168
}
1169+
1170+
define <32 x bfloat> @addv(<32 x bfloat> %a, <32 x bfloat> %b) nounwind {
1171+
; X64-LABEL: addv:
1172+
; X64: # %bb.0:
1173+
; X64-NEXT: vaddnepbf16 %ymm2, %ymm0, %ymm0 # encoding: [0x62,0xf5,0x7d,0x28,0x58,0xc2]
1174+
; X64-NEXT: vaddnepbf16 %ymm3, %ymm1, %ymm1 # encoding: [0x62,0xf5,0x75,0x28,0x58,0xcb]
1175+
; X64-NEXT: retq # encoding: [0xc3]
1176+
;
1177+
; X86-LABEL: addv:
1178+
; X86: # %bb.0:
1179+
; X86-NEXT: pushl %ebp # encoding: [0x55]
1180+
; X86-NEXT: movl %esp, %ebp # encoding: [0x89,0xe5]
1181+
; X86-NEXT: andl $-32, %esp # encoding: [0x83,0xe4,0xe0]
1182+
; X86-NEXT: subl $32, %esp # encoding: [0x83,0xec,0x20]
1183+
; X86-NEXT: vaddnepbf16 %ymm2, %ymm0, %ymm0 # encoding: [0x62,0xf5,0x7d,0x28,0x58,0xc2]
1184+
; X86-NEXT: vaddnepbf16 8(%ebp), %ymm1, %ymm1 # encoding: [0x62,0xf5,0x75,0x28,0x58,0x8d,0x08,0x00,0x00,0x00]
1185+
; X86-NEXT: movl %ebp, %esp # encoding: [0x89,0xec]
1186+
; X86-NEXT: popl %ebp # encoding: [0x5d]
1187+
; X86-NEXT: retl # encoding: [0xc3]
1188+
%add = fadd <32 x bfloat> %a, %b
1189+
ret <32 x bfloat> %add
1190+
}

0 commit comments

Comments
 (0)