Product Modulo (AGC047-C,800 points)

AtCoder解説記事は初投稿です

あと、pythonで書きます。

問題

atcoder.jp

ステップ0

入出力を書いてみる。現在はこうなっている。200003という数は意味ありげなので最初に書いておく。

N=int(input())
P=200003
a=[int(i) for i in input().split()]

#なんやかんやすることで答えを求める

print(ans)

ステップ1

問題では以下の値を求めることになっている

$$\sum_{1\leq i\lt j\leq N}(A_iA_j \mod{P})$$

ところで、これは以下の値を求めれればOKということになる。

$$X=\sum_{i=1}^{N}\sum_{j=1}^{N}(A_iA_j \mod{P})$$

シグマの対象範囲を{i\lt j}に限定していないことに注意

このとき、出力するべき値は

$$\frac{X-\sum_{i=1}^{N}(A_i^2\mod{P})}{2}$$

となっている。

ここで、{\sum_{i=1}^{N}(A_i^2\mod{P})}については愚直に計算しても{O(N)}で計算することができるため、{X}さえ計算できれば答えを計算することが可能であることがわかる。

注:もし入力例1の結果が1105216になったならば、ここの処理を戻すことを忘れて{X}そのものを出力している可能性が高いです。

現在はこうなっている。

N=int(input())
P=200003
a=[int(i) for i in input().split()]
ans=0

#なんやかんやすることでXを求める
#最終的にansという変数にXの値が格納される

for i in range(N):
    ans-=(a[i]*a[i])%P
print(ans//2)

ステップ2

{A_i}の取りうる値の範囲は{0\leq A_i\lt P}である。

ここで、以下のようなリストを作成してみる。

$$b_k=\#\{i|1\leq i\leq N,A_i=k\}$$

このリスト自体は{O(N)}で作成することができる。*1

すると

$$X=\sum_{k=0}^{P-1}k\sum_{nm\mod{P}=k}b_nb_m $$

となる。(注:ここで、{n,m}{0\leq n,m\lt P}かつ{nm\mod{P}=k}の範囲を動いている)

ようは各{k}において{(A_iA_j\mod{P})=k}となる{(i,j)}の個数を求めればいいことになる。これは{nm\mod{P}=k}となるような{(n,m)}を固定したときの{(A_i,A_j)=(n,m)}となるような{i,j}の組を数えることになっている。ステップ1でやったことから{i,j}の選び方は独立でいいので、単に
{A_i=n}となる{i}の個数)×({A_j=m}となる{j}の個数)
となっているため、これは先程定義した{b}を使うと{b_nb_m}という表記ができる。あとは{n,m}を条件のもと動かしてその中で総和を取れば良い。

現在はこうなっている。

N=int(input())
P=200003
a=[int(i) for i in input().split()]
b=[0 for i in range(P)]
for i in range(N):
    b[a[i]]+=1
ans=0

#なんやかんやすることでXを求める
#最終的にansという変数にXの値が格納される

for i in range(N):
    ans-=(a[i]*a[i])%P
