Rustでつくる詰将棋Solver その後

memo.sugyan.com

の続き。1ヶ月ほど経ってちょこちょこ更新して進化した。

残課題が幾つかあったが、そのうち幾つかは解決した。

探索無限ループ

f:id:sugyan:20211111001858p:plain

上図のような問題で ▲4四飛成 △3二玉 ▲4一竜 △3三玉 ▲4四竜 △3二玉 ... を無限に探索してしまっていた問題。

結局これはプログラム中のhashに格納する値を間違えていたことに起因していたようで、そこを直すことで解決した。 また合わせて最適解の選択ロジックを修正していくことでその他の問題で起きていた不具合も解消することが出来た。

Issue: 連続王手の繰り返し · Issue #2 · sugyan/tsumeshogi-solver · GitHub

後手から始まる問題、kifファイル入力

やねうら王公式からクリスマスプレゼントに詰将棋500万問を謹呈 | やねうら王 公式サイト の問題を試してみようと思ったら後手番から先手玉を詰ませる問題が半数含まれていて対応が必要だった。 これは主に合法手生成の部分を修正することで解決した。

1行ごとにsfen文字列を読んで 解を出力していく、といったことも出来るようにした。

500万問すべては試せていないが、各ファイル先頭200問ずつの計1000問くらいは最短ではないが詰みを見つけることは出来た。 まだ解けないものもあるようなので一つ一つ調査して原因を潰していきたいところ…

また、kif形式のファイルも読み込んでparseし初期盤面を問題として解かせることが出来るように簡単なparserを書いた。 あまりに自由な形式なので完全には対応できないが…

日本将棋連盟まいにち詰将棋 で使われているkifファイルくらいならだいたい正しくparseできるようにはなっていると思う。

f:id:sugyan:20211210152939p:plain

tsumeshogi-solver/kif_converter.rs at 0.2.0 · sugyan/tsumeshogi-solver · GitHub

証明数二重カウント問題

前述の500万問のものを試していて見つかったもの。例えば以下のような問題

f:id:sugyan:20211210001433p:plain

解としては例えば △7九角成 ▲同玉 △5七角 ▲6九玉 △6八銀 ▲5八玉 △6九銀打 の7手詰になる。

これを自作solverでdfpn探索していると △7九角成 ▲5九玉 △6八角 の手順をまず探索していって、これに対しては ▲4八玉▲5八玉 があり、どちらの場合も △5七角成 ▲5九玉 で同じ局面になる。次は △6八馬引 で詰むのだけど、それを探索する前に △6八馬上 ▲4八玉 △5七馬 という筋を探索しようとする。その優先度を決める際に4手目からの ▲4八玉 △5七角成 ▲5九玉 の経路と ▲5八玉 △5七角成 ▲5九玉 の経路の両方から証明数を足し合わせることになり、両者のノードの証明数が交互に過剰に加算されて更新されていってしまい、いつまでも探索が進まない という現象が起きていた。

これに対する解決策として、元の論文では

我々は,各局面に親局面へのポインタを持たせ,場合によっては親局面をたどることで DAG を検出し, 証明数の二重カウントを回避した.

-- df-pnアルゴリズムの詰将棋を解くプログラムへの応用

といったad-hocな回避方法が本文中に書かれていた。が、これは該当する実装が末尾の疑似コードには記載されておらず 疑似コードを極力忠実に移植した今回の自作solverではそういったことをしていなかった。

1手詰の検出

書かれている通りに親局面へのポインタを持たせるように変更してみるか、、とも思ったが、そもそも上記の問題はまず6手目 ▲5九玉 になった時点で △6八馬 の1手詰を検出できていればその先のループにハマることないよね?と思い 各ORノード(攻方手番)で1手詰めの判定を挟んでみることにした。

     // ノード n の展開
     fn mid(&mut self, hash: P::T, phi: U, delta: U, node: Node) -> (U, U) {

         ...

         // 2. 合法手の生成
         let children = self.generate_legal_moves(node);
+        // 攻方の場合のみ 1手詰を探索する
+        if node == Node::Or {
+            let mut mate = false;
+            for &(m, h) in &children {
+                self.make_move(m).expect("failed to make move");
+                let len = self.generate_legal_moves(Node::And).len();
+                self.unmake_move().expect("failed to unmake move");
+                if len == 0 {
+                    self.put_in_hash(h, (INF, 0));
+                    mate = true;
+                } else {
+                    self.put_in_hash(h, (1, len as U));
+                }
+            }
+            if mate {
+                self.put_in_hash(hash, (0, INF));
+                return (0, INF);
+            }
+        }
         if children.is_empty() {
             ...
         }
         // 3. ハッシュによるサイクル回避
         self.put_in_hash(hash, (delta, phi));
         // 4. 多重反復深化
         loop {
            ...
         }
    }

元の実装では、各局面で合法手を生成した後にループしながら証明数・反証数の更新していくので 合法手の生成時点ではそれらの子ノードが詰んでいるか否かはループ内で選択されて探索されない限り判明しない。 しかしそうすると前述のように「既に詰んでいる子ノードが見えているのに 他の子ノードを無駄に探索してしまう」ということが起こる。

ので、攻方手番の場合の子ノード展開時のみ、それらの子ノードが既に詰んでいるかどうか(子ノードからさらに次の子ノードが生成できるかどうか。できないということは詰んでいる、ということになる)の判定を入れてみた。子ノードから生成される合法手数はそのまま証明数になるので、たとえ詰みが見つからなくてもその後のループでの子ノード選択時にも優先して詰みやすいノードを選択できるようになる効果もある、はず。

ただ勿論これはORノードを探索するたびに子ノードのさらに子ノードまで生成を試みることになるので 非効率にもなり得る。そのぶん早く詰みを見つけて探索を打ち切ることができる可能性もあるし、単純に余計な合法手生成とハッシュテーブルの記憶領域を消費するだけになる可能性もある。どうなるかは「問題による」。 ただ少なくとも前述のような問題はこれによって解けるようにはなるのでアリかな、と思っているが どうなんだろう…。逆にこれによって解けなくなる問題もあるっぽいので ちゃんとDAG検出するとか他の方法をとるべきかもしれない。。

Issue: Cannot solve · Issue #4 · sugyan/tsumeshogi-solver · GitHub

初手に戻るループ

前述のものに近いが、ちょっと違うパターン。下図のような問題が解けなかった。

f:id:sugyan:20211210144912p:plain

正解は ▲1八歩 △同玉 ▲1九歩 △同玉 ▲3九飛 △1八玉 ▲1九香 の7手詰で、5手目の ▲3九飛 に対して持駒の歩も香も合駒として打てない(前に進めない場所には打てない!)ので玉は逃げるしかない、というのが面白い。

それはともかく、探索していると ▲1九香 △1八香 ▲1六飛 △同玉 ▲1八香 △1七飛 ▲同香 △同玉 という手順でまったく同じ局面に戻ってくる。このとき この局面の証明数・反証数がともに ∞-1 という値でハッシュテーブルに保存されているので、ループ内での証明数の最小値と反証数の総和がともに ∞-1 になる。これによっておかしな結果になってしまっているようだったので、反証数の和が ∞-1 以上だった場合は証明数の和を強制的に 0 にすることで解決した。

Issue: Cannot solve · Issue #7 · sugyan/tsumeshogi-solver · GitHub

打ち歩詰めの誤判定

下図のような問題が不詰として判定されてしまっていた。

f:id:sugyan:20211210150309p:plain

正解は ▲2二歩 △1一玉 ▲2三桂 の3手詰なのだけど、この初手 ▲2二歩 が合法手として生成されていなかった。原因を探ってみると、使用していた shogi-rs ライブラリの中で打ち歩詰めの判定に問題があって この ▲2二歩 に対して玉が1一に逃げられないと判断してしまって打ち歩詰めのためこの歩は打てない、という結果を返してしまっていることが分かった。

修正のコードを提案して取り込んでもらって無事に解決した。

Pull Request: Fix Uchifuzume check by sugyan · Pull Request #41 · nozaq/shogi-rs · GitHub

残課題

とりあえずここまでの変更・修正でバージョン 0.2 としてTagを切っておいた。

GitHub - sugyan/tsumeshogi-solver at 0.2.0

試しに日本将棋連盟まいにち詰将棋 の2020年の問題366問を解かせてみたところ、かなり時間かかるものはあったが一応すべてに対して最短の正解ではないにしろ詰みを見つけることは出来ていた。

未着手の課題としては以下。

  • まだ解けていない問題が幾つか発見されているので、可能な限り修正していきたいところ
  • sfen形式ではない、棋譜表記の文字列で回答を出力したい
  • 最短最善解を導出できるように、は勿論していきたい
    • まずは深さ上限を指定しての探索から実装することになる、か
  • パフォーマンスの問題はまだあるので可能であれば高速化は試みたい
  • 長手数の問題はどのレベルなら解けるのか、または残り何手の状態なら現実的な時間で解けるのか、など知っておきたい
  • WASM buildしてWebアプリ上で動かしたい

Rustでつくる詰将棋Solver

f:id:sugyan:20211110195955p:plain

というわけで突然Rustで詰将棋ソルバを作りたくなり、作った。

github.com

現時点ではまだ完成度は低くて6割ほどかな…。

とはいえそこらの素直な詰将棋問題なら普通に解けると思う。 冒頭の画像は看寿賞作品の3手詰「新たなる殺意」を2秒弱で解いたもの。

先行事例

将棋プログラムの多くはC++で書かれていて 最近はRustも増えてきているのかな? しかし「詰将棋を解く」ことに特化しているものはあまり多くはなさそうだった。

なかでもRustで書かれているものはna2hiroさんによるものくらいしか無さそうで、

github.com

これを大いに参考にさせていただいた。

