Skip to content

Commit 235f1f4

Browse files
committed
[KnownBits] Make nuw and nsw support in computeForAddSub optimal
Just some improvements that should hopefully strengthen analysis.
1 parent 92620a1 commit 235f1f4

File tree

8 files changed

+237
-61
lines changed

8 files changed

+237
-61
lines changed

llvm/lib/Support/KnownBits.cpp

Lines changed: 162 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,32 +54,178 @@ KnownBits KnownBits::computeForAddCarry(
5454
LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
5555
}
5656

57-
KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool /*NUW*/,
57+
KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
5858
const KnownBits &LHS, KnownBits RHS) {
5959
KnownBits KnownOut;
6060
if (Add) {
6161
// Sum = LHS + RHS + 0
62-
KnownOut = ::computeForAddCarry(
63-
LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
62+
KnownOut =
63+
::computeForAddCarry(LHS, RHS, /*CarryZero*/ true, /*CarryOne*/ false);
6464
} else {
6565
// Sum = LHS + ~RHS + 1
66-
std::swap(RHS.Zero, RHS.One);
67-
KnownOut = ::computeForAddCarry(
68-
LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
66+
KnownBits NotRHS = RHS;
67+
std::swap(NotRHS.Zero, NotRHS.One);
68+
KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
69+
/*CarryOne*/ true);
6970
}
71+
if (!NSW && !NUW)
72+
return KnownOut;
7073

71-
// Are we still trying to solve for the sign bit?
72-
if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
74+
auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
75+
const KnownBits &R, bool &OV) {
76+
APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
77+
APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
78+
79+
if (ForNSW) {
80+
LVal.clearSignBit();
81+
RVal.clearSignBit();
82+
}
83+
APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
84+
if (ForNSW) {
85+
OV = Res.isSignBitSet();
86+
Res.clearSignBit();
87+
if (Res.getBitWidth() > 1 && Res[Res.getBitWidth() - 2])
88+
Res.setSignBit();
89+
}
90+
return Res;
91+
};
92+
93+
auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
94+
const KnownBits &R, bool &OV) {
95+
return GetMinMaxVal(ForNSW, /*ForMax=*/true, L, R, OV);
96+
};
97+
98+
auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
99+
const KnownBits &R, bool &OV) {
100+
return GetMinMaxVal(ForNSW, /*ForMax=*/false, L, R, OV);
101+
};
102+
103+
std::optional<bool> Negative;
104+
bool Poison = false;
105+
// Handle add/sub given nsw and/or nuw.
106+
//
107+
// Possible TODO: Add/Sub implementations mirror one another in many ways.
108+
// They could probably be compressed into a single implementation of roughly
109+
// half the total LOC. Leaving seperate for now to increase clarity.
110+
// NB: We handle NSW by essentially treating as nuw of bitwidth - 1 then
111+
// deducing bits based on the known sign result.
112+
if (Add) {
73113
if (NSW) {
74-
// Adding two non-negative numbers, or subtracting a negative number from
75-
// a non-negative one, can't wrap into negative.
76-
if (LHS.isNonNegative() && RHS.isNonNegative())
77-
KnownOut.makeNonNegative();
78-
// Adding two negative numbers, or subtracting a non-negative number from
79-
// a negative one, can't wrap into non-negative.
80-
else if (LHS.isNegative() && RHS.isNegative())
81-
KnownOut.makeNegative();
114+
bool OverflowMax, OverflowMin;
115+
APInt MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
116+
APInt MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
117+
118+
if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
119+
// (add nuw) or (add nsw PosX, PosY)
120+
121+
// None of the adds can end up overflowing, so min consecutive highbits
122+
// in minimum possible of X + Y must all remain set.
123+
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
124+
125+
// NSW and Positive arguments leads to positive result.
126+
if (LHS.isNonNegative() && RHS.isNonNegative())
127+
Negative = false;
128+
else
129+
KnownOut.One.clearSignBit();
130+
131+
Poison = OverflowMin;
132+
} else if (LHS.isNegative() && RHS.isNegative()) {
133+
// (add nsw NegX, NegY)
134+
135+
// We need to re-overflow the signbit, so we are looking for sequence of
136+
// 0s from consecutive overflows.
137+
KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
138+
Negative = true;
139+
Poison = !OverflowMax;
140+
} else if (LHS.isNonNegative() || RHS.isNonNegative()) {
141+
// (add nsw PosX, ?Y)
142+
143+
// If the minimal possible of X + Y overflows the signbit, then Y must
144+
// have been signed (which will cause unsigned overflow otherwise nsw
145+
// will be violated) leading to unsigned result.
146+
if (OverflowMin)
147+
Negative = false;
148+
} else if (LHS.isNegative() || RHS.isNegative()) {
149+
// (add nsw NegX, ?Y)
150+
151+
// If the maximum possible of X + Y doesn't overflows the signbit, then
152+
// Y must have been unsigned (otherwise nsw violated) so NegX + PosY w.o
153+
// overflowing the signbit results in Negative.
154+
if (!OverflowMax)
155+
Negative = true;
156+
}
157+
}
158+
if (NUW) {
159+
// (add nuw X, Y)
160+
bool OverflowMax, OverflowMin;
161+
APInt MinVal = GetMinVal(/*ForNSW=*/false, LHS, RHS, OverflowMin);
162+
// Same as (add nsw PosX, PosY), basically since we can't overflow, the
163+
// high bits of minimum possible X + Y must remain set.
164+
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
165+
Poison = OverflowMin;
82166
}
167+
} else {
168+
if (NSW) {
169+
bool OverflowMax, OverflowMin;
170+
APInt MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
171+
APInt MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
172+
if (NUW || (LHS.isNegative() && RHS.isNonNegative())) {
173+
// (sub nuw) or (sub nsw NegX, PosY)
174+
175+
// None of the subs can overflow at any point, so any common high bits
176+
// will subtract away and result in zeros.
177+
KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
178+
if (LHS.isNegative() && RHS.isNonNegative())
179+
Negative = true;
180+
else
181+
KnownOut.Zero.clearSignBit();
182+
183+
Poison = OverflowMax;
184+
} else if (LHS.isNonNegative() && RHS.isNegative()) {
185+
// (sub nsw PosX, NegY)
186+
Negative = false;
187+
188+
// Opposite case of above, we must "re-overflow" the signbit, so minimal
189+
// set of high bits will be fixed.
190+
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
191+
Poison = !OverflowMin;
192+
} else if (LHS.isNegative() || RHS.isNonNegative()) {
193+
// (sub nsw NegX/?X, ?Y/PosY)
194+
if (OverflowMax)
195+
Negative = true;
196+
} else if (LHS.isNonNegative() || RHS.isNegative()) {
197+
// (sub nsw PosX/?X, ?Y/NegY)
198+
if (!OverflowMin)
199+
Negative = false;
200+
}
201+
}
202+
if (NUW) {
203+
// (sub nuw X, Y)
204+
bool OverflowMax, OverflowMin;
205+
APInt MaxVal = GetMaxVal(/*ForNSW=*/false, LHS, RHS, OverflowMax);
206+
207+
// Basically all common high bits between X/Y will cancel out as leading
208+
// zeros.
209+
KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
210+
Poison = OverflowMax;
211+
}
212+
}
213+
214+
// Handle any proven sign bit.
215+
if (Negative.has_value()) {
216+
KnownOut.One.clearSignBit();
217+
KnownOut.Zero.clearSignBit();
218+
219+
if (*Negative)
220+
KnownOut.makeNegative();
221+
else
222+
KnownOut.makeNonNegative();
223+
}
224+
225+
// Just return 0 if the nsw/nuw is violated and we have poison.
226+
if (Poison || KnownOut.hasConflict()) {
227+
KnownOut.setAllZero();
228+
return KnownOut;
83229
}
84230

