Skip to content

Refactor UnionFind #729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 143 additions & 134 deletions src/data_structures/union_find.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
//! A Union-Find (Disjoint Set) data structure implementation in Rust.
//!
//! The Union-Find data structure keeps track of elements partitioned into
//! disjoint (non-overlapping) sets.
//! It provides near-constant-time operations to add new sets, to find the
//! representative of a set, and to merge sets.

use std::cmp::Ordering;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;

/// UnionFind data structure
/// It acts by holding an array of pointers to parents, together with the size of each subset
#[derive(Debug)]
pub struct UnionFind<T: Debug + Eq + Hash> {
payloads: HashMap<T, usize>, // we are going to manipulate indices to parent, thus `usize`. We need a map to associate a value to its index in the parent links array
parent_links: Vec<usize>, // holds the relationship between an item and its parent. The root of a set is denoted by parent_links[i] == i
sizes: Vec<usize>, // holds the size
count: usize,
payloads: HashMap<T, usize>, // Maps values to their indices in the parent_links array.
parent_links: Vec<usize>, // Holds the parent pointers; root elements are their own parents.
sizes: Vec<usize>, // Holds the sizes of the sets.
count: usize, // Number of disjoint sets.
}

impl<T: Debug + Eq + Hash> UnionFind<T> {
/// Creates an empty Union Find structure with capacity n
///
/// # Examples
///
/// ```
/// use the_algorithms_rust::data_structures::UnionFind;
/// let uf = UnionFind::<&str>::with_capacity(5);
/// assert_eq!(0, uf.count())
/// ```
/// Creates an empty Union-Find structure with a specified capacity.
pub fn with_capacity(capacity: usize) -> Self {
Self {
parent_links: Vec::with_capacity(capacity),
Expand All @@ -31,7 +29,7 @@ impl<T: Debug + Eq + Hash> UnionFind<T> {
}
}

/// Inserts a new item (disjoint) in the data structure
/// Inserts a new item (disjoint set) into the data structure.
pub fn insert(&mut self, item: T) {
let key = self.payloads.len();
self.parent_links.push(key);
Expand All @@ -40,107 +38,63 @@ impl<T: Debug + Eq + Hash> UnionFind<T> {
self.count += 1;
}

pub fn id(&self, value: &T) -> Option<usize> {
self.payloads.get(value).copied()
/// Returns the root index of the set containing the given value, or `None` if it doesn't exist.
pub fn find(&mut self, value: &T) -> Option<usize> {
self.payloads
.get(value)
.copied()
.map(|key| self.find_by_key(key))
}

/// Returns the key of an item stored in the data structure or None if it doesn't exist
fn find(&self, value: &T) -> Option<usize> {
self.id(value).map(|id| self.find_by_key(id))
}

/// Creates a link between value_1 and value_2
/// returns None if either value_1 or value_2 hasn't been inserted in the data structure first
/// returns Some(true) if two disjoint sets have been merged
/// returns Some(false) if both elements already were belonging to the same set
///
/// #_Examples:
///
/// ```
/// use the_algorithms_rust::data_structures::UnionFind;
/// let mut uf = UnionFind::with_capacity(2);
/// uf.insert("A");
/// uf.insert("B");
///
/// assert_eq!(None, uf.union(&"A", &"C"));
///
/// assert_eq!(2, uf.count());
/// assert_eq!(Some(true), uf.union(&"A", &"B"));
/// assert_eq!(1, uf.count());
///
/// assert_eq!(Some(false), uf.union(&"A", &"B"));
/// ```
pub fn union(&mut self, item1: &T, item2: &T) -> Option<bool> {
match (self.find(item1), self.find(item2)) {
(Some(k1), Some(k2)) => Some(self.union_by_key(k1, k2)),
/// Unites the sets containing the two given values. Returns:
/// - `None` if either value hasn't been inserted,
/// - `Some(true)` if two disjoint sets have been merged,
/// - `Some(false)` if both elements were already in the same set.
pub fn union(&mut self, first_item: &T, sec_item: &T) -> Option<bool> {
let (first_root, sec_root) = (self.find(first_item), self.find(sec_item));
match (first_root, sec_root) {
(Some(first_root), Some(sec_root)) => Some(self.union_by_key(first_root, sec_root)),
_ => None,
}
}

/// Returns the parent of the element given its id
fn find_by_key(&self, key: usize) -> usize {
let mut id = key;
while id != self.parent_links[id] {
id = self.parent_links[id];
/// Finds the root of the set containing the element with the given index.
fn find_by_key(&mut self, key: usize) -> usize {
if self.parent_links[key] != key {
self.parent_links[key] = self.find_by_key(self.parent_links[key]);
}
id
self.parent_links[key]
}

/// Unions the sets containing id1 and id2
fn union_by_key(&mut self, key1: usize, key2: usize) -> bool {
let root1 = self.find_by_key(key1);
let root2 = self.find_by_key(key2);
if root1 == root2 {
return false; // they belong to the same set already, no-op
/// Unites the sets containing the two elements identified by their indices.
fn union_by_key(&mut self, first_key: usize, sec_key: usize) -> bool {
let (first_root, sec_root) = (self.find_by_key(first_key), self.find_by_key(sec_key));

if first_root == sec_root {
return false;
}
// Attach the smaller set to the larger one
if self.sizes[root1] < self.sizes[root2] {
self.parent_links[root1] = root2;
self.sizes[root2] += self.sizes[root1];
} else {
self.parent_links[root2] = root1;
self.sizes[root1] += self.sizes[root2];

match self.sizes[first_root].cmp(&self.sizes[sec_root]) {
Ordering::Less => {
self.parent_links[first_root] = sec_root;
self.sizes[sec_root] += self.sizes[first_root];
}
_ => {
self.parent_links[sec_root] = first_root;
self.sizes[first_root] += self.sizes[sec_root];
}
}
self.count -= 1; // we had 2 disjoint sets, now merged as one

self.count -= 1;
true
}

/// Checks if two items belong to the same set
///
/// #_Examples:
///
/// ```
/// use the_algorithms_rust::data_structures::UnionFind;
/// let mut uf = UnionFind::from_iter(["A", "B"]);
/// assert!(!uf.is_same_set(&"A", &"B"));
///
/// uf.union(&"A", &"B");
/// assert!(uf.is_same_set(&"A", &"B"));
///
/// assert!(!uf.is_same_set(&"A", &"C"));
/// ```
pub fn is_same_set(&self, item1: &T, item2: &T) -> bool {
matches!((self.find(item1), self.find(item2)), (Some(root1), Some(root2)) if root1 == root2)
/// Checks if two items belong to the same set.
pub fn is_same_set(&mut self, first_item: &T, sec_item: &T) -> bool {
matches!((self.find(first_item), self.find(sec_item)), (Some(first_root), Some(sec_root)) if first_root == sec_root)
}

/// Returns the number of disjoint sets
///
/// # Examples
///
/// ```
/// use the_algorithms_rust::data_structures::UnionFind;
/// let mut uf = UnionFind::with_capacity(5);
/// assert_eq!(0, uf.count());
///
/// uf.insert("A");
/// assert_eq!(1, uf.count());
///
/// uf.insert("B");
/// assert_eq!(2, uf.count());
///
/// uf.union(&"A", &"B");
/// assert_eq!(1, uf.count())
/// ```
/// Returns the number of disjoint sets.
pub fn count(&self) -> usize {
self.count
}
Expand All @@ -158,11 +112,11 @@ impl<T: Debug + Eq + Hash> Default for UnionFind<T> {
}

impl<T: Debug + Eq + Hash> FromIterator<T> for UnionFind<T> {
/// Creates a new UnionFind data structure from an iterable of disjoint elements
/// Creates a new UnionFind data structure from an iterable of disjoint elements.
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let mut uf = UnionFind::default();
for i in iter {
uf.insert(i);
for item in iter {
uf.insert(item);
}
uf
}
Expand All @@ -175,45 +129,100 @@ mod tests {
#[test]
fn test_union_find() {
let mut uf = UnionFind::from_iter(0..10);
assert_eq!(uf.find_by_key(0), 0);
assert_eq!(uf.find_by_key(1), 1);
assert_eq!(uf.find_by_key(2), 2);
assert_eq!(uf.find_by_key(3), 3);
assert_eq!(uf.find_by_key(4), 4);
assert_eq!(uf.find_by_key(5), 5);
assert_eq!(uf.find_by_key(6), 6);
assert_eq!(uf.find_by_key(7), 7);
assert_eq!(uf.find_by_key(8), 8);
assert_eq!(uf.find_by_key(9), 9);

assert_eq!(Some(true), uf.union(&0, &1));
assert_eq!(Some(true), uf.union(&1, &2));
assert_eq!(Some(true), uf.union(&2, &3));
assert_eq!(uf.find(&0), Some(0));
assert_eq!(uf.find(&1), Some(1));
assert_eq!(uf.find(&2), Some(2));
assert_eq!(uf.find(&3), Some(3));
assert_eq!(uf.find(&4), Some(4));
assert_eq!(uf.find(&5), Some(5));
assert_eq!(uf.find(&6), Some(6));
assert_eq!(uf.find(&7), Some(7));
assert_eq!(uf.find(&8), Some(8));
assert_eq!(uf.find(&9), Some(9));

assert!(!uf.is_same_set(&0, &1));
assert!(!uf.is_same_set(&2, &9));
assert_eq!(uf.count(), 10);

assert_eq!(uf.union(&0, &1), Some(true));
assert_eq!(uf.union(&1, &2), Some(true));
assert_eq!(uf.union(&2, &3), Some(true));
assert_eq!(uf.union(&0, &2), Some(false));
assert_eq!(uf.union(&4, &5), Some(true));
assert_eq!(uf.union(&5, &6), Some(true));
assert_eq!(uf.union(&6, &7), Some(true));
assert_eq!(uf.union(&7, &8), Some(true));
assert_eq!(uf.union(&8, &9), Some(true));
assert_eq!(uf.union(&7, &9), Some(false));

assert_ne!(uf.find(&0), uf.find(&9));
assert_eq!(uf.find(&0), uf.find(&3));
assert_eq!(uf.find(&4), uf.find(&9));
assert!(uf.is_same_set(&0, &3));
assert!(uf.is_same_set(&4, &9));
assert!(!uf.is_same_set(&0, &9));
assert_eq!(uf.count(), 2);

assert_eq!(Some(true), uf.union(&3, &4));
assert_eq!(Some(true), uf.union(&4, &5));
assert_eq!(Some(true), uf.union(&5, &6));
assert_eq!(Some(true), uf.union(&6, &7));
assert_eq!(Some(true), uf.union(&7, &8));
assert_eq!(Some(true), uf.union(&8, &9));
assert_eq!(Some(false), uf.union(&9, &0));

assert_eq!(1, uf.count());
assert_eq!(uf.find(&0), uf.find(&9));
assert_eq!(uf.count(), 1);
assert!(uf.is_same_set(&0, &9));

assert_eq!(None, uf.union(&0, &11));
}

#[test]
fn test_spanning_tree() {
// Let's imagine the following topology:
// A <-> B
// B <-> C
// A <-> D
// E
// F <-> G
// We have 3 disjoint sets: {A, B, C, D}, {E}, {F, G}
let mut uf = UnionFind::from_iter(["A", "B", "C", "D", "E", "F", "G"]);
uf.union(&"A", &"B");
uf.union(&"B", &"C");
uf.union(&"A", &"D");
uf.union(&"F", &"G");
assert_eq!(3, uf.count());

assert_eq!(None, uf.union(&"A", &"W"));

assert_eq!(uf.find(&"A"), uf.find(&"B"));
assert_eq!(uf.find(&"A"), uf.find(&"C"));
assert_eq!(uf.find(&"B"), uf.find(&"D"));
assert_ne!(uf.find(&"A"), uf.find(&"E"));
assert_ne!(uf.find(&"A"), uf.find(&"F"));
assert_eq!(uf.find(&"G"), uf.find(&"F"));
assert_ne!(uf.find(&"G"), uf.find(&"E"));

assert!(uf.is_same_set(&"A", &"B"));
assert!(uf.is_same_set(&"A", &"C"));
assert!(uf.is_same_set(&"B", &"D"));
assert!(!uf.is_same_set(&"B", &"F"));
assert!(!uf.is_same_set(&"E", &"A"));
assert!(!uf.is_same_set(&"E", &"G"));
assert_eq!(uf.count(), 3);
}

#[test]
fn test_with_capacity() {
let mut uf: UnionFind<i32> = UnionFind::with_capacity(5);
uf.insert(0);
uf.insert(1);
uf.insert(2);
uf.insert(3);
uf.insert(4);

assert_eq!(uf.count(), 5);

assert_eq!(uf.union(&0, &1), Some(true));
assert!(uf.is_same_set(&0, &1));
assert_eq!(uf.count(), 4);

assert_eq!(uf.union(&2, &3), Some(true));
assert!(uf.is_same_set(&2, &3));
assert_eq!(uf.count(), 3);

assert_eq!(uf.union(&0, &2), Some(true));
assert!(uf.is_same_set(&0, &1));
assert!(uf.is_same_set(&2, &3));
assert!(uf.is_same_set(&0, &3));
assert_eq!(uf.count(), 2);

assert_eq!(None, uf.union(&0, &10));
}
}