print(ans//2)

ステップ3

ここで知識があればFFTというものを思い出すはずである。FFTとは高速フーリエ変換と呼ばれるものである。
百聞は一見に如かず。このリンクを見れば早い。

atcoder.jp

このリンクでは、それぞれの{k}に対して「{i+j=k}の範囲内での{A_iB_j}の総和」を高速に求めている。*2
一方、今この問題でやろうとしているのは、それぞれの{k}に対して「{ij=k\pmod{P}}の範囲内での{b_ib_j}の総和」を高速に求めようとしているのである。

ATC001のリンクでは和に関しての総和をしているのに、ここでは積に関しての総和を取っている。

なんとかして積の状況を和の状況に置き換えたい…
ここで、和と積を変換するといえば指数・対数が思いつくので、多少ガバガバな議論だが以下のことを考えることができる。

まず最初に{a}という整数を取ってくる。ここで{2\leq a\leq P-1}を満たしている必要がある。いわばこれは「指数の底」みたいな役割を満たす存在である。

今の状況では「{nm\mod{P}=k}となるような{n,m}に対しての{b_nb_m}の総和」を求めるのが目標であった。ここで、{a^x\mod{P}=n,a^y\mod{P}=m}となるような整数{x,y}が存在したとき、{x+y=k}となるような{x,y}における{b_{a^x}b_{a^y}}の総和」を求めることになる。すると無事に和の状況に帰着できたため、FFTが使えるように帰着できるということになる。

雰囲気としてはこの通りの方針でやるべきだろう。実は、この議論は数学的に正当化することができる。

ところで、さっき考えた議論のガバを指摘するとすれば、例えば「任意の{n}に対して{a^x\mod{P}=n}となるような整数{x}は存在するか?」というようなところだろう。

そこで登場してくるものが原始根である。

「原始根」でググると詳しい説明が出てくるが一応説明すると、
『「{a^x=1\pmod{P}}となるような最小の正の整数{x}」が{P-1}と等しくなる』
ような整数{a(2\leq a\leq P-1)}を原始根と呼ぶ。
このような{a}が存在した場合、{n}をどのように取ってきても、{a^x=n\pmod{P}}となるような{x}が存在する。そしてこのような{x}{\mod{P-1}}で一意となる。
証明は演習問題(便利な言葉だ…)

実は任意の素数について、このような原始根は存在する。

ここで、フェルマーの小定理を思い出すと、{a^{P-1}\mod{P}=1}であるため、「{a^x}{P}で割ったあまり」は周期{P-1}で巡回しているのである。

よって、今までは{P}で割ったあまりでものを考えていたが、指数の肩に乗っている{x,y}みたいな数については、{P-1}で割ったあまりでものを考えることになる。

よってこのことを踏まえて{X}を式変形してみる。ステップ2を思い出すと

$$X=\sum_{k=0}^{P-1}k\sum_{nm\mod{P}=k}b_nb_m $$


であった。ここで{k=0}のことは考えなくてもいいため*3

$$X=\sum_{k=1}^{P-1}k\sum_{nm\mod{P}=k}b_nb_m $$
となる。ここで、{a^z=k\pmod{P},a^x=n\pmod{P},a^y=m\pmod{P}}となるような{x,y,z}が存在するということなので、

$$X=\sum_{z=0}^{P-2}a^z\sum_{x+y=z\mod{P-1}}b_{a^x}b_{a^y}$$

ここで、{\sum_{x+y=z}\cdots}を各{z}について計算するときにFFTを使えればいいということになる。

さて、ここで問題になるのが「原始根は存在するとして、具体的に何が原始根になるのだ?」といったところである。「任意の素数{P}に対して原始根を1つ求めるアルゴリズム」を作るのは難しそうだが、この状況では{P=200003}と固定されている。よって特殊な状況で1つ求めればよい。

しかし与えられた数に対して原始根であるかを判定するのは簡単そうである。
AtCoderのコンテストに出場しているなら目の前にプログラミングを実行する環境があるはずである。そこの別窓を開いて以下のコードを打ち込んでみる*4

P=200003
a=int(input())
S=set()
for i in range(P-1):
    S.add(pow(a,i,P))
print(len(S))

入力{a}に対して、出力が{200002=P-1}となればこのとき{a}は原始根となるはずである。*5
ためしにこのプログラムに対して2を標準入力に入れてみると、200002という出力がされる。
よって2はこのとき原始根となる。よって以降では{a=2}として考えてみることにする。

FFTをするためのリストを作成することにする。{c}という長さ{P-1}というリストを用意して、
$$c_x=b_{2^x\mod{P}}(0\leq x\lt P-1)$$
となるようにする。

N=int(input())
P=200003
a=[int(i) for i in input().split()]
b=[0 for i in range(P)]
for i in range(N):
    b[a[i]]+=1
c=[0 for i in range(P-1)]
for i in range(P-1):
    c[i]=b[pow(2,i,P)]
ans=0

#なんやかんやすることでXを求める
#最終的にansという変数にXの値が格納される

for i in range(N):
    ans-=(a[i]*a[i])%P
print(ans//2)

ステップ4

$$X=\sum_{z=0}^{P-2}a^z\sum_{x+y=z\mod{P-1}}b_{a^x}b_{a^y}$$
から、
$$X=\sum_{z=0}^{P-2}2^z\mod{P}\sum_{x+y=z\mod{P-1}}c_{x}c_{y}$$
となっている。

ここで、
$$f(t)=c_0+c_1t+c_2t^2+\cdots+c_{P-2}t^{P-2}$$
とする。

このとき、
$$\sum_{x+y=z}c_xc_y=f(t)^2のt^zの係数$$
となる。

FFTをすると、「{f(t)^2}{t^z}の係数」を高速に求めることができる。今求めようとしているのは「{z}{P-1}で割った余りを固定して{z}を動かしたときの『{f(t)^2}{t^z}の係数』の総和」なので、単にFFTをするとより情報量の多いものを求められるためこの方針で問題ないことになる。

pythonFFTをする際にどうすればいいかというと、実はnumpyにFFT機能があるのである。
np.fft.fft(c)でフーリエ変換ができて、np.fft.ifft(c)で逆フーリエ変換ができる。
またFFTをする際には配列のサイズを2の累乗にしないといけない。ここで、
$$131072=2^{17}\lt 200002=P-1\lt 2^{18}=262144$$
となっている。ここで、逆変換したら長さ{2(P-1)-1=2P-3}のリストが必要であるため({f(t)^2}の係数が{2P-3}以下なので)配列のサイズはこれよりも2倍大きい{2^{19}=524288}とするというのが最適解である。

ここで、無理やりFFTのサイズを2のべき乗にした場合、余白の部分のな値をどうすればいいかという疑問が湧くが、結論から言うと、余白の部分は全部0にすればよい。
なぜならば、FFTをするための配列は
$$f(t)=c_0+c_1t+c_2t^2+\cdots+c_{P-2}t^{P-2}+c_{P-1}t^{P-1}+\cdots+c_{2^{19}-1}t^{2^{19}-1}$$
というような多項式に対応しているが、{c_{P-1}}以降の値を0とすれば
$$f(t)=c_0+c_1t+c_2t^2+\cdots+c_{P-2}t^{P-2}$$
とまったく同じ多項式と見なせるからである。

結局fftを使うと{f(t)^2}

np.fft.ifft(np.fft.fft(c)*np.fft.fft(c))

というような表記になるはずである。今までの状況を整理すると以下のようになって、これが正解するソースコードとなる。

import numpy as np
N=int(input())
P=200003
a=[int(i) for i in input().split()]
b=[0 for i in range(P)]
for i in range(N):
    b[a[i]]+=1
M=1<<19
c=np.zeros(M)
for i in range(P-1):
    c[i]=b[pow(2,i,P)]
X=[int(i+0.1) for i in np.real(np.fft.ifft(np.fft.fft(c)*np.fft.fft(c)))]
ans=0
for k in range(M):
    ans+=X[k]*pow(2,k,P)
print(ans)
for i in range(N):
    ans-=(a[i]*a[i])%P
print(ans//2)

atcoder.jp

注意点が1つある。

X=[int(i+0.1) for i in np.real(np.fft.ifft(np.fft.fft(c)*np.fft.fft(c)))]

についてである。np.ifftの結果の配列の中にある値はシステムの上では複素数の扱いになっている。正確には実質的に「ほとんど」整数の値を取っているのだが、計算誤差のおかげでわずかなずれが存在している。
これを無理やりintにすると警告が出るためまずはrealという関数を使うことで実数値に直している。
最後にintに変換しているときにあえて0.1を足してから変換しているのがポイントである。
たとえば1が実質的な答えとしたとき、誤差があるとするならば0.999~1.001のような値になっている。intという関数は小数点以下切り捨てであるため、もし0.999~1.000の値を取った場合、1を出力するべきなのに0を出力してしまう。よってこのときWAとなってしまう。

0.1を足していないことでWAとなっている例↓
atcoder.jp

それの対処法として0.1を足しているのである。0.1を足しているならば、0.999~1.001の値を取っている場合、1.099~1.101となるため、どっちにしろ1を出力してくれる。これが破綻するときは計算誤差が0.1より大きくなるような場合だが、そのようなことはあまり無いため、0.1を足すことである程度の誤差があったとしても無事に正しい値を出力してくれるのである。*6

最終的にX[k]というリストは{f(t)^2}{t^k}の係数となっているため、
{f(t)}の全ての係数について前述の通り総和を計算してくれればOKとなる。

最後に

numpyを使っている場合はpypyで提出してはいけない(戒め)

*1:{i}{1}から{N}まで動かして各処理で{b[a_i]}に1を加えればよい

*2:どれくらい高速かというと、計算量は{O(N\log{N})}である

*3:{n,m}についてのシグマの前に{k}が掛かっているため、{\sum_{n,m}\cdots}がどのような値だったとしても「0に何を掛けても0」なので全体に影響が出ない。

*4:累乗のmodPについてはpowという組み込み関数を使うと良い。pow(a,i,p)はaのi乗をpで割った余りとなっている。

*5:もし出力が{P-1}ならば、{a^i\mod{P}}{0\leq i\leq P-2}ですべて異なるということを表していて、結局{a^i=1\pmod{P}}となるような{i}{0\leq i\leq P-2}で1つしか存在しない⇒存在するならば{i=0}だけということになるためである。

*6:説明のため便宜上0.999や1.001という値で書いているが、実際の誤差がそのくらいであるとは限らない。実際はもっと誤差は小さいと思われる。