disjoint set Algorithm
The disjoint set algorithm, also known as union-find or merge-find set, is a powerful data structure that efficiently manages a collection of non-overlapping sets. It is primarily used for solving problems where one needs to keep track of connected components in an undirected graph or to determine if a graph contains cycles. The algorithm is designed to perform two main operations: union and find. Union operation combines two sets into one, while find operation determines the set to which a particular element belongs. The algorithm boasts its ability to perform these operations with near-constant time complexity, making it an ideal choice for managing large sets of data.
The disjoint set algorithm is based on storing the sets as trees, where each element has a parent pointer, and the roots of the trees represent the unique sets. The find operation consists of following the parent pointers until the root is reached, which represents the set identifier. To improve efficiency, the algorithm employs path compression and union by rank heuristics. Path compression flattens the tree structure by making each node visited during the find operation point directly to the root. Union by rank, on the other hand, ensures that the tree with the smaller rank is attached to the tree with the larger rank during the union operation, keeping the tree height minimal. These optimizations allow the disjoint set algorithm to perform operations in near-constant time, making it an effective data structure for various applications, such as Kruskal's algorithm for finding the minimum spanning tree and percolation threshold estimation.
"""
disjoint set
Reference: https://en.wikipedia.org/wiki/Disjoint-set_data_structure
"""
class Node:
def __init__(self, data):
self.data = data
def make_set(x):
"""
make x as a set.
"""
# rank is the distance from x to its' parent
# root's rank is 0
x.rank = 0
x.parent = x
def union_set(x, y):
"""
union two sets.
set with bigger rank should be parent, so that the
disjoint set tree will be more flat.
"""
x, y = find_set(x), find_set(y)
if x.rank > y.rank:
y.parent = x
else:
x.parent = y
if x.rank == y.rank:
y.rank += 1
def find_set(x):
"""
return the parent of x
"""
if x != x.parent:
x.parent = find_set(x.parent)
return x.parent
def find_python_set(node: Node) -> set:
"""
Return a Python Standard Library set that contains i.
"""
sets = ({0, 1, 2}, {3, 4, 5})
for s in sets:
if node.data in s:
return s
raise ValueError(f"{node.data} is not in {sets}")
def test_disjoint_set():
"""
>>> test_disjoint_set()
"""
vertex = [Node(i) for i in range(6)]
for v in vertex:
make_set(v)
union_set(vertex[0], vertex[1])
union_set(vertex[1], vertex[2])
union_set(vertex[3], vertex[4])
union_set(vertex[3], vertex[5])
for node0 in vertex:
for node1 in vertex:
if find_python_set(node0).isdisjoint(find_python_set(node1)):
assert find_set(node0) != find_set(node1)
else:
assert find_set(node0) == find_set(node1)
if __name__ == "__main__":
test_disjoint_set()