N番目の素数を求める

SNSなどで話題になっていたので調べてみたら勉強になったのでメモ。

環境

手元のMacBook Pro 13-inchの開発機で実験した。

Pythonでの実装例

例1

最も単純に「2以上p未満のすべての数で割ってみて余りが0にならなかったら素数」とする、brute force 的なアプローチ。

import cProfile
import io
import pstats
import sys


def main(n: int) -> int:
    i = 0
    for p in range(2, 1000000):
        for q in range(2, p):
            if p % q == 0:
                break
        else:
            i += 1
        if i == n:
            return p
    raise ValueError


if __name__ == "__main__":
    n = int(sys.argv[1])
    with cProfile.Profile() as pr:
        for _ in range(10):
            result = main(n)
    print(result)

    s = io.StringIO()
    ps = pstats.Stats(pr, stream=s)
    ps.print_stats("main")
    print(s.getvalue())

Python3.9.1 で実行してみると

$ python3.9 1.py 1000
7919

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    2.237    0.224    2.237    0.224 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/1.py:7(main)

1,000番目のものを求めるのにも 224 msほどかかっていて遅い。10,000番目すら求めるのが大変。

例2

求めた素数をlistに入れていき、それらで割れるかどうかだけ確認していく。いわゆる「試し割り法」(trial division)というらしい。

試し割り法 - Wikipedia

def is_prime(num: int, primes: List[int]) -> bool:
    for p in primes:
        if num % p == 0:
            return False

    primes.append(num)
    return True


def main(n: int) -> int:
    i = 0
    primes: List[int] = []
    for p in range(2, 1000000):
        if is_prime(p, primes):
            i += 1
        if i == n:
            return p
    raise ValueError

これだと10倍ほど速くなって1,000番目も 27 ms程度で出てくる。

$ python3.9 2.py 1000
7919

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.015    0.002    0.275    0.027 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/2.py:17(main)

10,000番目だと 2,607 msくらい。

$ python3.9 2.py 10000
104729

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.205    0.021   26.067    2.607 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/2.py:17(main)

例3

「エラトステネスのふるい」と呼ばれるものを疑似したもの。limit までの数に含まれる素数を篩にかけて列挙していって、N個以上あればN番目を返す、無ければlimitを倍にしていく。

def faked_eratosthenes(limit: int) -> List[int]:
    nums = [i for i in range(2, limit + 1)]
    primes = []
    while True:
        p = min(nums)
        if p > math.sqrt(limit):
            break
        primes.append(p)
        i = 0
        while i < len(nums):
            if nums[i] % p == 0:
                nums.pop(i)
                continue
            i += 1
    return primes + nums


def main(n: int) -> int:
    limit = 1000
    while True:
        primes = faked_eratosthenes(limit)
        if len(primes) > n:
            return primes[n - 1]
        limit *= 2

例2のものより多少速くなるが、それほど大きくは変わらない。

$ python3.9 3.py 10000
104729

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.004    0.000   19.244    1.924 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/3.py:26(main)

エラトステネスの篩

下記記事が詳しいが、前述の例は「似非エラトステネスの篩」として言及されている。

zenn.dev

正しくは値をリストから削除するのではなくフラグで管理していく、とのこと。

def list_primes(limit: int) -> List[int]:
    primes = []
    is_prime = [True] * (limit + 1)
    is_prime[0] = False
    is_prime[1] = False

    for p in range(0, limit + 1):
        if not is_prime[p]:
            continue
        primes.append(p)
        for i in range(p * p, limit + 1, p):
            is_prime[i] = False

    return primes


def main(n: int) -> int:
    limit = 1000
    while True:
        primes = list_primes(limit)
        if len(primes) > n:
            return primes[n - 1]
        limit *= 2

こうすると(1回の list_primes の中では)リストのサイズ変更がなくなり領域の再確保やコピーもなくなり、倍数を篩によって除外するのも速くなる、ということ。

$ python3.9 eratosthenes.py 10000
104729

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.004    0.000    0.440    0.044 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/eratosthenes.py:24(main)

$ python3.9 eratosthenes.py 100000
1299709

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.092    0.009    7.978    0.798 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/eratosthenes.py:24(main)

10,000番目を求めるのも 40倍ほど速くなり、 100,000番目くらいでも 798 ms 程度で求められる。

Rustでの実装例

Pythonで満足したので次はRustで書いてみる。

試し割り法

pub fn trial_division(n: usize) -> u32 {
    let mut primes = Vec::with_capacity(n);
    primes.push(2u32);
    while primes.len() <= n {
        if let Some(prime) = (primes[primes.len() - 1] + 1..).find(|&m| {
            primes
                .iter()
                .take_while(|&e| e * e <= m)
                .all(|&e| m % e != 0)
        }) {
            primes.push(prime);
        }
    }
    primes[n - 1]
}

エラトステネスの篩

pub fn eratosthenes(n: usize) -> u32 {
    fn list_primes(limit: usize) -> Vec<u32> {
        let mut primes = Vec::new();
        let mut is_prime = vec![true; limit + 1];
        is_prime[0] = false;
        is_prime[1] = false;
        for p in 0..=limit {
            if !is_prime[p] {
                continue;
            }
            primes.push(p as u32);
            for i in (p * p..=limit).step_by(p) {
                is_prime[i] = false;
            }
        }
        primes
    }

    let mut limit = 1000;
    loop {
        let primes = list_primes(limit);
        if primes.len() > n {
            return primes[n - 1];
        }
        limit *= 2;
    }
}

アトキンの篩

全然知らなかったのだけど、エラトステネスの篩よりも速いアルゴリズムとして「アトキンの篩」というものがあるらしい。

Sieve of Atkin - Wikipedia

エラトステネスの篩の計算量 O(N \log \log N) に対し こちらは O(\frac{N}{\log \log N}) になる、とのこと。

原理は正直全然わからないけど、

  • 4 * x * x + y * y == n となる nmod 60{1, 13, 17, 29, 37, 41, 49, 53} に含まれる
  • 3 * x * x + y * y == n となる nmod 60{7, 19, 31, 43} に含まれる
  • x > y かつ 3 * x * x - y * y == n となる nmod 60{11, 23, 47, 59} に含まれる

の3つの式において xy の組み合わせの数が合計で 奇数個存在している 場合に、n素数の候補とすることが出来て、それらから平方数を除いたものが素数として列挙できるようだ。 (ちょっと解釈が間違ってるかもしれない)

とりあえず効率はあまり考えずに見様見真似で実装してみた。

pub fn atkin(n: usize) -> u32 {
    fn list_primes(limit: usize) -> Vec<u32> {
        let mut primes = Vec::new();
        if limit > 2 {
            primes.push(2);
        }
        if limit > 3 {
            primes.push(3);
        }
        let mut sieve = vec![false; limit];
        for x in (1..).take_while(|&x| x * x < limit) {
            for y in (1..).take_while(|&y| y * y < limit) {
                {
                    let n = (4 * x * x) + (y * y);
                    if n <= limit && (n % 12 == 1 || n % 12 == 5) {
                        sieve[n] ^= true;
                    }
                }
                {
                    let n = (3 * x * x) + (y * y);
                    if n <= limit && n % 12 == 7 {
                        sieve[n] ^= true;
                    }
                }
                if x > y {
                    let n = (3 * x * x) - (y * y);
                    if n <= limit && n % 12 == 11 {
                        sieve[n] ^= true;
                    }
                }
            }
        }
        for r in (5..).take_while(|&r| r * r < limit) {
            if sieve[r] {
                for i in (1..).map(|i| i * r * r).take_while(|&i| i < limit) {
                    sieve[i] = false;
                }
            }
        }
        primes.extend(
            sieve
                .iter()
                .enumerate()
                .filter_map(|(i, &b)| if b { Some(i as u32) } else { None }),
        );
        primes
    }

    let mut limit = 1000;
    loop {
        let primes = list_primes(limit);
        if primes.len() >= n {
            return primes[n - 1];
        }
        limit *= 2;
    }
}

おまけ: GMP

多倍長整数の算術ライブラリとしてGMPがあり、これのRust bindingがある。

https://crates.io/crates/rust-gmp

Mpz::nextprime() というのを呼ぶと「次の素数」を求められるらしいので、N回実行すればN番目の素数が求められそうだ。

gmp::mpz::Mpz - Rust

pub fn gmp(n: usize) -> u32 {
    let mut mpz = gmp::mpz::Mpz::new();
    for _ in 0..n {
        mpz = mpz.nextprime();
    }
    mpz.to_string().parse().unwrap()
}

