ELI5 (Explain Like I’m 5)
You have a row of 1000 numbers and someone keeps asking “what’s the sum from index 3 to index 500?” and also keeps changing values. Recalculating from scratch each time is slow. A segment tree precomputes partial sums in a tree structure — each node covers a range. Answering a range query takes O(log n) instead of O(n).
Explanation
- A segment tree is a binary tree over an array where each node stores an aggregate (sum, min, max) for a contiguous range
- Leaf nodes represent individual elements; internal nodes represent merged ranges
- Supports two operations in O(log n): range query and point update
- Built in O(n), stores in an array of size 4n (array-based representation)
Keyword trigger: “range sum query with updates”, “range min/max with updates”, “count inversions” → Segment Tree (or Fenwick Tree for simpler cases).
When to use?
- Range queries (sum, min, max) with point updates in between
- “How many elements in range [l, r] satisfy condition X?”
- When you need O(log n) updates AND O(log n) queries simultaneously
- If only queries (no updates), prefix sums work (O(1) query); if only updates, just update the array
Approach
Array-based Segment Tree
For an array of size n, the segment tree has at most 4n nodes.
Node i:
- left child: 2i
- right child: 2i + 1
- parent: i // 2
Build: fill from leaves up
Query: split range into O(log n) tree nodes, merge their values
Update: update leaf, propagate changes up
Build
def build(arr, node, start, end):
if start == end:
tree[node] = arr[start] # leaf
else:
mid = (start + end) // 2
build(arr, 2*node, start, mid)
build(arr, 2*node+1, mid+1, end)
tree[node] = tree[2*node] + tree[2*node+1] # merge (sum example)
Query
def query(node, start, end, l, r):
if r < start or end < l:
return 0 # out of range
if l <= start and end <= r:
return tree[node] # completely inside range
mid = (start + end) // 2
return query(2*node, start, mid, l, r) + query(2*node+1, mid+1, end, l, r)
Notes
- Time Complexity: O(n) build, O(log n) query, O(log n) update
- Space Complexity: O(n) — typically allocate
4 * nto be safe - For lazy propagation (range updates, not just point updates), you need to store lazy tags and push them down
- Python is slow for large segment trees — consider using a different approach (Fenwick Tree) for simpler range-sum problems
Data structures
- Array of size 4n — the segment tree itself (1-indexed, root at index 1)
- Lazy array — optional, for range updates with lazy propagation
How to Master This
Step-by-step
- Implement a range sum segment tree (build, query, update) from scratch
- Solve #307 (range sum query mutable) — the canonical segment tree problem
- Solve #315 (count of smaller numbers after self) — segment tree on coordinate-compressed values
- Learn lazy propagation for range update queries
Key Resources
- 📹 William Fiset — Segment Tree
- 📹 Errichto — Segment Trees
- 📝 CP-Algorithms — Segment Tree
- 🔁 Drill: #307 → #315 → #493
Sample LeetCode Problems
| # | Problem | Difficulty | Interview Frequency | Must-Do |
|---|---|---|---|---|
| 307 | Range Sum Query - Mutable | Medium | High | ✅ |
| 315 | Count of Smaller Numbers After Self | Hard | Medium | ⚡ |
| 493 | Reverse Pairs | Hard | Low | 📖 |
| 218 | The Skyline Problem | Hard | Medium | ⚡ |
Code Samples
- Range Sum Segment Tree — build, query, update
class SegmentTree:
def __init__(self, nums: list[int]):
self.n = len(nums)
self.tree = [0] * (4 * self.n)
self._build(nums, 1, 0, self.n - 1)
def _build(self, nums: list[int], node: int, start: int, end: int) -> None:
if start == end:
self.tree[node] = nums[start]
return
mid = (start + end) // 2
self._build(nums, 2 * node, start, mid)
self._build(nums, 2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def update(self, idx: int, val: int, node: int = 1, start: int = 0, end: int = -1) -> None:
if end == -1:
end = self.n - 1
if start == end:
self.tree[node] = val
return
mid = (start + end) // 2
if idx <= mid:
self.update(idx, val, 2 * node, start, mid)
else:
self.update(idx, val, 2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def query(self, l: int, r: int, node: int = 1, start: int = 0, end: int = -1) -> int:
if end == -1:
end = self.n - 1
if r < start or end < l:
return 0 # out of range
if l <= start and end <= r:
return self.tree[node] # fully inside range
mid = (start + end) // 2
left_sum = self.query(l, r, 2 * node, start, mid)
right_sum = self.query(l, r, 2 * node + 1, mid + 1, end)
return left_sum + right_sum
# Usage (LeetCode #307)
class NumArray:
def __init__(self, nums: list[int]):
self.st = SegmentTree(nums)
def update(self, index: int, val: int) -> None:
self.st.update(index, val)
def sumRange(self, left: int, right: int) -> int:
return self.st.query(left, right)