Skip to content

Commit 8bd9ade

Browse files
authored
[InstCombine] Fold fcmp pred sqrt(X), 0.0 -> fcmp pred2 X, 0.0 (#101626)
Proof (Please run alive-tv with larger smt-to): https://alive2.llvm.org/ce/z/-aqixk FMF propagation: https://alive2.llvm.org/ce/z/zyKK_p ``` sqrt(X) < 0.0 --> false sqrt(X) u>= 0.0 --> true sqrt(X) u< 0.0 --> X u< 0.0 sqrt(X) u<= 0.0 --> X u<= 0.0 sqrt(X) > 0.0 --> X > 0.0 sqrt(X) >= 0.0 --> X >= 0.0 sqrt(X) == 0.0 --> X == 0.0 sqrt(X) u!= 0.0 --> X u!= 0.0 sqrt(X) <= 0.0 --> X == 0.0 sqrt(X) u> 0.0 --> X u!= 0.0 sqrt(X) u== 0.0 --> X u<= 0.0 sqrt(X) != 0.0 --> X > 0.0 !isnan(sqrt(X)) --> X >= 0.0 isnan(sqrt(X)) --> X u< 0.0 ``` In most cases, `sqrt` cannot be eliminated since it has multiple uses. But this patch will break data dependencies and allow optimizer to sink expensive `sqrt` calls into successor blocks.
1 parent ea18a40 commit 8bd9ade

File tree

3 files changed

+298
-3
lines changed

3 files changed

+298
-3
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7980,6 +7980,67 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
79807980
}
79817981
}
79827982

7983+
/// Optimize sqrt(X) compared with zero.
7984+
static Instruction *foldSqrtWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
7985+
Value *X;
7986+
if (!match(I.getOperand(0), m_Sqrt(m_Value(X))))
7987+
return nullptr;
7988+
7989+
if (!match(I.getOperand(1), m_PosZeroFP()))
7990+
return nullptr;
7991+
7992+
auto ReplacePredAndOp0 = [&](FCmpInst::Predicate P) {
7993+
I.setPredicate(P);
7994+
return IC.replaceOperand(I, 0, X);
7995+
};
7996+
7997+
// Clear ninf flag if sqrt doesn't have it.
7998+
if (!cast<Instruction>(I.getOperand(0))->hasNoInfs())
7999+
I.setHasNoInfs(false);
8000+
8001+
switch (I.getPredicate()) {
8002+
case FCmpInst::FCMP_OLT:
8003+
case FCmpInst::FCMP_UGE:
8004+
// sqrt(X) < 0.0 --> false
8005+
// sqrt(X) u>= 0.0 --> true
8006+
llvm_unreachable("fcmp should have simplified");
8007+
case FCmpInst::FCMP_ULT:
8008+
case FCmpInst::FCMP_ULE:
8009+
case FCmpInst::FCMP_OGT:
8010+
case FCmpInst::FCMP_OGE:
8011+
case FCmpInst::FCMP_OEQ:
8012+
case FCmpInst::FCMP_UNE:
8013+
// sqrt(X) u< 0.0 --> X u< 0.0
8014+
// sqrt(X) u<= 0.0 --> X u<= 0.0
8015+
// sqrt(X) > 0.0 --> X > 0.0
8016+
// sqrt(X) >= 0.0 --> X >= 0.0
8017+
// sqrt(X) == 0.0 --> X == 0.0
8018+
// sqrt(X) u!= 0.0 --> X u!= 0.0
8019+
return IC.replaceOperand(I, 0, X);
8020+
8021+
case FCmpInst::FCMP_OLE:
8022+
// sqrt(X) <= 0.0 --> X == 0.0
8023+
return ReplacePredAndOp0(FCmpInst::FCMP_OEQ);
8024+
case FCmpInst::FCMP_UGT:
8025+
// sqrt(X) u> 0.0 --> X u!= 0.0
8026+
return ReplacePredAndOp0(FCmpInst::FCMP_UNE);
8027+
case FCmpInst::FCMP_UEQ:
8028+
// sqrt(X) u== 0.0 --> X u<= 0.0
8029+
return ReplacePredAndOp0(FCmpInst::FCMP_ULE);
8030+
case FCmpInst::FCMP_ONE:
8031+
// sqrt(X) != 0.0 --> X > 0.0
8032+
return ReplacePredAndOp0(FCmpInst::FCMP_OGT);
8033+
case FCmpInst::FCMP_ORD:
8034+
// !isnan(sqrt(X)) --> X >= 0.0
8035+
return ReplacePredAndOp0(FCmpInst::FCMP_OGE);
8036+
case FCmpInst::FCMP_UNO:
8037+
// isnan(sqrt(X)) --> X u< 0.0
8038+
return ReplacePredAndOp0(FCmpInst::FCMP_ULT);
8039+
default:
8040+
llvm_unreachable("Unexpected predicate!");
8041+
}
8042+
}
8043+
79838044
static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) {
79848045
CmpInst::Predicate Pred = I.getPredicate();
79858046
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -8247,6 +8308,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
82478308
if (Instruction *R = foldFabsWithFcmpZero(I, *this))
82488309
return R;
82498310

8311+
if (Instruction *R = foldSqrtWithFcmpZero(I, *this))
8312+
return R;
8313+
82508314
if (match(Op0, m_FNeg(m_Value(X)))) {
82518315
// fcmp pred (fneg X), C --> fcmp swap(pred) X, -C
82528316
Constant *C;

llvm/test/Transforms/InstCombine/fcmp.ll

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,3 +2117,236 @@ define <8 x i1> @fcmp_ogt_fsub_const_vec_denormal_preserve-sign(<8 x float> %x,
21172117
%cmp = fcmp ogt <8 x float> %fs, zeroinitializer
21182118
ret <8 x i1> %cmp
21192119
}
2120+
2121+
define i1 @fcmp_sqrt_zero_olt(half %x) {
2122+
; CHECK-LABEL: @fcmp_sqrt_zero_olt(
2123+
; CHECK-NEXT: ret i1 false
2124+
;
2125+
%sqrt = call half @llvm.sqrt.f16(half %x)
2126+
%cmp = fcmp olt half %sqrt, 0.0
2127+
ret i1 %cmp
2128+
}
2129+
2130+
define i1 @fcmp_sqrt_zero_ult(half %x) {
2131+
; CHECK-LABEL: @fcmp_sqrt_zero_ult(
2132+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
2133+
; CHECK-NEXT: ret i1 [[CMP]]
2134+
;
2135+
%sqrt = call half @llvm.sqrt.f16(half %x)
2136+
%cmp = fcmp ult half %sqrt, 0.0
2137+
ret i1 %cmp
2138+
}
2139+
2140+
define i1 @fcmp_sqrt_zero_ult_fmf(half %x) {
2141+
; CHECK-LABEL: @fcmp_sqrt_zero_ult_fmf(
2142+
; CHECK-NEXT: [[CMP:%.*]] = fcmp nsz ult half [[X:%.*]], 0xH0000
2143+
; CHECK-NEXT: ret i1 [[CMP]]
2144+
;
2145+
%sqrt = call half @llvm.sqrt.f16(half %x)
2146+
%cmp = fcmp ninf nsz ult half %sqrt, 0.0
2147+
ret i1 %cmp
2148+
}
2149+
2150+
define i1 @fcmp_sqrt_zero_ult_fmf_sqrt_ninf(half %x) {
2151+
; CHECK-LABEL: @fcmp_sqrt_zero_ult_fmf_sqrt_ninf(
2152+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ninf nsz ult half [[X:%.*]], 0xH0000
2153+
; CHECK-NEXT: ret i1 [[CMP]]
2154+
;
2155+
%sqrt = call ninf half @llvm.sqrt.f16(half %x)
2156+
%cmp = fcmp ninf nsz ult half %sqrt, 0.0
2157+
ret i1 %cmp
2158+
}
2159+
2160+
define i1 @fcmp_sqrt_zero_ult_nzero(half %x) {
2161+
; CHECK-LABEL: @fcmp_sqrt_zero_ult_nzero(
2162+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
2163+
; CHECK-NEXT: ret i1 [[CMP]]
2164+
;
2165+
%sqrt = call half @llvm.sqrt.f16(half %x)
2166+
%cmp = fcmp ult half %sqrt, -0.0
2167+
ret i1 %cmp
2168+
}
2169+
2170+
define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
2171+
; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec(
2172+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
2173+
; CHECK-NEXT: ret <2 x i1> [[CMP]]
2174+
;
2175+
%sqrt = call <2 x half> @llvm.sqrt.v2f16(<2 x half> %x)
2176+
%cmp = fcmp ult <2 x half> %sqrt, zeroinitializer
2177+
ret <2 x i1> %cmp
2178+
}
2179+
2180+
define <2 x i1> @fcmp_sqrt_zero_ult_vec_mixed_zero(<2 x half> %x) {
2181+
; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec_mixed_zero(
2182+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
2183+
; CHECK-NEXT: ret <2 x i1> [[CMP]]
2184+
;
2185+
%sqrt = call <2 x half> @llvm.sqrt.v2f16(<2 x half> %x)
2186+
%cmp = fcmp ult <2 x half> %sqrt, <half 0.0, half -0.0>
2187+
ret <2 x i1> %cmp
2188+
}
2189+
2190+
define i1 @fcmp_sqrt_zero_ole(half %x) {
2191+
; CHECK-LABEL: @fcmp_sqrt_zero_ole(
2192+
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
2193+
; CHECK-NEXT: ret i1 [[CMP]]
2194+
;
2195+
%sqrt = call half @llvm.sqrt.f16(half %x)
2196+
%cmp = fcmp ole half %sqrt, 0.0
2197+
ret i1 %cmp
2198+
}
2199+
2200+
define i1 @fcmp_sqrt_zero_ule(half %x) {
2201+
; CHECK-LABEL: @fcmp_sqrt_zero_ule(
2202+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
2203+
; CHECK-NEXT: ret i1 [[CMP]]
2204+
;
2205+
%sqrt = call half @llvm.sqrt.f16(half %x)
2206+
%cmp = fcmp ule half %sqrt, 0.0
2207+
ret i1 %cmp
2208+
}
2209+
2210+
define i1 @fcmp_sqrt_zero_ogt(half %x) {
2211+
; CHECK-LABEL: @fcmp_sqrt_zero_ogt(
2212+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
2213+
; CHECK-NEXT: ret i1 [[CMP]]
2214+
;
2215+
%sqrt = call half @llvm.sqrt.f16(half %x)
2216+
%cmp = fcmp ogt half %sqrt, 0.0
2217+
ret i1 %cmp
2218+
}
2219+
2220+
define i1 @fcmp_sqrt_zero_ugt(half %x) {
2221+
; CHECK-LABEL: @fcmp_sqrt_zero_ugt(
2222+
; CHECK-NEXT: [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
2223+
; CHECK-NEXT: ret i1 [[CMP]]
2224+
;
2225+
%sqrt = call half @llvm.sqrt.f16(half %x)
2226+
%cmp = fcmp ugt half %sqrt, 0.0
2227+
ret i1 %cmp
2228+
}
2229+
2230+
define i1 @fcmp_sqrt_zero_oge(half %x) {
2231+
; CHECK-LABEL: @fcmp_sqrt_zero_oge(
2232+
; CHECK-NEXT: [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
2233+
; CHECK-NEXT: ret i1 [[CMP]]
2234+
;
2235+
%sqrt = call half @llvm.sqrt.f16(half %x)
2236+
%cmp = fcmp oge half %sqrt, 0.0
2237+
ret i1 %cmp
2238+
}
2239+
2240+
define i1 @fcmp_sqrt_zero_uge(half %x) {
2241+
; CHECK-LABEL: @fcmp_sqrt_zero_uge(
2242+
; CHECK-NEXT: ret i1 true
2243+
;
2244+
%sqrt = call half @llvm.sqrt.f16(half %x)
2245+
%cmp = fcmp uge half %sqrt, 0.0
2246+
ret i1 %cmp
2247+
}
2248+
2249+
define i1 @fcmp_sqrt_zero_oeq(half %x) {
2250+
; CHECK-LABEL: @fcmp_sqrt_zero_oeq(
2251+
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
2252+
; CHECK-NEXT: ret i1 [[CMP]]
2253+
;
2254+
%sqrt = call half @llvm.sqrt.f16(half %x)
2255+
%cmp = fcmp oeq half %sqrt, 0.0
2256+
ret i1 %cmp
2257+
}
2258+
2259+
define i1 @fcmp_sqrt_zero_ueq(half %x) {
2260+
; CHECK-LABEL: @fcmp_sqrt_zero_ueq(
2261+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
2262+
; CHECK-NEXT: ret i1 [[CMP]]
2263+
;
2264+
%sqrt = call half @llvm.sqrt.f16(half %x)
2265+
%cmp = fcmp ueq half %sqrt, 0.0
2266+
ret i1 %cmp
2267+
}
2268+
2269+
define i1 @fcmp_sqrt_zero_one(half %x) {
2270+
; CHECK-LABEL: @fcmp_sqrt_zero_one(
2271+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
2272+
; CHECK-NEXT: ret i1 [[CMP]]
2273+
;
2274+
%sqrt = call half @llvm.sqrt.f16(half %x)
2275+
%cmp = fcmp one half %sqrt, 0.0
2276+
ret i1 %cmp
2277+
}
2278+
2279+
define i1 @fcmp_sqrt_zero_une(half %x) {
2280+
; CHECK-LABEL: @fcmp_sqrt_zero_une(
2281+
; CHECK-NEXT: [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
2282+
; CHECK-NEXT: ret i1 [[CMP]]
2283+
;
2284+
%sqrt = call half @llvm.sqrt.f16(half %x)
2285+
%cmp = fcmp une half %sqrt, 0.0
2286+
ret i1 %cmp
2287+
}
2288+
2289+
define i1 @fcmp_sqrt_zero_ord(half %x) {
2290+
; CHECK-LABEL: @fcmp_sqrt_zero_ord(
2291+
; CHECK-NEXT: [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
2292+
; CHECK-NEXT: ret i1 [[CMP]]
2293+
;
2294+
%sqrt = call half @llvm.sqrt.f16(half %x)
2295+
%cmp = fcmp ord half %sqrt, 0.0
2296+
ret i1 %cmp
2297+
}
2298+
2299+
define i1 @fcmp_sqrt_zero_uno(half %x) {
2300+
; CHECK-LABEL: @fcmp_sqrt_zero_uno(
2301+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
2302+
; CHECK-NEXT: ret i1 [[CMP]]
2303+
;
2304+
%sqrt = call half @llvm.sqrt.f16(half %x)
2305+
%cmp = fcmp uno half %sqrt, 0.0
2306+
ret i1 %cmp
2307+
}
2308+
2309+
; Make sure that ninf is cleared.
2310+
define i1 @fcmp_sqrt_zero_uno_fmf(half %x) {
2311+
; CHECK-LABEL: @fcmp_sqrt_zero_uno_fmf(
2312+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
2313+
; CHECK-NEXT: ret i1 [[CMP]]
2314+
;
2315+
%sqrt = call half @llvm.sqrt.f16(half %x)
2316+
%cmp = fcmp ninf uno half %sqrt, 0.0
2317+
ret i1 %cmp
2318+
}
2319+
2320+
define i1 @fcmp_sqrt_zero_uno_fmf_sqrt_ninf(half %x) {
2321+
; CHECK-LABEL: @fcmp_sqrt_zero_uno_fmf_sqrt_ninf(
2322+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ninf ult half [[X:%.*]], 0xH0000
2323+
; CHECK-NEXT: ret i1 [[CMP]]
2324+
;
2325+
%sqrt = call ninf half @llvm.sqrt.f16(half %x)
2326+
%cmp = fcmp ninf uno half %sqrt, 0.0
2327+
ret i1 %cmp
2328+
}
2329+
2330+
; negative tests
2331+
2332+
define i1 @fcmp_sqrt_zero_ult_var(half %x, half %y) {
2333+
; CHECK-LABEL: @fcmp_sqrt_zero_ult_var(
2334+
; CHECK-NEXT: [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
2335+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[SQRT]], [[Y:%.*]]
2336+
; CHECK-NEXT: ret i1 [[CMP]]
2337+
;
2338+
%sqrt = call half @llvm.sqrt.f16(half %x)
2339+
%cmp = fcmp ult half %sqrt, %y
2340+
ret i1 %cmp
2341+
}
2342+
2343+
define i1 @fcmp_sqrt_zero_ult_nonzero(half %x) {
2344+
; CHECK-LABEL: @fcmp_sqrt_zero_ult_nonzero(
2345+
; CHECK-NEXT: [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
2346+
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[SQRT]], 0xH3C00
2347+
; CHECK-NEXT: ret i1 [[CMP]]
2348+
;
2349+
%sqrt = call half @llvm.sqrt.f16(half %x)
2350+
%cmp = fcmp ult half %sqrt, 1.000000e+00
2351+
ret i1 %cmp
2352+
}

llvm/test/Transforms/InstCombine/known-never-nan.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
define i1 @fabs_sqrt_src_maybe_nan(double %arg0, double %arg1) {
1111
; CHECK-LABEL: @fabs_sqrt_src_maybe_nan(
12-
; CHECK-NEXT: [[FABS:%.*]] = call double @llvm.fabs.f64(double [[ARG0:%.*]])
13-
; CHECK-NEXT: [[OP:%.*]] = call double @llvm.sqrt.f64(double [[FABS]])
14-
; CHECK-NEXT: [[TMP:%.*]] = fcmp ord double [[OP]], 0.000000e+00
12+
; CHECK-NEXT: [[TMP:%.*]] = fcmp ord double [[ARG0:%.*]], 0.000000e+00
1513
; CHECK-NEXT: ret i1 [[TMP]]
1614
;
1715
%fabs = call double @llvm.fabs.f64(double %arg0)

0 commit comments

Comments
 (0)