しかしこれはおそらく後述の盤面ハッシュ値を生成するために手元で変更を加えているものを使っているようで、手元ではビルドも出来ない。 また個人的にはsfenではなくCSAや他の形式を入出力に使いたいなどの要望もあり、そういった問題も解決しつつ自分の理解も深めるために自作してみることにした。

df-pnアルゴリズム

過去に、Goでも詰将棋ソルバを書こうとしてdf-pnアルゴリズムを使った探索プログラムを実装していた。今回はこれのリメイクということになる。

memo.sugyan.com

df-pnについては上記記事でも触れているが 他にも詳しい記事があるので貼っておきます

komorinfo.com

qhapaq.hatenablog.com

Rustにおける将棋ライブラリ

盤面の状態保持や合法手生成のために、何らかのライブラリが必要になる。

Goで作ったときは愚直にif文for文まみれで自作して使っていて、当然ながら速度も出なかった。 RustではBitboardベースの優れたライブラリが既に作られ公開されているので、これを使うことにした。

github.com

実装に必要なもの

さて、df-pnアルゴリズムは簡単にいうと「root node (初期盤面) から合法手で辿れる child nodes についてそれぞれ(証明数, 反証数) のペアを計算し保存し、それらの値を元に最良な(詰/不詰を判定しやすそうな) child node を優先的に選択して反復深化し探索していく」というもの。

玉方と攻方でそれぞれ AND/OR Tree を形成して詰か不詰を判定することにはなるのだけど、実装においては木構造を作る必要はない

盤面は各合法手を選んだ際に状態遷移して探索から戻ってきたら盤面も戻すことが出来れば良いだけで、重要なのは各盤面の状態における(証明数, 反証数)を保存・取得する方法さえあれば良い、ということ。

一般的には盤面の状態を何らかの方法でハッシュした値を計算し、そのハッシュ値によってテーブルに格納する、ということになる。

よって、

  • 盤面の状態を合法手で進めたり戻したりしつつ 各盤面の状態における一意な値を得ることが出来る
  • 各状態から得られる値を元に (証明数, 反証数)のペアを保存・取得することが出来る

というものがあれば良い。

実装

ということでここからはRustでの実装の話。

Trait

必要なもの、と前述した通りに2つの Trait を用意することにした。 証明数・反証数に関してはとりあえず u32 にしているが 一応変えやすいようにtype aliasで U としている。

use shogi::{Bitboard, Color, Move, MoveError, Piece, PieceType, Square};
use std::hash::Hash;

type U = u32;
pub const INF: U = U::MAX;

pub trait HashPosition {
    type T: Eq + Hash + Copy;
    fn hand(&self, p: Piece) -> u8;
    fn in_check(&self, color: Color) -> bool;
    fn make_move(&mut self, m: Move) -> Result<(), MoveError>;
    fn move_candidates(&self, sq: Square, p: Piece) -> Bitboard;
    fn piece_at(&self, sq: Square) -> &Option<Piece>;
    fn player_bb(&self, c: Color) -> &Bitboard;
    fn side_to_move(&self) -> Color;
    fn unmake_move(&mut self) -> Result<(), MoveError>;

    fn current_hash(&self) -> Self::T;
}

pub trait Table {
    type T;
    fn look_up_hash(&self, key: &Self::T) -> (U, U);
    fn put_in_hash(&mut self, key: Self::T, value: (U, U));

    fn len(&self) -> usize;
    fn is_empty(&self) -> bool;
}

HashPosition traitは基本的に shogi-rsshogi::Position をwrapして使うことを想定して shogi::Position で使うメソッドをほぼそのまま羅列。そこに current_hash(&self) を追加して盤面の状態からハッシュ値を取得できるようにした。

Table はまぁ値を格納したり取り出したりするだけ。途中で状態確認したりtestで使ったりすることもあるかと思い len(&self)is_empty(&self) も一応用意しておいた。

それぞれ Associated type を持つようにしていて、繋ぎこむときに利用している。

Implementations

最初から struct にせず Trait にしたのは、複数の実装方法が考えられたから。

まず Position のハッシュ方法としては、最も簡単に考えられるのが 盤面の状態を表すそれぞれの値をもとに std::hash::Hasher traitを実装したものを使って u64 の値を作る、というもの。盤面各マスの各駒や手番などを何らかの基準で数値化してしまえば計算することができる。

その他に挙げられるのが、Zobrist hashing を使う方法。将棋やチェスなどでは手番ごとの遷移はせいぜい駒が一つ動くだけなので その部分の差分だけ更新して新しいハッシュ値を作る方が効率が良い、という考えのもの。各マスにおける各駒に乱数で値を割り振っておいて、駒が移動した際には移動元と移動先の変化を XOR をとることで高速に計算できる。

Table については 最も簡単なのが std::collections::HashMap に格納する方法。これは文句なしに簡単。

より高速に処理できそうな方法として Vec に格納してしまう、という方法。 Zobrist hashing で usize の値を返すようにして その値をmaskしたものをindexとして利用するという方法が考えられる。

ということで、今回はそれぞれ2つの実装を用意してみた。

  • HashPosition の実装として DefaultHashPositionZobristHashPosition
  • Table の実装として HashMapTableVecTable

DefaultHashPositionPosition をwrapしつつ std::hash::Hash Traitを実装し、std::collections::hash_map::DefaultHasher を使ってハッシュ値を計算している。

use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

pub struct DefaultHashPosition {
    pos: Position,
}

...

impl HashPosition for DefaultHashPosition {
    type T = u64;

...

    fn current_hash(&self) -> u64 {
        let mut s = DefaultHasher::new();
        self.hash(&mut s);
        s.finish()
    }
}

impl Hash for DefaultHashPosition {
    fn hash<H: Hasher>(&self, state: &mut H) {
        Square::iter().for_each(|sq| {
            self.pos.piece_at(sq).map_or(28, p8).hash(state);
        });
        PieceType::iter().for_each(|piece_type| {
            Color::iter().for_each(|color| self.pos.hand(Piece { piece_type, color }).hash(state))
        });
        match self.pos.side_to_move() {
            Color::Black => 0.hash(state),
            Color::White => 1.hash(state),
        };
    }
}

fn p8(p: Piece) -> u8 {
    ...
}

一方、 ZobristHashPositionBitXorAssign ができ Standard: Distribution<T> の定義できるいくつかの数値型を使えるようGeneric Data Typesで定義している。 make_move() が成功したときのみ 最新のhashを元にMoveに関わる部分だけXORを取って新しいhashを計算する。

use rand::distributions::{Distribution, Standard};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

pub struct ZobristHashPosition<T> {
    pos: Position,
    table_board: [[[T; 2]; 14]; 81],
    table_hand: [[[T; 19]; 2]; 14],
    table_turn: [T; 2],
    hash_history: Vec<T>,
}

impl<T> ZobristHashPosition<T>
where
    T: Default + Copy + BitXorAssign,
    Standard: Distribution<T>,
{
    pub fn new(pos: Position) -> Self {
        // init table
        ...
        // 初期hash値だけは全マス全持駒から計算する
        let hash = ...
        Self {
            pos,
            table_board,
            table_hand,
            table_turn,
            hash_history: vec![hash],
        }
    }
}

impl<V> HashPosition for ZobristHashPosition<V>
where
    V: Copy + Eq + Hash + BitXorAssign,
{
    type T = V;

    ...

    fn make_move(&mut self, m: Move) -> Result<(), MoveError> {
        match self.pos.make_move(m) {
            Ok(_) => {
                let mut h = self.current_hash();

                h ^= ...  // mに対応する部分だけXORとる

                self.hash_history.push(h);
                Ok(())
            }
            Err(e) => Err(e),
        }
    }
    fn unmake_move(&mut self) -> Result<(), MoveError> {
        match self.pos.unmake_move() {
            Ok(_) => {
                self.hash_history.pop();
                Ok(())
            }
            Err(e) => Err(e),
        }
    }

    fn current_hash(&self) -> V {
        *self.hash_history.last().expect("latest hash has not found")
    }
}

HashMapTableVecTable は以下のような感じ。 HashMapTableEq + Hash の任意の型をキーにでき、 VecTableusize のみ受け付けて key値にmaskをかけた値で Vec の内容にアクセスする。 占有率が上がってきたら古いエントリを捨てるなどの処理も入れるべきかもしれないが、ここでは特に何もしていない。

use std::collections::HashMap;

pub struct HashMapTable<T>
where
    T: Eq + Hash,
{
    table: HashMap<T, (U, U)>,
}

impl<T> HashMapTable<T>
where
    T: Eq + Hash,
{
    pub fn new() -> Self {
        Self {
            table: HashMap::new(),
        }
    }
}


pub struct VecTable {
    table: Vec<(U, U)>,
    mask: usize,
    len: usize,
}

impl VecTable {
    pub fn new(bits: u32) -> Self {
        Self {
            table: vec![(1, 1); 1 << bits],
            mask: (1 << bits) - 1,
            len: 0,
        }
    }
}

impl Table for VecTable {
    type T = usize;
    fn look_up_hash(&self, key: &Self::T) -> (U, U) {
        self.table[*key & self.mask]
    }
    fn put_in_hash(&mut self, key: Self::T, value: (U, U)) {
        if self.look_up_hash(&key) == (1, 1) {
            self.len += 1;
        }
        self.table[key & self.mask] = value;
    }
    fn len(&self) -> usize {
        self.len
    }
    fn is_empty(&self) -> bool {
        self.len == 0
    }
}

Solver

で、こうして用意した HashPositionTable を使って Solver を作る。

pub struct Solver<P, T> {
    pub pos: P,
    pub table: T,
}

impl<P, T> Solver<P, T>
where
    P: HashPosition,
    T: Table<T = P::T>,
{
    pub fn new(pos: P, table: T) -> Self {
        Self { pos, table }
    }
    pub fn dfpn(&mut self) {
        ...
    }
}

HashPositionTable のそれぞれの Associated type T が一致することが求められるので、以下のような組み合わせで作ることが出来る

    let p = DefaultHashPosition::new(pos);
    let t = HashMapTable::new();
    let mut solver = Solver::new(p, t);

