Skip to content

Commit 6368857

Browse files
authored
Merge pull request #6382 from natecook1000/nc-sort-median
[stdlib] Modify sort to pivot on median of 3
2 parents 2fe8ef8 + b582683 commit 6368857

File tree

2 files changed

+130
-24
lines changed

2 files changed

+130
-24
lines changed

stdlib/public/core/Sort.swift.gyb

Lines changed: 99 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -72,58 +72,131 @@ func _insertionSort<C>(
7272
}
7373
}
7474

75+
/// Sorts the elements at `elements[a]`, `elements[b]`, and `elements[c]`.
76+
/// Stable.
77+
///
78+
/// The indices passed as `a`, `b`, and `c` do not need to be consecutive, but
79+
/// must be in strict increasing order.
80+
///
81+
/// - Precondition: `a < b && b < c`
82+
/// - Postcondition: `elements[a] <= elements[b] && elements[b] <= elements[c]`
83+
public // @testable
84+
func _sort3<C>(
85+
_ elements: inout C,
86+
_ a: C.Index, _ b: C.Index, _ c: C.Index
87+
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""}
88+
)
89+
where
90+
C : MutableCollection & RandomAccessCollection
91+
${"" if p else ", C.Iterator.Element : Comparable"}
92+
{
93+
// There are thirteen possible permutations for the original ordering of
94+
// the elements at indices `a`, `b`, and `c`. The comments in the code below
95+
// show the relative ordering of the three elements using a three-digit
96+
// number as shorthand for the position and comparative relationship of
97+
// each element. For example, "312" indicates that the element at `a` is the
98+
// largest of the three, the element at `b` is the smallest, and the element
99+
// at `c` is the median. This hypothetical input array has a 312 ordering for
100+
// `a`, `b`, and `c`:
101+
//
102+
// [ 7, 4, 3, 9, 2, 0, 3, 7, 6, 5 ]
103+
// ^ ^ ^
104+
// a b c
105+
//
106+
// - If each of the three elements is distinct, they could be ordered as any
107+
// of the permutations of 1, 2, and 3: 123, 132, 213, 231, 312, or 321.
108+
// - If two elements are equivalent and one is distinct, they could be
109+
// ordered as any permutation of 1, 1, and 2 or 1, 2, and 2: 112, 121, 211,
110+
// 122, 212, or 221.
111+
// - If all three elements are equivalent, they are already in order: 111.
112+
113+
switch (${cmp("elements[b]", "elements[a]", p)},
114+
${cmp("elements[c]", "elements[b]", p)}) {
115+
case (false, false):
116+
// 0 swaps: 123, 112, 122, 111
117+
break
118+
119+
case (true, true):
120+
// 1 swap: 321
121+
// swap(a, c): 312->123
122+
swap(&elements[a], &elements[c])
123+
124+
case (true, false):
125+
// 1 swap: 213, 212 --- 2 swaps: 312, 211
126+
// swap(a, b): 213->123, 212->122, 312->132, 211->121
127+
swap(&elements[a], &elements[b])
128+
129+
if ${cmp("elements[c]", "elements[b]", p)} {
130+
// 132 (started as 312), 121 (started as 211)
131+
// swap(b, c): 132->123, 121->112
132+
swap(&elements[b], &elements[c])
133+
}
134+
135+
case (false, true):
136+
// 1 swap: 132, 121 --- 2 swaps: 231, 221
137+
// swap(b, c): 132->123, 121->112, 231->213, 221->212
138+
swap(&elements[b], &elements[c])
139+
140+
if ${cmp("elements[b]", "elements[a]", p)} {
141+
// 213 (started as 231), 212 (started as 221)
142+
// swap(a, b): 213->123, 212->122
143+
swap(&elements[a], &elements[b])
144+
}
145+
}
146+
}
147+
148+
/// Reorders `elements` and returns an index `p` such that every element in
149+
/// `elements[range.lowerBound..<p]` is less than every element in
150+
/// `elements[p..<range.upperBound]`.
151+
///
152+
/// - Precondition: The count of `range` must be >= 3:
153+
/// `elements.distance(from: range.lowerBound, to: range.upperBound) >= 3`
75154
func _partition<C>(
76155
_ elements: inout C,
77156
subRange range: Range<C.Index>
78157
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""}
79158
) -> C.Index
80159
where
81160
C : MutableCollection & RandomAccessCollection
82-
${"" if p else ", C.Iterator.Element : Comparable"} {
83-
161+
${"" if p else ", C.Iterator.Element : Comparable"}
162+
{
84163
var lo = range.lowerBound
85-
var hi = range.upperBound
86-
87-
if lo == hi {
88-
return lo
89-
}
164+
var hi = elements.index(before: range.upperBound)
90165

91-
// The first element is the pivot.
92-
let pivot = elements[range.lowerBound]
166+
// Sort the first, middle, and last elements, then use the middle value
167+
// as the pivot for the partition.
168+
let half = numericCast(elements.distance(from: lo, to: hi)) as UInt / 2
169+
let mid = elements.index(lo, offsetBy: numericCast(half))
170+
_sort3(&elements, lo, mid, hi
171+
${", by: areInIncreasingOrder" if p else ""})
172+
let pivot = elements[mid]
93173

94174
// Loop invariants:
95175
// * lo < hi
96-
// * elements[i] < pivot, for i in range.lowerBound+1..lo
97-
// * pivot <= elements[i] for i in hi..range.upperBound
98-
99-
Loop: while true {
100-
FindLo: repeat {
176+
// * elements[i] < pivot, for i in range.lowerBound..<lo
177+
// * pivot <= elements[i] for i in hi..<range.upperBound
178+
Loop: while true {
179+
FindLo: do {
101180
elements.formIndex(after: &lo)
102181
while lo != hi {
103182
if !${cmp("elements[lo]", "pivot", p)} { break FindLo }
104183
elements.formIndex(after: &lo)
105184
}
106185
break Loop
107-
} while false
186+
}
108187

109-
FindHi: repeat {
188+
FindHi: do {
110189
elements.formIndex(before: &hi)
111190
while hi != lo {
112191
if ${cmp("elements[hi]", "pivot", p)} { break FindHi }
113192
elements.formIndex(before: &hi)
114193
}
115194
break Loop
116-
} while false
195+
}
117196

118197
swap(&elements[lo], &elements[hi])
119198
}
120199

121-
elements.formIndex(before: &lo)
122-
if lo != range.lowerBound {
123-
// swap the pivot into place
124-
swap(&elements[lo], &elements[range.lowerBound])
125-
}
126-
127200
return lo
128201
}
129202

@@ -190,7 +263,7 @@ func _introSortImpl<C>(
190263
depthLimit: depthLimit &- 1)
191264
_introSortImpl(
192265
&elements,
193-
subRange: (elements.index(after: partIdx))..<range.upperBound,
266+
subRange: partIdx..<range.upperBound,
194267
${"by: areInIncreasingOrder, " if p else ""}
195268
depthLimit: depthLimit &- 1)
196269
}
@@ -234,6 +307,7 @@ func _siftDown<C>(
234307
${", by: areInIncreasingOrder" if p else ""})
235308
}
236309
}
310+
237311
func _heapify<C>(
238312
_ elements: inout C,
239313
subRange range: Range<C.Index>
@@ -262,6 +336,7 @@ func _heapify<C>(
262336
${", by: areInIncreasingOrder" if p else ""})
263337
}
264338
}
339+
265340
func _heapSort<C>(
266341
_ elements: inout C,
267342
subRange range: Range<C.Index>

validation-test/stdlib/Algorithm.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,5 +224,36 @@ Algorithm.test("sorted/return type") {
224224
let x: Array = ([5, 4, 3, 2, 1] as ArraySlice).sorted()
225225
}
226226

227+
Algorithm.test("sort3/simple")
228+
.forEach(in: [
229+
[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 1, 2], [3, 2, 1]
230+
]) {
231+
var input = $0
232+
_sort3(&input, 0, 1, 2)
233+
expectEqual([1, 2, 3], input)
234+
}
235+
236+
func isSorted<T>(_ a: [T], by areInIncreasingOrder: (T, T) -> Bool) -> Bool {
237+
return !a.dropFirst().enumerated().contains(where: { (offset, element) in
238+
areInIncreasingOrder(element, a[offset])
239+
})
240+
}
241+
242+
Algorithm.test("sort3/stable")
243+
.forEach(in: [
244+
[1, 1, 2], [1, 2, 1], [2, 1, 1], [1, 2, 2], [2, 1, 2], [2, 2, 1], [1, 1, 1]
245+
]) {
246+
// decorate with offset, but sort by value
247+
var input = Array($0.enumerated())
248+
_sort3(&input, 0, 1, 2) { $0.element < $1.element }
249+
// offsets should still be ordered for equal values
250+
expectTrue(isSorted(input) {
251+
if $0.element == $1.element {
252+
return $0.offset < $1.offset
253+
}
254+
return $0.element < $1.element
255+
})
256+
}
257+
227258
runAllTests()
228259

0 commit comments

Comments
 (0)