题目链接:http://uoj.ac/problem/449。
学到了多项式新操作。
先用$\min-\max$容斥搞一下,那么答案就是:
$f$表示集合$s$期望多少步之后某一只鸽子最先被喂饱,此时其他鸽子都没饱,因为每个鸽子是一样的,所以等价于任意$i$个鸽子。
假设我们现在要算$f(c)$,我们硬点集合内的某只鸽子最先被喂饱,那么最后把答案乘$c$就好了。
假设某只鸽子是一号鸽子,那么喂鸽子的序列一定满足:一号鸽子出现了$k$次且最后一次是一号,剩下的鸽子出现次数小于$k$。
我们可以搞一个$\rm EGF$把这种序列的方案数算出来:
那么恰好$i$次能喂饱第一只鸽子的概率就是$\dfrac{i!}{n^i}[x^i]g_c(x)$,中括号是取系数。
而因为每期望$\dfrac{n}{c}$次才会出现一次在当前集合里的鸽子,所以长度为$i$的序列的贡献就是$\dfrac{ni}{c}$,所以答案就是:
复杂度$O(n^2k^2)$,用$\rm NTT$可以优化到$O(n^2k\log k)$。
但是由于突然不想写ntt并且正好从别人博客学到了一个做法,考虑能不能更快的算出$g$。
设$a(x)$:
和上面的东西是一样的,考虑对它求导:
可以对照系数递推:设$f_{c,i}$表示$c$次方的第$i$项系数,可得:
复杂度就被优化成了$O(n^2k)$。
不用写ntt了,好!
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
| #include<bits/stdc++.h> using namespace std;
void read(int &x) { x=0;int f=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f; for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f; }
void print(int x) { if(x<0) putchar('-'),x=-x; if(!x) return ;print(x/10),putchar(x%10+48); } void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double #define ll long long
#define pii pair<int,int > #define vec vector<int >
#define pb push_back #define mp make_pair #define fr first #define sc second
#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++)
const int maxn = 1e6+10; const int inf = 1e9; const lf eps = 1e-8; const int mod = 998244353;
int n,k,f[52][50002],fac[maxn],ifac[maxn],inv[maxn];
int qpow(int a,int x) { int res=1; for(;x;x>>=1,a=1ll*a*a%mod) if(x&1) res=1ll*res*a%mod; return res; }
void gen() { fac[0]=ifac[0]=inv[0]=inv[1]=1;int m=n*k; for(int i=1;i<=m;i++) fac[i]=1ll*fac[i-1]*i%mod; for(int i=2;i<=m;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod; ifac[m]=qpow(fac[m],mod-2); for(int i=m-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod; }
void dp() { f[0][0]=1; for(int c=1;c<=n;c++) { f[c][0]=1; for(int i=1;i<=c*(k-1);i++) { f[c][i]=f[c][i-1]; if(i>=k) f[c][i]=(f[c][i]-1ll*ifac[k-1]*f[c-1][i-(k-1)-1]%mod+mod)%mod; f[c][i]=1ll*f[c][i]*c%mod*inv[i]%mod; } } }
int main() { read(n),read(k); gen();dp();int ans=0; for(int x=1,p=1;x<=n;x++,p=-p) { int res=0; for(int i=k,iv=qpow(inv[x],k);i<=x*k-x+1;i++,iv=1ll*iv*inv[x]%mod) res=(res+1ll*fac[i-1]*iv%mod*i%mod*n%mod*inv[x]%mod*f[x-1][i-k]%mod*ifac[k-1]%mod)%mod; res=1ll*res*x%mod; ans=(ans+1ll*p*res*fac[n]%mod*ifac[n-x]%mod*ifac[x]%mod)%mod; }write((ans+mod)%mod); return 0; }
|