DefaultHashPositionT = u64 なので自動的に u64HashMap のkeyに使われる

    let p = ZobristHashPosition::<u64>::new(pos);
    let t = HashMapTable::new();
    let mut solver = Solver::new(p, t);

ZobristHashPosition で指定可能な任意の型が使える

    let p = ZobristHashPosition::new(pos);
    let t = VecTable::new(20);
    let mut solver = Solver::new(p, t);

VecTableT = usize なので自動的に usizeZobristHashPosition になる

それぞれ速度に影響でたり hash計算方法や格納方法によって衝突が起きる可能性が違ってきたりするはず。 解かせたい問題によって実装を切り替えるなどもアリなのかもしれない。

df-pn

で、この Solver がdf-pnアルゴリズムで探索していく。 このコア部分だけだと100行ちょっとしか必要ない。

impl<P, T> Solver<P, T>
where
    P: HashPosition,
    T: Table<T = P::T>,
{
    ...

    pub fn dfpn(&mut self) {
        let hash = self.pos.current_hash();
        // ルートでの反復深化
        let (pn, dn) = self.mid(hash, &(INF - 1, INF - 1));
        if pn != INF && dn != INF {
            self.mid(hash, &(INF, INF));
        }
    }
    // ノードの展開
    fn mid(&mut self, hash: P::T, pd: &(U, U)) -> (U, U) {
        // 1. ハッシュを引く
        ...
        // 2. 合法手の生成
        let children = generate_legal_moves(&mut self.pos);
        // 3. ハッシュによるサイクル回避
        match self.pos.side_to_move() {
            Color::Black => self.table.put_in_hash(hash, (pd.0, pd.1)),
            Color::White => self.table.put_in_hash(hash, (pd.1, pd.0)),
        }
        // 4. 多重反復深化
        loop {
            ...

            let (best, ...) = self.select_child(&children);
            let phi_n_c = ...
            let delta_n_c = ...
            let (m, h) = best.expect("best move");
            self.pos.make_move(m).expect("failed to make move");
            match self.pos.side_to_move() {
                Color::Black => self.mid(h, &(phi_n_c, delta_n_c)),
                Color::White => self.mid(h, &(delta_n_c, phi_n_c)),
            };
            self.pos.unmake_move().expect("failed to unmake move");
        }
    }
    // 子ノードの選択
    fn select_child(&mut self, children: &[(Move, P::T)]) -> (Option<(Move, P::T)>, ...) {
        ...
    }
}


pub fn generate_legal_moves<P>(pos: &mut P) -> Vec<(Move, P::T)>
where
    P: HashPosition,
{
    ...
}

合法手の生成によって child nodes を展開していくことになるが、その時点でその遷移後のハッシュ値が計算できるので generate_legal_movesMove とそれが適用された後のハッシュ値のペアを返すようにして使い回すようにしている。

generate legal moves

ここは HashPosition さえあれば計算できるので Table は必要ない。

通常の Move::Normal については単純で、 player_bb() で盤上の自陣駒が列挙できて そこから move_candidates() で移動先の候補を列挙できるので、そこからのすべての組み合わせで一旦 make_move() を実行してみる。その結果、

  • 攻方の場合は 動いた後が玉方に対する王手になっている
  • 玉方の場合は 動いた後が自陣の王手が外れている

という条件を満たしていれば legal move として選択されることになる。

pub fn generate_legal_moves<P>(pos: &mut P) -> Vec<(Move, P::T)>
where
    P: HashPosition,
{
    let mut children = Vec::new();
    // normal moves
    for from in *pos.player_bb(pos.side_to_move()) {
        if let Some(p) = *pos.piece_at(from) {
            for to in pos.move_candidates(from, p) {
                for promote in [true, false] {
                    let m = Move::Normal { from, to, promote };
                    if let Ok(h) = try_legal_move(pos, m) {
                        children.push((m, h));
                    }
                }
            }
        }
    }
    
    ...

    children
}

fn try_legal_move<P>(pos: &mut P, m: Move) -> Result<P::T, MoveError>
where
    P: HashPosition,
{
    match pos.make_move(m) {
        Ok(_) => {
            let mut hash = None;
            if pos.side_to_move() == Color::Black || pos.in_check(Color::White) {
                hash = Some(pos.current_hash());
            }
            pos.unmake_move().expect("failed to unmake move");
            if let Some(h) = hash {
                Ok(h)
            } else {
                Err(MoveError::Inconsistent("Not legal move for tsumeshogi"))
            }
        }
        Err(e) => Err(e),
    }
}

Move::Drop についてはちょっと面倒で、盤上すべて眺めて空いているマスすべてに対して持駒に持っているものを置いてみる、というのが簡単に列挙はできるのだけど、例えば攻方は相手玉に届かない位置に打つのはすべて無意味だし 玉方は飛び駒以外で王手されている際に駒を打つ手は有り得ないので、そういったものは事前に除外しておかないとかなり非効率な探索になってしまう。

ともかく shogi::Position のinterface的に、合法手の判定は「何らかの Movemake_move() してみて それが Err を返さなかったら合法」と判定するのが基本になる。明らかに非合法な手は事前に除外した方が勿論良いが、処理の先頭で判定しているものもあるので二度手間にならない程度であれば良いようには思える。

benchmark

とりあえずここまででdf-pnアルゴリズムによる探索が動くので、実装による速度の違いを調べてみた。 nightly で使える cargo bench

  • DefaultHashPosition + HashMapTable
  • ZobristHashPosition + HashMapTable
  • ZobristHashPosition + VecTable

のそれぞれの組み合わせで Solver を作り 幾つかの入力に対して探索を実行。

% cargo +nightly bench
    Finished bench [optimized] target(s) in 0.13s
     Running unittests (target/release/deps/tsumeshogi_solver-9e8e57cabb9c5863)

running 1 test
test tests::test_solve ... ignored

test result: ok. 0 passed; 0 failed; 1 ignored; 0 measured; 0 filtered out; finished in 0.00s

     Running unittests (target/release/deps/bench-2ca32c22a0dba4e1)

running 3 tests
test bench_default_hashmap ... bench:  11,456,869 ns/iter (+/- 681,707)
test bench_zobrist_hashmap ... bench:  11,205,930 ns/iter (+/- 674,396)
test bench_zobrist_vec     ... bench:  11,280,294 ns/iter (+/- 749,866)

test result: ok. 0 passed; 0 failed; 0 ignored; 3 measured; 0 filtered out; finished in 10.41s

どれも大して変わらなかった!!

おそらく現状 generate_legal_moves のところが遥かに支配的で、ここが大きなボトルネックになっている限りはハッシュの計算や格納方法での差異は殆ど出てこない、のかな…。

もうちょっとprofilingとか試してしっかり見直したいところ。

最適解の導出

ところで solver.dfpn() の探索処理が完了したところで、詰将棋の「解答」が出てくるわけではない。 手元に残るのは root から探索した各盤面に対する (証明数(以下pn), 反証数(以下dn)) が記録された Table があるだけなので、ここからそれらの値を元に詰みの経路を見つける必要がある。

root の (pn, dn)(0, INF) になっていた場合に、その root が「詰み」であると判明したことになる。その child nodes には今度は (INF, 0) になっているものが存在しているはず。その child ではまた (0, INF) 、… と辿るたびに AND/OR が反転するが ともかく正しく詰みを見つけられた場合はそういった node を見つけられるので、葉(それ以降は合法手が無い、つまり詰み)までDFSで辿っていけば詰みの手順を列挙することが出来る。

fn solve(pos: Position) {
    let mut solver = Solver::new(...);
    solver.dfpn();

    let mut answers = Vec::new();
    search_all_mates(&mut solver, &mut Vec::new(), &mut answers);
}

fn search_all_mates<P, T>(
    s: &mut Solver<P, T>,
    moves: &mut Vec<Move>,
    answers: &mut Vec<Vec<Move>>,
) where
    P: HashPosition,
    T: Table<T = P::T>,
{
    let mut leaf = true;
    for &(m, h) in &generate_legal_moves(&mut s.pos) {
        let pd = s.table.look_up_hash(&h);
        if (s.pos.side_to_move() == Color::Black && pd == (INF, 0))
            || (s.pos.side_to_move() == Color::White && pd == (0, INF))
        {
            leaf = false;
            moves.push(m);
            s.pos.make_move(m).expect("failed to make move");
            search_all_mates(s, moves, answers);
            s.pos.unmake_move().expect("failed to unmake move");
            moves.pop();
        }
    }
    if leaf {
        answers.push(moves.clone())
    }
}

これによって得られる answersdfpn() の探索で見つけることができた詰み手順だが、最短であるとは限らない。このへんは難しいところで 余詰がある場合はまた他の手順で詰むかどうかも判定する必要がある…。ここはまだ実装できていない。

ただ「攻方のある手で詰む」というのは見えているので、その中での玉方の最善応手を見つけることは出来そう。玉方は最長で駒を余らせない応手を選ぶのが正解なので、列挙された answers の中から「詰みまでの手数が最長で かつ詰んだ後の攻方の持駒が少ない(無い)もの」を解として選ぶことが出来る。

また面倒なのが「無駄合駒」で、単純に手数の長いものを選択しようとするとこれに引っ掛かる。今回は

  1. 玉方が合駒として打った駒が後に取られて
  2. 最終的に詰んだときにそれが攻方の持駒に入っている

という場合に無駄合駒だったとみなして候補から外すようにした。合駒をせずに別の方法で王手回避した場合は必ず探索しているはずなので、単純に除外するだけで良いはず…。

これで、余詰の無い問題に対しては「玉方が最善の応手をして駒の余らない最長の手順」を選択できるようになっている。と思う。

残課題

と、ここまでは上手く解けた例の話で、まだまだ上手く解けていないケースが存在しているのを把握している。 例えば以下のような問題で、

