树状数组也就是 Binary Indexed Tree 或以作者名字命名为 Fenwick Tree,最早由 Peter M. Fenwick 在 1994 年发表的 A New Data Structure for Cumulative Frequency Tables,起初是为了解决 Cumulative Frequency 的计算,现在也用于高效计算数列的前缀和、区间和。
这里详细介绍其原理以及实现方式。
简介
如果通过数组实现 前缀和
以及 区间和
的计算,那么查询可以达到 O(1)
,而更新的话时间复杂度就是 O(n)
了,例如更新第一个元素,那么后续所有的值都需要进行更新。
而通过树状数组,可以达到 O(logN)
的更新以及查询,除了这里介绍的树状数组之外,还可以使用线段树求解。
基本原理
树状数组的树状显然是修饰词,实际上仍然是一个数组,只是利用下标以及二进制运算来维护元素间的父子关系,也就是 $parent = son + 2^k$ ,其中 k 是子节点下标对应二进制末尾 0 的个数。
如上图所示,A 和 B 都是数组,其中 A 是正常存储数据,而 B 则是树状数组,B4、B6、B7 都是 B8 的子节点,可以通过如下方式计算父节点。
4(0100) --> 4 + 2^2 = 8
6(0110) --> 6 + 2^1 = 8
7(0111) --> 7 + 2^0 = 8
所以,仅通过下标以及简单的位运算,那么就可以维护父子关系。
节点含义
仍以上图为例,其中奇数节点都是叶子节点,保存的是原数组相同下标保存的值,而偶数节点都是父节点,保存的是区间和,例如 B4 保存的是 B1 + B2 + B3 + A4 = A1 + A2 + A3 + A4 ;而 B5 只保存了 A5 的值。
所以,每个节点表示为某个区间的和,而区间的左边界是该节点最左侧叶子节点对应下标,而右边界则是自己的下标,例如 B8 表示 [1, 8]
的区间和;B6 表示 [5, 6]
的区间和。
所有节点表示的含义列举如下。
B1 = A1
B2 = B1 + A2 = A1 + A2
B3 = A3
B4 = B2 + B3 + A4 = A1 + A2 + A3 + A4
B5 = A5
B6 = B5 + A6 = A5 + A6
B7 = A7
B8 = B4 + B6 + B7 + A8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8
所以可以归纳出,左边界的下标一定是 $i - 2^k + 1$ ,其中 $i$ 为节点下标,右边界下标为 $i$ ,也就是。
8(1000) --> 8 - 2^3 + 1 = 1
7(0111) --> 7 - 2^0 + 1 = 7
6(0110) --> 6 - 2^1 + 1 = 5
lowbit
那么这里的关键就是如何计算二进制默认 0 的个数,通过 lowbit(n)
来表示非负整数 n 表示最低位 1 及其后面所有 0 组成的二进制数值,其实现如下。
static int low_bit(int x)
{
return x & -x;
}
计算机中数值通过补码表示,这样可以将符号和数值统一处理,其中,正数的补码与原码相同,而负数的补码等于正数的原码每位取反再加 1 ,这样会使得正负数末尾的 0 数相同,那么相与后就是所需的结果。
这个实际上就是一个规则,可以简单那个数值验证下,例如 34 的二进制是 00100010
,其中 -34
对应的补码为 11011110
,所以相与后就是 2 。
实现
常用于对一个数组进行更新以及求前缀和,需要支持如下操作:
bit_add(pos, delta)
对第pos
元素进行更新。bit_query(pos)
查询前缀和,也就是区间[0, to]
的元素进行累加。
另外,还有一些扩展接口,常见如下。
bit_get(pos)
查询第pos
的元素值。bit_query_range(from, to)
查询区间[from, to]
的区间和。
如下是针对 C 的实现,上述的公式是从下标为 1 开始计算,对于类似 C 的数组从 0 开始,那么,计算时只需要在位置上加 1 。
struct bit {
int size;
int *data;
};
static int low_bit(int x)
{
return x & -x;
}
void bit_destroy(struct bit *bit)
{
if (bit == NULL)
return;
if (bit->data != NULL)
free(bit->data);
free(bit);
}
struct bit *bit_create(int size)
{
struct bit *bit;
bit = malloc(sizeof(*bit));
if (bit == NULL)
return NULL;
bit->size = size;
bit->data = calloc(size, sizeof(int));
if (bit->data == NULL) {
free(bit);
return NULL;
}
return bit;
}
void bit_init(struct bit *bit, const int *array, int size)
{
int i, j;
for (i = 0; i < size; i++) {
bit->data[i] = array[i];
for (j = i - low_bit(i + 1) + 1; j < i; j++) {
bit->data[i] += array[j];
}
}
}
void bit_add(struct bit *bit, int pos, int delta)
{
while (pos < bit->size) {
bit->data[pos] += delta;
pos += low_bit(pos + 1);
}
}
// get sum of range [0, pos]
int bit_query(struct bit *bit, int pos)
{
int sum = 0;
while (pos >= 0) {
sum += bit->data[pos];
pos -= low_bit(pos + 1);
}
return sum;
}
int bit_query_range(struct bit *bit, int start, int end)
{
// [start, end]
return bit_query(bit, end) - bit_query(bit, start - 1);
// (start, end]
//return bit_query(bit, end) - bit_query(bit, start);
// [start, end)
//return bit_query(bit, end - 1) - bit_query(bit, start - 1);
// (start, end)
//return bit_query(bit, end - 1) - bit_query(bit, start);
}
int bit_get(struct bit *bit, int pos)
{
int ret = bit->data[pos];
int end = pos - low_bit(pos + 1);
pos--;
while (pos > end) {
ret -= bit->data[pos];
pos -= low_bit(pos + 1);
}
return ret;
}