parent[] array with find(x) (path compression) + union(a, b) (by rank/size)| Operation | Naive | Path Compression Only | Union by Rank Only | Both Optimizations | Amortized (Both) |
|---|---|---|---|---|---|
make_set(x) | O(1) | O(1) | O(1) | O(1) | O(1) |
find(x) | O(n) | O(log n)* | O(log n) | O(α(n)) | ~O(1) |
union(a, b) | O(n) | O(log n)* | O(log n) | O(α(n)) | ~O(1) |
connected(a, b) | O(n) | O(log n)* | O(log n) | O(α(n)) | ~O(1) |
| Space | O(n) | O(n) | O(n) | O(n) | O(n) |
| Component count | O(1) w/ counter | O(1) w/ counter | O(1) w/ counter | O(1) w/ counter | O(1) |
| Component size | O(n) | O(n) | O(1) w/ size[] | O(1) w/ size[] | O(1) |
| Undo last union | N/A | N/A | N/A | N/A | Use rollback DSU |
*Amortized over m operations: O(m · α(n)) total, where α is the inverse Ackermann function (effectively ≤ 4 for any practical n < 10600). [src1] [src3]
START: Do you need to track connectivity between elements?
|
+-- YES: Are edges added online (one at a time)?
| |
| +-- YES: Do you need to undo/remove edges?
| | |
| | +-- YES --> Use Link-Cut Trees or Offline DSU with rollback
| | +-- NO --> Union-Find (this unit) -- optimal choice
| |
| +-- NO (all edges known upfront): Do you need more than connectivity?
| |
| +-- YES (shortest path, all neighbors, etc.) --> BFS/DFS on adjacency list
| +-- NO (just connected components) --> Union-Find OR BFS/DFS (both work)
|
+-- NO: Do you need to find minimum spanning tree?
|
+-- YES: Dense graph (E >> V)?
| |
| +-- YES --> Prim's algorithm with priority queue
| +-- NO --> Kruskal's algorithm + Union-Find (this unit)
|
+-- NO --> This unit is not applicable
Create a parent array where each element is its own root (self-loop). Optionally track rank and component count. [src1]
class UnionFind:
def __init__(self, n):
self.parent = list(range(n)) # Each element is its own root
self.rank = [0] * n # Rank for union by rank
self.count = n # Number of connected components
Verify: uf = UnionFind(5) → expected: uf.parent == [0, 1, 2, 3, 4]
Path compression flattens the tree by making every node on the find path point directly to the root. [src1] [src4]
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # Path compression
return self.parent[x]
Verify: After find(x), every node on the path has root as direct parent.
Attach the shorter tree under the taller tree to keep height logarithmic. [src1] [src4]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False
if self.rank[root_x] < self.rank[root_y]:
root_x, root_y = root_y, root_x
self.parent[root_y] = root_x
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
self.count -= 1
return True
Verify: uf.union(0, 1) → expected: uf.count == 4, uf.find(0) == uf.find(1)
def connected(self, x, y):
return self.find(x) == self.find(y)
def component_count(self):
return self.count
Verify: uf.union(0, 1); uf.union(1, 2) → expected: uf.connected(0, 2) == True
Replace rank with size for applications needing component sizes. [src2]
class UnionFindBySize:
def __init__(self, n):
self.parent = list(range(n))
self.size = [1] * n
self.count = n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
root_x, root_y = self.find(x), self.find(y)
if root_x == root_y:
return False
if self.size[root_x] < self.size[root_y]:
root_x, root_y = root_y, root_x
self.parent[root_y] = root_x
self.size[root_x] += self.size[root_y]
self.count -= 1
return True
def get_size(self, x):
return self.size[self.find(x)]
# Input: n (vertices), edges as [(weight, u, v), ...]
# Output: MST edges and total weight
def kruskal_mst(n, edges):
uf = UnionFind(n)
edges.sort() # Sort by weight
mst = []
total_weight = 0
for weight, u, v in edges:
if uf.union(u, v): # Only add if no cycle
mst.append((u, v, weight))
total_weight += weight
if len(mst) == n - 1: # MST complete
break
return mst, total_weight
// Input: n (number of elements)
// Output: UnionFind object with find, union, connected
class UnionFind {
constructor(n) {
this.parent = Array.from({length: n}, (_, i) => i);
this.rank = new Array(n).fill(0);
this.count = n;
}
find(x) {
if (this.parent[x] !== x)
this.parent[x] = this.find(this.parent[x]);
return this.parent[x];
}
union(x, y) {
let rx = this.find(x), ry = this.find(y);
if (rx === ry) return false;
if (this.rank[rx] < this.rank[ry]) [rx, ry] = [ry, rx];
this.parent[ry] = rx;
if (this.rank[rx] === this.rank[ry]) this.rank[rx]++;
this.count--;
return true;
}
connected(x, y) { return this.find(x) === this.find(y); }
}
// Input: n vertices, int[][] edges (each [u, v])
// Output: true if graph contains a cycle
public class UnionFind {
private int[] parent, rank;
public UnionFind(int n) {
parent = new int[n]; rank = new int[n];
for (int i = 0; i < n; i++) parent[i] = i;
}
public int find(int x) {
if (parent[x] != x) parent[x] = find(parent[x]);
return parent[x];
}
public boolean union(int x, int y) {
int rx = find(x), ry = find(y);
if (rx == ry) return false;
if (rank[rx] < rank[ry]) { int t = rx; rx = ry; ry = t; }
parent[ry] = rx;
if (rank[rx] == rank[ry]) rank[rx]++;
return true;
}
public static boolean hasCycle(int n, int[][] edges) {
UnionFind uf = new UnionFind(n);
for (int[] e : edges)
if (!uf.union(e[0], e[1])) return true;
return false;
}
}
// Input: n (number of elements)
// Output: DSU struct with find, unite, connected, get_size
struct DSU {
vector<int> parent, sz;
int components;
DSU(int n) : parent(n), sz(n, 1), components(n) {
iota(parent.begin(), parent.end(), 0);
}
int find(int x) {
return parent[x] == x ? x : parent[x] = find(parent[x]);
}
bool unite(int a, int b) {
a = find(a); b = find(b);
if (a == b) return false;
if (sz[a] < sz[b]) swap(a, b);
parent[b] = a;
sz[a] += sz[b];
components--;
return true;
}
bool connected(int a, int b) { return find(a) == find(b); }
int get_size(int a) { return sz[find(a)]; }
};
# BAD -- O(n) worst case, tree becomes a linked list
def find(self, x):
while self.parent[x] != x:
x = self.parent[x]
return x
# GOOD -- amortized O(alpha(n)), tree stays flat
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
# BAD -- always attaches under same root, creates O(n) chains
def union(self, x, y):
root_x, root_y = self.find(x), self.find(y)
if root_x != root_y:
self.parent[root_y] = root_x # No rank/size check
# GOOD -- keeps tree balanced, O(log n) height guarantee
def union(self, x, y):
root_x, root_y = self.find(x), self.find(y)
if root_x == root_y:
return False
if self.rank[root_x] < self.rank[root_y]:
root_x, root_y = root_y, root_x
self.parent[root_y] = root_x
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
return True
# BAD -- updating rank as if it were size corrupts the invariant
self.rank[root_x] += self.rank[root_y] # Wrong! Rank != size
# GOOD -- rank: only increment when ranks are equal
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
# GOOD -- size: always add sizes together
self.size[root_x] += self.size[root_y]
parent[x] may not be the root after path compression. Fix: always compare find(x) == find(y). [src1]True/False from union(). [src5]parent = [0] * n makes every element's root 0. Fix: parent = list(range(n)). [src6]self.count -= 1 after merge. [src7]| Use When | Don't Use When | Use Instead |
|---|---|---|
| Checking if two nodes share a connected component | You need the actual path between nodes | BFS/DFS traversal |
| Building MST with Kruskal's algorithm | Graph is dense (E close to V²) | Prim's algorithm |
| Detecting cycles in undirected graphs | Detecting cycles in directed graphs | DFS with back-edge detection |
| Grouping equivalent items (accounts merge, synonyms) | You need to split groups or undo unions | Link-Cut Trees / rollback DSU |
| Online connectivity (edges added incrementally) | All edges known upfront + need full traversal | BFS/DFS (simpler) |
| Percolation simulation / grid connectivity | You need weighted shortest path | Dijkstra / Bellman-Ford |
| Very large graphs (106+ nodes) | Graph fits in memory and BFS/DFS suffices | BFS/DFS (simpler to code) |
RecursionError for large inputs with recursive find. Use sys.setrecursionlimit() or iterative find.