Skip to content

Refactor Heap Sort Implementation #705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 85 additions & 119 deletions src/sorting/heap_sort.rs
Original file line number Diff line number Diff line change
@@ -1,147 +1,113 @@
/// Sort a mutable slice using heap sort.
///
/// Heap sort is an in-place O(n log n) sorting algorithm. It is based on a
/// max heap, a binary tree data structure whose main feature is that
/// parent nodes are always greater or equal to their child nodes.
///
/// # Max Heap Implementation
//! This module provides functions for heap sort algorithm.

use std::cmp::Ordering;

/// Builds a heap from the provided array.
///
/// A max heap can be efficiently implemented with an array.
/// For example, the binary tree:
/// ```text
/// 1
/// 2 3
/// 4 5 6 7
/// ```
/// This function builds either a max heap or a min heap based on the `is_max_heap` parameter.
///
/// ... is represented by the following array:
/// ```text
/// 1 23 4567
/// ```
/// # Arguments
///
/// Given the index `i` of a node, parent and child indices can be calculated
/// as follows:
/// ```text
/// parent(i) = (i-1) / 2
/// left_child(i) = 2*i + 1
/// right_child(i) = 2*i + 2
/// ```
/// * `arr` - A mutable reference to the array to be sorted.
/// * `is_max_heap` - A boolean indicating whether to build a max heap (`true`) or a min heap (`false`).
fn build_heap<T: Ord>(arr: &mut [T], is_max_heap: bool) {
let mut i = (arr.len() - 1) / 2;
while i > 0 {
heapify(arr, i, is_max_heap);
i -= 1;
}
heapify(arr, 0, is_max_heap);
}

/// # Algorithm
/// Fixes a heap violation starting at the given index.
///
/// Heap sort has two steps:
/// 1. Convert the input array to a max heap.
/// 2. Partition the array into heap part and sorted part. Initially the
/// heap consists of the whole array and the sorted part is empty:
/// ```text
/// arr: [ heap |]
/// ```
/// This function adjusts the heap rooted at index `i` to fix the heap property violation.
/// It assumes that the subtrees rooted at left and right children of `i` are already heaps.
///
/// Repeatedly swap the root (i.e. the largest) element of the heap with
/// the last element of the heap and increase the sorted part by one:
/// ```text
/// arr: [ root ... last | sorted ]
/// --> [ last ... | root sorted ]
/// ```
/// # Arguments
///
/// After each swap, fix the heap to make it a valid max heap again.
/// Once the heap is empty, `arr` is completely sorted.
pub fn heap_sort<T: Ord>(arr: &mut [T]) {
if arr.len() <= 1 {
return; // already sorted
/// * `arr` - A mutable reference to the array representing the heap.
/// * `i` - The index to start fixing the heap violation.
/// * `is_max_heap` - A boolean indicating whether to maintain a max heap or a min heap.
fn heapify<T: Ord>(arr: &mut [T], i: usize, is_max_heap: bool) {
let mut comparator: fn(&T, &T) -> Ordering = |a, b| a.cmp(b);
if !is_max_heap {
comparator = |a, b| b.cmp(a);
}

heapify(arr);
let mut idx = i;
let l = 2 * i + 1;
let r = 2 * i + 2;

for end in (1..arr.len()).rev() {
arr.swap(0, end);
move_down(&mut arr[..end], 0);
if l < arr.len() && comparator(&arr[l], &arr[idx]) == Ordering::Greater {
idx = l;
}

if r < arr.len() && comparator(&arr[r], &arr[idx]) == Ordering::Greater {
idx = r;
}
}

/// Convert `arr` into a max heap.
fn heapify<T: Ord>(arr: &mut [T]) {
let last_parent = (arr.len() - 2) / 2;
for i in (0..=last_parent).rev() {
move_down(arr, i);
if idx != i {
arr.swap(i, idx);
heapify(arr, idx, is_max_heap);
}
}

/// Move the element at `root` down until `arr` is a max heap again.
/// Sorts the given array using heap sort algorithm.
///
/// This assumes that the subtrees under `root` are valid max heaps already.
fn move_down<T: Ord>(arr: &mut [T], mut root: usize) {
let last = arr.len() - 1;
loop {
let left = 2 * root + 1;
if left > last {
break;
}
let right = left + 1;
let max = if right <= last && arr[right] > arr[left] {
right
} else {
left
};
/// This function sorts the array either in ascending or descending order based on the `ascending` parameter.
///
/// # Arguments
///
/// * `arr` - A mutable reference to the array to be sorted.
/// * `ascending` - A boolean indicating whether to sort in ascending order (`true`) or descending order (`false`).
pub fn heap_sort<T: Ord>(arr: &mut [T], ascending: bool) {
if arr.len() <= 1 {
return;
}

if arr[max] > arr[root] {
arr.swap(root, max);
}
root = max;
// Build heap based on the order
build_heap(arr, ascending);

let mut end = arr.len() - 1;
while end > 0 {
arr.swap(0, end);
heapify(&mut arr[..end], 0, ascending);
end -= 1;
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::sorting::have_same_elements;
use crate::sorting::is_sorted;
use crate::sorting::{have_same_elements, heap_sort, is_descending_sorted, is_sorted};

#[test]
fn empty() {
let mut arr: Vec<i32> = Vec::new();
let cloned = arr.clone();
heap_sort(&mut arr);
assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned));
}
macro_rules! test_heap_sort {
($($name:ident: $input:expr,)*) => {
$(
#[test]
fn $name() {
let input_array = $input;
let mut arr_asc = input_array.clone();
heap_sort(&mut arr_asc, true);
assert!(is_sorted(&arr_asc) && have_same_elements(&arr_asc, &input_array));

#[test]
fn single_element() {
let mut arr = vec![1];
let cloned = arr.clone();
heap_sort(&mut arr);
assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned));
}

#[test]
fn sorted_array() {
let mut arr = vec![1, 2, 3, 4];
let cloned = arr.clone();
heap_sort(&mut arr);
assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned));
}

#[test]
fn unsorted_array() {
let mut arr = vec![3, 4, 2, 1];
let cloned = arr.clone();
heap_sort(&mut arr);
assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned));
}

#[test]
fn odd_number_of_elements() {
let mut arr = vec![3, 4, 2, 1, 7];
let cloned = arr.clone();
heap_sort(&mut arr);
assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned));
let mut arr_dsc = input_array.clone();
heap_sort(&mut arr_dsc, false);
assert!(is_descending_sorted(&arr_dsc) && have_same_elements(&arr_dsc, &input_array));
}
)*
}
}

#[test]
fn repeated_elements() {
let mut arr = vec![542, 542, 542, 542];
let cloned = arr.clone();
heap_sort(&mut arr);
assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned));
test_heap_sort! {
empty_array: Vec::<i32>::new(),
single_element_array: vec![5],
sorted: vec![1, 2, 3, 4, 5],
sorted_desc: vec![5, 4, 3, 2, 1, 0],
basic_0: vec![9, 8, 7, 6, 5],
basic_1: vec![8, 3, 1, 5, 7],
basic_2: vec![4, 5, 7, 1, 2, 3, 2, 8, 5, 4, 9, 9, 100, 1, 2, 3, 6, 4, 3],
duplicated_elements: vec![5, 5, 5, 5, 5],
strings: vec!["aa", "a", "ba", "ab"],
}
}