「题解 P5072」Ynoi2015 盼君勿忘

【幸好你在失去一切之前, 回到了这里】

【喜悦和悲伤】

【还有喜欢某个人的情绪】

【现在依旧还残存着一些吧?】

这题是一道偏向思维的莫队题, 最主要的问题是将传统莫队模板题中的“区间 $[l,r]$ 中不同的数之和”换为了“区间中每个子序列不同的数之和”, 显然, 我们不能对每个子序列分别处理, 那么我们则考虑从整个区间对子序列的贡献方面着手

若整个区间中某个数 $x$ 出现了 $k$ 次, 那么它产生的贡献为:

$\qquad ans=x\times$ (子序列个数-不含该数的子序列数)

$\qquad \qquad = x\times(2^{len}-2^{len-k})$

如果对于每个不同的 $x$ 进行计算, 那么时间复杂度至少是 $O(mn)$ 级别的, 但如果对于不同的 $k$ 计算, 即使在最坏情况下, 即: $n=\frac{k(k-1)}{2}$, 在使用双向链表的情况下, 我们会发现每次查询的总时间复杂度约为 $O(m\sqrt{n})$, 再套上莫队的 $O(n\sqrt{n})$, 是可以接受 $10^5$ 的数据量的.

但是在具体代码实现的过程中, 我们会发现由于对于每一次询问, $p$ 各不相同, 所以无法一次预处理, 而根号的复杂度说明即使再套上一个快速幂的 $O(logn)$ 也将无法接受, 所以可以考虑使用我看题解才知道的光速幂, 我们注意到:

$\qquad 2^k = 2^{\frac{k}{\sqrt{n}}\times \sqrt{n}}=2^{\lfloor \frac{k}{\sqrt{n}}\rfloor \times \sqrt{n}}\times 2^{k\ mod\sqrt{n}}$

因此, 可以对每次询问预处理出两个数组 $pow1,pow2$, 分别储存 $2^1,2^2,2^3,…,2^{\sqrt{n}}$ 与 $2^{\sqrt{n}}, 2^{2\sqrt{n}},…,2^n$, 那么在枚举 $k$ 计算时可以做到 $O(1)$ 得到2的整数次幂取模结果

当然, 最后还是要记得使用一些防毒瘤卡常 $trick$: 1.fread读入 (其实并没有优化太多) 2.手写双向链表3.各种位运算优化莫队 …

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 4e5 + 5;

class QuickRead
{

public:
    inline char gc()
    {
        static const int IN_LEN = 1 << 18 | 1;
        static char buf[IN_LEN], *s, *t;
        return (s == t) && (t = (s = buf) + fread(buf, 1, IN_LEN, stdin)), s == t ? -1 : *s++;
    }

    template <typename _Tp>
    inline QuickRead & operator >> (_Tp & num)
    {
        static char ch, sgn;
        ch = gc();
        while (!isdigit(ch))
        {
            if (ch == -1)
                return *this;
            sgn |= (ch == '-');
            ch = gc();
        }
        num = 0;
        while (isdigit(ch))
        {
            num = (num << 1) + (num << 3) + (ch ^ 48);
            ch = gc();
        }
        sgn && (num = -num);
        return *this;
    }
} qin;

class LinkList
{

public:
    struct Node {
        int pre, nxt;
    } node[MAXN];
    int las;

    LinkList() { las = 0; }
    inline void insert(int x)
    {
        node[las].nxt = x;
        node[x].pre = las;
        las = x;
    }

    inline void erase(int x)
    {
        if (x ^ las)
        {
            node[node[x].pre].nxt = node[x].nxt;
            node[node[x].nxt].pre = node[x].pre;
        }
        else las = node[x].pre;
        node[x].pre = node[x].nxt = 0;
    }
} lis;

struct Ask {
    int l, r, pos, num;
    ll p;
    bool operator < (const Ask & x) const
    {
        if (pos != x.pos)
            return pos < x.pos;
        else {
            if (pos & 1)
                return r < x.r;
            else return r > x.r;
        }
    }
} query[MAXN];

inline void inc(ll & x, ll y, ll p) { x = (x + y) % p; }

int n, q, len;
ll data[MAXN], ans[MAXN];
int a[MAXN], cnt[MAXN];
ll sum[MAXN];
pair<int, int> b[MAXN];

inline void add(int num)
{
    sum[cnt[num]] -= data[num];
    if (!sum[cnt[num]])
        lis.erase(cnt[num]);
    cnt[num]++;
    if (!sum[cnt[num]])
        lis.insert(cnt[num]);
    sum[cnt[num]] += data[num];
}
inline void erase(int num)
{
    sum[cnt[num]] -= data[num];
    if (!sum[cnt[num]])
        lis.erase(cnt[num]);
    cnt[num]--;
    if (!sum[cnt[num]])
        lis.insert(cnt[num]);
    sum[cnt[num]] += data[num];
}

ll pow1[MAXN], pow2[MAXN];
/*  pow1[i] refers to 2^i,
    pow2[i] refers to 2^(i*sqrt(x)),
    so that 2^k = 2^[(k/sqrt(x))*sqrt(x)]*2^(k mod sqrt(x))
*/
inline void pre_pow(ll p)
{
    pow1[0] = 1;
    for (int i = 1; i <= len; i++)
        pow1[i] = 1LL * pow1[i - 1] * 2 % p;
    pow2[0] = 1;
    for (int i = 1; i <= len; i++)
        pow2[i] = 1LL * pow2[i - 1] * pow1[len] % p;
}
inline ll get_pow(ll k, ll p) { return 1LL * pow2[k / len] * pow1[k % len] % p; }

int main()
{
    qin >> n >> q;
    len = sqrt(n);
    for (int i = 1; i <= n; i++)
    {
        qin >> b[i].first;
        b[i].second = i;
    }
    sort(b + 1, b + n + 1);
    for (int i = 1, j = 0; i <= n; i++)
    {
        if (b[i].first != b[i - 1].first)
            j++;
        a[b[i].second] = j;
        data[j] = (ll)b[i].first;
    }
    for (int i = 1; i <= q; i++)
    {
        qin >> query[i].l >> query[i].r >> query[i].p;
        query[i].pos = query[i].l / len;
        query[i].num = i;
    }
    sort(query + 1, query + q + 1);
    int l = 1, r = 0;
    for (int i = 1; i <= q; i++)
    {
        pre_pow(query[i].p);
        while (l < query[i].l)
            erase(a[l++]);
        while (l > query[i].l)
            add(a[--l]);
        while (r < query[i].r)
            add(a[++r]);
        while (r > query[i].r)
            erase(a[r--]);
        int siz = query[i].r - query[i].l + 1;
        ll MOD = query[i].p;
        for (int j = lis.node[0].nxt; j; j = lis.node[j].nxt)
            inc(ans[query[i].num], sum[j] * (get_pow(siz, MOD) - get_pow(siz - j, MOD) + MOD) % MOD, MOD);
    }
    for (int i = 1; i <= q; i++)
        printf("%lld\n", ans[i]);
    return 0;
}