题目链接: https://loj.ac/problem/3164 。
设$\Sigma$为字符集大小,显然可以得到一个$O(n+\Sigma ^8)$的暴力,枚举每个角是啥,拿个桶记录下$a$开始$b$结束的本质不同串的个数,乘起来就好了。
然后我想了个$O(n+\Sigma^6)$的暴力$\rm dp$,首先给点标个号,底下一圈叫$1,2,3,4$,上面对应的一圈为$5,6,7,8$,那么枚举$1,2$是什么,然后$dp$转移到$5,6$,然后是$3,4$,$1,2$,最后算一下就好了,转移复杂度为$O(\Sigma^4)$。
题解给出了个$O(n+\Sigma^5)$的做法,先枚举$1,2,3,4$是什么,然后我们考虑$5$是什么,那么这个时候$1$其实是没有用的,那么不需要记录$1$这个状态,同理继续考虑后面的是什么,每次都只在状态里记录了$4$个点,状态为$O(\Sigma^4)$的,每次转移$O(\Sigma)$。
正解是这样的,注意到这个图是一个二分图,左右各四个点,那么预处理出$f_{a,b,c}$表示三个串,由任意一个点开头,分别以$a,b,c$结尾的方案数,那么枚举二分图左边为$a,b,c,d$,这种情况的方案数就是$f_{a,b,c}\cdot f_{a,b,d}\cdot f_{a,c,d}\cdot f_{b,c,d}$,复杂度$O(n+\Sigma^4)$,常数有点大,需要一点小优化才能过。
别人写string好方便啊。。。我为啥要用char啊啊啊啊啊难写死了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
| #include<bits/stdc++.h> using namespace std;
void read(int &x) { x=0;int f=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f; for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f; }
void print(int x) { if(x<0) putchar('-'),x=-x; if(!x) return ;print(x/10),putchar(x%10+48); } void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double #define ll long long
#define pii pair<int,int > #define vec vector<int >
#define pb push_back #define mp make_pair #define fr first #define sc second
#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++)
const int maxn = 2e5+10; const int inf = 1e9; const lf eps = 1e-8; const int mod = 998244353;
vector<char* > t[12]; char s[maxn][12]; int n,ans,f[63][63][63],c[63][63];
const int v[2][2][2]={{{24,12},{12,4}},{{12,6},{4,1}}};
int get(char x) { if(x>='a'&&x<='z') return x-'a'; if(x>='A'&&x<='Z') return x-'A'+26; return x-'0'+52; }
void add(int &x,int y) {x+=y;if(x>=mod) x-=mod;}
int len;
int cmp(char *a,char *b) { for(int i=1;i<=len;i++) if(a[i]>b[i]) return 1; else if(a[i]<b[i]) return 0; return 0; }
int eq(char *a,char *b) { for(int i=1;i<=len;i++) if(a[i]!=b[i]) return 0; return 1; }
void solve(int x) { memset(f,0,sizeof f); memset(c,0,sizeof c);len=x; sort(t[x].begin(),t[x].end(),cmp); for(int i=0;i<t[x].size();i++) { if(i&&eq(t[x][i],t[x][i-1])) continue; int p=get(t[x][i][1]),q=get(t[x][i][x]); c[p][q]++; } for(int i=0;i<=61;i++) for(int j=i;j<=61;j++) for(int k=j;k<=61;k++) for(int r=0;r<=61;r++) add(f[i][j][k],1ll*c[r][i]*c[r][j]%mod*c[r][k]%mod); for(int i=0;i<=61;i++) for(int j=i;j<=61;j++) for(int k=j;k<=61;k++) for(int r=k;r<=61;r++) add(ans,1ll*v[i==j][j==k][k==r]*f[i][j][k]%mod*f[i][j][r]%mod*f[i][k][r]%mod*f[j][k][r]%mod); }
int main() { read(n); for(int i=1;i<=n;i++) { scanf("%s",s[i]+1);int x; t[x=strlen(s[i]+1)].pb(s[i]); for(int j=1;j<=x;j++) s[n+i][j]=s[i][j]; reverse(s[n+i]+1,s[n+i]+x+1); t[x].pb(s[n+i]); } for(int i=3;i<=10;i++) solve(i); write(ans); return 0; }
|