题目链接:https://loj.ac/problem/3119。
考虑容斥,设$t_i$表示至少有$i$个极大的点的方案数。
那么这个东西可以分成几个部分乘起来:选$i$个不相干的点的方案数,然后填上所有与这些点相关的点,最后把剩下的部分随便填一下。
我们先设几个记号方便描述,设$N=nml$。
设$c_i$表示$i$个点影响到的点,那么$c_i=N-(n-i)(m-i)(l-i)$。
设$f_i$表示选$i$个点的方案,那么很快可以得出:
因为每次选之后剩下来的子问题都是一样的。
注意现在我们只是枚举出了影响到的位置,因为数是两两不同的,所以怎么选都一样,这里要乘一个组合数$\displaystyle\binom{N}{c_i}$。
现在考虑这些被影响到的位置应该怎么填,如果极大点$x$比$y$后处理,那么我们保证$x>y$就可以不受影响了。
这启发我们每次把最大的填到极大值的位置上,然后这个位置影响到的点随便填,设$h_i$表示$i$个点影响到的那些点的填法,可以得到:
前面组合数是选出当前需要的数,后面是随便摆放,化简一下就是:
填完$i$个极大点后剩下的点随便填就是$(N-c_i)!$。
所以$t_i$可以表示成:
化简一下:
设恰好有$i$个极大点的方案为$ans_i$,设$w=\min(n,m,l)$,可得:
反演一下:
注意题目最后让我们算的是概率,也就是说要除掉$N!$,正好这个玩意我们算不出来,所以忽略他就好了。
最后还有一个问题就是怎么求$c_i$的逆元,$\log$求就$\rm T$了,我们可以仿照求阶乘逆元的方法,先把$\prod_{i=1}^{w}c_i$的逆元暴力求出来,然后每次乘$w_i$就好了。
复杂度是$O(n)$的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
| #include<bits/stdc++.h> using namespace std;
void read(int &x) { x=0;int f=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f; for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f; }
void print(int x) { if(x<0) putchar('-'),x=-x; if(!x) return ;print(x/10),putchar(x%10+48); } void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double #define ll long long
#define pii pair<int,int > #define vec vector<int >
#define pb push_back #define mp make_pair #define fr first #define sc second
#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++)
const int maxn = 5e6+10; const int inf = 1e9; const lf eps = 1e-8; const int mod = 998244353;
int n,m,l,k,c[maxn],f[maxn],g[maxn],invc[maxn],fac[maxn],ifac[maxn];
int add(int x,int y) {x+=y;if(x>=mod) x-=mod;return x;} int del(int x,int y) {x-=y;x+=x>>31&mod;return x;}
int qpow(int a,int x) { int res=1; for(;x;x>>=1,a=1ll*a*a%mod) if(x&1) res=1ll*res*a%mod; return res; }
int C(int a,int b) {return 1ll*fac[a]*ifac[b]%mod*ifac[a-b]%mod;}
void solve() { read(n),read(m),read(l),read(k); int N=1ll*n*m%mod*l%mod,w=min(n,min(m,l)),ss=1;f[1]=N; for(int i=1;i<=w;i++) { int x=1ll*(n-i)*(m-i)%mod*(l-i)%mod; f[i+1]=1ll*f[i]*x%mod;c[i]=del(N,x); ss=1ll*ss*c[i]%mod;if(!x) f[i+1]=1; }invc[w]=qpow(ss,mod-2); for(int i=w-1;~i;i--) invc[i]=1ll*invc[i+1]*c[i+1]%mod; int ans=0; for(int i=k;i<=w;i++) ans=(((i-k)&1)?del:add)(ans,1ll*f[i]*invc[i]%mod*C(i,k)%mod); write(ans); }
void gen() { fac[0]=ifac[0]=1;int N=5e6; for(int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%mod; ifac[N]=qpow(fac[N],mod-2); for(int i=N-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod; }
int main() { gen();int t;read(t);while(t--) solve(); return 0; }
|