单位根反演

先扔个题引入吧(好像大家学这玩意第一题都写的是这个)

给定$n,s,a_0,a_1,a_2,a_3$,求:

其中$n\leqslant 1e18$。

题目链接

简单变换一下,我们需要算这个:

设生成函数$f(x)=\sum_{i=0}^{n}\binom{n}{i}s^ix^i=(sx+1)^n$。

那么我们现在要分别求$f$的$\bmod 4=i$的项之和。

我们引入单位根反演,就是说在$\rm ntt$结束的时候证明一般是用到了这个式子:

当$n|k$的时候显然,每一项都是$1$,否则可以利用等比数列求和得到。

我们把单位根带进前面的生成函数,假设生成函数系数为$a_i$:

那么我们就得到了$f$的$\bmod 4=0$的项之和,至于$\bmod 4$为其他值的和我们可以通过把$f$右移若干位得到。

那么最终答案就是:

代码对着式子抄一遍就行了,我就不放了


在看看这个题

给定$n,k$,求:

其中$F_i$为斐波那契数列,$n\leqslant 1e18,k\leqslant 2e4$。

我们把这玩意搞成矩阵形式:

其中$A$是斐波那契的转移矩阵。

那么生成函数:

$I$是单位矩阵。

所以套一下上面的式子就写完了:

code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#pragma GCC optimize(2)
#include<bits/stdc++.h>
using namespace std;

void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}

void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

#define lf double
#define ll long long

#define pii pair<int,int >
#define vec vector<int >

#define pb push_back
#define mp make_pair
#define fr first
#define sc second

#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++)

const int maxn = 1e6+10;
const int inf = 1e9;
const lf eps = 1e-8;

int k,p,g,w[maxn];ll n;
vector<int > r;

int qpow(int a,int x) {
int res=1;
for(;x;x>>=1,a=1ll*a*a%p) if(x&1) res=1ll*res*a%p;
return res;
}

void gen() {
int e=p-1;r.clear();
for(int i=2;i*i<=e;i++) {
if(e%i) continue;
r.pb(i);while(!(e%i)) e/=i;
}if(e!=1) r.pb(e);
for(g=2;;g++) {
int bo=1;
for(int x=0;x<r.size();x++)
if(qpow(g,(p-1)/r[x])==1) {bo=0;break;}
if(bo) break;
}
}

int add(int x,int y) {x+=y;return x>=p?x-p:x;}

struct matrix {
int a[2][2];

void clear() {memset(a,0,sizeof a);}

matrix operator * (const matrix &r) const {
matrix res;res.clear();
// for(int i=0;i<=1;i++)
// for(int j=0;j<=1;j++)
// for(int k=0;k<=1;k++)
// res.a[i][j]=add(res.a[i][j],1ll*a[i][k]*r.a[k][j]%p);
res.a[0][0]=add(1ll*a[0][0]*r.a[0][0]%p,1ll*a[0][1]*r.a[1][0]%p);
res.a[0][1]=add(1ll*a[0][0]*r.a[0][1]%p,1ll*a[0][1]*r.a[1][1]%p);
res.a[1][0]=add(1ll*a[1][0]*r.a[0][0]%p,1ll*a[1][1]*r.a[1][0]%p);
res.a[1][1]=add(1ll*a[1][0]*r.a[0][1]%p,1ll*a[1][1]*r.a[1][1]%p);
return res;
}

matrix operator * (const int &r) const {
matrix res;res.clear();
// for(int i=0;i<=1;i++)
// for(int j=0;j<=1;j++)
// res.a[i][j]=1ll*a[i][j]*r%p;
res.a[0][0]=1ll*a[0][0]*r%p;
res.a[0][1]=1ll*a[0][1]*r%p;
res.a[1][0]=1ll*a[1][0]*r%p;
res.a[1][1]=1ll*a[1][1]*r%p;
return res;
}

matrix operator + (const matrix &r) const {
matrix res;res.clear();
// for(int i=0;i<=1;i++)
// for(int j=0;j<=1;j++)
// res.a[i][j]=add(a[i][j],r.a[i][j]);
res.a[0][0]=add(a[0][0],r.a[0][0]);
res.a[0][1]=add(a[0][1],r.a[0][1]);
res.a[1][0]=add(a[1][0],r.a[1][0]);
res.a[1][1]=add(a[1][1],r.a[1][1]);
return res;
}
}I,A;

matrix qpow(matrix a,ll x) {
matrix res=I;
for(;x;x>>=1,a=a*a) if(x&1) res=res*a;
return res;
}

void solve() {
// int st=clock();
scanf("%lld",&n),read(k),read(p);gen();
I.a[0][0]=I.a[1][1]=1;w[0]=1,w[1]=qpow(g,(p-1)/k);
A.a[0][0]=A.a[0][1]=A.a[1][0]=1;
for(int i=2;i<k;i++) w[i]=1ll*w[i-1]*w[1]%p;
matrix ans;ans.clear();
for(int i=0;i<k;i++) ans=ans+qpow(A*w[i]+I,n);
write(1ll*ans.a[0][0]*qpow(k,p-2)%p);
// cerr<<(lf)(clock()-st)/1e3<<endl;
}

int main() {
int t;read(t);while(t--) solve();
return 0;
}