引入 FWT
快速沃尔什变换(FWT)是解决这样一类卷积问题:
$$ c_i=\sum_{i=j\odot k}a_jb_k $$
其中,$\odot$ 是位运算的一种。举个例子,给定数列 $a,b$,求:
$$ c_i=\sum_{j\oplus k=i} a_jb_k $$
FWT 的思想
看到 FWT 的名字,我们可以联想到之前学过的 FFT(很可惜,我没有写过 FFT 的笔记,所以没有链接),先看看 FFT 的原理:
- 把 $a,b$ 变换为 $A,B$,$O(n\log n)$;
- 通过 $C_i=A_iB_i$ 计算,$O(n)$;
- 把 $C$ 变换回 $c$,$O(n\log n)$。
综上,时间复杂度是 $O(n\log n)$ 的。
在 FFT 中,我们构造了 $A,B$ 为 $a,b$ 的点值表示法,这么做满足 $C_i=A_iB_i$ 且容易变换。
其实 FWT 的思想也是一样的,主要也是需要构造 $A,B$,使得其满足 $C_i=A_iB_i$ 且可以快速变换。下面我们举 $\cup$(按位或)、$\cap$(按位与)和 $\oplus$(按位异或)为例。
因为数列长度是 $2$ 的幂会更好处理,所以下文认为数列长度为 $2^n$。
按位或
$$ c_i=\sum_{j\cup k=i} a_jb_k $$
我们可以构造 $A_i=\sum_{i\cup j=i} a_i$。看看为什么需要这么构造。
首先,它满足 $C_i=A_iB_i$:
$$ \begin{align} A_iB_i&=(\sum_{i\cup j=i} a_j)(\sum_{i\cup k=i} b_k) \\\ &=\sum_{i\cup j=i}\sum_{i\cup k=i}a_jb_k \\\ &=\sum_{i\cup j=i}\sum_{i\cup k=i}a_jb_k \\\ &=\sum_{i\cup(j\cup k)=i}a_jb_k \\\ &= C_i \end{align} $$
其次,它可以快速变换。举顺变换的例子。类比 FFT 的步骤,我们采用分治的方法来处理它。假设目前考虑到第 $i$ 位,其中 $A_0$ 和 $A_1$ 是 $i-1$ 位分治的结果:
$$ A=\text{merge}(A_0, A_0+A_1) $$
其中,$A_0$ 是数列 $A$ 的左半部分,$A_1$ 是 $A$ 的右半部分。$\text{merge}$ 函数就是把两个数列像拼接字符串一样拼接起来。$+$ 则是将两个数列对应相加。
这么做为什么是正确的呢?容易发现,$A_0$ 恰好是当前处理到的二进制位为 $0$ 的子数列,$A_1$ 则是当前处理到的二进制位为 $1$ 的子数列。若当前位为 $0$,则只能取二进制位为 $0$ 的子数列 $A_0$ 才能使得 $i\cup j=i$。而若当前位为 $1$,则两种序列都能取。
考虑逆变换,则是将加上的 $A_0$ 减回去:
$$ a=\text{merge}(a_0, a_1-a_0) $$
下面我们给出代码实现。容易发现顺变换和逆变换可以合并为一个函数,顺变换时 $\text{type}=1$,逆变换时 $\text{type}=-1$。
void OR(int a[], int n, int tp) {
for (int len = 2; len <= n; len <<= 1) {
for (int i = 0; i < n; i += len) {
for (int j = i; j < i + (len >> 1); j ++ ) {
add(a[j + (len >> 1)], a[j] * tp);
}
}
}
}
按位与
$$ c_i=\sum_{j\cap k=i} a_jb_k $$
同理构造 $A_i=\sum_{i\cap j=i} a_i$。$C_i=A_iB_i$ 的正确性不证了。
容易发现,$A_0$ 恰好是当前处理到的二进制位为 $0$ 的子数列,$A_1$ 则是当前处理到的二进制位为 $1$ 的子数列。若当前位为 $1$,则只能取二进制位为 $1$ 的子数列 $A_0$ 才能使得 $i\cap j=i$。而若当前位为 $0$,则两种序列都能取。
$$ A=\text{merge}(A_0+A_1, A_1) $$
$$ a=\text{merge}(a_0 - a_1, a_1) $$
下面我们给出代码实现。顺变换时 $\text{type}=1$,逆变换时 $\text{type}=-1$。
void AND(int a[], int n, int tp) {
for (int len = 2; len <= n; len <<= 1) {
for (int i = 0; i < n; i += len) {
for (int j = i; j < i + (len >> 1); j ++ ) {
add(a[j], a[j + (len >> 1)] * tp);
}
}
}
}
按位异或
发现异或有点难搞,对此我们引入一个新的运算符 $\circ$。定义 $x\circ y=\text{popcnt}(x\cap y)\bmod 2$,其中 $\text{popcnt}$ 表示二进制下 $1$ 的个数,并重申一下 $\cap$ 表示按位与。
不用慌,我们也不需要你真正实现一个 $\text{popcnt}$,它仅仅只是作为一个理解的辅助罢了。
我们发现它满足 $(x\circ y)\oplus (x\circ z)=x\circ(y\oplus z)$。(重申一下 $\oplus$ 表示按位异或)
感性证明:发现这个新的运算符 $\circ$ 其实就是 $x$ 与 $y$ 相同位数的奇偶性。若 $(x\circ y)\oplus (x\circ z)=0$,则 $x$ 与 $y$、$x$ 与 $z$ 相同位数个数奇偶性相同,所以 $y\oplus z$ 和 $x$ 相同位数个数奇偶性也是相同的 ;若 $(x\circ y)\oplus (x\circ z)=1$,则 $x$ 与 $y$、$x$ 与 $z$ 相同位数个数奇偶性不同,所以 $y\oplus z$ 和 $x$ 相同位数个数奇偶性也是不同的。
设 $A_i=\sum_{i\circ j=0}a_j-\sum_{i\circ j=1}a_j$。我们来证一下 $C_i=A_iB_i$ 的正确性:
$$ \begin{align} A_iB_i&=(\sum_{i\circ j=0}a_j-\sum_{i\circ j=1}a_j)(\sum_{i\circ k=0}b_k-\sum_{i\circ k=1}b_k) \\\ &=(\sum_{i\circ j=0}a_j\sum_{i\circ k=0}b_k+\sum_{i\circ j=1}a_j\sum_{i\circ k=1}b_k)-(\sum_{i\circ j=0}a_j\sum_{i\circ k=1}b_k+\sum_{i\circ j=1}a_j\sum_{i\circ k=0}b_k) \\\ &=\sum_{(j\oplus k)\circ i=0}a_jb_k-\sum_{(j\oplus k)\circ i=1}a_jb_k \\\ &=C_i \end{align} $$
来看看怎么快速计算 $A,B$ 的值,依旧是分治:
对于 $i$ 在当前位为 $0$ 的子数列 $A_0$,进行 $\circ$ 运算时发现它和 $0$ 计算或和 $1$ 计算结果都不会变(因为 $0\cap 0=0,0\cap1=0$),所以 $A_i=\sum_{i\circ j=0}a_j-\sum_{i\circ j=1}a_j$ 中的 $\sum_{i\circ j=1}a_j=0$。
对于 $i$ 在当前位为 $1$ 的子数列 $A_1$,进行 $\circ$ 运算时发现它和 $0$ 计算结果是 $0$,和 $1$ 计算结果是 $1$(因为 $1\cap 0=0,1\cap1=1$)。
综上,有:
$$ A=\text{merge}((A_0+A_1)-0, A_0-A_1) $$
也就是:
$$ A=\text{merge}(A_0+A_1, A_0-A_1) $$
逆变换易得:
$$ a=\text{merge}(\frac{a_0+a_1}{2}, \frac{a_0-a_1}{2}) $$
给出代码,顺变换时 $\text{type}=1$,逆变换时 $\text{type}=\frac{1}{2}$。
void XOR(int a[], int n, int tp) {
for (int len = 2; len <= n; len <<= 1) {
int k = (len >> 1);
for (int i = 0; i < n; i += len) {
for (int j = i; j < i + k; j ++ ) {
int ls = a[j], rs = a[j + k];
a[j] = 1ll * (ls + rs) % mod * tp % mod;
a[j + k] = 1ll * (ls - rs + mod) % mod * tp % mod;
}
}
}
}
性质
类似 DFT,我们构造出的 FWT 也是一个线性变换。所以我们能得出如下性质:
$$ \begin{aligned} \mathrm{FWT}(A+B) & =\mathrm{FWT}(A)+\mathrm{FWT}(B) \\ \mathrm{FWT}(\lambda \cdot A) & =\lambda \cdot \mathrm{FWT}(A) \end{aligned} $$
「洛谷 P4717」 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)
求 $\cup$、$\cap$、$\oplus$ 的三种卷积。$n\le17$。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
#define fir first
#define sec second
#define ep emplace
#define eb emplace_back
#define lowbit(x) ((x) & (-(x)))
#define add(a, b) ((a) = ((a) + (b) >= mod))
inline int read() {
int x = 0, f = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
return x * f;
}
const int N = 131075;
const int mod = 998244353;
int lgn, n, a[N], b[N], c0[N], c1[N];
#define add(a, b) { \
if (((a) = (a) + (b)) >= mod) (a) -= mod; \
if (a < 0) a += mod; \
}
void OR(int a[], int n, int tp) {
for (int len = 2; len <= n; len <<= 1) {
for (int i = 0; i < n; i += len) {
for (int j = i; j < i + (len >> 1); j ++ ) {
add(a[j + (len >> 1)], a[j] * tp);
}
}
}
}
void AND(int a[], int n, int tp) {
for (int len = 2; len <= n; len <<= 1) {
for (int i = 0; i < n; i += len) {
for (int j = i; j < i + (len >> 1); j ++ ) {
add(a[j], a[j + (len >> 1)] * tp);
}
}
}
}
void XOR(int a[], int n, int tp) {
for (int len = 2; len <= n; len <<= 1) {
int k = (len >> 1);
for (int i = 0; i < n; i += len) {
for (int j = i; j < i + k; j ++ ) {
int ls = a[j], rs = a[j + k];
a[j] = 1ll * (ls + rs) % mod * tp % mod;
a[j + k] = 1ll * (ls - rs + mod) % mod * tp % mod;
}
}
}
}
int main() {
lgn = read(), n = 1 << lgn;
for (int i = 0; i < n; i ++ ) a[i] = c0[i] = read();
for (int i = 0; i < n; i ++ ) b[i] = c1[i] = read();
OR(c0, n, 1), OR(c1, n, 1);
for (int i = 0; i < n; i ++ ) c0[i] = 1ll * c0[i] * c1[i] % mod;
OR(c0, n, -1);
for (int i = 0; i < n; i ++ ) {
printf("%d ", c0[i]);
c0[i] = a[i], c1[i] = b[i];
}
puts(""), AND(c0, n, 1), AND(c1, n, 1);
for (int i = 0; i < n; i ++ ) c0[i] = 1ll * c0[i] * c1[i] % mod;
AND(c0, n, -1);
for (int i = 0; i < n; i ++ ) {
printf("%d ", c0[i]);
c0[i] = a[i], c1[i] = b[i];
}
puts(""), XOR(c0, n, 1), XOR(c1, n, 1);
for (int i = 0; i < n; i ++ ) c0[i] = 1ll * c0[i] * c1[i] % mod;
XOR(c0, n, 499122177);
for (int i = 0; i < n; i ++ ) printf("%d ", c0[i]);
return 0;
}
子集卷积
给定两个长度为 $2^n$($n \le 20$)的序列 $f_0,f_1,\cdots,f_{2^n-1}$ 和 $b_0,b_1,\cdots,b_{2^n-1}$,你需要求出一个序列 $g_0,g_1,\cdots,g_{2^n-1}$,其中 $h_k$ 满足:
$$ h_k=\sum_{{i\cap j=0},~{i \cup j=k}}f_ig_j $$
对于 $i\cup j=k$ 这一限制可以直接用 FWT/FMT 解决。
容易发现 $i\cap j=0$ 等价于 $|i|+|j|=|i\cup j|$,我们引入 占位多项式:
$$ F_{i, j}= \begin{cases}f_j & |j|=i \\ 0 & |j| \neq i\end{cases} $$
求 $H_{i, j}$ 时,可以枚举拼成 $j$ 的两个集合大小 $i_1$ 和 $i_2$ ,我们有:
$$ H_{i, j}=\sum_{i_1+i_2=i} \sum_{a \text { or } b} F_{i_1, a} G_{i_2, b} $$
也就是:
$$ H_i=\sum_{i_1+i_2=i} F_{i_1} * G_{i_2} $$
其中 $*$ 为或卷积,由于 FWT 是线性变换,因此一个 $F*G$ 在 $\sum$ 后的结果 $H$ 仍然是满足 FWT 性质的。
因此我们再做 IFWT 还原即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
#define fir first
#define sec second
#define ep emplace
#define eb emplace_back
#define lowbit(x) ((x) & (-(x)))
#define add(a, b) { \
if (((a) = (a) + (b)) >= mod) (a) -= mod; \
if ((a) < 0) (a) += mod; \
}
inline int read() {
int x = 0, f = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
return x * f;
}
const int N = 1048580;
const int mod = 1e9 + 9;
int n, m, a[21][N], b[21][N], h[21][N];
void OR(int a[], int n, int tp) {
for (int len = 2; len <= n; len <<= 1) {
for (int i = 0; i < n; i += len) {
for (int j = i; j < i + (len >> 1); j ++ ) {
add(a[j + (len >> 1)], a[j] * tp);
}
}
}
}
int main() {
n = read(), m = 1 << n;
for (int i = 0; i < m; i ++ ) a[__builtin_popcount(i)][i] = read();
for (int i = 0; i < m; i ++ ) b[__builtin_popcount(i)][i] = read();
for (int i = 0; i <= n; i ++ ) OR(a[i], m, 1), OR(b[i], m, 1);
for (int i = 0; i <= n; i ++ ) {
for (int j = 0; j <= i; j ++ ) {
for (int k = 0; k < m; k ++ ) {
add(h[i][k], 1ll * a[j][k] * b[i - j][k] % mod);
}
}
}
for (int i = 0; i <= n; i ++ ) OR(h[i], m, -1);
for (int i = 0; i < m; i ++ ) printf("%d ", h[__builtin_popcount(i)][i]);
return 0;
}
参考资料: