BZOJ5104 Fib数列

题意

在 $mod\ \ 10^9+\color{red}9$ 的意义下求出数字 $n$ 在 Fib 数列中出现在哪个位置
由于没有 spj如果有多个,输出最靠前的那个位置)

题解

之前 Picks 讲课的时候提到了这样一道题,但没说来源,所以我一直不知道网上有,后来乱逛时发现了这道题,做了一下
题解放在这篇文章的 T1 那里了
这里稍微讲一下模数为奇素数的二次剩余,使用 Cipolla 算法求解
现在要解决一个问题,给定 $n$ ,要求出模意义下的 $x$ 使
$$ x^2\equiv n\pmod{p} $$
或者说求 $\sqrt{n}\pmod{p}$
这里只是简要讲解,更详细的请到引用资料里看

由欧拉定理
$$
\begin{eqnarray}
x^{p-1}\equiv 1\pmod{p} \\
n^{p-1}\equiv 1\pmod{p} \\
\end{eqnarray}
$$
由于 $1$ 的二次剩余只有 $1,-1$ ,所以当 $n$ 存在二次剩余时
$$ x^{p-1}\equiv n^{\frac{p-1}{2}}\equiv 1\pmod{p} $$
否则
$$ n^{\frac{p-1}{2}}\equiv -1\pmod{p} $$
此时引进勒让德符号
$$
x^\frac{p-1}{2}\equiv
\begin{cases}
1& (x存在二次剩余)\\
-1& (x不存在二次剩余)\\
0& (x是0)
\end{cases}
$$
但是只用勒让德符号只能做判断,而无法求出某个数的二次剩余
引用一个结论
如果 $a^2-n$ 没有二次剩余。那么 $n$ 有二次剩余,且 $n$ 的二次剩余为
$$\big{(}a+\sqrt{a^2-n}\big{)}^{p+1} $$
证明在引用资料里

这时就可以得出 Cipolla 算法的大致过程了:
首先随便选一个 $a$ ,判断 $a^2-n$ 有没有二次剩余。
如果有,继续找 $a$ ,直到找到为止。
如果没有有,直接根据上面的结论求出 $n$ 的二次剩余
由于约有一半的数没有二次剩余,我们这样做的尝试次数期望是 $2$,因此算法总复杂度为$\mathcal{O(\log(n))}$

需要注意的是,由于过程中带有根号,所以需要实现一个类似于复数的数域 $a+b\sqrt{w}$
这个数域中的乘法定义为
$$ (a,b\sqrt{w})\times (c,d\sqrt{w})=(ac+bdw,(ad+bd)\sqrt{w})$$

代码

好想不一定好写
一道数论题还这么长代码,吓得我赶快打了几个 namespace
细节比较多,需要注意一下,尽量把式子写工整把思路理清楚了再写
理性愉悦一下吧

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <map>

#define reg register
#define MAX_N 100006
#define MOD 1000000009
#define sqrt5 383008016
#define X 691504013
#define INV2 500000005
typedef long long ll;

template <typename _T> inline void read (_T& x) {
    x = 0;
    reg _T y = 1;
    reg char c = getchar();
    while (c < '0' || '9' < c) {
        if (c == '-') y = -1;
        c = getchar();
    }
    while ('0' <= c && c <= '9') x = x * 10 + c - '0', c = getchar();
    x *= y;
}

int N;

namespace math {
    inline int quick_pow (int x, int p) {
        int ans = 1;
        while (p) {
            if (p & 1) ans = 1LL * ans * x % MOD;
            x = 1LL * x * x % MOD;
            p >>= 1;
        }
        return ans;
    }
};

namespace cipolla {
    int w;
    struct complex {
        int a, b;

        complex (const int& x = 0, const int& y = 0) : a(x), b(y) {}
        complex operator * (const complex x) {
            return complex(
                ((1LL * a * x.a % MOD) + (1LL * b * x.b % MOD * w % MOD)) % MOD,
                ((1LL * a * x.b % MOD) + (1LL * b * x.a % MOD)) % MOD);
        }
        complex operator % (const complex x) {return complex(a % MOD, b % MOD);}
    };

    inline complex quick_pow (complex x, int p) {
        complex ans(1);
        while (p) {
            if (p & 1) ans = ans * x;
            x = x * x;
            p >>= 1;
        }
        return ans;
    }

    inline int legend (int x) {return math::quick_pow(x, (MOD - 1) >> 1);}

    inline int getSqrt (int x) {
        if (x == 0) return 0;
        if (legend(x) == MOD - 1) return -1;
        while (true) {
            int a = rand() % (MOD - 2) + 1; w = (1LL * a * a % MOD - x + MOD) % MOD;
            if (legend(w) == MOD - 1)
                return quick_pow(complex(a, 1), (MOD + 1) >> 1).a;
        }
    }
};

namespace BSGS {
    int l;
    std::map<int, int> has;

    inline void init () {
        has.clear();
        int z = 1, x = X;
        l = 1;
        for (int i = 1; i * i <= MOD; ++i, ++l) {
            z = 1LL * z * x % MOD; //z=x^i
            if (!has[z]) has[z] = i;
        }
    }

    inline int getRes (int y) {
        if (has[y]) return has[y];
        int z = 1, x = math::quick_pow(X, l), res = 0;
        for (int i = 1; i * i <= MOD; ++i) {
            z = 1LL * z * x % MOD;
            if ((res = has[1LL * math::quick_pow(z, MOD - 2) * y % MOD]) != 0 && 1LL * (1LL * i * l + res) <= MOD)
                return i * l + res;
        }
        return -1;
    }
};

int mn = -1;
inline void solve (int y, int a, int b) {
    int sqrtDelta = cipolla::getSqrt(((1LL * b * b % MOD) + 1LL * 4LL * a + MOD) % MOD);
    if (sqrtDelta == -1) return;
    int ans = 1LL * (1LL * sqrt5 * y % MOD - sqrtDelta + MOD) % MOD * INV2 % MOD;
    int res = BSGS::getRes(ans); //X^res\equiv ans
    if (res != -1) mn = (mn == -1 ? res : std::min(res, mn));
    ans = 1LL * (1LL * sqrt5 * y % MOD + sqrtDelta + MOD) % MOD * INV2 % MOD;
    res = BSGS::getRes(ans);
    if (res != -1) mn = (mn == -1 ? res : std::min(res, mn));
}

int main () {
    read(N);
    BSGS::init();
    int x;
    solve(N, -1, 1LL * sqrt5 * N % MOD);
    solve(N, 1, 1LL * sqrt5 * N % MOD);
    printf("%d\n", mn);
    return 0;
}

引用资料

sys_con

OIer,常规 / 竞赛都渣得不行

发表评论

电子邮件地址不会被公开。 必填项已用*标注