树状数组详解


树状数组也就是 Binary Indexed Tree 或以作者名字命名为 Fenwick Tree,最早由 Peter M. Fenwick 在 1994 年发表的 A New Data Structure for Cumulative Frequency Tables,起初是为了解决 Cumulative Frequency 的计算,现在也用于高效计算数列的前缀和、区间和。

这里详细介绍其原理以及实现方式。

简介

如果通过数组实现 前缀和 以及 区间和 的计算,那么查询可以达到 O(1) ,而更新的话时间复杂度就是 O(n) 了,例如更新第一个元素,那么后续所有的值都需要进行更新。

而通过树状数组,可以达到 O(logN) 的更新以及查询,除了这里介绍的树状数组之外,还可以使用线段树求解。

基本原理

树状数组的树状显然是修饰词,实际上仍然是一个数组,只是利用下标以及二进制运算来维护元素间的父子关系,也就是 $parent = son + 2^k$ ,其中 k 是子节点下标对应二进制末尾 0 的个数。

Binary Index Tree

如上图所示,A 和 B 都是数组,其中 A 是正常存储数据,而 B 则是树状数组,B4、B6、B7 都是 B8 的子节点,可以通过如下方式计算父节点。

4(0100) --> 4 + 2^2 = 8
6(0110) --> 6 + 2^1 = 8
7(0111) --> 7 + 2^0 = 8

所以,仅通过下标以及简单的位运算,那么就可以维护父子关系。

节点含义

仍以上图为例,其中奇数节点都是叶子节点,保存的是原数组相同下标保存的值,而偶数节点都是父节点,保存的是区间和,例如 B4 保存的是 B1 + B2 + B3 + A4 = A1 + A2 + A3 + A4 ;而 B5 只保存了 A5 的值。

所以,每个节点表示为某个区间的和,而区间的左边界是该节点最左侧叶子节点对应下标,而右边界则是自己的下标,例如 B8 表示 [1, 8] 的区间和;B6 表示 [5, 6] 的区间和。

所有节点表示的含义列举如下。

B1 = A1
B2 = B1 + A2 = A1 + A2
B3 = A3
B4 = B2 + B3 + A4 = A1 + A2 + A3 + A4
B5 = A5
B6 = B5 + A6 = A5 + A6
B7 = A7
B8 = B4 + B6 + B7 + A8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8

所以可以归纳出,左边界的下标一定是 $i - 2^k + 1$ ,其中 $i$ 为节点下标,右边界下标为 $i$ ,也就是。

8(1000) --> 8 - 2^3 + 1 = 1
7(0111) --> 7 - 2^0 + 1 = 7
6(0110) --> 6 - 2^1 + 1 = 5

lowbit

那么这里的关键就是如何计算二进制默认 0 的个数,通过 lowbit(n) 来表示非负整数 n 表示最低位 1 及其后面所有 0 组成的二进制数值,其实现如下。

static int low_bit(int x)
{
    return x & -x;
}

计算机中数值通过补码表示,这样可以将符号和数值统一处理,其中,正数的补码与原码相同,而负数的补码等于正数的原码每位取反再加 1 ,这样会使得正负数末尾的 0 数相同,那么相与后就是所需的结果。

这个实际上就是一个规则,可以简单那个数值验证下,例如 34 的二进制是 00100010 ,其中 -34 对应的补码为 11011110 ,所以相与后就是 2 。

实现

常用于对一个数组进行更新以及求前缀和,需要支持如下操作:

  • bit_add(pos, delta) 对第 pos 元素进行更新。
  • bit_query(pos) 查询前缀和,也就是区间 [0, to] 的元素进行累加。

另外,还有一些扩展接口,常见如下。

  • bit_get(pos) 查询第 pos 的元素值。
  • bit_query_range(from, to) 查询区间 [from, to] 的区间和。

如下是针对 C 的实现,上述的公式是从下标为 1 开始计算,对于类似 C 的数组从 0 开始,那么,计算时只需要在位置上加 1 。

struct bit {
	int size;
	int *data;
};

static int low_bit(int x)
{
	return x & -x;
}

void bit_destroy(struct bit *bit)
{
	if (bit == NULL)
		return;
	if (bit->data != NULL)
		free(bit->data);
	free(bit);
}

struct bit *bit_create(int size)
{
	struct bit *bit;

	bit = malloc(sizeof(*bit));
	if (bit == NULL)
		return NULL;

	bit->size = size;
	bit->data = calloc(size, sizeof(int));
	if (bit->data == NULL) {
		free(bit);
		return NULL;
	}

	return bit;
}

void bit_init(struct bit *bit, const int *array, int size)
{
	int i, j;

	for (i = 0; i < size; i++) {
		bit->data[i] = array[i];
		for (j = i - low_bit(i + 1) + 1; j < i; j++) {
			bit->data[i] += array[j];
		}
	}
}

void bit_add(struct bit *bit, int pos, int delta)
{
	while (pos < bit->size) {
		bit->data[pos] += delta;
		pos += low_bit(pos + 1);
	}
}

// get sum of range [0, pos]
int bit_query(struct bit *bit, int pos)
{
	int sum = 0;
	while (pos >= 0) {
		sum += bit->data[pos];
		pos -= low_bit(pos + 1);
	}
	return sum;
}

int bit_query_range(struct bit *bit, int start, int end)
{
    // [start, end]
    return bit_query(bit, end) - bit_query(bit, start - 1);
    // (start, end]
    //return bit_query(bit, end) - bit_query(bit, start);
    // [start, end)
    //return bit_query(bit, end - 1) - bit_query(bit, start - 1);
    // (start, end)
    //return bit_query(bit, end - 1) - bit_query(bit, start);
}

int bit_get(struct bit *bit, int pos)
{
	int ret = bit->data[pos];
	int end = pos - low_bit(pos + 1);

	pos--;
	while (pos > end) {
		ret -= bit->data[pos];
		pos -= low_bit(pos + 1);
	}
	return ret;
}