BZOJ3684 大朋友和多叉树

文章目录

题目链接

题解

这题做的我心累
一开始怎么想都不知道这个怎么用生成函数搞
后来去看题解。。。好吧是我太菜了
设 $F(x)$ 为方案数的 $OGF$
于是有转移
$$ F(x)=x + \sum_{i\in D} F^i(x)$$
解释一下,前面的 $x$ 是指叶子节点,也就是权值为 $1$ 时的方案数(显然只有一种),后面代表枚举儿子个数,用乘法原理将每个儿子的选择方案合起来,然后用加法原理把不同儿子数之间的方案合起来

然而就算有人告诉我这个我也不会做
鬼才想得到接下来用拉格朗日反演

$$ A(x)=x-\sum_{i\in D}x^i $$
于是有
$$ A(F(x))=x+\sum_{i\in D}F^i(x)-\sum_{i\in D}F^i(x) = x $$
然后愉快的使用拉格朗日反演
$$ [x^n]F(x)=\frac{1}{n}[x^{n-1}]\bigg{(}\frac{x}{A(x)}\bigg{)}^n $$
然后变化一下后面的式子
$$ \bigg{(}\frac{x}{A(x)}\bigg{)}^n = e^{-n\ln{(\frac{A(x)}{x})}}$$
然后经过一系列数学变换就能在最终多项式的第 $n-1$ 位得到 $F(x)$ 的第 $n$ 位,问题就解决了

代码

我写的太慢了需要大力卡常,顺便发现 bzoj 原来是 32 位机
在 32 位机上 long long 比 int 慢了好多
对此做了个实验
64 位
64.png
32 位(编译时开启 m32 选项)
32.png

由于差别太大,所以建议使用 int 然后在做可能爆的运算时转 long long 再转回来

#define MAX_N 300007
#define G 7LL
#define MOD 950009857
#define register
typedef long long ll;

int N, T;
int d[MAX_N];
int g[MAX_N], f[MAX_N], a[MAX_N], b[MAX_N], c[MAX_N], tmp[MAX_N], tmp2[MAX_N];
int r[MAX_N];

inline void input () {
    read(N), read(T);
    for (int i = 1; i <= T; ++i)
        read(d[i]);
}

inline int quick_pow (int x, int p) {
    int ans = 1;
    while (p) {
        if (p & 1) ans = (int)((ll)ans * (ll)x % MOD);
        p >>= 1;
        x = (int)((ll)x * (ll)x % MOD);
    }
    return ans % MOD;
}

inline void NTT (int *A, int M, int type) {
    for (int i = 0; i < M; ++i)
        if (i < r[i])
            std::swap(A[i], A[r[i]]);
    for (int i = 1; i < M; i <<= 1) {
        int gn = quick_pow(G, (MOD - 1) / (i << 1));
        if (type == -1)
            gn = quick_pow(gn, MOD - 2);
        for (int j = 0, L = i << 1; j < M; j += L) {
            int g = 1, x, y;
            for (int k = 0; k < i; ++k) {
                x = A[j + k], y = (int)((ll)A[j + i + k] * (ll)g % MOD);
                A[j + k] = (x + y) % MOD, A[j + i + k] = (x - y + MOD) % MOD;
                g = (int)((ll)g * (ll)gn % MOD);
            }
        }
    }
    if (type == 1) return;
    for (int i = 0; i < M; ++i)
        A[i] = (int)((ll)A[i] * (ll)quick_pow(M, MOD - 2) % MOD);
}

void trans (int deg, int *a, int *b) {
    if (deg == 1) {
        b[0] = quick_pow(a[0], MOD - 2);
        return;
    }
    trans((deg + 1) >> 1, a, b);
    int M = 1, len = 0;
    while (M < (deg << 1)) M <<= 1, ++len;
    for (int i = 0; i < M; ++i)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << len - 1);
    for (int i = 0; i < deg; ++i)
        tmp[i] = a[i];
    for (int i = deg; i < M; ++i)
        tmp[i] = 0;
    NTT(tmp, M, 1), NTT(b, M, 1);
    for (int i = 0; i < M; ++i)
        b[i] = (int)((2LL - (ll)b[i] * (ll)tmp[i] % MOD + MOD) % MOD * (ll)b[i] % MOD);
    NTT(b, M, -1);
    for (int i = deg; i < M; ++i)
        b[i] = 0;
}

inline void transln (int N, int *a, int *b) {
    memset(tmp, 0, sizeof(tmp));
    trans(N, a, b);
    for (int i = 1; i < N; ++i)
        a[i - 1] = (int)((ll)a[i] * (ll)i % MOD);
    int M = 1, len = 0;
    while (M < (N << 1)) M <<= 1, ++len;
    for (int i = 0; i < M; ++i)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << len - 1);
    for (int i = N; i < M; ++i)
        b[i] = 0;
    for (int i = N - 1; i < M; ++i)
        a[i] = 0;
    NTT(b, M, 1), NTT(a, M, 1);
    for (int i = 0; i < M; ++i)
        a[i] = (int)((ll)a[i] * (ll)b[i] % MOD);
    NTT(a, M, -1);
    for (int i = N; i < M; ++i)
        a[i] = 0;
    for (int i = N - 1; i; --i)
        a[i] = (int)((ll)a[i - 1] * (ll)quick_pow(i, MOD - 2) % MOD);
    a[0] = 0;
}

void transe (int deg, int *a, int *b, int *c) {
    if (deg == 1) {
        b[0] = 1;
        return;
    }
    transe((deg + 1) >> 1, a, b, c);
    memset(tmp2, 0, sizeof(tmp2));
    for (int i = 0; i < deg; ++i)
        c[i] = b[i];
    transln(deg, c, tmp2); //trans c -> ln b
    for (int i = 0; i < deg; ++i)
        c[i] = (c[i] - a[i] + MOD) % MOD; //c -> ln b - a
    int len = 0, M = 1;
    while (M < (deg << 1)) M <<= 1, ++len;
    for (int i = deg; i < M; ++i)
        b[i] = 0;
    NTT(c, M, 1), NTT(b, M, 1);
    for (int i = 0; i < M; ++i)
        b[i] = (int)((1LL - (ll)c[i] + MOD) % MOD * (ll)b[i] % MOD);
    NTT(b, M, -1);
    for (int i = deg; i < M; ++i)
        b[i] = 0;
}

inline void solve () {
    memset(g, 0, sizeof(g));
    g[0] = 1;
    for (int i = 1; i <= T; ++i)
        g[d[i] - 1] = -1;
    transln(N, g, b);
    for (int i = 0; i < N; ++i)
        g[i] = (int)(((ll)g[i] * -1LL * (ll)N % MOD + MOD) % MOD);
    transe(N, g, b, c);
    printf("%d\n", (int)((ll)b[N - 1] * (ll)quick_pow(N, MOD - 2) % MOD));
}

int main () {
    input();
    solve();
    return 0;
}

sys_con

高一 OIer,常规 / 竞赛都渣得不行

发表评论

电子邮件地址不会被公开。 必填项已用*标注