Benchmark

というわけでこのあたりでbenchmarkを取ってみると

$ rustup run nightly cargo bench

...

test bench_100000_atkin           ... bench:  13,437,614 ns/iter (+/- 3,614,150)
test bench_100000_eratosthenes    ... bench:  30,768,639 ns/iter (+/- 24,144,191)
test bench_10000_atkin            ... bench:     858,282 ns/iter (+/- 724,131)
test bench_10000_eratosthenes     ... bench:   1,783,792 ns/iter (+/- 269,701)
test bench_10000_gmp              ... bench:  19,331,126 ns/iter (+/- 19,085,347)
test bench_10000_trial_division   ... bench:   2,958,690 ns/iter (+/- 6,219,626)

とりあえず gmp のものは問題外に遅かった。次に遅いのが trial_division 。そして eratosthenes だと 10,000番目は 1.78 ms程度、で 100,000番目だと 30.7 ms程度。これだけでもPython版より20倍くらい速い…

そして atkineratosthenes と比較しても2倍くらい速い!すごい!!

高速化のテクニック

エラトステネスの篩よりアトキンの篩の方が速いのでそれを使おう、で終わりにしても良いかもしれないけど、エラトステネスの篩を使う場合でも色々工夫すれば速くしていける。

例を幾つか

上限個数を見積もる

篩のアルゴリズムの性質上、「limitまでの数を用意して フラグ管理していくことで素数を列挙していく」ということしか出来ず、その結果が何個になるかは最終出力を見てみないと分からない。

「N番目の素数を求めたい」というときに limitを幾つに設定して篩にかけていけばN番目までの素数を導き出せるかが不明なので、前述の実装だと limit = 1000 から始めて素数列を列挙し、N個に満たなければ limit を倍々にしていってN番目が求められるまで繰り返していっている。

10,000番目を求めるためには limit128000になるまで8回、100,000番目を求めるためにはlimit2048000になるまで12回、list_primesを呼び出して毎回同じような篩の操作をしていることになる。

この繰り返しも無駄だし、limitが無駄に大きすぎても N番目以降の数まで篩にかける操作が発生して無駄になる。

これを避けるために、N個の素数を返すギリギリのlimitの値を設定してあげたい。 ある自然数までに含まれる素数の個数を求める研究はたくさんされているようで

素数定理 - Wikipedia

素数計数関数 - Wikipedia

などで既に求められているものを使うとかなり近いものが出せそう。とりあえずは足りなくならない程度に雑に多めに見積もってやってみる。

pub fn eratosthenes_pi(n: usize) -> u32 {
    let n_ = n as f64;
    let lg = n_.ln();
    let limit = std::cmp::max(100, (n_ * lg * 1.2) as usize);

    let mut primes = Vec::new();
    let mut is_prime = vec![true; limit + 1];
    is_prime[0] = false;
    is_prime[1] = false;
    for p in 0..=limit {
        if !is_prime[p] {
            continue;
        }
        primes.push(p as u32);
        if primes.len() == n {
            return primes[n - 1];
        }
        for i in (p * p..=limit).step_by(p) {
            is_prime[i] = false;
        }
    }
    unreachable!();
}

これで従来のeratosthenesなどと比較してみると

test bench_100000_atkin           ... bench:  13,750,106 ns/iter (+/- 5,263,586)
test bench_100000_eratosthenes    ... bench:  30,559,236 ns/iter (+/- 7,994,169)
test bench_100000_eratosthenes_pi ... bench:  10,841,103 ns/iter (+/- 7,613,241)
test bench_10000_atkin            ... bench:     984,568 ns/iter (+/- 331,771)
test bench_10000_eratosthenes     ... bench:   2,210,553 ns/iter (+/- 2,621,658)
test bench_10000_eratosthenes_pi  ... bench:     907,250 ns/iter (+/- 254,367)

これだけで格段に速くなり atkin よりも高速になった。(勿論atkinでも同様の最適化すればもっと速くなるだろうけど)

Wheel factorization

そもそもlimitまでの数を調べていくのに半分は偶数で明らかに素数じゃないし、3の倍数のものも33%ほど含まれていて無駄。5の倍数だってそれなりに存在している… ということで無駄なものを最初から省いて調べていくのが良さそう。

