「HNOI2019」白兔之舞

题目链接:https://loj.ac/problem/3058

咕咕咕了几个月终于来写这个东西了

先考虑$n=1$的情况,假设读入的种类数是$w$,设$f_i$表示走$i$步的方案数,容易知道:

把这个东西写成生成函数:

根据单位根反演的套路,假设我们要求$\bmod k=0$的项之和,可以这样算:

$\bmod k$为其他值的项可以通过把$F(x)$的项移位得到。

设答案为$g(i)$,即$\mod k=i$的项之和,那么可以写成:

此时复杂度为$O(k^2)$,好像一分都拿不到

接下来就是我不会的神仙操作。。有一个组合意义很明显的式子:

暴力展开也能验证。

然后把这个式子带入$\omega_k$的指数:

注意到前面和$a$有关,后面和$i+a$有关,随便把前面还是后面翻转一下就是一个卷积的形式,因为模数是读入的,所以上任意模数$\rm FFT$就行了。

不过好像$8$次或者$7$次$\rm FFT$的做法被卡常了,我写的是$4$次$\rm FFT$的做法。

至于$n\ne 1$的情况也是一样,只需要把$v$换成读入的矩阵,$1$换成单位矩阵,然后暴力处理出前面那个$L$次方项,其他的都是一样的。

复杂度$O(k\log k)$。

可以看看我代码$\rm 120$行左右的一段注释,我被那里坑了好久。。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#include<bits/stdc++.h>
using namespace std;

void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}

void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

#define lf double
#define ll long long

#define pii pair<int,int >
#define vec vector<int >

#define pb push_back
#define mp make_pair
#define fr first
#define sc second

#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++)

const int maxn = 1e6+10;
const int inf = 1e9;
const lf eps = 1e-8;
const lf pi = acos(-1);

int n,mod,g,k,L;

int add(int a,int b) {a+=b;if(a>=mod) a-=mod;return a;}
int del(int a,int b) {a-=b;a+=a>>31&mod;return a;}

struct matrix {
int a[3][3];

matrix () {memset(a,0,sizeof a);}

int* operator [] (int x) {return a[x];}

matrix operator * (matrix x) const {
matrix r;
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
for(int k=0;k<n;k++)
r[i][j]=add(r[i][j],1ll*a[i][k]*x[k][j]%mod);
return r;
}

matrix operator * (int x) const {
matrix r=*this;
for(int i=0;i<n;i++) for(int j=0;j<n;j++) r[i][j]=1ll*r[i][j]*x%mod;
return r;
}

matrix operator + (matrix x) const {
matrix r;
for(int i=0;i<n;i++) for(int j=0;j<n;j++) r[i][j]=add(a[i][j],x[i][j]);
return r;
}

matrix operator ^ (int x) const {
matrix r,a=*this;r[0][0]=r[1][1]=r[2][2]=1;
for(;x;x>>=1,a=a*a) if(x&1) r=r*a;
return r;
}
}tr,I;

int a[maxn],b[maxn],w,c[maxn];
int t[maxn],cnt;

int qpow(int a,int x) {
int res=1;
for(;x;x>>=1,a=1ll*a*a%mod) if(x&1) res=1ll*res*a%mod;
return res;
}

void get_root() {
int e=mod-1;
for(int i=2;i*i<=e;i++) {
if(e%i) continue;t[++cnt]=i;
while(e%i==0) e/=i;
}if(e!=1) t[++cnt]=e;
for(g=2;;g++) {
int bo=1;
for(int i=1;i<=cnt;i++) if(qpow(g,(mod-1)/t[i])==1) {bo=0;break;}
if(bo) break;
}w=qpow(g,(mod-1)/k);
}

namespace MTT {
struct cp {
lf r,i;
cp () {r=i=0;}
cp (lf x,lf y) {r=x,i=y;}
cp operator * (cp x) {return cp(r*x.r-i*x.i,r*x.i+i*x.r);}
cp operator - (cp x) {return cp(r-x.r,i-x.i);}
cp operator + (cp x) {return cp(r+x.r,i+x.i);}
cp operator / (int x) {return cp(r/x,i/x);}
}ww[maxn],g[4][maxn];

int pos[maxn],N,bit;

void dft(cp *r) {
for(int i=0;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
for(int i=1,d=N>>1;i<N;i<<=1,d>>=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++) {
cp x=r[j+k],y=ww[k*d]*r[i+j+k];
r[j+k]=x+y,r[i+j+k]=x-y;
}
}

void init(int len) {
for(N=1,bit=-1;N<len;N<<=1,bit++);
for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<bit);
// ww[0]=cp(1,0),ww[1]=cp(cos(2*pi/N),sin(2*pi/N));
// for(int i=2;i<N;i++) ww[i]=ww[i-1]*ww[1];
// 千万别学我写上面那个破玩意。。。我就因为上面这两句话调了一个小时
// 上面这种写法如果是FFT会带来巨大的误差导致全部WA掉,但是如果是NTT这种写法就是对的
for(int i=0;i<N;i++) ww[i]=cp(cos(2*pi/N*i),sin(2*pi/N*i));
}

cp conj(cp x) {return cp(x.r,-x.i);}

void mul(int *r,int *s,int *t,int len) {
init(len);int all=32767;
for(int i=0;i<N;i++) g[0][i]=cp(r[i]>>15,r[i]&all),g[1][i]=cp(s[i]>>15,s[i]&all);
dft(g[0]),dft(g[1]);
for(int i=0;i<N;i++) {
int j=N-i;if(!i) j=0;
g[2][j]=(g[0][i]+conj(g[0][j]))*cp(0.5,0)*g[1][i];
g[3][j]=(g[0][i]-conj(g[0][j]))*cp(0,-0.5)*g[1][i];
}dft(g[2]),dft(g[3]);
for(int i=0;i<N;i++) g[2][i].r/=N,g[2][i].i/=N,g[3][i].i/=N,g[3][i].r/=N;
for(int i=0;i<N;i++) {
ll a=g[2][i].r+0.5,b=g[2][i].i+0.5,c=g[3][i].r+0.5,d=g[3][i].i+0.5;
t[i]=(((a%mod)<<30)+(((b+c)%mod)<<15)+d)%mod;
}
}
}

using MTT :: mul;

int main() {
int x,y;
read(n),read(k),read(L),read(x),read(y),read(mod);x--,y--;
for(int i=0;i<n;i++) for(int j=0;j<n;j++) read(tr[i][j]);
get_root();I[0][0]=I[1][1]=I[2][2]=1;
for(int i=0;i<k;i++) a[i]=1ll*((tr*qpow(w,i)+I)^L)[x][y]*qpow(w,1ll*i*(i-1)/2%(mod-1))%mod;
for(int i=0;i<k*2-1;i++) b[i]=qpow(qpow(w,1ll*i*(i-1)/2%(mod-1)),mod-2);
reverse(b,b+k*2-1);
mul(a,b,c,k*3);
for(int i=0;i<k;i++) write(1ll*qpow(w,1ll*i*(i-1)/2%(mod-1))*qpow(k,mod-2)%mod*c[k*2-2-i]%mod);
return 0;
}