Skip to content

Commit 66f8a9b

Browse files
authored
Merge pull request #13007 from xwu/fused-multiply-add-stride
[stdlib] Eliminate intermediate rounding error in floating-point strides (and related gardening)
2 parents e86202e + 2a5b0d4 commit 66f8a9b

File tree

3 files changed

+80
-178
lines changed

3 files changed

+80
-178
lines changed

stdlib/public/core/Range.swift.gyb

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,6 @@ extension RangeExpression {
8080
}
8181
}
8282

83-
// FIXME(ABI)#55 (Statically Unavailable/Dynamically Available): remove this
84-
// type, it creates an ABI burden on the library.
85-
//
86-
// A dummy type that we can use when we /don't/ want to create an
87-
// ambiguity indexing CountableRange<T> outside a generic context.
88-
public enum _DisabledRangeIndex_ {}
89-
9083
/// A half-open range that forms a collection of consecutive values.
9184
///
9285
/// You create a `CountableRange` instance by using the half-open range

stdlib/public/core/Stride.swift.gyb

Lines changed: 68 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public protocol Strideable : Comparable {
4343
from start: Self, by distance: Self.Stride
4444
) -> (index: Int?, value: Self)
4545

46-
associatedtype _DisabledRangeIndex = _DisabledRangeIndex_
46+
associatedtype _DisabledRangeIndex = Never
4747
}
4848

4949
extension Strideable {
@@ -137,13 +137,27 @@ extension Strideable where Stride : FloatingPoint {
137137
from start: Self, by distance: Self.Stride
138138
) -> (index: Int?, value: Self) {
139139
if let i = current.index {
140+
// When Stride is a floating-point type, we should avoid accumulating
141+
// rounding error from repeated addition.
140142
return (i + 1, start.advanced(by: Stride(i + 1) * distance))
141143
}
142-
// If current.index == nil, either we're just starting out (in which case
143-
// the next index is 1), or we should proceed without an index just as
144-
// though this floating point specialization doesn't exist.
145-
return (current.value == start ? 1 : nil,
146-
current.value.advanced(by: distance))
144+
return (nil, current.value.advanced(by: distance))
145+
}
146+
}
147+
148+
extension Strideable where Self : FloatingPoint, Self == Stride {
149+
@_inlineable
150+
public static func _step(
151+
after current: (index: Int?, value: Self),
152+
from start: Self, by distance: Self.Stride
153+
) -> (index: Int?, value: Self) {
154+
if let i = current.index {
155+
// When both Self and Stride are the same floating-point type, we should
156+
// take advantage of fused multiply-add (where supported) to eliminate
157+
// intermediate rounding error.
158+
return (i + 1, start.addingProduct(Stride(i + 1), distance))
159+
}
160+
return (nil, current.value.advanced(by: distance))
147161
}
148162
}
149163

@@ -168,7 +182,7 @@ public struct StrideToIterator<Element : Strideable> : IteratorProtocol {
168182
self._start = _start
169183
_end = end
170184
_stride = stride
171-
_current = (nil, _start)
185+
_current = (0, _start)
172186
}
173187

174188
/// Advances to the next element and returns it, or `nil` if no next element
@@ -255,111 +269,52 @@ public struct StrideTo<Element : Strideable> : Sequence, CustomReflectable {
255269
}
256270
}
257271