f:id:sugyan:20211111001858p:plain

(出典: https://www.shogi.or.jp/tsume_shogi/everyday/20211183_1.html)

正解手順は ▲3四飛 △同玉 ▲4四飛成 なのだけど、dfpn() で探索していると ▲4四飛成 △3二玉 ▲4一竜 △3三玉 ▲4四竜 △3二玉 ... と無限に追いかけっこが発生して詰まないループを検出し、結果として root の (pn, dn)(1, INF) という値になる。正解の詰み手順である ▲3四飛 は探索されずに終わってしまている。

元の論文では GHI (Graph History Interaction) 問題を回避するために閾値を調整して二回探索する手法で解決しているように書いてあるが、その通りに実装しているつもりなのだけど上手くいっていないようだ…。 ここはもうちょっと深く追って調べてみる必要がある。

その他にも 上記のような連続王手の千日手になる手順に辿り着く可能性がある問題に対して うまく探索が完了していないケースがありそうだ。

やねうら王公式からクリスマスプレゼントに詰将棋500万問を謹呈 の問題を食わせて様々なエッジケースを拾えるかな、と思ったが まずそもそも先手から攻める実装しか考えていなくて「後手番から始まって先手玉を詰ませる」というケースを考えていなかった。そこも対応できるように直す必要がある。

あとは前述した通り余詰のチェックもまだ出来ていないので、探索深さ上限を指定して(?)別の手順を探索できるようにはしておきたい。

優越関係などを利用することでもっと長手数のものも解けるようになるだろうか。性能限界と改善の可能性はもう少し調べたい。 shogi-rs を変更してチューニングしてもっと早くなるだろうか?とか。

その他にも入出力の形式をもっと選べるようにconverterを作っておきたいと思っている。

creative codingに入門してみている

動機・目的

上記の通り、「子に喜んでもらえるものを作る」ことを目指す。 特にshort codingなどにチャレンジしたりはせず、またgenerative artのようなものよりはどちらかというと"触って遊べる"インタラクティヴなものを優先的に。

1歳半の我が子は最近はiPadの操作も慣れてきていて、幼児向けアプリなどで画面を指で触って楽しむことは出来るようになっている。複雑な操作はまだ難しいが 直感的に動かせるようなものなら楽しんでくれるのでは(という勝手な期待)。

…とはいえ 単純に自分が面白そうと思ったからやるだけなのだけど。@nagayamaさんなど 周りでやっている人もいて見ているととても楽しそうで興味を持ったので。

環境

iPadで簡単に動かせるのが良い、ということで p5.js を使用してブラウザ上で動くものを作る。

最近Rust触っている身としては Nannou も気になるところだったけど、どうもブラウザ上での動作はまだ対応中?で厳しそうな雰囲気があったので今はまだ手を出さずにいる。

作品

1. おえかき

f:id:sugyan:20211008215333g:plain

まずは単純に、指でなぞった軌道に色がつく、くらいのものをやってみた。 その色がカラフルに変化すればまぁ面白いかな?と時間経過でHSB空間上でのHue値だけを変化させるようにした。

リセットの必要なく幾らでも描き続けられるように、時間経過とともに古い円から消えるようにすることを考えた。透過させていったり小さくしていったり試していたけど、randomで座標を揺らしながら小さくしていったら何か面白い動きになった。 そこで @amagitakayosi さんから「noiseを使うといいですよ」とアドバイスをいただいた。noiseでユラユラと揺らしながらだんだん小さくなりつつ揺れ幅を増やしていくと 亡霊のような不思議な消え方をするようになったのでこれを採用することにした。

音も出た方が良いよね、ということで p5.soundOscillator を使って三角波の音を発生させ、早く動かして円が増えるにつれて音量と周波数が増えるようにしてみたところ より亡霊っぽくなった。

2. シャボン玉

f:id:sugyan:20211008220348g:plain

noise の性質が面白かったのでもっと色んなものを揺らしてみよう、と思い シャボン玉を作ってみた。 ellipse で縦横のサイズを指定した楕円が描けるので、その各方向のサイズを noise で変化させながら さらに noise の載った回転角で rotate してやることで 泡のようなユラユラした球体の動きを表現できた。

風に吹かれて飛んでいく動きを表現するために左下方向からランダムな向きで等速直線運動、わずかに上方向には等加速度で移動する軌道をつくり、そこにさらにxy方向それぞれに noiseで微妙に揺らす。フワーっと飛んでいるような動きになった気がする。

シャボン玉っぽさを表現するのは難しかった。干渉による虹色のような光り方はどうにもムリそうだったので、個々に色をつけてそのhue値を時間経過で変化させるくらいに留めた。が 意外にそれはそれで可愛くてポップな感じになったし悪くなかった気がする。ただ赤は鮮やかすぎて気持ち悪かったのであまり赤くはならないよう範囲を調整した。 球体っぽさを出すために内側にいくにつれてalpha値が強くなるよう段階を変えながら複数の楕円を描画し、あとは何となく左上の方に白い楕円を配置するだけでわりとそれっぽい感じになったと思う。

当然タッチしたら割れるようにはしたい、と思い とりあえず適当に当たり判定つけて簡単なエフェクトはつけてみたがなんか微妙だ。ここも表現方法を工夫したいところ… 音も必要か…?

実装

というわけで、こうして作ってみた p5.js のsketchはそれ単体としても取っておきたいし、Webアプリ上でそれぞれを見られるようにしたい、と思い それらを閲覧できるような置き場を Next.js で作ってみた。

https://cc.sugyan.com/

今回、Next.jsに初めて触った。create-react-appでも良いかと思ったけどroutingを自分で書きたくなかったというのがあり Nextのdynamic routingで1個だけファイル置いておけば良いのはラクで良かったと思う。利点はそれだけかもしれないけど。 とはいえ Vercel はpushしたものがすぐに反映されて開発体験は良かった。

p5.js のsketchをReactで扱うための react-p5-wrapper というComponent libraryはあったのだけど、どうも画面遷移した際にmouse eventなどが残ってしまっていたり うまく用途と合わなかったので結局wrapper部分は最低限の自作で間に合わせた。それでもsound周りでおかしな挙動が残っていたりしてまだ完全には解決していない…。

p5.js のsketchはNextでSSRできないので Dynamic Import を利用する必要がある、というところは書き方がなかなか理解できずだいぶ苦戦した。。

あと当然ながらすべてTypeScriptで書いているのだけど、どうも @types/p5 が少し古くて最新の p5.js に追従できていない部分があったり それを更新するためのツール? p5.ts がメンテ止まってしまっていたり また p5.sound まわりはTypeScriptで上手くimportできなかったり 、など問題は色々あるように感じた。余裕があればこのへんが改善されるようcontributeしていきたいが、、、

展望

まだまだよく知らない部分も多いので、他の人の作品なども見つつ 自分なりに楽しいものが作れるようにしていきたい。 子にはまだ難しいだろうけど各種アルゴリズムの可視化とかは個人的にやってみたいところ。

しかしcreative codingはわりと沼で、ちょっと表現を変えようとパラメータいじったりするだけで数時間もってかれたりするのでハマりすぎないようには注意したいところ…。

Repository

関連

ISUCON11予選のNode.js実装を書いた

ISUCON11 予選おつかれさまでした。

ここ数年は参加者として予選敗退を繰り返してきたのだけど、今年はちょっと違う関わり方をしてみるか、と思い 「参考実装の移植」に立候補してみました。

isucon.net

Node.js担当として採用していただき、ちょっと不安もあったので id:hokaccha 氏にレビュアーとしてついてもらって、言語移植チームとして加わりました。

Node.js 実装

github.com

中身としては素朴な express のアプリケーションで、TypeScriptで実装しました。 mysql clientには mysql2/promise を使うことで async/await で簡潔に書けそうだったのでそれでやってみました。

www.npmjs.com

行数としては元実装の Go 1262行に対し 1243行とほぼ同等、それなりに丁寧に型定義などを書きつつ 元実装を忠実に再現したものになったかと思います。

昨年のisucon10のリポジトリなどをとても参考にしつつ、今回の自分の実装が未来の実装者の参考にもなるように、と気をつけて書いたつもりです。

1組だけですがNode.js実装で本選にも残ったチームも居たようで良かったです。

isucon.net

開発環境

移植作業をする時点で既にGoの参考実装やbenchmarker、そして開発環境がある程度できあがっていて、作業を進めやすくてとても助かりました。このへんは作問チームのレベルの高さをすごく感じました。

複数のコンポーネントからなるそれなりに複雑なアプリケーションながらも、docker-compose.yml や Makefile がしっかり用意されていて、そこに自分の実装する言語用の環境を用意できればコマンド一発でWebアプリが起動できるし benchmarkerを走らせたりもできる。 benchmarkerは -no-load オプションで負荷走行前のアプリケーション互換性チェックシナリオを走らせるので、まずはそこがすべて通るようになれば大体の移植は出来ていると判断できる、という具合。

勿論CIでもテストが走るようになっていて、それらが通っていてさらに作問チームメンバーやアドバイザリーからレビューを受けた上でapproveされたらmergeできる、といったルールもしっかり整備されていて、とても体験の良い開発環境でした。

実装についての疑問や相談もSlack上でいつでも作問メンバーが素早く回答してくれて、とにかく素晴らしいチームだな、と思いました。 このチームの人たちと同じプロジェクトに取り組めたというだけで 今回応募してみて本当に良かったと思っています。

Contributions

Node.jsの実装も一通り出来たところで 他の言語実装も出揃ってきていたので試しに動かしてみて、細かいエラーが出ていたところを修正したりもしました。

「オレでなきゃ見逃しちゃうね」的な細かいものとか

あとはPerl書ける人が少なかったようなので「いちおう私も多少Perlの経験あるのでレビューくらいなら」とちょっと見てみたり

benchmarkerの細かい挙動について指摘したりライブラリのバグをついたりもしました

あと実際に本番前々日くらいにサーバ上で各言語の初期実装に対してbenchmarkしてみたところ、何故かNode.jsがGo実装よりも2倍以上高いスコアが出た(!)という謎現象が見つかり、調査の結果それはNodeがたまたまbenchmarkerの秘孔をつく動作をしていたことが分かり修正されたのですが、そういうのを見つけることが出来るという点でも複数言語での移植は意義があるのだなぁと思いました。

tagomorisバグ

余談。

Perlの実装を手伝っているときに、benchmarkerを走らせていると何故か予期せぬところで 401 を返していてチェックが失敗するという現象が起きていた。

401 ってことはcookie-sessionまわりだよな〜 でも特に変なところとか無いはずだしな〜 と id:kfly8 氏と一緒に見ていて、ところで Plack::Middleware::Session::Cookiesecret"tagomoris" とかじゃなくて 今回は "isucondition" (もしくは SESSION_KEY 環境変数) で統一するって話じゃないですかと指摘してそこだけ直してもらった

https://github.com/isucon/isucon11-qualify/pull/1099/commits/57f165c73f92bac69a449912d8cb8118b05bf38c

 builder {
     enable 'ReverseProxy';
     enable 'Session::Cookie',
-        session_key => $ENV{SESSION_KEY} // 'isucondition_perl',
+        session_key => 'isucondition_perl',
         expires     => 3600,
-        secret      => 'tagomoris';
+        secret      => $ENV{SESSION_KEY} || 'isucondition',
     enable 'Static',
         path => qr!^/assets/!,
         root => $root_dir . '/../public/';

この変更を入れたら 先述の 401 のエラーが出なくなり…。

「えっ cookie の secret key を "tagomoris" から違う文字列に変えただけでバグが直ったの!?」と混乱した、という出来事。

種明かしをするとこの変更は secret 指定の末尾が ; だったのを , に変えてしまっているというミスが含まれていて。

-MO=Deparse してみると 変更前は

&builder(sub {
    enable('ReverseProxy');
    enable('Session::Cookie', 'session_key', 'isucondition_perl', 'expires', 3600, 'secret', 'tagomoris');
    enable('Static', 'path', qr"^/assets/", 'root', $root_dir . '/../public/');
    $app;
}

というものだったのが変更後は

&builder(sub {
    enable('ReverseProxy');
    enable('Session::Cookie', 'session_key', 'isucondition_perl', 'expires', 3600, 'secret', 'isucondition', enable('Static', 'path', qr"^/assets/", 'root', $root_dir . '/../public/'));
    $app;
}

と解釈されてしまう。続く Static middleware を enable した結果が Session::Cookie の引数に並べられる形になり、内部ではその値は無視されるのだろうけど、何が起こっているかというと「enable が呼ばれる順番が変わる」。

ここでまた別の事象として、benchmarkerが「/assets/ 以下に含まれるファイルをGETした際に Set-Cookie ヘッダが含まれているとそれによってbenchmark scenarioを回しているagentが別人になってしまい その後のリクエストで 401 になってしまう」というバグがあった。

つまり このpsgiでは 先に Session::Cookieenable していると その後に呼ばれる Static のファイルたちにも Set-Cookie がつくことになり、そのbenchmarkerのバグをつくことになってしまっていた。 "tagomoris" を修正した際に間違えて ;, に変えてしまたことにより意図せず StaticSession::Cookie の順に enable する形に変わっていて そのバグをつかないように変更されたので エラーに遭遇しなくなった、というオチでした。

奇妙な挙動にとても戸惑ったけど 複数の不具合が密接に絡んで起きた奇跡のような現象でした。という話。

いやー数年ぶりに書いたけどやっぱりPerlむずかしい…。 一晩で解決できたのも奇跡だったのかもしれない

本選へ

…というわけで ともかく予選の移植は無事(?)に完遂し、本番でも特に言語実装依存の不具合なく終えられたようで何よりでした。

また本選に向けて頑張っていこうと思います。

よろしくお願いします。

3Dモデルを動かせるアプリをPWAで作る

子(1歳5ヶ月)が最近すごく消防車とか救急車に興味を持っているようで、またiPadを人差し指で操作することを覚えてきているので、じゃあ好きな車を表示してグリグリ動かせるのを作ってあげよう、と思った。

Web系エンジニアとしてはやはりブラウザで動くようなものが作りやすいな、と思い Web技術を軸に実装してみた。

モデリングデータ

まずは3Dモデルを探してみた。海外の救急車のものなどはよくヒットしたが、国内のものでそれっぽく良いものはなかなか見つからなかった。 最終的にこれを購入。

booth.pm

消防車も欲しいけどここには無さそう… どこかに良いの無いかな……

Three.jsで表示

データが手に入れば、あとは Three.js のような優れたライブラリを使えば簡単に読み込んで表示して動かしたりできる。マトモに使ったことは無かったけど、exampleなど見ながらちょっと書いてみたらスッと動くものが出来上がった。

import {
  Scene,
  PerspectiveCamera,
  WebGLRenderer,
  AmbientLight,
  DirectionalLight,
} from "three";
import { OBJLoader } from "three/examples/jsm/loaders/OBJLoader";
import { MTLLoader } from "three/examples/jsm/loaders/MTLLoader";
import { OrbitControls } from "three/examples/jsm/controls/OrbitControls";

document.addEventListener("DOMContentLoaded", async () => {
  const renderer = new WebGLRenderer();
  renderer.setSize(window.innerWidth, window.innerHeight);
  document.body.appendChild(renderer.domElement);

  const camera = new PerspectiveCamera();
  camera.fov = 40;
  camera.aspect = window.innerWidth / window.innerHeight;
  camera.updateProjectionMatrix();
  camera.position.y = 3;
  camera.position.z = 8;

  const scene = new Scene();

  const aLight = new AmbientLight(0xffffff, 0.75);
  scene.add(aLight);
  const dLight = new DirectionalLight(0xffffff, 0.25);
  scene.add(dLight);

  const mtlLoader = new MTLLoader();
  const mtl = await mtlLoader.loadAsync("/data/ambulance_ob.mtl");
  const objLoader = new OBJLoader();
  objLoader.setMaterials(mtl);
  const ambulance = await objLoader.loadAsync("/data/ambulance_ob.obj");
  ambulance.position.y = -1;
  scene.add(ambulance);

  const controls = new OrbitControls(camera, renderer.domElement);
  const animate = function () {
    controls.update();
    renderer.render(scene, camera);
    requestAnimationFrame(animate);
  };
  animate();
});

PWA化

手元のブラウザでは動いても、これを子が遊ぶためのiPadで表示するにはどこかweb上にdeployしなければならない。しかし使用条件として「再配布禁止」というのがあり、web上にdeployするとそこから幾らでもダウンロードできるようになってしまう。

ので、作ったものをPWA (Progressive web apps) として配布できるように準備した。

web.dev

これも今までまともにやったことなくて苦労したが… manifest ファイルやアイコン画像やService Workerを準備し、レスポンスをキャッシュしてオフラインでも動作するように設定。 結局 iOS のアプリアイコンにするには manifesticons 指定ではダメで apple-touch-icon で指定する必要がある、とかハマりどころは多かった。

Install

PWAとして動くよう準備できたら、ngrok を使って一時的に手元のアプリをtunnelさせる。そこで発行されたURLにiPadから繋ぎ、「ホーム画面に追加」でアプリとしてインストール。 一度だけ開いてService Workerでレスポンスをキャッシュしてしまえば、あとはオフラインでも動作するアプリになる。 終わったらngrokを切断。

これで出来上がり。いつでもiPadから起動して触って動かすことができるようになる。

結果

OrbitControls での操作はまだ慣れないのか、結局あまり触ってくれず。泣

もう少し背景とか道路とか動きがあった方が興味持ってくれるかもしれないのでアップデートは必要そう

StyleGAN2で属性を指定して顔画像を生成する

f:id:sugyan:20210402004016p:plain

memo.sugyan.com

の記事の続き(?)。 ある程度の学習データを収集して学習させたモデルが出来たので、それを使って実際に色々やってみる。

StyleGAN2-ADA

前回の記事でも書いたけど、厳選した16,000枚の画像を使って StyleGAN2-ADA を使って生成モデルを学習させてみた。

github.com

これは StyleGAN2 から進化したもので、より少ない枚数からでも安定して学習が成功するようになっていて、さらにparameter数など調整されて学習や推論もより早くなっている、とのこと。

それまでのStyleGANシリーズはTensorFlowで実装されていたが、最近はPyTorchに移行しつつある?のか、今はPyTorch版が積極的に開発進んでいるようだ。 そういう時代の流れなのか…。

github.com

今回は最初に触ったのがTensorFlow版だったのでまだPyTorch版は使っておらず、本記事はすべてTensorFlowだけを使っている。

学習

256x256程度のサイズのものであれば、Google Colabでも適度にsnapshotを残しながら学習しておき 切断されたら休ませてまた続きから学習始める、を繰り返せば数日〜数週間でそれなりに学習できる感じだった。 512x512やそれ以上のサイズだとちょっと厳しいかもしれない。

mapping出力と生成画像

StyleGAN2学習済みモデルを使ったmorphing、latent spaceの探求 - すぎゃーんメモ の記事で書いたが、StyleGAN の generator は、"mapping network" と "synthesis network" の2つの network によって作られている。 実際に画像を生成するのは synthesis network の方で、前段の mapping network は乱数入力を synthesis network への入力として適したものに変換するような役割になっている。

どちらの入力も潜在空間(latent space)としてみなせるが、 synthesis network への入力(= mapping network の出力)の方が次元数も多く、より表現力を持つものになっている。 この dlatents (disentangled latents) と呼ばれるものを線形に変化させることでスムーズなmorphingを表現できることを前述の記事で確かめた。

ということはこの dlatents の中に生成画像の属性を決定させるような要素があり、例えば顔画像生成のモデルの場合は顔の表情や向きなどを表すベクトルなどが存在しているかもしれない。 というのが今回のテーマ。

生成画像の属性推定結果から潜在空間の偏りを抽出

今回試したのは、以下のような手法。

  • 生成モデルを使ってランダムに数千〜数万件の顔画像を生成
    • このとき、生成結果とともに dlatents の値もペアで保存しておく
  • 生成結果の画像すべてに対し、顔画像の属性を推定する
  • 推定結果の上位(または下位)数%を抽出し、それらを生成した dlatents たちの平均をとる

例えば表情に関する属性の場合、各生成結果の画像の「笑顔度」のようなものを機械的に推定し(もちろん手動で判別しても良いが、めちゃめちゃ高コストなので機械にやってもらいたい)、それが高scoreになっているものだけを集めて それらを生成した dlatents たちの平均値を計算する。 その値は、あらゆる顔画像を生成する dlatents の平均値と比較すると、笑顔を作る成分が強いものになっている、はず。

今回はまず学習済み生成モデルを使って適当な乱数から 20,000件の顔画像を生成し、それらに対して各属性を推定し、その結果で上位(または下位)0.5% の 100件だけを抽出するようにしてみた。 ものによってはもっとサンプル数が多く必要だったり、もっと少なくても問題ない場合もあるかもしれない。

表情推定

顔画像から表情を推定するモデルは幾つかあったが、今回は pypaz を利用した。

github.com

表情推定だけでなく、keypoint estimationやobject detectionなど様々な視覚機能を盛り込んでいて、それらを抽象化されたAPIで使えるようにしている便利ライブラリのようだ。

ここでは表情推定の部分だけを使用した。

import pathlib
from typing import Dict, List

import dlib
import numpy as np
import paz.processors as pr
from paz.abstract import Box2D
from paz.backend.image import load_image
from paz.pipelines import MiniXceptionFER


class EmotionDetector(pr.Processor):  # type: ignore
    def __init__(self) -> None:
        super(EmotionDetector, self).__init__()
        self.detector = dlib.get_frontal_face_detector()
        self.crop = pr.CropBoxes2D()
        self.classify = MiniXceptionFER()

    def call(self, image: np.ndarray) -> List[np.ndarray]:
        detections, scores, _ = self.detector.run(image, 1)
        boxes2D = []
        for detection, score in zip(detections, scores):
            boxes2D.append(
                Box2D(
                    [
                        detection.left(),
                        detection.top(),
                        detection.right(),
                        detection.bottom(),
                    ],
                    score,
                )
            )
        results = []
        for cropped_image in self.crop(image, boxes2D):
            results.append(self.classify(cropped_image)["scores"])
        return results


def predict(target_dir: pathlib.Path) -> Dict[str, np.ndarray]:
    results = {}
    detect = EmotionDetector()
    for i, img_file in enumerate(map(str, target_dir.glob("*.png"))):
        image = load_image(img_file)
        predictions = detect(image)
        if len(predictions) != 1:
            continue

        print(f"{i:05d} {img_file}", predictions[0][0].tolist())
        results[img_file] = predictions[0][0]

    return results

paz.processors として定義されたDetectorが、dlibで顔領域を検出した上でその領域に対し paz.pipelines.MiniXceptionFER によって表情を推定した結果を返してくれる。 MiniXceptionFER から返ってくるのは ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral'] の7 classesでの分類結果。

この結果で happy1.0に近いものが得られたなら、それは確信度高く笑顔である、ということなので、その顔を生成した dlatents を集めて平均値を算出し、全体平均からの偏りをvectorとして抽出した。

これをmapping出力に加えていくことで、ランダムな生成顔画像も笑顔に変えていくことができた。

f:id:sugyan:20210405225344g:plain

顔の向きや髪型も多少は影響を受けているが、概ね顔の特徴はそのままで主に口と目?あたりだけが変化している。 最初から笑顔だったものはより笑顔に、無表情だったものも口角が上がるくらいにはなっている。 口が開いて前歯が見えるようになったりするのも興味深い。

ちなみに他の表情に関しては、同様のことをやっても「怒っている顔」や「悲しい顔」のようなものは作れなかった。 まず、今の生成モデルの学習のために収集し厳選したデータは大部分が笑顔か無表情の顔画像であったため(悲しい顔ばかり載せているようなアイドルは居ない)、生成される画像も多くはそのどちらかであり、それ以外の表情の画像はほぼ生成されない。 ので、機械推定した結果もそれらの表情を強く検出するものはなく、笑顔のようにベクトルを抽出することは難しいようだ。

f:id:sugyan:20210405225113g:plain

無表情化の場合。元が笑顔のものが真顔になる程度の変化は一応ありそう。 真ん中上段の子は前髪がだいぶアレだが…。

f:id:sugyan:20210405225635g:plain

悲しみ。少し笑顔が消えて眉が困ってそうな感じになっているようには見える。 が、とても微妙…。

f:id:sugyan:20210405225510g:plain

怒り。何故か顔の向きばかり変化してしまっているが、結局表情はほとんど変化が無いようだ。

・結果まとめ

笑顔に関しては概ね上手く抽出できた。 その他の表情についてはほとんど良い結果にならなかったが、学習データに使う顔画像の表情がもっと豊富にあればそういった顔画像も生成できて抽出が可能になると思われる。

顔姿勢推定

顔の向きも推定して同様のことをしてみる。 機械学習モデルでも顔角度の推定できるものありそうだが、今回は dlib で検出したlandmarkの座標から計算する、というものをやってみた。

詳しくは以下の記事を参照。

learnopencv.com

ほぼこの記事の通りに実装して、入力画像から yaw, pitch, roll のEuler anglesを算出する。 参照している3Dモデルの座標や数によって精度も変わってきそうだが、とりあえずは上記記事で使われている6点だけのものでそれなりに正しく角度が導き出せるようだった。

import math
import pathlib
from typing import List

import cv2
import dlib
import numpy as np


class HeadposeDetector:
    def __init__(self) -> None:
        predictor_path = "shape_predictor_68_face_landmarks.dat"

        self.model_points = np.array(
            [
                (0.0, 0.0, 0.0),  # Nose tip
                (0.0, -330.0, -65.0),  # Chin
                (-225.0, 170.0, -135.0),  # Left eye left corner
                (225.0, 170.0, -135.0),  # Right eye right corne
                (-150.0, -150.0, -125.0),  # Left Mouth corner
                (150.0, -150.0, -125.0),  # Right mouth corner
            ]
        )
        self.detector = dlib.get_frontal_face_detector()
        self.predictor = dlib.shape_predictor(predictor_path)

    def __call__(self, img_file: pathlib.Path) -> List[float]:
        image = cv2.imread(str(img_file))
        size = image.shape

        # 2D image points
        points = []
        rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        dets = self.detector(rgb, 1)
        if len(dets) != 1:
            return [np.nan, np.nan, np.nan]

        d = dets[0]
        shape = self.predictor(rgb, d)
        for i in [30, 8, 36, 45, 48, 54]:
            points.append([shape.part(i).x, shape.part(i).y])
        image_points = np.array(points, dtype=np.float64)

        # Camera internals
        focal_length = size[1]
        center = (size[1] / 2, size[0] / 2)
        camera_matrix = np.array(
            [[focal_length, 0, center[0]], [0, focal_length, center[1]], [0, 0, 1]],
            dtype=np.float64,
        )

        # Calculate rotation vector and translation vector
        dist_coeffs = np.zeros((4, 1))  # Assuming no lens distortion
        success, rotation_vector, translation_vector = cv2.solvePnP(
            self.model_points,
            image_points,
            camera_matrix,
            dist_coeffs,
            flags=cv2.SOLVEPNP_ITERATIVE,
        )

        # Calculate euler angles
        rotation_mat, _ = cv2.Rodrigues(rotation_vector)
        _, _, _, _, _, _, euler_angles = cv2.decomposeProjectionMatrix(
            cv2.hconcat([rotation_mat, translation_vector])
        )

        return [
            math.degrees(math.asin(math.sin(math.radians(a))))
            for a in euler_angles.flatten()
        ]

表情と同様に、値の大きなものを生成した dlatents の平均、値の小さなものを生成した dlatents の平均、を使って顔向きを変化させるベクトルを求める。

f:id:sugyan:20210405225728g:plain

yaw (左右向き)。全体の大きさが変わってしまうので連続的に動くと不自然に感じてしまうが、ともかく顔の特徴はそのままに向きだけが変化しているのは観測できる。 顔の向きが変わっても視線は固定されている。

f:id:sugyan:20210405225817g:plain

pitch (上下向き)。yawほど顕著に差は出ないが、それなりには変化する。 真ん中上段の子の前髪はやはりハゲやすいようだ…。

f:id:sugyan:20210405225848g:plain

roll (傾き?)。これは首を傾けるように変化するはずのものだが、そもそも 学習データの前処理の段階 で正規化されているので傾いた画像が生成されるはずがない。 ので表情のときと同様に正しくベクトルを抽出できず、結果的に何故かyawと似たような動きになってしまっている。

・結果まとめ

yaw, pitch それぞれに関しては概ね上手く抽出できた。yawの方が学習データのバリエーションが多かったからか、より顕著に差が出るようだった。

髪領域推定 (顔解析)

次は髪色、髪の長さなどを変化させたい。 これらの属性を数値として得るには、まず髪の領域を抽出する必要がある。

ここでは、TensorFlowで動かせる学習済みモデルとして face_toolbox_keras を使用した。

github.com

これもpypazのように顔やlandmarkの検出など様々な機能があるが、その中の一つとして Face parsing がある。 元々は https://github.com/zllrunning/face-parsing.PyTorch で、それを移植したものらしい。 この Face parsing は入力の顔画像から「目」「鼻」「口」など約20のclassに各pixelを分類する。 髪の領域は値が 17 になっているので、その領域だけを抽出することで顔画像から髪の部分だけを取り出すことができる。

髪の領域だけ得ることができれば、あとは

  • その面積で「髪のボリューム」
  • 下端の位置・下端の幅で「髪の長さ」
  • 画素の平均値で「髪色の明るさ」

などを数値化できる。 顔姿勢と同様に、値の大きなものを生成した dlatents の平均、値の小さなものを生成した dlatents の平均、を使ってベクトルを求める。

f:id:sugyan:20210405225926g:plain

ボリューム。面積だけで計算しているのでちょっと頭の形が変になったりするかもしれない…。

f:id:sugyan:20210405230002g:plain

長さ。ボリュームよりは自然な感じで長さが変化しているように見える。 真ん中上段の子の前髪はやはり(ry

f:id:sugyan:20210405230038g:plain

明るさ。暗くするとみんな黒髪になるし、明るくすると茶髪や金髪など様々な明るい色になる。

・結果まとめ

髪の領域を推定することで、髪に関する属性を計算することができて長さや色などを変化させることができた。 もう少し頑張って上手く数値化できれば、前髪の具合や触覚の有無指定などもできるようになるかもしれない。

年齢 (上手くいかず)

これも他と同様、理論的には「顔画像から年齢を推定し、高い数値のものを生成した dlatents と 低い数値のものを生成した dlatents からベクトルを抽出」という感じで 童顔にしたり大人びた顔にしたりできると思っていたのだけど、そもそもの年齢の推定が全然正確にできなそうで断念した…。

など幾つかの学習済みモデルを使って年齢推定をかけてみたのだけど、どれも「東アジアの10〜20代女性」の学習データが乏しいのもあるのか(もしくは使い方が悪かった…?)、結果のブレが激しくて とても正しい年齢推定ができている感じがしなかった。

生成の学習に使ったデータからある程度は年齢ラベルつけたデータセットは作れるので、頑張れば年齢推定モデルを自前で学習させて より正確な推定ができるようになるかもしれないが… そこまでやる気にはならなかったので諦めた。

複合

とりあえずはここまでで

  • 表情 (笑顔◎、無表情△)
  • 顔角度 (左右◎、上下○)
  • 髪 (長さ◎、明るさ◎)

といった属性については変化させるためのベクトルが抽出できた。 ので、複数を足し合わせたりすることもできる。

f:id:sugyan:20210405230420g:plain

無表情 + 右上向き + 髪短く明るく

f:id:sugyan:20210405230514g:plain

笑顔 + 下向き + 髪長く暗く

というわけで、記事の冒頭に貼った画像は元は同じ顔からこうして変化させて作ったものでした。

Repository

N番目の素数を求める

SNSなどで話題になっていたので調べてみたら勉強になったのでメモ。

環境

手元のMacBook Pro 13-inchの開発機で実験した。

Pythonでの実装例

例1

最も単純に「2以上p未満のすべての数で割ってみて余りが0にならなかったら素数」とする、brute force 的なアプローチ。

import cProfile
import io
import pstats
import sys


def main(n: int) -> int:
    i = 0
    for p in range(2, 1000000):
        for q in range(2, p):
            if p % q == 0:
                break
        else:
            i += 1
        if i == n:
            return p
    raise ValueError


if __name__ == "__main__":
    n = int(sys.argv[1])
    with cProfile.Profile() as pr:
        for _ in range(10):
            result = main(n)
    print(result)

    s = io.StringIO()
    ps = pstats.Stats(pr, stream=s)
    ps.print_stats("main")
    print(s.getvalue())

Python3.9.1 で実行してみると

$ python3.9 1.py 1000
7919

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    2.237    0.224    2.237    0.224 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/1.py:7(main)

1,000番目のものを求めるのにも 224 msほどかかっていて遅い。10,000番目すら求めるのが大変。

例2

求めた素数をlistに入れていき、それらで割れるかどうかだけ確認していく。いわゆる「試し割り法」(trial division)というらしい。

試し割り法 - Wikipedia

def is_prime(num: int, primes: List[int]) -> bool:
    for p in primes:
        if num % p == 0:
            return False

    primes.append(num)
    return True


def main(n: int) -> int:
    i = 0
    primes: List[int] = []
    for p in range(2, 1000000):
        if is_prime(p, primes):
            i += 1
        if i == n:
            return p
    raise ValueError

これだと10倍ほど速くなって1,000番目も 27 ms程度で出てくる。

$ python3.9 2.py 1000
7919

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.015    0.002    0.275    0.027 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/2.py:17(main)

10,000番目だと 2,607 msくらい。

$ python3.9 2.py 10000
104729

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.205    0.021   26.067    2.607 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/2.py:17(main)

例3

「エラトステネスのふるい」と呼ばれるものを疑似したもの。limit までの数に含まれる素数を篩にかけて列挙していって、N個以上あればN番目を返す、無ければlimitを倍にしていく。

def faked_eratosthenes(limit: int) -> List[int]:
    nums = [i for i in range(2, limit + 1)]
    primes = []
    while True:
        p = min(nums)
        if p > math.sqrt(limit):
            break
        primes.append(p)
        i = 0
        while i < len(nums):
            if nums[i] % p == 0:
                nums.pop(i)
                continue
            i += 1
    return primes + nums


def main(n: int) -> int:
    limit = 1000
    while True:
        primes = faked_eratosthenes(limit)
        if len(primes) > n:
            return primes[n - 1]
        limit *= 2

例2のものより多少速くなるが、それほど大きくは変わらない。

$ python3.9 3.py 10000
104729

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.004    0.000   19.244    1.924 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/3.py:26(main)

エラトステネスの篩

下記記事が詳しいが、前述の例は「似非エラトステネスの篩」として言及されている。

zenn.dev

正しくは値をリストから削除するのではなくフラグで管理していく、とのこと。

def list_primes(limit: int) -> List[int]:
    primes = []
    is_prime = [True] * (limit + 1)
    is_prime[0] = False
    is_prime[1] = False

    for p in range(0, limit + 1):
        if not is_prime[p]:
            continue
        primes.append(p)
        for i in range(p * p, limit + 1, p):
            is_prime[i] = False

    return primes


def main(n: int) -> int:
    limit = 1000
    while True:
        primes = list_primes(limit)
        if len(primes) > n:
            return primes[n - 1]
        limit *= 2

こうすると(1回の list_primes の中では)リストのサイズ変更がなくなり領域の再確保やコピーもなくなり、倍数を篩によって除外するのも速くなる、ということ。

$ python3.9 eratosthenes.py 10000
104729

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.004    0.000    0.440    0.044 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/eratosthenes.py:24(main)

$ python3.9 eratosthenes.py 100000
1299709

...

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.092    0.009    7.978    0.798 /Users/sugyan/dev/sugyan/nth-prime-benchmark/python/eratosthenes.py:24(main)

10,000番目を求めるのも 40倍ほど速くなり、 100,000番目くらいでも 798 ms 程度で求められる。

Rustでの実装例

Pythonで満足したので次はRustで書いてみる。

試し割り法

pub fn trial_division(n: usize) -> u32 {
    let mut primes = Vec::with_capacity(n);
    primes.push(2u32);
    while primes.len() <= n {
        if let Some(prime) = (primes[primes.len() - 1] + 1..).find(|&m| {
            primes
                .iter()
                .take_while(|&e| e * e <= m)
                .all(|&e| m % e != 0)
        }) {
            primes.push(prime);
        }
    }
    primes[n - 1]
}

エラトステネスの篩

pub fn eratosthenes(n: usize) -> u32 {
    fn list_primes(limit: usize) -> Vec<u32> {
        let mut primes = Vec::new();
        let mut is_prime = vec![true; limit + 1];
        is_prime[0] = false;
        is_prime[1] = false;
        for p in 0..=limit {
            if !is_prime[p] {
                continue;
            }
            primes.push(p as u32);
            for i in (p * p..=limit).step_by(p) {
                is_prime[i] = false;
            }
        }
        primes
    }

    let mut limit = 1000;
    loop {
        let primes = list_primes(limit);
        if primes.len() > n {
            return primes[n - 1];
        }
        limit *= 2;
    }
}

アトキンの篩

全然知らなかったのだけど、エラトステネスの篩よりも速いアルゴリズムとして「アトキンの篩」というものがあるらしい。

Sieve of Atkin - Wikipedia

エラトステネスの篩の計算量 O(N \log \log N) に対し こちらは O(\frac{N}{\log \log N}) になる、とのこと。

原理は正直全然わからないけど、

  • 4 * x * x + y * y == n となる nmod 60{1, 13, 17, 29, 37, 41, 49, 53} に含まれる
  • 3 * x * x + y * y == n となる nmod 60{7, 19, 31, 43} に含まれる
  • x > y かつ 3 * x * x - y * y == n となる nmod 60{11, 23, 47, 59} に含まれる

の3つの式において xy の組み合わせの数が合計で 奇数個存在している 場合に、n素数の候補とすることが出来て、それらから平方数を除いたものが素数として列挙できるようだ。 (ちょっと解釈が間違ってるかもしれない)

とりあえず効率はあまり考えずに見様見真似で実装してみた。

pub fn atkin(n: usize) -> u32 {
    fn list_primes(limit: usize) -> Vec<u32> {
        let mut primes = Vec::new();
        if limit > 2 {
            primes.push(2);
        }
        if limit > 3 {
            primes.push(3);
        }
        let mut sieve = vec![false; limit];
        for x in (1..).take_while(|&x| x * x < limit) {
            for y in (1..).take_while(|&y| y * y < limit) {
                {
                    let n = (4 * x * x) + (y * y);
                    if n <= limit && (n % 12 == 1 || n % 12 == 5) {
                        sieve[n] ^= true;
                    }
                }
                {
                    let n = (3 * x * x) + (y * y);
                    if n <= limit && n % 12 == 7 {
                        sieve[n] ^= true;
                    }
                }
                if x > y {
                    let n = (3 * x * x) - (y * y);
                    if n <= limit && n % 12 == 11 {
                        sieve[n] ^= true;
                    }
                }
            }
        }
        for r in (5..).take_while(|&r| r * r < limit) {
            if sieve[r] {
                for i in (1..).map(|i| i * r * r).take_while(|&i| i < limit) {
                    sieve[i] = false;
                }
            }
        }
        primes.extend(
            sieve
                .iter()
                .enumerate()
                .filter_map(|(i, &b)| if b { Some(i as u32) } else { None }),
        );
        primes
    }

    let mut limit = 1000;
    loop {
        let primes = list_primes(limit);
        if primes.len() >= n {
            return primes[n - 1];
        }
        limit *= 2;
    }
}

おまけ: GMP

多倍長整数の算術ライブラリとしてGMPがあり、これのRust bindingがある。

https://crates.io/crates/rust-gmp

Mpz::nextprime() というのを呼ぶと「次の素数」を求められるらしいので、N回実行すればN番目の素数が求められそうだ。

gmp::mpz::Mpz - Rust

pub fn gmp(n: usize) -> u32 {
    let mut mpz = gmp::mpz::Mpz::new();
    for _ in 0..n {
        mpz = mpz.nextprime();
    }
    mpz.to_string().parse().unwrap()
}

Benchmark

というわけでこのあたりでbenchmarkを取ってみると

$ rustup run nightly cargo bench

...

test bench_100000_atkin           ... bench:  13,437,614 ns/iter (+/- 3,614,150)
test bench_100000_eratosthenes    ... bench:  30,768,639 ns/iter (+/- 24,144,191)
test bench_10000_atkin            ... bench:     858,282 ns/iter (+/- 724,131)
test bench_10000_eratosthenes     ... bench:   1,783,792 ns/iter (+/- 269,701)
test bench_10000_gmp              ... bench:  19,331,126 ns/iter (+/- 19,085,347)
test bench_10000_trial_division   ... bench:   2,958,690 ns/iter (+/- 6,219,626)

とりあえず gmp のものは問題外に遅かった。次に遅いのが trial_division 。そして eratosthenes だと 10,000番目は 1.78 ms程度、で 100,000番目だと 30.7 ms程度。これだけでもPython版より20倍くらい速い…

そして atkineratosthenes と比較しても2倍くらい速い!すごい!!

高速化のテクニック

エラトステネスの篩よりアトキンの篩の方が速いのでそれを使おう、で終わりにしても良いかもしれないけど、エラトステネスの篩を使う場合でも色々工夫すれば速くしていける。

例を幾つか

上限個数を見積もる

篩のアルゴリズムの性質上、「limitまでの数を用意して フラグ管理していくことで素数を列挙していく」ということしか出来ず、その結果が何個になるかは最終出力を見てみないと分からない。

「N番目の素数を求めたい」というときに limitを幾つに設定して篩にかけていけばN番目までの素数を導き出せるかが不明なので、前述の実装だと limit = 1000 から始めて素数列を列挙し、N個に満たなければ limit を倍々にしていってN番目が求められるまで繰り返していっている。

10,000番目を求めるためには limit128000になるまで8回、100,000番目を求めるためにはlimit2048000になるまで12回、list_primesを呼び出して毎回同じような篩の操作をしていることになる。

この繰り返しも無駄だし、limitが無駄に大きすぎても N番目以降の数まで篩にかける操作が発生して無駄になる。

これを避けるために、N個の素数を返すギリギリのlimitの値を設定してあげたい。 ある自然数までに含まれる素数の個数を求める研究はたくさんされているようで

素数定理 - Wikipedia

素数計数関数 - Wikipedia

などで既に求められているものを使うとかなり近いものが出せそう。とりあえずは足りなくならない程度に雑に多めに見積もってやってみる。

pub fn eratosthenes_pi(n: usize) -> u32 {
    let n_ = n as f64;
    let lg = n_.ln();
    let limit = std::cmp::max(100, (n_ * lg * 1.2) as usize);

    let mut primes = Vec::new();
    let mut is_prime = vec![true; limit + 1];
    is_prime[0] = false;
    is_prime[1] = false;
    for p in 0..=limit {
        if !is_prime[p] {
            continue;
        }
        primes.push(p as u32);
        if primes.len() == n {
            return primes[n - 1];
        }
        for i in (p * p..=limit).step_by(p) {
            is_prime[i] = false;
        }
    }
    unreachable!();
}

これで従来のeratosthenesなどと比較してみると

test bench_100000_atkin           ... bench:  13,750,106 ns/iter (+/- 5,263,586)
test bench_100000_eratosthenes    ... bench:  30,559,236 ns/iter (+/- 7,994,169)
test bench_100000_eratosthenes_pi ... bench:  10,841,103 ns/iter (+/- 7,613,241)
test bench_10000_atkin            ... bench:     984,568 ns/iter (+/- 331,771)
test bench_10000_eratosthenes     ... bench:   2,210,553 ns/iter (+/- 2,621,658)
test bench_10000_eratosthenes_pi  ... bench:     907,250 ns/iter (+/- 254,367)

これだけで格段に速くなり atkin よりも高速になった。(勿論atkinでも同様の最適化すればもっと速くなるだろうけど)

Wheel factorization

そもそもlimitまでの数を調べていくのに半分は偶数で明らかに素数じゃないし、3の倍数のものも33%ほど含まれていて無駄。5の倍数だってそれなりに存在している… ということで無駄なものを最初から省いて調べていくのが良さそう。

ということを考えると、{2, 3, 5}から始めると そこから続く 7, 11, 13, 17, 19, 23, 29, 31 だけが2の倍数でも3の倍数でも5の倍数でもなく、そこから先は30ずつの周期で同様の間隔で増加させていった数値だけ見ていけば良い。 増加周期は [4, 2, 4, 2, 4, 6, 2, 6] となり、37, 41, 43, 47, 49, 53, 59, 61, ... と増えていく。もちろんこれは2, 3, 5しか見ていないので 7の倍数である49などは残る。これらをエラトステネスの篩で消していけば良い。

こういう手法を Wheel factorization と呼ぶらしい。

Wheel factorization - Wikipedia

とりあえずは単純に for p in 0..=limit で1つずつ順番に見ていたloopの部分だけ変更。

pub fn eratosthenes_wf(n: usize) -> u32 {
    let n_ = n as f64;
    let lg = n_.ln();

    let limit = std::cmp::max(100, (n_ * lg * 1.2) as usize);

    let mut primes = vec![2, 3, 5];
    let mut is_prime = vec![true; limit + 1];
    is_prime[0] = false;
    is_prime[1] = false;

    let inc = [6, 4, 2, 4, 2, 4, 6, 2];
    let mut p = 1;
    for i in 0.. {
        p += inc[i & 7];
        if p >= limit {
            break;
        }
        if !is_prime[p] {
            continue;
        }
        primes.push(p as u32);
        if primes.len() >= n {
            return primes[n - 1];
        }
        for j in (p * p..=limit).step_by(p) {
            is_prime[j] = false;
        }
    }
    unreachable!();
}

Benchmark結果は…

test bench_100000_atkin           ... bench:  25,911,095 ns/iter (+/- 46,670,614)
test bench_100000_eratosthenes    ... bench:  33,172,283 ns/iter (+/- 24,657,454)
test bench_100000_eratosthenes_pi ... bench:  11,062,096 ns/iter (+/- 4,717,035)
test bench_100000_eratosthenes_wf ... bench:   5,971,694 ns/iter (+/- 3,127,972)
test bench_10000_atkin            ... bench:     936,174 ns/iter (+/- 178,170)
test bench_10000_eratosthenes     ... bench:   1,790,384 ns/iter (+/- 711,067)
test bench_10000_eratosthenes_pi  ... bench:     797,356 ns/iter (+/- 171,738)
test bench_10000_eratosthenes_wf  ... bench:     399,302 ns/iter (+/- 48,778)

これだけでeratosthenes_piよりさらに2倍以上速くなった!単純なeratosthenesと比較するともう圧倒的ですね。

オチ

とこのように、エラトステネスの篩を使う手法にも様々な最適化の手段があり、それらを盛り込んでいると思われる primal というcrateがあります。

https://crates.io/crates/primal

これを使って StreamingSieve::nth_prime() でN番目の素数を求められる。

pub fn primal(n: usize) -> u32 {
    primal::StreamingSieve::nth_prime(n) as u32
}

中では 入力値の範囲によってより近似された p(n) (n番目の素数を含む下限/上限値)を見積もっていたり、ビット演算を駆使して高速に篩をかけるようにしているようだ。

Benchmark結果は…

test bench_100000_atkin           ... bench:  25,911,095 ns/iter (+/- 46,670,614)
test bench_100000_eratosthenes    ... bench:  33,172,283 ns/iter (+/- 24,657,454)
test bench_100000_eratosthenes_pi ... bench:  11,062,096 ns/iter (+/- 4,717,035)
test bench_100000_eratosthenes_wf ... bench:   5,971,694 ns/iter (+/- 3,127,972)
test bench_100000_primal          ... bench:     134,676 ns/iter (+/- 11,903)
test bench_10000_atkin            ... bench:     936,174 ns/iter (+/- 178,170)
test bench_10000_eratosthenes     ... bench:   1,790,384 ns/iter (+/- 711,067)
test bench_10000_eratosthenes_pi  ... bench:     797,356 ns/iter (+/- 171,738)
test bench_10000_eratosthenes_wf  ... bench:     399,302 ns/iter (+/- 48,778)
test bench_10000_primal           ... bench:       9,083 ns/iter (+/- 3,915)

はい、さらに数十倍速くなっていて完全に大勝利です。N番目の素数を求めたければこれを使いましょう。

Repository

ツッコミあればご指摘いただけると助かります。

References