Skip to content

Commit a8d506b

Browse files
Zain Jaffalasavonic
authored andcommitted
[AArch64] Fix crash caused by performExtractVectorEltCombine on DUP nodes with float operands.
This is a cherry-pick of 872924f. Fixes issue llvm#65422 "[Autodiff] Assertion failure SIMD AArch64" Reviewed By: dmgreen Differential Revision: https://reviews.llvm.org/D148705
1 parent 66a493d commit a8d506b

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16152,7 +16152,8 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1615216152

1615316153
// extract(dup x) -> x
1615416154
if (N0.getOpcode() == AArch64ISD::DUP)
16155-
return DAG.getZExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
16155+
return VT.isInteger() ? DAG.getZExtOrTrunc(N0.getOperand(0), SDLoc(N), VT)
16156+
: N0.getOperand(0);
1615616157

1615716158
// Rewrite for pairwise fadd pattern
1615816159
// (f32 (extract_vector_elt
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
2+
; RUN: llc -mtriple=arm64-unknown-unknown < %s -o -| FileCheck %s
3+
4+
; This test covers a case where extract_vector_elt is selected when DUP is
5+
; generated. Where it tries to generate a ZextOrTrunc node with floating point
6+
; type resulting in a crash.
7+
; See https://reviews.llvm.org/D128144#4280024 for context
8+
define void @dot_product(double %a) {
9+
; CHECK-LABEL: dot_product:
10+
; CHECK: // %bb.0: // %entry
11+
; CHECK-NEXT: fmov d1, #1.00000000
12+
; CHECK-NEXT: fadd d0, d0, d1
13+
; CHECK-NEXT: fadd d0, d0, d1
14+
; CHECK-NEXT: movi d1, #0000000000000000
15+
; CHECK-NEXT: fadd d0, d0, d1
16+
; CHECK-NEXT: fsqrt d0, d0
17+
; CHECK-NEXT: fcmp d0, #0.0
18+
; CHECK-NEXT: ret
19+
entry:
20+
%fadd = call double @llvm.vector.reduce.fadd.v3f64(double %a, <3 x double> <double 1.000000e+00, double 1.000000e+00, double 0.000000e+00>)
21+
%sqrt = call double @llvm.sqrt.f64(double %fadd)
22+
%insert = insertelement <3 x double> zeroinitializer, double %sqrt, i64 0
23+
%shuffle = shufflevector <3 x double> %insert, <3 x double> zeroinitializer, <3 x i32> zeroinitializer
24+
%mul = fmul <3 x double> %shuffle, <double 1.000000e+00, double 1.000000e+00, double 0.000000e+00>
25+
%shuffle.1 = extractelement <3 x double> %mul, i64 0
26+
%shuffle.2 = extractelement <3 x double> %mul, i64 1
27+
%cmp = fcmp ogt double %shuffle.2, 0.000000e+00
28+
br i1 %cmp, label %exit, label %bb.1
29+
30+
bb.1:
31+
%mul.2 = fmul double %shuffle.1, 0.000000e+00
32+
br label %exit
33+
34+
exit:
35+
ret void
36+
}
37+
38+
declare double @llvm.sqrt.f64(double)
39+
declare double @llvm.vector.reduce.fadd.v3f64(double, <3 x double>)

0 commit comments

Comments
 (0)