ELI5 (Explain Like I’m 5)
Imagine a school where students form friend groups. At first, everyone is their own group. When two students become friends, you merge their groups. Later, if someone asks “are these two students in the same group?”, you can answer instantly. Union Find is exactly this — a data structure for tracking who’s connected to whom, with super-fast merge and lookup.
Explanation
- Union Find (Disjoint Set Union / DSU) maintains a collection of disjoint sets
- Two operations: find(x) — which group is x in? union(x, y) — merge x’s and y’s groups
- Two optimizations make both operations nearly O(1):
- Path compression — when finding root, flatten the tree so everyone points directly to root
- Union by rank — always attach smaller tree under larger tree
Keyword trigger: “connected components”, “detect cycle in undirected graph”, “minimum spanning tree”, “number of groups” → Union Find.
When to use?
- Dynamically connecting nodes and checking connectivity
- Detecting cycles in undirected graphs
- Minimum spanning tree (Kruskal’s algorithm)
- Counting connected components as edges are added
- Grouping problems where merges happen over time
Approach
Template
class UnionFind:
def __init__(self, n):
self.parent = list(range(n)) # each node is its own parent
self.rank = [0] * n # tree height (for union by rank)
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # path compression
return self.parent[x]
def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False # already connected
# union by rank: attach smaller tree under larger
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return True
Cycle Detection
If union(x, y) returns False, x and y were already in the same component → adding this edge creates a cycle.
Count Components
Start with n components. Each successful union (returns True) decreases count by 1.
Notes
- Time Complexity: O(α(n)) per operation — α is the inverse Ackermann function, effectively O(1)
- Space Complexity: O(n) for parent and rank arrays
- Path compression + union by rank together give the O(α(n)) bound — use both
- For grid problems, you can map (row, col) → index with
row * cols + col
Data structures
- Parent array —
parent[i]= parent of node i (root points to itself) - Rank array — approximate tree height, used to keep trees shallow
How to Master This
Step-by-step
- Write the template from memory — it’s short and always the same
- Solve #547 (number of provinces) — direct connected components
- Solve #684 (redundant connection) — cycle detection
- Solve #1584 (min cost to connect all points) — Kruskal’s with Union Find
- Solve #128 (longest consecutive sequence) — Union Find on numbers
Key Resources
- 📹 NeetCode — Union Find
- 📹 NeetCode — Redundant Connection
- 📝 NeetCode.io — Advanced Graphs section
- 🔁 Drill: #547 → #684 → #128 → #1584
Sample LeetCode Problems
| # | Problem | Difficulty | Interview Frequency | Must-Do |
|---|---|---|---|---|
| 547 | Number of Provinces | Medium | High | ✅ |
| 684 | Redundant Connection | Medium | High | ✅ |
| 128 | Longest Consecutive Sequence | Medium | Very High | ✅ |
| 1584 | Min Cost to Connect All Points | Medium | Medium | ⚡ |
| 990 | Satisfiability of Equality Equations | Medium | Medium | ⚡ |
| 721 | Accounts Merge | Medium | High | ✅ |
Code Samples
- Number of Provinces — count connected components
def findCircleNum(isConnected: list[list[int]]) -> int:
n = len(isConnected)
parent = list(range(n))
rank = [0] * n
def find(x):
if parent[x] != x:
parent[x] = find(parent[x]) # path compression
return parent[x]
def union(x, y):
px, py = find(x), find(y)
if px == py:
return False
if rank[px] < rank[py]:
px, py = py, px
parent[py] = px
if rank[px] == rank[py]:
rank[px] += 1
return True
components = n
for i in range(n):
for j in range(i + 1, n):
if isConnected[i][j] == 1:
if union(i, j):
components -= 1 # merged two components
return components
- Redundant Connection — find edge that creates a cycle
def findRedundantConnection(edges: list[list[int]]) -> list[int]:
n = len(edges)
parent = list(range(n + 1))
rank = [0] * (n + 1)
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
px, py = find(x), find(y)
if px == py:
return False # already connected — this edge is redundant
if rank[px] < rank[py]:
px, py = py, px
parent[py] = px
if rank[px] == rank[py]:
rank[px] += 1
return True
for u, v in edges:
if not union(u, v):
return [u, v] # this edge created a cycle
return []