ということを考えると、{2, 3, 5}から始めると そこから続く 7, 11, 13, 17, 19, 23, 29, 31 だけが2の倍数でも3の倍数でも5の倍数でもなく、そこから先は30ずつの周期で同様の間隔で増加させていった数値だけ見ていけば良い。 増加周期は [4, 2, 4, 2, 4, 6, 2, 6] となり、37, 41, 43, 47, 49, 53, 59, 61, ... と増えていく。もちろんこれは2, 3, 5しか見ていないので 7の倍数である49などは残る。これらをエラトステネスの篩で消していけば良い。

こういう手法を Wheel factorization と呼ぶらしい。

Wheel factorization - Wikipedia

とりあえずは単純に for p in 0..=limit で1つずつ順番に見ていたloopの部分だけ変更。

pub fn eratosthenes_wf(n: usize) -> u32 {
    let n_ = n as f64;
    let lg = n_.ln();

    let limit = std::cmp::max(100, (n_ * lg * 1.2) as usize);

    let mut primes = vec![2, 3, 5];
    let mut is_prime = vec![true; limit + 1];
    is_prime[0] = false;
    is_prime[1] = false;

    let inc = [6, 4, 2, 4, 2, 4, 6, 2];
    let mut p = 1;
    for i in 0.. {
        p += inc[i & 7];
        if p >= limit {
            break;
        }
        if !is_prime[p] {
            continue;
        }
        primes.push(p as u32);
        if primes.len() >= n {
            return primes[n - 1];
        }
        for j in (p * p..=limit).step_by(p) {
            is_prime[j] = false;
        }
    }
    unreachable!();
}

Benchmark結果は…

test bench_100000_atkin           ... bench:  25,911,095 ns/iter (+/- 46,670,614)
test bench_100000_eratosthenes    ... bench:  33,172,283 ns/iter (+/- 24,657,454)
test bench_100000_eratosthenes_pi ... bench:  11,062,096 ns/iter (+/- 4,717,035)
test bench_100000_eratosthenes_wf ... bench:   5,971,694 ns/iter (+/- 3,127,972)
test bench_10000_atkin            ... bench:     936,174 ns/iter (+/- 178,170)
test bench_10000_eratosthenes     ... bench:   1,790,384 ns/iter (+/- 711,067)
test bench_10000_eratosthenes_pi  ... bench:     797,356 ns/iter (+/- 171,738)
test bench_10000_eratosthenes_wf  ... bench:     399,302 ns/iter (+/- 48,778)

これだけでeratosthenes_piよりさらに2倍以上速くなった!単純なeratosthenesと比較するともう圧倒的ですね。

オチ

とこのように、エラトステネスの篩を使う手法にも様々な最適化の手段があり、それらを盛り込んでいると思われる primal というcrateがあります。

https://crates.io/crates/primal

これを使って StreamingSieve::nth_prime() でN番目の素数を求められる。

pub fn primal(n: usize) -> u32 {
    primal::StreamingSieve::nth_prime(n) as u32
}

中では 入力値の範囲によってより近似された p(n) (n番目の素数を含む下限/上限値)を見積もっていたり、ビット演算を駆使して高速に篩をかけるようにしているようだ。

Benchmark結果は…

test bench_100000_atkin           ... bench:  25,911,095 ns/iter (+/- 46,670,614)
test bench_100000_eratosthenes    ... bench:  33,172,283 ns/iter (+/- 24,657,454)
test bench_100000_eratosthenes_pi ... bench:  11,062,096 ns/iter (+/- 4,717,035)
test bench_100000_eratosthenes_wf ... bench:   5,971,694 ns/iter (+/- 3,127,972)
test bench_100000_primal          ... bench:     134,676 ns/iter (+/- 11,903)
test bench_10000_atkin            ... bench:     936,174 ns/iter (+/- 178,170)
test bench_10000_eratosthenes     ... bench:   1,790,384 ns/iter (+/- 711,067)
test bench_10000_eratosthenes_pi  ... bench:     797,356 ns/iter (+/- 171,738)
test bench_10000_eratosthenes_wf  ... bench:     399,302 ns/iter (+/- 48,778)
test bench_10000_primal           ... bench:       9,083 ns/iter (+/- 3,915)

はい、さらに数十倍速くなっていて完全に大勝利です。N番目の素数を求めたければこれを使いましょう。

Repository

ツッコミあればご指摘いただけると助かります。

References