Segment Trees, and how you can crack range queries problems

Segment Trees, and how you can crack range queries problems
Photo by niko photos / Unsplash

A quick refresher on segment trees

A segment tree is a data structure that allows answering range queries like the sum of subarray a[[l...r] or the minimum element in such a range in O(log n). The segment trees require a linear amount of memory and modify one element or a whole subarray.

How a segment tree works

Structure

Considering an array a[0...n-1] and a scenario where we require range sum queries. We construct the segment tree by splitting the array in half each time. The left child would store the sum of the left half, and the right child would store the other half.  Each of these halves would, in turn, be split, and the process would continue until all segments reach a size of 1.

Segment Tree | Set 1 (Sum of given range) - GeeksforGeeks

Implementation

Although the theoretical structure of segment trees is, well... a tree, the practical implementation usually uses arrays to avoid storing redundant information. Instead, we store the sums in an array starting from index 1. For any index i>=1,  the left and right children are respectively 2i and 2i+1.

The array size is capped at 4n since the number of elements would be N=1 + 2 + 4 + ... + 2tree height. The height of the tree is log(n) since we start with n element and divide each time by two until we have just one element, hence N = 2log(n)+1 -1 < 4n.

int N, t[4*N]; // t is the segment tree

To build the segment tree, we use a recursive function that gives the tree leaves the values from the original array and recursively computes the sums.

void build(vector<int> input, vector<int>& t, int ti, int tl, int tr) {
	if(tl == tr) {
    	// leaf
        t[ti] = input[tl];
    } else {
    	int tm = (tl+tr)/2;
        build(input, t, 2*ti, tl, tm);
        build(input, t, 2*ti+1, tm+1, tr);
        t[ti] = t[2*ti] + t[2*ti+1];
    }
}

To update an element in the tree:

void update(vector<int> input, vector<int>& t, in ti, int i, int val, int tl, int tr) {
	if(tl == tr) {
    	t[ti] = val;
    } else {
    	int tm = (tl+tr)/2;
        if(tm >= i)
        	update(input, t, 2*ti, i, val, tl, tm);
        else 
        	update(input, t, 2*ti+1, i, val, tm+1, tr);
        t[ti] = t[2*ti] + t[2*ti+1];
    }
}

finally, for the sum queries:

int rangeSum(vector<int> t, int ti, int tl, int tr, int l, int r) {
	if(l > r)
    	return 0;
    
    if(tl == l && tr == r) {
    	return t[ti];
    }
    
    int tm = (tl+tr)/2;
    return rangeSum(t, 2*ti, tl, tm, l, min(r, tm)) + rangeSum(t, 2*ti+1, tm+1, tr, max(l, tm+1), r);
}

Applying a segment tree to real-world scenarios

You are given an array of integers memory consisting of 0s and 1s - whether the corresponding memory unit is free or not. memory[i] = 0  means that ith memory unit is free and memory[i] = 1 means it's occupied.

Your task is to perform two types of queries:
alloc X : Find the left-most memory block of x consecutive free memory units and mark these units as occupied. If there are no blocks with x consecutive free units, return -1 otherwise, return the index of the first position of the allocated block segment.

erase index : If there exists an allocated memory block starting at position index, free all its memory units. If the memory cell at position index was occupied at the very first operation, free this cell only. Return the length of the deleted memory block. If there is no such allocated block starting at the position index, return -1.


The queries are given in the following format:
queries is an array of 2-elements arrays:

• if queries[i][0] = 0 then this is an alloc type query where the length X = queries[i][1]

• if queries[i][0] = 1 then this is an erase type query, where index = queries[i][1]
Return an array containing the results of all the queries.

#include<bits/stdc++.h>
using namespace std;

void update(vector<int>& t, int v, int tl, int tr, int pos, int new_val) {
	if(tl == tr) {
		t[v] = new_val;	
	} else {
		int tm = (tl + tr) / 2;
		if(pos <= tm)
			update(t, 2*v, tl, tm, pos, new_val);
		else
			update(t, 2*v+1, tm+1, tr, pos, new_val);
		t[v] = t[2*v] + t[2*v+1];
	}	
}

int sum(vector<int> t, int v, int tl, int tr, int l, int r) {
	if(l > r)
		return 0;

	if(l == tl && r == tr)
		return t[v]; 

	int tm = (tl + tr)/2;
	
	return sum(t, 2*v, tl, tm, l, min(r, tm)) + sum(t, 2*v+1, tm+1, tr, max(tm+1, l), r);
}

pair<int, int> alloc(vector<int>& t, int x, int v, int tl, int tr) {
	if(tl == tr)
		return {x == 1 && t[v] == 0? tl : -1, t[v]}; 

	int tm = (tl + tr)/2;
	auto bl = alloc(t, x, 2*v, tl, tm);
	if(bl.first != -1)
		return bl;
	auto br = alloc(t, x, 2*v+1, tm+1, tr);
	if(br.first != -1)
		return br;

	if(bl.second + br.second == 0 && x <= tr-tl+1)
		return {tl, bl.second+br.second};

	return {-1, bl.second + br.second};
}

void buildTree(vector<int>& t, vector<int> a, int v, int tl, int tr) {
	if(tl == tr) {
		t[v] = a[tl];	
	} else {
		int tm = (tl + tr)/2;
		buildTree(t, a, 2*v, tl, tm);
		buildTree(t, a, 2*v+1, tm+1, tr);
		t[v] = t[2*v] + t[2*v+1];
	}
}


int main() {
	vector<int> input = {0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0};
	int n = input.size();
	vector<vector<int>> queries = {{0, 2}, {0, 1}, {0, 1}, {1, 0}, {1, 1}, {1, 3}, {0, 4}}; 

	set<int> initial;
	for(int i = 0; i < input.size(); i++) {
		if(input[i] == 1) {
			initial.insert(i);
		}
	}

	vector<int> t;
	t.resize(4*n);
	buildTree(t, input, 1, 0, n-1);

	map<int, int>  blocks;

	vector<int> res;
	for(const auto& q: queries) {
		int tmp = -1;
		if(q[0] == 0) {
			auto p = alloc(t, q[1], 1, 0, n-1); 
			if(p.first != -1) {
				tmp = p.first;
				blocks[p.first] = q[1];
				for(int pos = p.first; pos < p.first + q[1]; pos++)
					update(t, 1, 0, n-1, pos, 1); 
			}
		} else {
			int l = q[1], r = q[1];
			if(initial.count(q[1]) > 0) {
				tmp = 1;	
			} else if(blocks.find(q[1]) != blocks.end()) {
				tmp = blocks[q[1]];
				r += tmp;
			}

			if(tmp != -1) {
				for(int i = l; i <= r; i++)
					update(t, 1, 0, n-1, i, 0);
			}
		}
		res.push_back(tmp);
	}

	for(auto n: res)
		cout << n << " ";
	cout << endl;
}