BSGS + Exgcd

Table of Contents

  1. Description
  2. Solution

Description

给定一个递推方程:$X_{i+1} = (aX_i + b)\,mod\,p$, 求最小的$i$, 使得$X_i = t$.

Solution

讨论$a = 0, a = 1, a \ge 2$.

分别是特判, exgcd, BSGS + exgcd.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cctype>
#include <map>
#include <cmath>

using namespace std;

typedef long long ll;

ll a, b, X1, p, t;

inline int rd() {
register int x = 0, c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
return x;
}
ll exgcd(ll a, ll b, ll &x, ll &y) {
if (!b) {
x = 1; y = 0;
return a;
}
ll ret = exgcd(b, a % b, x, y);
ll tmp = x;
x = y;
y = tmp - a / b * y;
return ret;
}
ll quick_power(ll base, ll index) {
ll ret = 1;
base %= p;
while (index) {
if (index & 1) ret = ret * base % p;
index >>= 1;
base = base * base % p;
}
return ret;
}

namespace BSGS {
using std::map;
using std::sqrt;

ll main(ll a, ll t) {
a %= p;
if (!a) {
if (!t) return 1;
else return -1;
}
map<ll, int> mp; mp.clear();
ll lim = ll(sqrt(p)) + 1;
ll powf = 1, powb = 1, powinv = quick_power(a, p - lim - 1);
mp[1] = 0;
for (int i = 1; i < lim; ++i) {
powf = powf * a % p;
if (mp.find(powf) == mp.end()) mp[powf] = i;
}
for (int k = 0; k < lim; ++k) {
int j = (mp.find(t * powb % p) != mp.end() ? mp[t*powb%p] : -1);
if (~j) return k * lim + j;
powb = powb * powinv % p;
}
return -1;
}
}

ll solve1() {
ll C = (t - X1 + p) % p, xx, yy;
ll G = exgcd(b, p, xx, yy);
if (C % G) return -1;
C /= G;
xx = xx * C % p;
if (xx < 0) xx += p;
return xx + 1;
}
ll solve2() {
ll inva = quick_power(a - 1, p - 2), A = (X1 + b * inva) % p, C = (t + b * inva) % p, xx, yy;
ll G = exgcd(A, p, xx, yy);
if (C % G) return -1;
C /= G;
//xx = (xx * C % p + p) % p;
if (xx < p) xx = xx % p + p;
ll ret = BSGS::main(a, xx * C % p);
if (ret != -1) return ret + 1;
else return -1;
}

int main() {
int T = rd();
while (T--) {
p = rd(); a = rd(); b = rd(); X1 = rd(); t = rd();
if (X1 == t) {
puts("1"); continue;
}
if (a == 0) {
if (b == t) puts("2");
else puts("-1");
} else if (a == 1) {
printf("%lld\n", solve1());
} else printf("%lld\n", solve2());
}
return 0;
}