Skip to content

Commit 2e4e04c

Browse files
authored
[X86][BF16] Do not lower to VCVTNEPS2BF16 without AVX512VL (#86395)
Fixes: #86305
1 parent 7d2d8e2 commit 2e4e04c

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21512,7 +21512,9 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
2151221512
}
2151321513

2151421514
if (VT.getScalarType() == MVT::bf16) {
21515-
if (SVT.getScalarType() == MVT::f32 && isTypeLegal(VT))
21515+
if (SVT.getScalarType() == MVT::f32 &&
21516+
((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
21517+
Subtarget.hasAVXNECONVERT()))
2151621518
return Op;
2151721519
return SDValue();
2151821520
}
@@ -21619,7 +21621,8 @@ SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
2161921621
SDLoc DL(Op);
2162021622

2162121623
MVT SVT = Op.getOperand(0).getSimpleValueType();
21622-
if (SVT == MVT::f32 && (Subtarget.hasBF16() || Subtarget.hasAVXNECONVERT())) {
21624+
if (SVT == MVT::f32 && ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
21625+
Subtarget.hasAVXNECONVERT())) {
2162321626
SDValue Res;
2162421627
Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, Op.getOperand(0));
2162521628
Res = DAG.getNode(X86ISD::CVTNEPS2BF16, DL, MVT::v8bf16, Res);

llvm/test/CodeGen/X86/pr86305.ll

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avx512bf16 | FileCheck %s
3+
4+
define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
5+
; CHECK-LABEL: add:
6+
; CHECK: # %bb.0:
7+
; CHECK-NEXT: pushq %rbx
8+
; CHECK-NEXT: movq %rdx, %rbx
9+
; CHECK-NEXT: movzwl (%rsi), %eax
10+
; CHECK-NEXT: shll $16, %eax
11+
; CHECK-NEXT: vmovd %eax, %xmm0
12+
; CHECK-NEXT: movzwl (%rdi), %eax
13+
; CHECK-NEXT: shll $16, %eax
14+
; CHECK-NEXT: vmovd %eax, %xmm1
15+
; CHECK-NEXT: vaddss %xmm0, %xmm1, %xmm0
16+
; CHECK-NEXT: callq __truncsfbf2@PLT
17+
; CHECK-NEXT: vpextrw $0, %xmm0, (%rbx)
18+
; CHECK-NEXT: popq %rbx
19+
; CHECK-NEXT: retq
20+
%a = load bfloat, ptr %pa
21+
%b = load bfloat, ptr %pb
22+
%add = fadd bfloat %a, %b
23+
store bfloat %add, ptr %pc
24+
ret void
25+
}
26+
27+
define <4 x bfloat> @fptrunc_v4f32(<4 x float> %a) nounwind {
28+
; CHECK-LABEL: fptrunc_v4f32:
29+
; CHECK: # %bb.0:
30+
; CHECK-NEXT: pushq %rbp
31+
; CHECK-NEXT: pushq %r15
32+
; CHECK-NEXT: pushq %r14
33+
; CHECK-NEXT: pushq %rbx
34+
; CHECK-NEXT: subq $72, %rsp
35+
; CHECK-NEXT: vmovaps %xmm0, (%rsp) # 16-byte Spill
36+
; CHECK-NEXT: callq __truncsfbf2@PLT
37+
; CHECK-NEXT: vmovaps %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
38+
; CHECK-NEXT: vpermilpd $1, (%rsp), %xmm0 # 16-byte Folded Reload
39+
; CHECK-NEXT: # xmm0 = mem[1,0]
40+
; CHECK-NEXT: callq __truncsfbf2@PLT
41+
; CHECK-NEXT: vmovapd %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
42+
; CHECK-NEXT: vpshufd $255, (%rsp), %xmm0 # 16-byte Folded Reload
43+
; CHECK-NEXT: # xmm0 = mem[3,3,3,3]
44+
; CHECK-NEXT: callq __truncsfbf2@PLT
45+
; CHECK-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
46+
; CHECK-NEXT: callq __truncsfbf2@PLT
47+
; CHECK-NEXT: vpextrw $0, %xmm0, %ebx
48+
; CHECK-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
49+
; CHECK-NEXT: vpextrw $0, %xmm0, %ebp
50+
; CHECK-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
51+
; CHECK-NEXT: vpextrw $0, %xmm0, %r14d
52+
; CHECK-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
53+
; CHECK-NEXT: vpextrw $0, %xmm0, %r15d
54+
; CHECK-NEXT: vmovshdup (%rsp), %xmm0 # 16-byte Folded Reload
55+
; CHECK-NEXT: # xmm0 = mem[1,1,3,3]
56+
; CHECK-NEXT: callq __truncsfbf2@PLT
57+
; CHECK-NEXT: vpextrw $0, %xmm0, %eax
58+
; CHECK-NEXT: vmovd %r15d, %xmm0
59+
; CHECK-NEXT: vpinsrw $1, %eax, %xmm0, %xmm0
60+
; CHECK-NEXT: vpinsrw $2, %r14d, %xmm0, %xmm0
61+
; CHECK-NEXT: vpinsrw $3, %ebp, %xmm0, %xmm0
62+
; CHECK-NEXT: vpinsrw $4, %ebx, %xmm0, %xmm0
63+
; CHECK-NEXT: vpinsrw $5, %ebx, %xmm0, %xmm0
64+
; CHECK-NEXT: vpinsrw $6, %ebx, %xmm0, %xmm0
65+
; CHECK-NEXT: vpinsrw $7, %ebx, %xmm0, %xmm0
66+
; CHECK-NEXT: addq $72, %rsp
67+
; CHECK-NEXT: popq %rbx
68+
; CHECK-NEXT: popq %r14
69+
; CHECK-NEXT: popq %r15
70+
; CHECK-NEXT: popq %rbp
71+
; CHECK-NEXT: retq
72+
%b = fptrunc <4 x float> %a to <4 x bfloat>
73+
ret <4 x bfloat> %b
74+
}

0 commit comments

Comments
 (0)