Skip to content

Commit 7e5c78c

Browse files
committed
[stdlib] DoubleWidth formatting and bit shift audit
1 parent e3e36f7 commit 7e5c78c

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

stdlib/public/core/DoubleWidth.swift.gyb

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ public struct DoubleWidth<Base : FixedWidthInteger> :
318318
let initialOffset = q.leadingZeroBitCount +
319319
(DoubleWidth.bitWidth - rhs.leadingZeroBitCount) - 1
320320

321-
// TODO(performance): Use &>> instead here?
322321
// Start with remainder capturing the high bits of q.
322+
// (These need to be smart shifts, as initialOffset can be > q.bitWidth)
323323
var r = q >> Magnitude(DoubleWidth.bitWidth - initialOffset)
324324
q <<= Magnitude(initialOffset)
325325

@@ -401,8 +401,9 @@ public struct DoubleWidth<Base : FixedWidthInteger> :
401401
let mid2 = sum(b.carry, c.carry, d.partial)
402402

403403
let low = DoubleWidth<Low>((mid1.partial, a.partial))
404-
let high = DoubleWidth(
405-
(High(mid2.carry + d.carry), mid1.carry + mid2.partial))
404+
let high = DoubleWidth((
405+
High(mid2.carry + d.carry), mid1.carry + mid2.partial
406+
))
406407

407408
if isNegative {
408409
let (lowComplement, overflow) = (~low).addingReportingOverflow(1)
@@ -438,6 +439,7 @@ public struct DoubleWidth<Base : FixedWidthInteger> :
438439
return
439440
}
440441

442+
// Shift is larger than this type's bit width.
441443
if rhs._storage.high != (0 as High) ||
442444
rhs._storage.low >= DoubleWidth.bitWidth
443445
{
@@ -481,34 +483,30 @@ public struct DoubleWidth<Base : FixedWidthInteger> :
481483
}
482484

483485
public static func &<<=(lhs: inout DoubleWidth, rhs: DoubleWidth) {
486+
// Need to use smart shifts here, since rhs can be > Base.bitWidth
484487
let rhs = rhs & DoubleWidth(DoubleWidth.bitWidth - 1)
485488

486489
lhs._storage.high <<= High(rhs._storage.low)
487-
if Base.bitWidth > rhs._storage.low {
488-
let h = lhs._storage.low >>
489-
(numericCast(Base.bitWidth) - rhs._storage.low)
490-
lhs._storage.high |= High(extendingOrTruncating: h)
491-
} else {
492-
let h = lhs._storage.low <<
493-
(rhs._storage.low - numericCast(Base.bitWidth))
494-
lhs._storage.high |= High(extendingOrTruncating: h)
495-
}
490+
491+
let t = Base.bitWidth > rhs._storage.low
492+
? lhs._storage.low >> (numericCast(Base.bitWidth) - rhs._storage.low)
493+
: lhs._storage.low << (rhs._storage.low - numericCast(Base.bitWidth))
494+
lhs._storage.high |= High(extendingOrTruncating: t)
495+
496496
lhs._storage.low <<= rhs._storage.low
497497
}
498498

499499
public static func &>>=(lhs: inout DoubleWidth, rhs: DoubleWidth) {
500+
// Need to use smart shifts here, since rhs can be > Base.bitWidth
500501
let rhs = rhs & DoubleWidth(DoubleWidth.bitWidth - 1)
501502

502503
lhs._storage.low >>= rhs._storage.low
503-
if Base.bitWidth > rhs._storage.low {
504-
let l = lhs._storage.high <<
505-
numericCast(numericCast(Base.bitWidth) - rhs._storage.low)
506-
lhs._storage.low |= Low(extendingOrTruncating: l)
507-
} else {
508-
let l = lhs._storage.high >>
509-
numericCast(rhs._storage.low - numericCast(Base.bitWidth))
510-
lhs._storage.low |= Low(extendingOrTruncating: l)
511-
}
504+
505+
let t = Base.bitWidth > rhs._storage.low
506+
? lhs._storage.high << (numericCast(Base.bitWidth) - rhs._storage.low)
507+
: lhs._storage.high >> (rhs._storage.low - numericCast(Base.bitWidth))
508+
lhs._storage.low |= Low(extendingOrTruncating: t)
509+
512510
lhs._storage.high >>= High(extendingOrTruncating: rhs._storage.low)
513511
}
514512

0 commit comments

Comments
 (0)