258-
// FIXME(conditional-conformances): these extra types can easily be turned into
259-
// conditional extensions to StrideTo type
260-
% for Self, ElementConstraint, Where in [
261-
% ('IntegerStrideToCollection', 'BinaryInteger', 'Element.Stride : BinaryInteger'),
262-
% ('FloatingPointStrideToCollection', 'BinaryFloatingPoint', 'Element.Stride == Element'),
263-
% ]:
264-
% ElementIsInteger = ElementConstraint == 'BinaryInteger'
265-
266-
internal struct ${Self}<
267-
Element : ${ElementConstraint}
268-
> : RandomAccessCollection, CustomReflectable
269-
where ${Where} {
270-
271-
//===----------------------------------------------------------------------===//
272-
// This block is copied from StrideTo struct definition //
273-
//===----------------------------------------------------------------------===//
274-
@_inlineable
275-
public func makeIterator() -> StrideToIterator<Element> {
276-
return StrideToIterator(_start: _start, end: _end, stride: _stride)
277-
}
278-
279-
@_inlineable
280-
public func _customContainsEquatableElement(
281-
_ element: Element
282-
) -> Bool? {
283-
if element < _start || _end <= element {
284-
return false
285-
}
286-
return nil
287-
}
288-
289-
@_inlineable
290-
@_versioned
291-
internal init(_start: Element, end: Element, stride: Element.Stride) {
292-
_precondition(stride != 0, "Stride size must not be zero")
293-
// At start, striding away from end is allowed; it just makes for an
294-
// already-empty Sequence.
295-
self._start = _start
296-
self._end = end
297-
self._stride = stride
298-
}
299-
300-
@_versioned
301-
internal let _start: Element
302-
303-
@_versioned
304-
internal let _end: Element
305-
306-
@_versioned
307-
internal let _stride: Element.Stride
308-
309-
@_inlineable // FIXME(sil-serialize-all)
310-
public var customMirror: Mirror {
311-
return Mirror(self, children: ["from": _start, "to": _end, "by": _stride])
312-
}
313-
//===----------------------------------------------------------------------===//
314-
// The end of the copied block
315-
//===----------------------------------------------------------------------===//
316-
317-
// RandomAccessCollection conformance
272+
// FIXME(conditional-conformances): This does not yet compile (SR-6474).
273+
#if false
274+
extension StrideTo : RandomAccessCollection
275+
where Element.Stride : BinaryInteger {
318276
public typealias Index = Int
319-
public typealias SubSequence = RandomAccessSlice<${Self}>
277+
public typealias SubSequence = RandomAccessSlice<StrideTo<Element>>
320278
public typealias Indices = CountableRange<Int>
321279

280+
@_inlineable
322281
public var startIndex: Index { return 0 }
282+
283+
@_inlineable
323284
public var endIndex: Index { return count }
324285

286+
@_inlineable
325287
public var count: Int {
326-
let (start, end, stride) =
327-
(_stride > 0) ? (_start, _end, _stride) : (_end, _start, -_stride)
328-
% if ElementIsInteger:
329-
return Int((start.distance(to: end) - 1) / stride) + 1
330-
% else:
331-
let nonExactCount = (start.distance(to: end)) / stride
332-
return Int(nonExactCount.rounded(.toNearestOrAwayFromZero))
333-
% end
288+
let distance = _start.distance(to: _end)
289+
guard distance != 0 && (distance < 0) == (_stride < 0) else { return 0 }
290+
return Int((distance - 1) / _stride) + 1
334291
}
335292

336293
public subscript(position: Index) -> Element {
337-
_failEarlyRangeCheck(position, bounds: startIndex ..< endIndex)
338-
return _indexToElement(position)
294+
_failEarlyRangeCheck(position, bounds: startIndex..<endIndex)
295+
return _start.advanced(by: Element.Stride(position) * _stride)
339296
}
340297

341-
public subscript(bounds: Range<Index>) -> RandomAccessSlice<${Self}> {
342-
_failEarlyRangeCheck(bounds, bounds: startIndex ..< endIndex)
298+
public subscript(
299+
bounds: Range<Index>
300+
) -> RandomAccessSlice<StrideTo<Element>> {
301+
_failEarlyRangeCheck(bounds, bounds: startIndex..<endIndex)
343302
return RandomAccessSlice(base: self, bounds: bounds)
344303
}
345304

346-
public func index(after i: Index) -> Index {
347-
_failEarlyRangeCheck(i, bounds: startIndex-1 ..< endIndex)
348-
return i+1
349-
}
350-
305+
@_inlineable
351306
public func index(before i: Index) -> Index {
352-
_failEarlyRangeCheck(i, bounds: startIndex+1 ... endIndex)
353-
return i-1
307+
_failEarlyRangeCheck(i, bounds: startIndex + 1...endIndex)
308+
return i - 1
354309
}
355310

356-
@inline(__always)
357-
internal func _indexToElement(_ i: Index) -> Element {
358-
return _start.advanced(by: Element.Stride(i) * _stride)
311+
@_inlineable
312+
public func index(after i: Index) -> Index {
313+
_failEarlyRangeCheck(i, bounds: startIndex - 1..<endIndex)
314+
return i + 1
359315
}
360316
}
361-
362-
% end
317+
#endif
363318

364319
/// Returns the sequence of values (`self`, `self + stride`, `self +
365320
/// 2 * stride`, ... *last*) where *last* is the last value in the
@@ -395,7 +350,7 @@ public struct StrideThroughIterator<Element : Strideable> : IteratorProtocol {
395350
self._start = _start
396351
_end = end
397352
_stride = stride
398-
_current = (nil, _start)
353+
_current = (0, _start)
399354
}
400355

401356
/// Advances to the next element and returns it, or `nil` if no next element
@@ -487,93 +442,41 @@ public struct StrideThrough<
487442
}
488443
}
489444

490-
// FIXME(conditional-conformances): these extra types can easily be turned into
491-
// conditional extensions to StrideThrough type
492-
% for Self, ElementConstraint, Where in [
493-
% ('IntegerStrideThroughCollection', 'BinaryInteger', 'Element.Stride : BinaryInteger'),
494-
% ('FloatingPointStrideThroughCollection', 'BinaryFloatingPoint', 'Element.Stride == Element'),
495-
% ]:
496-
% ElementIsInteger = ElementConstraint == 'BinaryInteger'
497-
498-
internal struct ${Self}<
499-
Element : ${ElementConstraint}
500-
> : RandomAccessCollection, CustomReflectable
501-
where ${Where} {
502-
503-
//===----------------------------------------------------------------------===//
504-
// This block is copied from StrideThrough struct definition //
505-
//===----------------------------------------------------------------------===//
506-
/// Returns an iterator over the elements of this sequence.
507-
///
508-
/// - Complexity: O(1).
509-
@_inlineable
510-
public func makeIterator() -> StrideThroughIterator<Element> {
511-
return StrideThroughIterator(_start: _start, end: _end, stride: _stride)
512-
}
513-
514-
@_inlineable
515-
public func _customContainsEquatableElement(
516-
_ element: Element
517-
) -> Bool? {
518-
if element < _start || _end < element {
519-
return false
520-
}
521-
return nil
522-
}
523-
524-
@_inlineable
525-
@_versioned
526-
internal init(_start: Element, end: Element, stride: Element.Stride) {
527-
_precondition(stride != 0, "Stride size must not be zero")
528-
self._start = _start
529-
self._end = end
530-
self._stride = stride
531-
}
532-
533-
@_versioned
534-
internal let _start: Element
535-
@_versioned
536-
internal let _end: Element
537-
@_versioned
538-
internal let _stride: Element.Stride
539-
540-
@_inlineable // FIXME(sil-serialize-all)
541-
public var customMirror: Mirror {
542-
return Mirror(self,
543-
children: ["from": _start, "through": _end, "by": _stride])
544-
}
545-
//===----------------------------------------------------------------------===//
546-
// The end of the copied block
547-
//===----------------------------------------------------------------------===//
548-
549-
// RandomAccessCollection conformance
445+
// FIXME(conditional-conformances): This does not yet compile (SR-6474).
446+
#if false
447+
extension StrideThrough : RandomAccessCollection
448+
where Element.Stride : BinaryInteger {
550449
public typealias Index = ClosedRangeIndex<Int>
551450
public typealias IndexDistance = Int
552-
public typealias SubSequence = RandomAccessSlice<${Self}>
451+
public typealias SubSequence = RandomAccessSlice<StrideThrough<Element>>
553452

554453
@_inlineable
555-
public var startIndex: Index { return ClosedRangeIndex(0) }
454+
public var startIndex: Index {
455+
let distance = _start.distance(to: _end)
456+
return distance == 0 || (distance < 0) == (_stride < 0)
457+
? ClosedRangeIndex(0)
458+
: ClosedRangeIndex()
459+
}
460+
556461
@_inlineable
557462
public var endIndex: Index { return ClosedRangeIndex() }
558463

559464
@_inlineable
560465
public var count: Int {
561-
let (start, end, stride) =
562-
(_stride > 0) ? (_start, _end, _stride) : (_end, _start, -_stride)
563-
% if ElementIsInteger:
564-
return Int(start.distance(to: end) / stride) + 1
565-
% else:
566-
let nonExactCount = start.distance(to: end) / stride
567-
return Int(nonExactCount.rounded(.toNearestOrAwayFromZero)) + 1
568-
% end
466+
let distance = _start.distance(to: _end)
467+
guard distance != 0 else { return 1 }
468+
guard (distance < 0) == (_stride < 0) else { return 0 }
469+
return Int(distance / _stride) + 1
569470
}
570471

571472
public subscript(position: Index) -> Element {
572473
let offset = Element.Stride(position._dereferenced) * _stride
573474
return _start.advanced(by: offset)
574475
}
575476

576-
public subscript(bounds: Range<Index>) -> RandomAccessSlice<${Self}> {
477+
public subscript(
478+
bounds: Range<Index>
479+
) -> RandomAccessSlice<StrideThrough<Element>> {
577480
return RandomAccessSlice(base: self, bounds: bounds)
578481
}
579482

@@ -601,8 +504,7 @@ where ${Where} {
601504
}
602505
}
603506
}
604-
605-
% end
507+
#endif
606508

607509
/// Returns the sequence of values (`self`, `self + stride`, `self +
608510
/// 2 * stride`, ... *last*) where *last* is the last value in the
@@ -615,4 +517,3 @@ public func stride<T>(
615517
) -> StrideThrough<T> {
616518
return StrideThrough(_start: start, end: end, stride: stride)
617519
}
618-

test/stdlib/Strideable.swift

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,20 @@ StrideTestSuite.test("FloatingPointStride") {
188188
expectEqual([ 1.4, 2.4, 3.4 ], result)
189189
}
190190

191-
StrideTestSuite.test("ErrorAccumulation") {
192-
let a = Array(stride(from: Float(1.0), through: Float(2.0), by: Float(0.1)))
191+
StrideTestSuite.test("FloatingPointStride/rounding error") {
192+
// Ensure that there is no error accumulation
193+
let a = Array(stride(from: 1 as Float, through: 2, by: 0.1))
193194
expectEqual(11, a.count)
194-
expectEqual(Float(2.0), a.last)
195-
let b = Array(stride(from: Float(1.0), to: Float(10.0), by: Float(0.9)))
195+
expectEqual(2 as Float, a.last)
196+
let b = Array(stride(from: 1 as Float, to: 10, by: 0.9))
196197
expectEqual(10, b.count)
198+
199+
// Ensure that there is no intermediate rounding error on supported platforms
200+
if (-0.2).addingProduct(0.2, 6) == 1 {
201+
let c = Array(stride(from: -0.2, through: 1, by: 0.2))
202+
expectEqual(7, c.count)
203+
expectEqual(1 as Double, c.last)
204+
}
197205
}
198206

199207
func strideIteratorTest<

0 commit comments

Comments
 (0)