题目链接: https://loj.ac/problem/3164 。
设$\Sigma$为字符集大小,显然可以得到一个$O(n+\Sigma ^8)$的暴力,枚举每个角是啥,拿个桶记录下$a$开始$b$结束的本质不同串的个数,乘起来就好了。
然后我想了个$O(n+\Sigma^6)$的暴力$\rm dp$,首先给点标个号,底下一圈叫$1,2,3,4$,上面对应的一圈为$5,6,7,8$,那么枚举$1,2$是什么,然后$dp$转移到$5,6$,然后是$3,4$,$1,2$,最后算一下就好了,转移复杂度为$O(\Sigma^4)$。
题解给出了个$O(n+\Sigma^5)$的做法,先枚举$1,2,3,4$是什么,然后我们考虑$5$是什么,那么这个时候$1$其实是没有用的,那么不需要记录$1$这个状态,同理继续考虑后面的是什么,每次都只在状态里记录了$4$个点,状态为$O(\Sigma^4)$的,每次转移$O(\Sigma)$。
正解是这样的,注意到这个图是一个二分图,左右各四个点,那么预处理出$f_{a,b,c}$表示三个串,由任意一个点开头,分别以$a,b,c$结尾的方案数,那么枚举二分图左边为$a,b,c,d$,这种情况的方案数就是$f_{a,b,c}\cdot f_{a,b,d}\cdot f_{a,c,d}\cdot f_{b,c,d}$,复杂度$O(n+\Sigma^4)$,常数有点大,需要一点小优化才能过。
别人写string好方便啊。。。我为啥要用char啊啊啊啊啊难写死了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
   | #include<bits/stdc++.h> using namespace std;
  void read(int &x) {     x=0;int f=1;char ch=getchar();     for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;     for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f; }
  void print(int x) {     if(x<0) putchar('-'),x=-x;     if(!x) return ;print(x/10),putchar(x%10+48); } void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
  #define lf double #define ll long long 
  #define pii pair<int,int > #define vec vector<int >
  #define pb push_back #define mp make_pair #define fr first #define sc second
  #define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++)
  const int maxn = 2e5+10; const int inf = 1e9; const lf eps = 1e-8; const int mod = 998244353;
  vector<char* > t[12]; char s[maxn][12]; int n,ans,f[63][63][63],c[63][63];
  const int v[2][2][2]={{{24,12},{12,4}},{{12,6},{4,1}}};
  int get(char x) {     if(x>='a'&&x<='z') return x-'a';     if(x>='A'&&x<='Z') return x-'A'+26;     return x-'0'+52; }
  void add(int &x,int y) {x+=y;if(x>=mod) x-=mod;}
  int len;
  int cmp(char *a,char *b) {     for(int i=1;i<=len;i++)         if(a[i]>b[i]) return 1;         else if(a[i]<b[i]) return 0;     return 0; }
  int eq(char *a,char *b) {     for(int i=1;i<=len;i++) if(a[i]!=b[i]) return 0;     return 1; }
  void solve(int x) {     memset(f,0,sizeof f);     memset(c,0,sizeof c);len=x;     sort(t[x].begin(),t[x].end(),cmp);     for(int i=0;i<t[x].size();i++) {         if(i&&eq(t[x][i],t[x][i-1])) continue;         int p=get(t[x][i][1]),q=get(t[x][i][x]);         c[p][q]++;     }     for(int i=0;i<=61;i++)         for(int j=i;j<=61;j++)             for(int k=j;k<=61;k++)                 for(int r=0;r<=61;r++)                     add(f[i][j][k],1ll*c[r][i]*c[r][j]%mod*c[r][k]%mod);          for(int i=0;i<=61;i++)         for(int j=i;j<=61;j++)             for(int k=j;k<=61;k++)                 for(int r=k;r<=61;r++)                            add(ans,1ll*v[i==j][j==k][k==r]*f[i][j][k]%mod*f[i][j][r]%mod*f[i][k][r]%mod*f[j][k][r]%mod); }
  int main() {          read(n);     for(int i=1;i<=n;i++) {         scanf("%s",s[i]+1);int x;         t[x=strlen(s[i]+1)].pb(s[i]);         for(int j=1;j<=x;j++) s[n+i][j]=s[i][j];         reverse(s[n+i]+1,s[n+i]+x+1);         t[x].pb(s[n+i]);     }     for(int i=3;i<=10;i++) solve(i);     write(ans);          return 0; }
   |