P4916 [MtOI2018]魔力环 题解

Zesty_Fox

2021-07-23 17:59:49

Solution

也许更好的阅读体验:[我的博客](https://www.cnblogs.com/acceptedzhs/p/15050023.html) 模拟考考到了这道题,现在写一发题解。 首先,环不好处理,考虑断环为链,相当于将 $m$​ 个黑色魔力珠插入白色魔力珠的间隙中。注意由于是环,所以第一个白球前面与最后一个白球后面是相通的,所以只有 $n-m$​ 个间隙,可强制最后一个白球后不能放球。 不妨记一种方案为序列 $a$,$a_i$ 表示第 $i$ 个白球前的间隙有多少个黑球。如:$\blacksquare\Box\blacksquare\blacksquare\Box\blacksquare\Box\Box \rightarrow 1,2,1,0$。显然有 $\sum \limits_{i=1}^{n-m} a_i=m$。 接下来,不考虑一种方案是否被重复统计,考虑如何求方案数。 如果不考虑任何限制,将黑球随意分配(序列中允许值为 $0$​ 的项),根据插板法,结果为 $\dbinom{n-1}{n-m-1}$​​。 类似[硬币购物](https://www.luogu.com.cn/problem/P1450)的做法,我们可以枚举序列中有多少项超过了 $k$​ ,先预先给它们每项都分配 $k+1$​ 个黑球,然后将剩下的黑球随意分配(分配到之前预分配的项也没关系),用容斥原理进行计算。那么,最终答案为: $$ \sum\limits_{i=0}^{\lfloor \frac{m}{k+1} \rfloor} (-1)^i \dbinom{n-m}{i}\dbinom{n-i\times (k+1)-1}{n-m-1} $$ 现在再来考虑如何解决重复统计的问题。 手玩一下可发现一种方案被重复统计的次数是其序列最短循环节的长度(如 $2,1,2,1$​​​​​​ 被重复统计了 $2$​​​​​​ 次,为 $2,1,2,1$​​​​​​ 与 $1,2,1,2$​​​​​​​​),所以我们可以枚举循环节的长度(注意不一定是最小循环节)。当循环节个数为 $i$​​​​​​ 时(即循环节长度为$\dfrac{n-m}{i}$​​​),相当于将 $m$​​​ 个黑球、$n-m$​​​ 个白球变成了 $\dfrac{m}{i}$​​​ 个黑球、$\dfrac{n-m}{i}$​ 个白球,而且 $i$​ 满足 $i \mid {\gcd(m,n-m)}$ ,对于每个 $i$ 都可以用上文方法容斥计算,可将结果记为 $f(i)$​​​​​​。 最后,考虑如何统计答案。前面说不一定是最小循环节,这是因为对某个循环节长度的情况计算时可能包括循环节更小的结果,如循环节为 $4$​​ 时包含了 $1,2,1,2,1,2,1,2$​​,而它实际最小循环节为 $2$​​。考虑再次容斥,记 $g(i)$​​​​ 为循环节**个数恰好**为 $i$​​ 的方案数,则有: $$ f(i)=\sum\limits_{i \mid d} g(d) $$ 莫反一下,则有: $$ g(i)=\sum\limits_{i|d}\mu(\frac{d}{i})\times f(d) $$ 最后答案便是: $$ Ans=\sum\limits_{d\mid \gcd(m,n-m)} g(d) \times \dfrac{d}{n-m} $$ Code: (考场代码,有点丑,请见谅) ```cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; typedef vector<int> vi; typedef pair<int,int> pii; template<typename T> inline T read(){ T x=0,f=0;char ch=getchar(); while(ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar(); while(ch>='0'&&ch<='9') x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } #define rdi read<int> #define rdll read<ll> #define fi first #define se second #define pb push_back #define mp make_pair const int MOD=998244353,N=200010; int n,m,k; inline ll qpow(ll x,ll y=MOD-2){ ll res=1; while(y){ if(y&1) (res*=x)%=MOD; (x*=x)%=MOD;y>>=1; } return res; } ll fac[N],faci[N]; inline void init(){ fac[0]=1; for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%MOD; faci[n]=qpow(fac[n]); for(int i=n-1;i>=0;i--) faci[i]=faci[i+1]*(i+1)%MOD; } int mu[N],v[N],pr[N],cnt; inline void init_mu(){ mu[1]=1; for(int i=2;i<=n;i++){ if(!v[i]){v[i]=i,pr[++cnt]=i,mu[i]=-1;} for(int j=1;j<=cnt;j++){ if(pr[j]>v[i]||pr[j]>n/i) break; v[pr[j]*i]=pr[j]; if(i%pr[j]) mu[pr[j]*i]=-mu[i]; else mu[pr[j]*i]=0; } } } inline ll C(int n,int m){ if(n<0||m<0||n<m) return 0; return fac[n]*faci[m]%MOD*faci[n-m]%MOD; } int gcd(int x,int y){return !x?y:gcd(y%x,x);} ll f[N],ans; int main(){ n=rdi(),m=rdi(),k=rdi(); if(n==m){ puts(m<=k?"1":"0"); return 0; } if(!m){ puts("1"); return 0; } init();init_mu();int gc=gcd(n-m,m); for(int x=gc;x>=1;x--){ if(gc%x) continue; int n1=n/x,m1=m/x; for(int i=0,c=1;i<=m1/(k+1);i++,c=-c){ f[x]+=c*C(n1-m1,i)*C(n1-i*(k+1)-1,n1-m1-1); f[x]=(f[x]%MOD+MOD)%MOD; } } for(int x=1;x<=gc;x++){ if(gc%x) continue; ll sum=0; for(int i=x;i<=gc;i+=x){ sum+=mu[i/x]*f[i]%MOD;sum=(sum%MOD+MOD)%MOD; } ans=(ans+sum*x)%MOD; } cout<<ans*qpow(n-m)%MOD<<endl; return 0; } ```