@@ -318,8 +318,8 @@ public struct DoubleWidth<Base : FixedWidthInteger> :
318
318
let initialOffset = q. leadingZeroBitCount +
319
319
( DoubleWidth . bitWidth - rhs. leadingZeroBitCount) - 1
320
320
321
- // TODO(performance): Use &>> instead here?
322
321
// Start with remainder capturing the high bits of q.
322
+ // (These need to be smart shifts, as initialOffset can be > q.bitWidth)
323
323
var r = q >> Magnitude ( DoubleWidth . bitWidth - initialOffset)
324
324
q <<= Magnitude ( initialOffset)
325
325
@@ -401,8 +401,9 @@ public struct DoubleWidth<Base : FixedWidthInteger> :
401
401
let mid2 = sum ( b. carry, c. carry, d. partial)
402
402
403
403
let low = DoubleWidth < Low > ( ( mid1. partial, a. partial) )
404
- let high = DoubleWidth (
405
- ( High ( mid2. carry + d. carry) , mid1. carry + mid2. partial) )
404
+ let high = DoubleWidth ( (
405
+ High ( mid2. carry + d. carry) , mid1. carry + mid2. partial
406
+ ) )
406
407
407
408
if isNegative {
408
409
let ( lowComplement, overflow) = ( ~ low) . addingReportingOverflow ( 1 )
@@ -438,6 +439,7 @@ public struct DoubleWidth<Base : FixedWidthInteger> :
438
439
return
439
440
}
440
441
442
+ // Shift is larger than this type's bit width.
441
443
if rhs. _storage. high != ( 0 as High ) ||
442
444
rhs. _storage. low >= DoubleWidth . bitWidth
443
445
{
@@ -481,34 +483,30 @@ public struct DoubleWidth<Base : FixedWidthInteger> :
481
483
}
482
484
483
485
public static func &<<= ( lhs: inout DoubleWidth , rhs: DoubleWidth ) {
486
+ // Need to use smart shifts here, since rhs can be > Base.bitWidth
484
487
let rhs = rhs & DoubleWidth ( DoubleWidth . bitWidth - 1 )
485
488
486
489
lhs. _storage. high <<= High ( rhs. _storage. low)
487
- if Base . bitWidth > rhs. _storage. low {
488
- let h = lhs. _storage. low >>
489
- ( numericCast ( Base . bitWidth) - rhs. _storage. low)
490
- lhs. _storage. high |= High ( extendingOrTruncating: h)
491
- } else {
492
- let h = lhs. _storage. low <<
493
- ( rhs. _storage. low - numericCast( Base . bitWidth) )
494
- lhs. _storage. high |= High ( extendingOrTruncating: h)
495
- }
490
+
491
+ let t = Base . bitWidth > rhs. _storage. low
492
+ ? lhs. _storage. low >> ( numericCast ( Base . bitWidth) - rhs. _storage. low)
493
+ : lhs. _storage. low << ( rhs. _storage. low - numericCast( Base . bitWidth) )
494
+ lhs. _storage. high |= High ( extendingOrTruncating: t)
495
+
496
496
lhs. _storage. low <<= rhs. _storage. low
497
497
}
498
498
499
499
public static func &>>= ( lhs: inout DoubleWidth , rhs: DoubleWidth ) {
500
+ // Need to use smart shifts here, since rhs can be > Base.bitWidth
500
501
let rhs = rhs & DoubleWidth ( DoubleWidth . bitWidth - 1 )
501
502
502
503
lhs. _storage. low >>= rhs. _storage. low
503
- if Base . bitWidth > rhs. _storage. low {
504
- let l = lhs. _storage. high <<
505
- numericCast ( numericCast ( Base . bitWidth) - rhs. _storage. low)
506
- lhs. _storage. low |= Low ( extendingOrTruncating: l)
507
- } else {
508
- let l = lhs. _storage. high >>
509
- numericCast ( rhs. _storage. low - numericCast( Base . bitWidth) )
510
- lhs. _storage. low |= Low ( extendingOrTruncating: l)
511
- }
504
+
505
+ let t = Base . bitWidth > rhs. _storage. low
506
+ ? lhs. _storage. high << ( numericCast ( Base . bitWidth) - rhs. _storage. low)
507
+ : lhs. _storage. high >> ( rhs. _storage. low - numericCast( Base . bitWidth) )
508
+ lhs. _storage. low |= Low ( extendingOrTruncating: t)
509
+
512
510
lhs. _storage. high >>= High ( extendingOrTruncating: rhs. _storage. low)
513
511
}
514
512
0 commit comments