Skip to content

Commit 801dd88

Browse files
committed
[X86][BF16] Fix 2 crashes with vector broadcast
Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D151808
1 parent ae5c472 commit 801dd88

File tree

3 files changed

+74
-5
lines changed

3 files changed

+74
-5
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2269,6 +2269,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
22692269
setOperationAction(ISD::FMUL, VT, Expand);
22702270
setOperationAction(ISD::FDIV, VT, Expand);
22712271
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
2272+
setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
22722273
}
22732274
addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
22742275
}
@@ -2281,6 +2282,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
22812282
setOperationAction(ISD::FMUL, MVT::v32bf16, Expand);
22822283
setOperationAction(ISD::FDIV, MVT::v32bf16, Expand);
22832284
setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
2285+
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
22842286
}
22852287

22862288
if (!Subtarget.useSoftFloat() && Subtarget.hasVLX()) {
@@ -19099,11 +19101,11 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
1909919101
return DAG.getBitcast(VT, DAG.getVectorShuffle(FpVT, DL, V1, V2, Mask));
1910019102
}
1910119103

19102-
if (VT == MVT::v16f16) {
19103-
V1 = DAG.getBitcast(MVT::v16i16, V1);
19104-
V2 = DAG.getBitcast(MVT::v16i16, V2);
19105-
return DAG.getBitcast(MVT::v16f16,
19106-
DAG.getVectorShuffle(MVT::v16i16, DL, V1, V2, Mask));
19104+
if (VT == MVT::v16f16 || VT.getVectorElementType() == MVT::bf16) {
19105+
MVT IVT = VT.changeVectorElementTypeToInteger();
19106+
V1 = DAG.getBitcast(IVT, V1);
19107+
V2 = DAG.getBitcast(IVT, V2);
19108+
return DAG.getBitcast(VT, DAG.getVectorShuffle(IVT, DL, V1, V2, Mask));
1910719109
}
1910819110

1910919111
switch (VT.SimpleTy) {

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12965,6 +12965,27 @@ let Predicates = [HasBF16, HasVLX] in {
1296512965
(VCVTNEPS2BF16Z256rr VR256X:$src)>;
1296612966
def : Pat<(v8bf16 (int_x86_vcvtneps2bf16256 (loadv8f32 addr:$src))),
1296712967
(VCVTNEPS2BF16Z256rm addr:$src)>;
12968+
12969+
def : Pat<(v8bf16 (X86VBroadcastld16 addr:$src)),
12970+
(VPBROADCASTWZ128rm addr:$src)>;
12971+
def : Pat<(v16bf16 (X86VBroadcastld16 addr:$src)),
12972+
(VPBROADCASTWZ256rm addr:$src)>;
12973+
12974+
def : Pat<(v8bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
12975+
(VPBROADCASTWZ128rr VR128X:$src)>;
12976+
def : Pat<(v16bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
12977+
(VPBROADCASTWZ256rr VR128X:$src)>;
12978+
12979+
// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
12980+
}
12981+
12982+
let Predicates = [HasBF16] in {
12983+
def : Pat<(v32bf16 (X86VBroadcastld16 addr:$src)),
12984+
(VPBROADCASTWZrm addr:$src)>;
12985+
12986+
def : Pat<(v32bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
12987+
(VPBROADCASTWZrr VR128X:$src)>;
12988+
// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
1296812989
}
1296912990

1297012991
let Constraints = "$src1 = $dst" in {

llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,49 @@ entry:
356356
%2 = select <4 x i1> %1, <4 x float> %0, <4 x float> %E
357357
ret <4 x float> %2
358358
}
359+
360+
define <16 x i16> @test_no_vbroadcast1() {
361+
; CHECK-LABEL: test_no_vbroadcast1:
362+
; CHECK: # %bb.0: # %entry
363+
; CHECK-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
364+
; CHECK-NEXT: vpbroadcastw %xmm0, %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0xc0]
365+
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
366+
entry:
367+
%0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> poison, <8 x bfloat> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
368+
%1 = bitcast <8 x bfloat> %0 to <8 x i16>
369+
%2 = shufflevector <8 x i16> %1, <8 x i16> undef, <16 x i32> zeroinitializer
370+
ret <16 x i16> %2
371+
}
372+
373+
;; FIXME: This should generate the same output as above, but let's fix the crash first.
374+
define <16 x bfloat> @test_no_vbroadcast2() nounwind {
375+
; X86-LABEL: test_no_vbroadcast2:
376+
; X86: # %bb.0: # %entry
377+
; X86-NEXT: pushl %ebp # encoding: [0x55]
378+
; X86-NEXT: movl %esp, %ebp # encoding: [0x89,0xe5]
379+
; X86-NEXT: andl $-32, %esp # encoding: [0x83,0xe4,0xe0]
380+
; X86-NEXT: subl $64, %esp # encoding: [0x83,0xec,0x40]
381+
; X86-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
382+
; X86-NEXT: vmovaps %xmm0, (%esp) # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x29,0x04,0x24]
383+
; X86-NEXT: vpbroadcastw (%esp), %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0x04,0x24]
384+
; X86-NEXT: movl %ebp, %esp # encoding: [0x89,0xec]
385+
; X86-NEXT: popl %ebp # encoding: [0x5d]
386+
; X86-NEXT: retl # encoding: [0xc3]
387+
;
388+
; X64-LABEL: test_no_vbroadcast2:
389+
; X64: # %bb.0: # %entry
390+
; X64-NEXT: pushq %rbp # encoding: [0x55]
391+
; X64-NEXT: movq %rsp, %rbp # encoding: [0x48,0x89,0xe5]
392+
; X64-NEXT: andq $-32, %rsp # encoding: [0x48,0x83,0xe4,0xe0]
393+
; X64-NEXT: subq $64, %rsp # encoding: [0x48,0x83,0xec,0x40]
394+
; X64-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
395+
; X64-NEXT: vmovaps %xmm0, (%rsp) # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x29,0x04,0x24]
396+
; X64-NEXT: vpbroadcastw (%rsp), %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0x04,0x24]
397+
; X64-NEXT: movq %rbp, %rsp # encoding: [0x48,0x89,0xec]
398+
; X64-NEXT: popq %rbp # encoding: [0x5d]
399+
; X64-NEXT: retq # encoding: [0xc3]
400+
entry:
401+
%0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> poison, <8 x bfloat> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
402+
%1 = shufflevector <8 x bfloat> %0, <8 x bfloat> undef, <16 x i32> zeroinitializer
403+
ret <16 x bfloat> %1
404+
}

0 commit comments

Comments
 (0)