题目大意 给定 $n$($2 \le n \le 3 \times 10^5$)个数 $a_i$($0 \le a_i \le 2^{30}-1$)和 $k$($0 \le k \le 2^{30}-1$),挑出尽可能多的数,并且使得它们的任意两两异或和不小于 $k$,输出数量和挑选的数。
题目链接
思路 首先需要推出一个非常重要的结论:
从 $n$ 个数中挑出两个数,使得它们的异或和最小,那么这两个数一定是把这 $n$ 个数排序后相邻的数。
证明:假设有三个数 $a < b < c$:
$a \oplus c < b \oplus c$:那么我们比较一下 $b, c$,在二进制下从最高位往低位找,出现的第一个不同的位数,因为 $a < b < c$,所以这一位肯定是 $c$ 这一位为 $1$,$a, b$ 这一位为 $0$。所以可以推出 $a \oplus b < a \oplus c$。
$a \oplus c > b \oplus c$:我们只需要比较一下 $a \oplus b, b \oplus c$ 即可,最小值还是出现在相邻的数中的。
有了这个结论后,我们把 $n$ 个数从小到大排序后,从小到大遍历。假如我们已经找到了一个集合 $s$,遍历到了新的数 $a_i$,想要判断它能否放入这个集合中,只需判断 $a_i \oplus \max(s) \ge k$。
因此,我们设 $dp_i$ 为以 $a_i$ 为最大值的集合中最多有多少个数,可以得到状态转移方程:
同时题目让我们输出每个被选中的数,所以我们在更新 $dp_j$ 时记录 $pre[j] = i$。
但是,直接进行状态转移的时间复杂度是 $O(n^2)$,会超时。由于是异或操作,我们想到用 trie 树来进行加速。
从最高位开始往树中插入。在 trie 中查询与 $a_i$ 异或值 $\ge k$ 的数时,设当前 bit 位和 root 为当前节点:
如果这一位 $k$ 为 $1$,那么我们只能往 $trie[root][bit \oplus 1]$ 走(XOR这一位才能是1)。
如果这一位 $k$ 为 $0$,那么有两种选择:
选择 $trie[root][bit \oplus 1]$:那么异或和就一定大于 $k$ 了,直接更新答案。
选择 $trie[root][bit]$:继续往下寻找更新答案。
计算完 $dp_i$ 后,将 $a_i$ 插入树中,同时更新 $f[root]$,表示经过这个点的数中,集合最大的数是多少。
代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 #include <cstdio> #include <iostream> #include <algorithm> using namespace std;const int maxN = 3e5 + 5 ;struct Num { int val, index; }a[maxN]; bool cmp (Num x, Num y) { return x.val < y.val; } int n, k, pre[maxN * 31 ], trie[maxN * 31 ][2 ], f[maxN * 31 ], d[maxN * 31 ], cnt = 1 ;int main () { scanf ("%d%d" , &n, &k); for (int i = 1 ; i <= n; ++i) { scanf ("%d" , &a[i].val); a[i].index = i; } sort (a + 1 , a + 1 + n, cmp); for (int i = 1 ; i <= n; ++i) { int ans = 0 , root = 1 ; for (int j = 30 ; j >= 0 && root; --j) { int bit = (a[i].val >> j) & 1 , rev = bit ^ 1 ;; if ((k >> j) & 1 ) root = trie[root][rev]; else { if (d[ans] < d[f[trie[root][rev]]]) ans = f[trie[root][rev]]; root = trie[root][bit]; } if (j == 0 && d[ans] < d[f[root]]) ans = f[root]; } d[i] = d[ans] + 1 ; pre[i] = ans; root = 1 ; for (int j = 30 ; j >= 0 ; j--) { int bit = (a[i].val >> j) & 1 ; if (!trie[root][bit]) trie[root][bit] = ++cnt; root = trie[root][bit]; if (d[f[root]] < d[i]) f[root] = i; } } int ans = 1 ; for (int i = 2 ; i <= n; ++i) if (d[ans] < d[i]) ans = i; if (d[ans] < 2 ) printf ("-1\n" ); else { printf ("%d\n" , d[ans]); for (int i = ans; i != 0 ; i = pre[i]) printf ("%d " , a[i].index); } return 0 ; }