@@ -54,32 +54,178 @@ KnownBits KnownBits::computeForAddCarry(
54
54
LHS, RHS, Carry.Zero .getBoolValue (), Carry.One .getBoolValue ());
55
55
}
56
56
57
- KnownBits KnownBits::computeForAddSub (bool Add, bool NSW, bool /* NUW*/ ,
57
+ KnownBits KnownBits::computeForAddSub (bool Add, bool NSW, bool NUW,
58
58
const KnownBits &LHS, KnownBits RHS) {
59
59
KnownBits KnownOut;
60
60
if (Add) {
61
61
// Sum = LHS + RHS + 0
62
- KnownOut = :: computeForAddCarry (
63
- LHS, RHS, /* CarryZero*/ true , /* CarryOne*/ false );
62
+ KnownOut =
63
+ ::computeForAddCarry ( LHS, RHS, /* CarryZero*/ true , /* CarryOne*/ false );
64
64
} else {
65
65
// Sum = LHS + ~RHS + 1
66
- std::swap (RHS.Zero , RHS.One );
67
- KnownOut = ::computeForAddCarry (
68
- LHS, RHS, /* CarryZero*/ false , /* CarryOne*/ true );
66
+ KnownBits NotRHS = RHS;
67
+ std::swap (NotRHS.Zero , NotRHS.One );
68
+ KnownOut = ::computeForAddCarry (LHS, NotRHS, /* CarryZero*/ false ,
69
+ /* CarryOne*/ true );
69
70
}
71
+ if (!NSW && !NUW)
72
+ return KnownOut;
70
73
71
- // Are we still trying to solve for the sign bit?
72
- if (!KnownOut.isNegative () && !KnownOut.isNonNegative ()) {
74
+ auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
75
+ const KnownBits &R, bool &OV) {
76
+ APInt LVal = ForMax ? L.getMaxValue () : L.getMinValue ();
77
+ APInt RVal = Add == ForMax ? R.getMaxValue () : R.getMinValue ();
78
+
79
+ if (ForNSW) {
80
+ LVal.clearSignBit ();
81
+ RVal.clearSignBit ();
82
+ }
83
+ APInt Res = Add ? LVal.uadd_ov (RVal, OV) : LVal.usub_ov (RVal, OV);
84
+ if (ForNSW) {
85
+ OV = Res.isSignBitSet ();
86
+ Res.clearSignBit ();
87
+ if (Res.getBitWidth () > 1 && Res[Res.getBitWidth () - 2 ])
88
+ Res.setSignBit ();
89
+ }
90
+ return Res;
91
+ };
92
+
93
+ auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
94
+ const KnownBits &R, bool &OV) {
95
+ return GetMinMaxVal (ForNSW, /* ForMax=*/ true , L, R, OV);
96
+ };
97
+
98
+ auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
99
+ const KnownBits &R, bool &OV) {
100
+ return GetMinMaxVal (ForNSW, /* ForMax=*/ false , L, R, OV);
101
+ };
102
+
103
+ std::optional<bool > Negative;
104
+ bool Poison = false ;
105
+ // Handle add/sub given nsw and/or nuw.
106
+ //
107
+ // Possible TODO: Add/Sub implementations mirror one another in many ways.
108
+ // They could probably be compressed into a single implementation of roughly
109
+ // half the total LOC. Leaving seperate for now to increase clarity.
110
+ // NB: We handle NSW by essentially treating as nuw of bitwidth - 1 then
111
+ // deducing bits based on the known sign result.
112
+ if (Add) {
73
113
if (NSW) {
74
- // Adding two non-negative numbers, or subtracting a negative number from
75
- // a non-negative one, can't wrap into negative.
76
- if (LHS.isNonNegative () && RHS.isNonNegative ())
77
- KnownOut.makeNonNegative ();
78
- // Adding two negative numbers, or subtracting a non-negative number from
79
- // a negative one, can't wrap into non-negative.
80
- else if (LHS.isNegative () && RHS.isNegative ())
81
- KnownOut.makeNegative ();
114
+ bool OverflowMax, OverflowMin;
115
+ APInt MaxVal = GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
116
+ APInt MinVal = GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
117
+
118
+ if (NUW || (LHS.isNonNegative () && RHS.isNonNegative ())) {
119
+ // (add nuw) or (add nsw PosX, PosY)
120
+
121
+ // None of the adds can end up overflowing, so min consecutive highbits
122
+ // in minimum possible of X + Y must all remain set.
123
+ KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
124
+
125
+ // NSW and Positive arguments leads to positive result.
126
+ if (LHS.isNonNegative () && RHS.isNonNegative ())
127
+ Negative = false ;
128
+ else
129
+ KnownOut.One .clearSignBit ();
130
+
131
+ Poison = OverflowMin;
132
+ } else if (LHS.isNegative () && RHS.isNegative ()) {
133
+ // (add nsw NegX, NegY)
134
+
135
+ // We need to re-overflow the signbit, so we are looking for sequence of
136
+ // 0s from consecutive overflows.
137
+ KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
138
+ Negative = true ;
139
+ Poison = !OverflowMax;
140
+ } else if (LHS.isNonNegative () || RHS.isNonNegative ()) {
141
+ // (add nsw PosX, ?Y)
142
+
143
+ // If the minimal possible of X + Y overflows the signbit, then Y must
144
+ // have been signed (which will cause unsigned overflow otherwise nsw
145
+ // will be violated) leading to unsigned result.
146
+ if (OverflowMin)
147
+ Negative = false ;
148
+ } else if (LHS.isNegative () || RHS.isNegative ()) {
149
+ // (add nsw NegX, ?Y)
150
+
151
+ // If the maximum possible of X + Y doesn't overflows the signbit, then
152
+ // Y must have been unsigned (otherwise nsw violated) so NegX + PosY w.o
153
+ // overflowing the signbit results in Negative.
154
+ if (!OverflowMax)
155
+ Negative = true ;
156
+ }
157
+ }
158
+ if (NUW) {
159
+ // (add nuw X, Y)
160
+ bool OverflowMax, OverflowMin;
161
+ APInt MinVal = GetMinVal (/* ForNSW=*/ false , LHS, RHS, OverflowMin);
162
+ // Same as (add nsw PosX, PosY), basically since we can't overflow, the
163
+ // high bits of minimum possible X + Y must remain set.
164
+ KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
165
+ Poison = OverflowMin;
82
166
}
167
+ } else {
168
+ if (NSW) {
169
+ bool OverflowMax, OverflowMin;
170
+ APInt MaxVal = GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
171
+ APInt MinVal = GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
172
+ if (NUW || (LHS.isNegative () && RHS.isNonNegative ())) {
173
+ // (sub nuw) or (sub nsw NegX, PosY)
174
+
175
+ // None of the subs can overflow at any point, so any common high bits
176
+ // will subtract away and result in zeros.
177
+ KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
178
+ if (LHS.isNegative () && RHS.isNonNegative ())
179
+ Negative = true ;
180
+ else
181
+ KnownOut.Zero .clearSignBit ();
182
+
183
+ Poison = OverflowMax;
184
+ } else if (LHS.isNonNegative () && RHS.isNegative ()) {
185
+ // (sub nsw PosX, NegY)
186
+ Negative = false ;
187
+
188
+ // Opposite case of above, we must "re-overflow" the signbit, so minimal
189
+ // set of high bits will be fixed.
190
+ KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
191
+ Poison = !OverflowMin;
192
+ } else if (LHS.isNegative () || RHS.isNonNegative ()) {
193
+ // (sub nsw NegX/?X, ?Y/PosY)
194
+ if (OverflowMax)
195
+ Negative = true ;
196
+ } else if (LHS.isNonNegative () || RHS.isNegative ()) {
197
+ // (sub nsw PosX/?X, ?Y/NegY)
198
+ if (!OverflowMin)
199
+ Negative = false ;
200
+ }
201
+ }
202
+ if (NUW) {
203
+ // (sub nuw X, Y)
204
+ bool OverflowMax, OverflowMin;
205
+ APInt MaxVal = GetMaxVal (/* ForNSW=*/ false , LHS, RHS, OverflowMax);
206
+
207
+ // Basically all common high bits between X/Y will cancel out as leading
208
+ // zeros.
209
+ KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
210
+ Poison = OverflowMax;
211
+ }
212
+ }
213
+
214
+ // Handle any proven sign bit.
215
+ if (Negative.has_value ()) {
216
+ KnownOut.One .clearSignBit ();
217
+ KnownOut.Zero .clearSignBit ();
218
+
219
+ if (*Negative)
220
+ KnownOut.makeNegative ();
221
+ else
222
+ KnownOut.makeNonNegative ();
223
+ }
224
+
225
+ // Just return 0 if the nsw/nuw is violated and we have poison.
226
+ if (Poison || KnownOut.hasConflict ()) {
227
+ KnownOut.setAllZero ();
228
+ return KnownOut;
83
229
}
84
230
85
231
return KnownOut;
0 commit comments