avatar

Derek Zeng

A programmer

Segment tree and Binary indexed tree (1)

by coderek

Recently, I discovered two interesting data structures, Segment tree and Binary index tree (BIT).

Both data structures are good at managing some range of aggregated values. For example, in a discrete range, and I want to query some summarised value for a sub-range. I also want to updates this value for elemental interval. Both segment tree and BIT make query and update operations efficient with worst case complexity O(lgn).

To me, segment tree is easier to understand. It is a binary tree that each leafs forms a distinct contiguous ranges. And internal nodes is a union range of both children. So it looks like this:

Each node stores some interesting value. For example, it can store the sum of sub array with index range matching the node's range. Or it can store the max/min value in that range as well.

When computing the range, we always use two pointers, where first one represents the starting position inclusive and second one represents the ending position exclusive. The following code can be used to construct the tree.

def segment_tree(root):
    start, end = root.start, root.end
    if end == start:
        return None
    mid = start + (end - start) / 2
    root.left = segment_tree(TreeNode(start, mid))
    root.right = segment_tree(TreeNode(mid, end))

With segment tree, we can efficiently query data in any sub range. The presumption here is that the computation of the data associated with the range is associative. That is, a + b = b + a and (a + b) + c = a + (b + c), + can be any operation.

Let's say we are looking at sum. And the sum of individual element in the tree is equal to it's index. And the sum of all element in node range is store with the node as val. So obviously sum of range 0~9 is 45.

If we want sum of range 2~8, we look at the root node first, root has range 0~9, this means the requested range is valid, and there is a sum for it. We break down 2~8 into 2~4 and 5~8, pass 2~4 to the left child, and 5~8 to the right. And apply the same logic recursively on children. Child will return a sum, then the parent/root node add the sum to get the result. The code would be:

def get_value(root, s, e):
    start,end=root.start,root.end
    if end==start: return 0
    if e==end and s==start: return tree.val
    if e<=start or end<=s: return 0 # out of range
    mid=start+(end-start)/2
    return get_value(root.left, s, mid) + get_value(root.right, mid, e)

Update of value are per value, so let's say we update value at index 4 to 8. By looking the the tree graph, we know that we need to update several range nodes because they contains index 4, namely, 0~9,0~4,2~4,3~4 and 4. The update itself is quite simple, just add the difference. To find the ranges that needs to be updated, we apply the same logic as get_value, instead of return a sum, we apply the update to the node.

def update(root, i, diff):
    start,end=root.start,root.end
    if start==i and end==start+1: return
    if not start<=i<=end: return
    root.val += diff
    mid=start+(end-start)/2
    return update(root.left, s, mid) + update(root.right, mid, e)

There is a variant of interval tree called segment tree. So each node instead of containing a range, contains an segment. So segments are mutually exclusive to each other. And the value of segment is the aggregated value of all segments in that interval. For example, we have a list of intervals 1~5, 3~6, 2~4, segment tree will break it down into segments 1~2,3~4,4~5,5~6 and with each segment node it stores the summarized value of all intervals containing that segment.

I will discuss BIT in next post.

(End of article)