Loading... ## 快速沃尔什变换解决的卷积问题 快速沃尔什变换(FWT)是解决这样一类卷积问题: $$ c_i=\sum_{i=j\odot k}a_jb_k $$ 其中,$\odot$ 是位运算的一种。举个例子,给定数列 $a,b$,求: $$ c_i=\sum_{j\oplus k=i} a_jb_k $$ ## FWT 的思想 看到 FWT 的名字,我们可以联想到之前学过的 FFT(很可惜,我没有写过 FFT 的笔记,所以没有链接),先看看 FFT 的原理: 1. 把 $a,b$ 变换为 $A,B$,$O(n\log n)$; 2. 通过 $C_i=A_iB_i$ 计算,$O(n)$; 3. 把 $C$ 变换回 $c$,$O(n\log n)$。 综上,时间复杂度是 $O(n\log n)$ 的。 在 FFT 中,我们构造了 $A,B$ 为 $a,b$ 的点值表示法,这么做满足 $C_i=A_iB_i$ 且容易变换。 其实 FWT 的思想也是一样的,主要也是需要构造 $A,B$,使得其满足 $C_i=A_iB_i$ 且可以快速变换。下面我们举 $\cup$(按位或)、$\cap$(按位与)和 $\oplus$(按位异或)为例。 因为数列长度是 $2$ 的幂会更好处理,所以下文认为数列长度为 $2^n$。 ### 按位或 $$ c_i=\sum_{j\cup k=i} a_jb_k $$ 我们可以构造 $A_i=\sum_{i\cup j=i} a_i$。看看为什么需要这么构造。 首先,它满足 $C_i=A_iB_i$: $$ \begin{align} A_iB_i&=(\sum_{i\cup j=i} a_j)(\sum_{i\cup k=i} b_k) \\\ &=\sum_{i\cup j=i}\sum_{i\cup k=i}a_jb_k \\\ &=\sum_{i\cup j=i}\sum_{i\cup k=i}a_jb_k \\\ &=\sum_{i\cup(j\cup k)=i}a_jb_k \\\ &= C_i \end{align} $$ 其次,它可以快速变换。举顺变换的例子。类比 FFT 的步骤,我们采用分治的方法来处理它。假设目前考虑到第 $i$ 位,其中 $A_0$ 和 $A_1$ 是 $i-1$ 位分治的结果: $$ A=\text{merge}(A_0, A_0+A_1) $$ 其中,$A_0$ 是数列 $A$ 的左半部分,$A_1$ 是 $A$ 的右半部分。$\text{merge}$ 函数就是把两个数列像拼接字符串一样拼接起来。$+$ 则是将两个数列对应相加。 这么做为什么是正确的呢?容易发现,$A_0$ 恰好是当前处理到的二进制位为 $0$ 的子数列,$A_1$ 则是当前处理到的二进制位为 $1$ 的子数列。若当前位为 $0$,则只能取二进制位为 $0$ 的子数列 $A_0$ 才能使得 $i\cup j=i$。而若当前位为 $1$,则两种序列都能取。 --- 考虑逆变换,则是将加上的 $A_0$ 减回去: $$ a=\text{merge}(a_0, a_1-a_0) $$ 下面我们给出代码实现。容易发现顺变换和逆变换可以合并为一个函数,顺变换时 $\text{type}=1$,逆变换时 $\text{type}=-1$。 ```cpp void OR(int a[], int n, int tp) { for (int len = 2; len <= n; len <<= 1) { for (int i = 0; i < n; i += len) { for (int j = i; j < i + (len >> 1); j ++ ) { add(a[j + (len >> 1)], a[j] * tp); } } } } ``` ### 按位与 $$ c_i=\sum_{j\cap k=i} a_jb_k $$ 同理构造 $A_i=\sum_{i\cap j=i} a_i$。$C_i=A_iB_i$ 的正确性不证了。 容易发现,$A_0$ 恰好是当前处理到的二进制位为 $0$ 的子数列,$A_1$ 则是当前处理到的二进制位为 $1$ 的子数列。若当前位为 $1$,则只能取二进制位为 $1$ 的子数列 $A_0$ 才能使得 $i\cap j=i$。而若当前位为 $0$,则两种序列都能取。 $$ A=\text{merge}(A_0+A_1, A_1) $$ $$ a=\text{merge}(a_0 - a_1, a_1) $$ --- 下面我们给出代码实现。顺变换时 $\text{type}=1$,逆变换时 $\text{type}=-1$。 ```cpp void AND(int a[], int n, int tp) { for (int len = 2; len <= n; len <<= 1) { for (int i = 0; i < n; i += len) { for (int j = i; j < i + (len >> 1); j ++ ) { add(a[j], a[j + (len >> 1)] * tp); } } } } ``` ### 按位异或 发现异或有点难搞,对此我们引入一个新的运算符 $\circ$。定义 $x\circ y=\text{popcnt}(x\cap y)\bmod 2$,其中 $\text{popcnt}$ 表示二进制下 $1$ 的个数,并重申一下 $\cap$ 表示按位与。 不用慌,我们也不需要你真正实现一个 $\text{popcnt}$,它仅仅只是作为一个理解的辅助罢了。 我们发现它满足 $(x\circ y)\oplus (x\circ z)=x\circ(y\oplus z)$。(重申一下 $\oplus$ 表示按位异或) 感性证明:发现这个新的运算符 $\circ$ 其实就是 $x$ 与 $y$ 相同位数的奇偶性。若 $(x\circ y)\oplus (x\circ z)=0$,则 $x$ 与 $y$、$x$ 与 $z$ 相同位数个数奇偶性相同,所以 $y\oplus z$ 和 $x$ 相同位数个数奇偶性也是相同的 ;若 $(x\circ y)\oplus (x\circ z)=1$,则 $x$ 与 $y$、$x$ 与 $z$ 相同位数个数奇偶性不同,所以 $y\oplus z$ 和 $x$ 相同位数个数奇偶性也是不同的。 设 $A_i=\sum_{i\circ j=0}a_j-\sum_{i\circ j=1}a_j$。我们来证一下 $C_i=A_iB_i$ 的正确性: $$ \begin{align} A_iB_i&=(\sum_{i\circ j=0}a_j-\sum_{i\circ j=1}a_j)(\sum_{i\circ k=0}b_k-\sum_{i\circ k=1}b_k) \\\ &=(\sum_{i\circ j=0}a_j\sum_{i\circ k=0}b_k+\sum_{i\circ j=1}a_j\sum_{i\circ k=1}b_k)-(\sum_{i\circ j=0}a_j\sum_{i\circ k=1}b_k+\sum_{i\circ j=1}a_j\sum_{i\circ k=0}b_k) \\\ &=\sum_{(j\oplus k)\circ i=0}a_jb_k-\sum_{(j\oplus k)\circ i=1}a_jb_k \\\ &=C_i \end{align} $$ 来看看怎么快速计算 $A,B$ 的值,依旧是分治: 对于 $i$ 在当前位为 $0$ 的子数列 $A_0$,进行 $\circ$ 运算时发现它和 $0$ 计算或和 $1$ 计算结果都不会变(因为 $0\cap 0=0,0\cap1=0$),所以 $A_i=\sum_{i\circ j=0}a_j-\sum_{i\circ j=1}a_j$ 中的 $\sum_{i\circ j=1}a_j=0$。 对于 $i$ 在当前位为 $1$ 的子数列 $A_1$,进行 $\circ$ 运算时发现它和 $0$ 计算结果是 $0$,和 $1$ 计算结果是 $1$(因为 $1\cap 0=0,1\cap1=1$)。 综上,有: $$ A=\text{merge}((A_0+A_1)-0, A_0-A_1) $$ 也就是: $$ A=\text{merge}(A_0+A_1, A_0-A_1) $$ 逆变换易得: $$ a=\text{merge}(\frac{a_0+a_1}{2}, \frac{a_0-a_1}{2}) $$ 给出代码,顺变换时 $\text{type}=1$,逆变换时 $\text{type}=\frac{1}{2}$。 ```cpp void XOR(int a[], int n, int tp) { for (int len = 2; len <= n; len <<= 1) { int k = (len >> 1); for (int i = 0; i < n; i += len) { for (int j = i; j < i + k; j ++ ) { int ls = a[j], rs = a[j + k]; a[j] = 1ll * (ls + rs) % mod * tp % mod; a[j + k] = 1ll * (ls - rs + mod) % mod * tp % mod; } } } } ``` ### 性质 类似 DFT,我们构造出的 FWT 也是一个线性变换。所以我们能得出如下性质: $$ \begin{aligned} \mathrm{FWT}(A+B) & =\mathrm{FWT}(A)+\mathrm{FWT}(B) \\ \mathrm{FWT}(\lambda \cdot A) & =\lambda \cdot \mathrm{FWT}(A) \end{aligned} $$ **[「洛谷 P4717」 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)](https://www.luogu.com.cn/problem/P4717)** 求 $\cup$、$\cap$、$\oplus$ 的三种卷积。$n\le17$。 ```cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int> PII; #define fir first #define sec second #define ep emplace #define eb emplace_back #define lowbit(x) ((x) & (-(x))) #define add(a, b) ((a) = ((a) + (b) >= mod)) inline int read() { int x = 0, f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1; for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48); return x * f; } const int N = 131075; const int mod = 998244353; int lgn, n, a[N], b[N], c0[N], c1[N]; #define add(a, b) { \ if (((a) = (a) + (b)) >= mod) (a) -= mod; \ if (a < 0) a += mod; \ } void OR(int a[], int n, int tp) { for (int len = 2; len <= n; len <<= 1) { for (int i = 0; i < n; i += len) { for (int j = i; j < i + (len >> 1); j ++ ) { add(a[j + (len >> 1)], a[j] * tp); } } } } void AND(int a[], int n, int tp) { for (int len = 2; len <= n; len <<= 1) { for (int i = 0; i < n; i += len) { for (int j = i; j < i + (len >> 1); j ++ ) { add(a[j], a[j + (len >> 1)] * tp); } } } } void XOR(int a[], int n, int tp) { for (int len = 2; len <= n; len <<= 1) { int k = (len >> 1); for (int i = 0; i < n; i += len) { for (int j = i; j < i + k; j ++ ) { int ls = a[j], rs = a[j + k]; a[j] = 1ll * (ls + rs) % mod * tp % mod; a[j + k] = 1ll * (ls - rs + mod) % mod * tp % mod; } } } } int main() { lgn = read(), n = 1 << lgn; for (int i = 0; i < n; i ++ ) a[i] = c0[i] = read(); for (int i = 0; i < n; i ++ ) b[i] = c1[i] = read(); OR(c0, n, 1), OR(c1, n, 1); for (int i = 0; i < n; i ++ ) c0[i] = 1ll * c0[i] * c1[i] % mod; OR(c0, n, -1); for (int i = 0; i < n; i ++ ) { printf("%d ", c0[i]); c0[i] = a[i], c1[i] = b[i]; } puts(""), AND(c0, n, 1), AND(c1, n, 1); for (int i = 0; i < n; i ++ ) c0[i] = 1ll * c0[i] * c1[i] % mod; AND(c0, n, -1); for (int i = 0; i < n; i ++ ) { printf("%d ", c0[i]); c0[i] = a[i], c1[i] = b[i]; } puts(""), XOR(c0, n, 1), XOR(c1, n, 1); for (int i = 0; i < n; i ++ ) c0[i] = 1ll * c0[i] * c1[i] % mod; XOR(c0, n, 499122177); for (int i = 0; i < n; i ++ ) printf("%d ", c0[i]); return 0; } ``` ## 子集卷积 给定两个长度为 $2^n$($n \le 20$)的序列 $f_0,f_1,\cdots,f_{2^n-1}$ 和 $b_0,b_1,\cdots,b_{2^n-1}$,你需要求出一个序列 $g_0,g_1,\cdots,g_{2^n-1}$,其中 $h_k$ 满足: $$ h_k=\sum_{{i\cap j=0},~{i \cup j=k}} f_i g_j $$ 对于 $i\cup j=k$ 这一限制可以直接用 FWT/FMT 解决。 容易发现 $i\cap j=0$ 等价于 $|i|+|j|=|i\cup j|$,我们引入 **占位多项式**: $$ F_{i, j}= \begin{cases}f_j & |j|=i \\ 0 & |j| \neq i\end{cases} $$ 求 $H_{i, j}$ 时,可以枚举拼成 $j$ 的两个集合大小 $i_1$ 和 $i_2$ ,我们有: $$ H_{i, j}=\sum_{i_1+i_2=i} \sum_{a \text { or } b} F_{i_1, a} G_{i_2, b} $$ 也就是: $$ H_i=\sum_{i_1+i_2=i} F_{i_1} * G_{i_2} $$ 其中 $*$ 为或卷积,由于 FWT 是线性变换,因此一个 $F*G$ 在 $\sum$ 后的结果 $H$ 仍然是满足 FWT 性质的。 因此我们再做 IFWT 还原即可。 **[「洛谷 P6097」【模板】子集卷积](https://www.luogu.com.cn/problem/P6097)** ```cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int> PII; #define fir first #define sec second #define ep emplace #define eb emplace_back #define lowbit(x) ((x) & (-(x))) #define add(a, b) { \ if (((a) = (a) + (b)) >= mod) (a) -= mod; \ if ((a) < 0) (a) += mod; \ } inline int read() { int x = 0, f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1; for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48); return x * f; } const int N = 1048580; const int mod = 1e9 + 9; int n, m, a[21][N], b[21][N], h[21][N]; void OR(int a[], int n, int tp) { for (int len = 2; len <= n; len <<= 1) { for (int i = 0; i < n; i += len) { for (int j = i; j < i + (len >> 1); j ++ ) { add(a[j + (len >> 1)], a[j] * tp); } } } } int main() { n = read(), m = 1 << n; for (int i = 0; i < m; i ++ ) a[__builtin_popcount(i)][i] = read(); for (int i = 0; i < m; i ++ ) b[__builtin_popcount(i)][i] = read(); for (int i = 0; i <= n; i ++ ) OR(a[i], m, 1), OR(b[i], m, 1); for (int i = 0; i <= n; i ++ ) { for (int j = 0; j <= i; j ++ ) { for (int k = 0; k < m; k ++ ) { add(h[i][k], 1ll * a[j][k] * b[i - j][k] % mod); } } } for (int i = 0; i <= n; i ++ ) OR(h[i], m, -1); for (int i = 0; i < m; i ++ ) printf("%d ", h[__builtin_popcount(i)][i]); return 0; } ``` 参考资料: - ZnPdCo [博客园](https://www.cnblogs.com/znpdco/p/18172429) 最后修改:2025 年 08 月 15 日 © 允许规范转载 打赏 赞赏作者 支付宝微信 赞 2 赠人玫瑰,手有余香。您的赞赏是对我最大的支持!