85231
return KnownOut;

llvm/test/CodeGen/AArch64/sve-cmp-folds.ll

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,12 @@ define i1 @foo_last(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
114114
; CHECK-LABEL: foo_last:
115115
; CHECK: // %bb.0:
116116
; CHECK-NEXT: ptrue p0.s
117-
; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, z1.s
118-
; CHECK-NEXT: ptest p0, p1.b
119-
; CHECK-NEXT: cset w0, lo
117+
; CHECK-NEXT: mov x8, #-1 // =0xffffffffffffffff
118+
; CHECK-NEXT: whilels p1.s, xzr, x8
119+
; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, z1.s
120+
; CHECK-NEXT: mov z0.s, p0/z, #1 // =0x1
121+
; CHECK-NEXT: lastb w8, p1, z0.s
122+
; CHECK-NEXT: and w0, w8, #0x1
120123
; CHECK-NEXT: ret
121124
%vcond = fcmp oeq <vscale x 4 x float> %a, %b
122125
%vscale = call i64 @llvm.vscale.i64()

llvm/test/CodeGen/AArch64/sve-extract-element.ll

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,9 +614,11 @@ define i1 @test_lane9_8xi1(<vscale x 8 x i1> %a) #0 {
614614
define i1 @test_last_8xi1(<vscale x 8 x i1> %a) #0 {
615615
; CHECK-LABEL: test_last_8xi1:
616616
; CHECK: // %bb.0:
617-
; CHECK-NEXT: ptrue p1.h
618-
; CHECK-NEXT: ptest p1, p0.b
619-
; CHECK-NEXT: cset w0, lo
617+
; CHECK-NEXT: mov x8, #-1 // =0xffffffffffffffff
618+
; CHECK-NEXT: mov z0.h, p0/z, #1 // =0x1
619+
; CHECK-NEXT: whilels p1.h, xzr, x8
620+
; CHECK-NEXT: lastb w8, p1, z0.h
621+
; CHECK-NEXT: and w0, w8, #0x1
620622
; CHECK-NEXT: ret
621623
%vscale = call i64 @llvm.vscale.i64()
622624
%shl = shl nuw nsw i64 %vscale, 3

llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -137,49 +137,46 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
137137
; CI: ; %bb.0:
138138
; CI-NEXT: s_load_dword s0, s[0:1], 0x0
139139
; CI-NEXT: s_mov_b64 vcc, 0
140-
; CI-NEXT: v_not_b32_e32 v0, v0
141-
; CI-NEXT: v_lshlrev_b32_e32 v0, 2, v0
142-
; CI-NEXT: v_mov_b32_e32 v2, 0x7b
140+
; CI-NEXT: v_mov_b32_e32 v1, 0x7b
141+
; CI-NEXT: v_mov_b32_e32 v2, 0
142+
; CI-NEXT: s_mov_b32 m0, -1
143143
; CI-NEXT: s_waitcnt lgkmcnt(0)
144-
; CI-NEXT: v_mov_b32_e32 v1, s0
145-
; CI-NEXT: v_div_fmas_f32 v1, v1, v1, v1
144+
; CI-NEXT: v_mov_b32_e32 v0, s0
145+
; CI-NEXT: v_div_fmas_f32 v0, v0, v0, v0
146146
; CI-NEXT: s_mov_b32 s0, 0
147-
; CI-NEXT: s_mov_b32 m0, -1
148147
; CI-NEXT: s_mov_b32 s3, 0xf000
149148
; CI-NEXT: s_mov_b32 s2, -1
150149
; CI-NEXT: s_mov_b32 s1, s0
151-
; CI-NEXT: ds_write_b32 v0, v2 offset:65532
152-
; CI-NEXT: buffer_store_dword v1, off, s[0:3], 0
150+
; CI-NEXT: ds_write_b32 v2, v1
151+
; CI-NEXT: buffer_store_dword v0, off, s[0:3], 0
153152
; CI-NEXT: s_waitcnt vmcnt(0)
154153
; CI-NEXT: s_endpgm
155154
;
156155
; GFX9-LABEL: write_ds_sub_max_offset_global_clamp_bit:
157156
; GFX9: ; %bb.0:
158157
; GFX9-NEXT: s_load_dword s0, s[0:1], 0x0
159158
; GFX9-NEXT: s_mov_b64 vcc, 0
160-
; GFX9-NEXT: v_not_b32_e32 v0, v0
161-
; GFX9-NEXT: v_lshlrev_b32_e32 v3, 2, v0
162-
; GFX9-NEXT: v_mov_b32_e32 v4, 0x7b
159+
; GFX9-NEXT: v_mov_b32_e32 v3, 0x7b
160+
; GFX9-NEXT: v_mov_b32_e32 v4, 0
161+
; GFX9-NEXT: ds_write_b32 v4, v3
163162
; GFX9-NEXT: s_waitcnt lgkmcnt(0)
164-
; GFX9-NEXT: v_mov_b32_e32 v1, s0
165-
; GFX9-NEXT: v_div_fmas_f32 v2, v1, v1, v1
163+
; GFX9-NEXT: v_mov_b32_e32 v0, s0
164+
; GFX9-NEXT: v_div_fmas_f32 v2, v0, v0, v0
166165
; GFX9-NEXT: v_mov_b32_e32 v0, 0
167166
; GFX9-NEXT: v_mov_b32_e32 v1, 0
168-
; GFX9-NEXT: ds_write_b32 v3, v4 offset:65532
169167
; GFX9-NEXT: global_store_dword v[0:1], v2, off
170168
; GFX9-NEXT: s_waitcnt vmcnt(0)
171169
; GFX9-NEXT: s_endpgm
172170
;
173171
; GFX10-LABEL: write_ds_sub_max_offset_global_clamp_bit:
174172
; GFX10: ; %bb.0:
175173
; GFX10-NEXT: s_load_dword s0, s[0:1], 0x0
176-
; GFX10-NEXT: v_not_b32_e32 v0, v0
177174
; GFX10-NEXT: s_mov_b32 vcc_lo, 0
178-
; GFX10-NEXT: v_mov_b32_e32 v3, 0x7b
179-
; GFX10-NEXT: v_lshlrev_b32_e32 v2, 2, v0
180175
; GFX10-NEXT: v_mov_b32_e32 v0, 0
176+
; GFX10-NEXT: v_mov_b32_e32 v2, 0x7b
177+
; GFX10-NEXT: v_mov_b32_e32 v3, 0
181178
; GFX10-NEXT: v_mov_b32_e32 v1, 0
182-
; GFX10-NEXT: ds_write_b32 v2, v3 offset:65532
179+
; GFX10-NEXT: ds_write_b32 v3, v2
183180
; GFX10-NEXT: s_waitcnt lgkmcnt(0)
184181
; GFX10-NEXT: v_div_fmas_f32 v4, s0, s0, s0
185182
; GFX10-NEXT: global_store_dword v[0:1], v4, off
@@ -189,13 +186,11 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
189186
; GFX11-LABEL: write_ds_sub_max_offset_global_clamp_bit:
190187
; GFX11: ; %bb.0:
191188
; GFX11-NEXT: s_load_b32 s0, s[0:1], 0x0
192-
; GFX11-NEXT: v_not_b32_e32 v0, v0
193189
; GFX11-NEXT: s_mov_b32 vcc_lo, 0
194-
; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1)
195-
; GFX11-NEXT: v_dual_mov_b32 v3, 0x7b :: v_dual_lshlrev_b32 v2, 2, v0
196190
; GFX11-NEXT: v_mov_b32_e32 v0, 0
191+
; GFX11-NEXT: v_dual_mov_b32 v2, 0x7b :: v_dual_mov_b32 v3, 0
197192
; GFX11-NEXT: v_mov_b32_e32 v1, 0
198-
; GFX11-NEXT: ds_store_b32 v2, v3 offset:65532
193+
; GFX11-NEXT: ds_store_b32 v3, v2
199194
; GFX11-NEXT: s_waitcnt lgkmcnt(0)
200195
; GFX11-NEXT: v_div_fmas_f32 v4, s0, s0, s0
201196
; GFX11-NEXT: global_store_b32 v[0:1], v4, off dlc

llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ define i64 @log2_ceil_idiom_zext(i32 %x) {
4343
; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X]], -1
4444
; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.ctlz.i32(i32 [[TMP1]], i1 false), !range [[RNG0]]
4545
; CHECK-NEXT: [[TMP3:%.*]] = sub nuw nsw i32 32, [[TMP2]]
46-
; CHECK-NEXT: [[RET:%.*]] = zext i32 [[TMP3]] to i64
46+
; CHECK-NEXT: [[RET:%.*]] = zext nneg i32 [[TMP3]] to i64
4747
; CHECK-NEXT: ret i64 [[RET]]
4848
;
4949
%ctlz = tail call i32 @llvm.ctlz.i32(i32 %x, i1 true)

llvm/test/Transforms/InstCombine/icmp-sub.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ define i1 @test_nuw_nsw_and_unsigned_pred(i64 %x) {
3636

3737
define i1 @test_nuw_nsw_and_signed_pred(i64 %x) {
3838
; CHECK-LABEL: @test_nuw_nsw_and_signed_pred(
39-
; CHECK-NEXT: [[Z:%.*]] = icmp sgt i64 [[X:%.*]], 7
39+
; CHECK-NEXT: [[Z:%.*]] = icmp ugt i64 [[X:%.*]], 7
4040
; CHECK-NEXT: ret i1 [[Z]]
4141
;
4242
%y = sub nuw nsw i64 10, %x
@@ -46,8 +46,7 @@ define i1 @test_nuw_nsw_and_signed_pred(i64 %x) {
4646

4747
define i1 @test_negative_nuw_and_signed_pred(i64 %x) {
4848
; CHECK-LABEL: @test_negative_nuw_and_signed_pred(
49-
; CHECK-NEXT: [[NOTSUB:%.*]] = add nuw i64 [[X:%.*]], -11
50-
; CHECK-NEXT: [[Z:%.*]] = icmp sgt i64 [[NOTSUB]], -4
49+
; CHECK-NEXT: [[Z:%.*]] = icmp ugt i64 [[X:%.*]], 7
5150
; CHECK-NEXT: ret i1 [[Z]]
5251
;
5352
%y = sub nuw i64 10, %x

llvm/test/Transforms/InstCombine/sub.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2367,7 +2367,7 @@ define <2 x i8> @sub_to_and_vector3(<2 x i8> %x) {
23672367
; CHECK-LABEL: @sub_to_and_vector3(
23682368
; CHECK-NEXT: [[SUB:%.*]] = sub nuw <2 x i8> <i8 71, i8 71>, [[X:%.*]]
23692369
; CHECK-NEXT: [[AND:%.*]] = and <2 x i8> [[SUB]], <i8 120, i8 undef>
2370-
; CHECK-NEXT: [[R:%.*]] = sub <2 x i8> <i8 44, i8 44>, [[AND]]
2370+
; CHECK-NEXT: [[R:%.*]] = sub nsw <2 x i8> <i8 44, i8 44>, [[AND]]
23712371
; CHECK-NEXT: ret <2 x i8> [[R]]
23722372
;
23732373
%sub = sub nuw <2 x i8> <i8 71, i8 71>, %x

0 commit comments

Comments
 (0)