SIMDによる将棋Bitboard計算の高速化

ということでSIMDでの高速化のメモ。

SIMDとは

ja.wikipedia.org

の通り、複数のデータを1命令で同時に演算する、というもの。 将棋Bitboardは81マスのデータを表現するのに64bitでは足りないので、一般的に [u64; 2] など複数の要素を使うことになる。 このBitboard同士の論理演算を普通に実装しようとすると、

struct Bitboard([u64; 2]);

impl std::ops::BitAnd for Bitboard {
    type Output = Self;

    fn bitand(self, rhs: Self) -> Self::Output {
        Self([self.0[0] & rhs.0[0], self.0[1] & rhs.0[1]])
    }

}

のように各要素同士をそれぞれ演算した結果を新たに格納する、という操作が必要になる。 こういった場面でSIMDを利用することで、u64同士の2つの演算を1命令で同時に行うことができる。

2命令が1命令に減る程度ではあるが、合法手生成のための駒の利きの計算などで多くの論理演算を行っているので、計算の高速化のためには重要なものになる。

実装

defaultの実装にしても良いのかもしれないが、とりあえずは features として定義した。

[features]
simd = []

cargo build --features simd など指定したときのみ有効になり、これが指定されなかったり 対応する実装が存在していない場合は、SIMDを使わない shogi_core::Bitboard を使用する実装にfallbackするようにしている。

x86_64

まずはx86_64のものを実装した。128bitレジスタを利用するSSE命令をメインに、一部でAVX2を使用して256bitレジスタでの計算も行っている。

use std::arch::x86_64;

pub(crate) struct Bitboard(x86_64::__m128i);

SIMD Bitboard by sugyan · Pull Request #17 · sugyan/yasai · GitHub

基本演算

前述したような論理演算については x86_64::_mm_and_si128, x86_64::_mm_or_si128, x86_64::_mm_xor_si128 でそれぞれ128bitの計算を1命令で出来るのでそれを使う。

同値判定、ゼロ値判定には x86_64::_mm_test_all_zerosi32 の値を返してくれるのでこれで判定できる。

飛び利き計算

前回の記事 でLeading/Trailing Zerosを使う方法について書いたが、このLeading/Trailing Zerosを求めるためには一度SIMDレジスタから何らかの形で値を取り出す必要があり、SIMDとはあまり相性が良くないようだった。ので極力使わずに論理演算のみで解決する方向を模索した。

まずは一番簡単な後手香車の利き。decrementして xor とる、という方法。 1〜7筋か/8,9筋か、で対応する u64 を取り出して -1 するのが一番命令数少なくて済みそうではあるが、分岐をなくして論理演算だけで計算するようにしてみた。

 impl Bitboard {

    ...

   fn sliding_positive_consecutive(&self, mask: &Self) -> Self {
        unsafe {
            let and = x86_64::_mm_and_si128(self.0, mask.0);
            let all = x86_64::_mm_cmpeq_epi64(self.0, self.0);
            let add = x86_64::_mm_add_epi64(and, all);
            let xor = x86_64::_mm_xor_si128(and, add);
            Self(x86_64::_mm_and_si128(xor, mask.0))
        }
    }
 }

_mm_cmpeq_epi64 に同値を渡しすべて1になっているものを作り、それを足すことで両方を同時に1減じる。 ここではmaskがある筋の連続的なbitしか渡ってこない前提であり、最終的にこのmaskとの論理積をとるので関係ない部分まで変化してしまっても問題ない。

先手香車はQugiyの手法が最もシンプルで良さそうだったのでそのまま使わせていただいた。

impl Bitboard {

    ...

    fn sliding_negative_consecutive(&self, mask: &Self) -> Self {
        unsafe {
            let m = x86_64::_mm_and_si128(self.0, mask.0);
            let m = x86_64::_mm_or_si128(m, x86_64::_mm_srli_epi64::<1>(m));
            let m = x86_64::_mm_or_si128(m, x86_64::_mm_srli_epi64::<2>(m));
            let m = x86_64::_mm_or_si128(m, x86_64::_mm_srli_epi64::<4>(m));
            let m = x86_64::_mm_srli_epi64::<1>(m);
            Self(x86_64::_mm_andnot_si128(m, mask.0))
        }
    }
}

飛車・角行の飛び利きは、4方向をそれぞれ(Squareのindex的に)正方向のものと負方向のもので分けて2つずつ計算する。 本当は飛車の縦方向は香車の利きを再利用するのが良いのだとは思うけど、両駒の利きを同じinterfaceで計算できるようにするためにそれはあえてしなかった。

正方向は前述のdecrementしてxor取るものを、ちゃんと桁借りを考慮して128bit全体のdecrementになるようにした上で、AVX2の256bitレジスタを使って2方向分を同時に計算する、という形に。

impl Bitboard {

    ...

    fn sliding_positives(&self, masks: &[Self; 2]) -> Self {
        unsafe {
            let self256 = x86_64::_mm256_broadcastsi128_si256(self.0);
            let mask256 = x86_64::_mm256_set_m128i(masks[0].0, masks[1].0);
            let masked = x86_64::_mm256_and_si256(self256, mask256);
            // decrement masked 256
            let all = x86_64::_mm256_cmpeq_epi64(self256, self256);
            let add = x86_64::_mm256_add_epi64(masked, all);
            let cmp = x86_64::_mm256_cmpeq_epi64(add, all);
            let shl = x86_64::_mm256_slli_si256::<8>(x86_64::_mm256_xor_si256(cmp, all));
            let dec = x86_64::_mm256_sub_epi64(add, shl);
            // (masked ^ masked.decrement()) & mask
            let xor = x86_64::_mm256_xor_si256(masked, dec);
            let ret = x86_64::_mm256_and_si256(xor, mask256);
            Self(x86_64::_mm_or_si128(
                x86_64::_mm256_castsi256_si128(ret),
                x86_64::_mm256_extracti128_si256::<1>(ret),
            ))
        }
    }
}

負方向はおそらくQugiyのように swap_bytes して unpack して桁借りを計算して… というのが良いのだと思いつつも、前述した leading_zeros を使うのが気に入っていたのでこれを応用することにした。

SIMDレジスタ上での leading_zeros の正確な値は求めづらいが、「leftmost set bitが含まれる8bit」がどこかは見つけやすい。 _mm256_movemask_epi8 を使って、各8bitにおける最上位の値を集めた u32 を得ることができる。 ので、 _mm256_cmpeq_epi8_mm256_setzero_si256 と比較した結果に対してこのmovemaskをかけてやれば、「各8bitにおいて値が 0 ではなかったもの」をbit列で得ることが出来る。あとは16bitずつ上位と下位で切り出して反転させてから leading_zeros をとれば、どの8bitに leftmost set bit があったかを知ることができる。 ということでその8bit以降を全部塗り潰すmaskだけ用意しておく。

あとは先手香車のように >>(shift right) と or を3回だけ繰り返すことで、各8bit内で leftmost set bit 以降をすべて 1 にすることは出来るので、それと上述のmaskの論理和をとれば「leftmost set bit以降がすべて 1 になっているもの」を作り出すことはできる。

const MASKED_VALUES: [(i64, i64); 16] = [
    (0, 0),
    (0x0000_0000_0000_00ff, 0),
    (0x0000_0000_0000_ffff, 0),
    (0x0000_0000_00ff_ffff, 0),
    (0x0000_0000_ffff_ffff, 0),
    (0x0000_00ff_ffff_ffff, 0),
    (0x0000_ffff_ffff_ffff, 0),
    (0x00ff_ffff_ffff_ffff, 0),
    (-1, 0),
    (-1, 0x0000_0000_0000_00ff),
    (-1, 0x0000_0000_0000_ffff),
    (-1, 0x0000_0000_00ff_ffff),
    (-1, 0x0000_0000_ffff_ffff),
    (-1, 0x0000_00ff_ffff_ffff),
    (-1, 0x0000_ffff_ffff_ffff),
    (-1, 0x00ff_ffff_ffff_ffff),
];

impl Bitboard {

    ...

    fn sliding_negatives(&self, masks: &[Self; 2]) -> Self {
        unsafe {
            let self256 = x86_64::_mm256_broadcastsi128_si256(self.0);
            let mask256 = x86_64::_mm256_set_m128i(masks[0].0, masks[1].0);
            let masked = x86_64::_mm256_and_si256(self256, mask256);

            let eq = x86_64::_mm256_cmpeq_epi8(masked, x86_64::_mm256_setzero_si256());
            let mv = x86_64::_mm256_movemask_epi8(eq) as u32;
            let e0 = MASKED_VALUES[15 - (mv ^ 0xffff_ffff | 0x0001_0000).leading_zeros() as usize];
            let e1 = MASKED_VALUES[31 - (mv & 0xffff ^ 0xffff | 0x0001).leading_zeros() as usize];

            let m = masked;
            let m = x86_64::_mm256_or_si256(m, x86_64::_mm256_srli_epi16::<1>(m));
            let m = x86_64::_mm256_or_si256(m, x86_64::_mm256_srli_epi16::<2>(m));
            let m = x86_64::_mm256_or_si256(m, x86_64::_mm256_srli_epi16::<4>(m));
            let m = x86_64::_mm256_or_si256(
                x86_64::_mm256_srli_epi16::<1>(m),
                x86_64::_mm256_set_epi64x(e0.1, e0.0, e1.1, e1.0),
            );
            let ret = x86_64::_mm256_andnot_si256(m, mask256);
            Self(x86_64::_mm_or_si128(
                x86_64::_mm256_castsi256_si128(ret),
                x86_64::_mm256_extracti128_si256::<1>(ret),
            ))
        }
    }
}

これでとりあえずは同時に2方向ずつの飛び利きを求めることはできた。

AArch64

x86_64用のSIMDを頑張って実装したものの、じつは手元のメインの開発機としては最近M1 MBPを使っている。

のでこちらでもSIMDを使えるようにしたいと思い、NEONを使って同様に128bit計算を同時に行うように実装してみた。

use std::arch::aarch64;

pub(crate) struct Bitboard(aarch64::uint64x2_t);

SIMD Bitboard for AArch64 NEON by sugyan · Pull Request #19 · sugyan/yasai · GitHub

だいたいx86_64と同じように出来るかな…と思ったが まぁまぁ違うところがあって苦戦した。

同値判定、ゼロ値判定

x86_64::_mm_test_all_zerosi32 で値を返してくれたが、 aarch64 にはそれに相当するようなものは無い。 aarch64::vceqq_u64 なり aarch64::veorq_u64 なりで2値の比較をしても、いちいちその結果から中身を取り出して値を確認しないと bool を返すことはできない。

検索して調べてみたが、どうやら aarch64::vqmovn_u64 を使って3命令でゼロ値判定するのが最良のようだった。

impl Bitboard {

    ...

    pub fn is_empty(&self) -> bool {
        unsafe {
            let vqmovn = aarch64::vqmovn_u64(self.0);
            let result = aarch64::vreinterpret_u64_u32(vqmovn);
            aarch64::vget_lane_u64::<0>(result) == 0
        }
    }
}

VQMOVN (Vector Saturating Move and Narrow) copies each element of the operand vector to the corresponding element of the destination vector. The result element is half the width of the operand element, and values are saturated to the result width.

https://developer.arm.com/documentation/dui0489/h/neon-and-vfp-programming/vmovl--v-q-movn--vqmovun

とのことで、bit幅を半分にし 値がオーバーフローする場合は飽和させてコピーする、というものらしい。 以下の記事が詳しくて分かりやすかった。

qiita.com

ゼロ値との比較に使うぶんには飽和の影響は無く、正しく 0 か否かを保持したまま uint64x2_t から uint32x2_t に変換できる、ということになる。あとはこれを uint64x1_t として解釈した上でその要素として単一の u64 を取り出して 0 と比較すれば良い。

飛び利き計算

香車の利きはx86_64と同様にQugiyの手法で問題なかったが、飛車・角行の利きには同じことをしようにも x86_64::_mm256_movemask_epi8 と同等なものが見つからなかった。 各8bitから値をかき集めて単一の u32(もしくは u16)を得る、のようなことが簡単には出来ないらしい。

また、 x86_64::_mm_slli_si128 のようにlaneを跨いだbit shiftも出来ないようだった。まぁdecrementの桁借り計算で必要なのは64bitのlane丸ごと移動だから aarch64::vextq_u64 でどうにかはなるのだけど…

ここはまぁ仕方ないか、ということで前回の記事で書いた Leading/Trailing Zeros を使う手法で実装した。毎回 [u64; 2] の値を取得して分岐して…というのが必要になるのは微妙ではあるが そこまで遅いというわけでもないと思う。

use std::mem::MaybeUninit;

const MASKED_VALUES: [[u64; 2]; Square::NUM + 2] = {
    let mut values = [[0; 2]; Square::NUM + 2];
    let mut i = 0;
    while i < Square::NUM + 2 {
        let u = (1_u128 << i) - 1;
        values[i] = [u as u64, (u >> 64) as u64];
        i += 1;
    }
    values
};

impl Bitboard {

    ...

    #[inline(always)]
    fn values(self) -> [u64; 2] {
        unsafe {
            let m = MaybeUninit::<[u64; 2]>::uninit();
            aarch64::vst1q_u64(m.as_ptr() as *mut _, self.0);
            m.assume_init()
        }
    }
    fn sliding_positive(&self, mask: &Bitboard) -> Bitboard {
        let m = (*self & mask).values();
        let tz = if m[0] == 0 {
            (m[1] | 0x0002_0000).trailing_zeros() + 64
        } else {
            m[0].trailing_zeros()
        };
        Self(unsafe {
            aarch64::vandq_u64(
                mask.0,
                aarch64::vld1q_u64(MASKED_VALUES[tz as usize + 1].as_ptr()),
            )
        })
    }
    fn sliding_negative(&self, mask: &Bitboard) -> Bitboard {
        let m = (*self & mask).values();
        let lz = if m[1] == 0 {
            (m[0] | 1).leading_zeros() + 64
        } else {
            m[1].leading_zeros()
        };
        Self(unsafe {
            aarch64::vbicq_u64(
                mask.0,
                aarch64::vld1q_u64(MASKED_VALUES[127 - lz as usize].as_ptr()),
            )
        })
    }
}

Iterator

しかしどうも実装して実際にbenchmarkを取ってみると、すごく遅い。SIMD使わない方が速いくらい。一体何故…?と思って調べてみると、Bitboardの計算自体ではなく、そこから Square を取り出す処理が遅くなっているようだった。

合法手の生成にはそれぞれ利きの通る位置への移動を列挙したりするので、論理演算から得たBitboardからすべての対応する Square を列挙するような操作が多くなる。 これには Iterator を実装し その内部で fn pop(&mut self) -> Option<Square> を呼ぶのが初期の実装だった。

impl Bitboard {
    pub fn pop(&mut self) -> Option<Square> {
        let mut m = unsafe {
            let mut u = std::mem::MaybeUninit::<[u64; 2]>::uninit();
            aarch64::vst1q_u64(u.as_mut_ptr() as *mut _, self.0);
            u.assume_init()
        };
        if m[0] != 0 {
            unsafe {
                let sq = Some(Square::from_u8_unchecked(Self::pop_lsb(&mut m[0]) + 1));
                self.0 = aarch64::vsetq_lane_u64::<0>(m[0], self.0);
                sq
            }
        } else if m[1] != 0 {
            unsafe {
                let sq = Some(Square::from_u8_unchecked(Self::pop_lsb(&mut m[1]) + 64));
                self.0 = aarch64::vsetq_lane_u64::<1>(m[1], self.0);
                sq
            }
        } else {
            None
        }
    }
    fn pop_lsb(n: &mut u64) -> u8 {
        let ret = n.trailing_zeros() as u8;
        *n = *n & (*n - 1);
        ret
    }
}

impl Iterator for Bitboard {
    type Item = Square;

    fn next(&mut self) -> Option<Self::Item> {
        self.pop()
    }
}

内部の値の trailing_zeros() で rightmost set bit の位置を探して Square を作り、その後そのbitだけを 0 にclearする。これを値がすべてゼロになるまで繰り返す。

とはいえ trailing_zeros() を得るのはSIMDでは出来ないのでどうしても一度 u64 の値で読み込む必要はある。各要素で分岐は必要だし、それを pop_lsb で bit clear した値を書き戻す必要もある。

どうもこの読み込みと書き込みを繰り返す操作が、x86_64のときはそこまでオーバーヘッドを気にすることが無かったが AArch64では特に遅くなっていて影響が出ているようだ。 合法手生成でいうと持駒を打つ場所を列挙するときなど、多くの候補がある(すなわち多くの 1 が存在している)Bitboardに対する Square 列挙の繰り返しで影響が大きくなっていることが分かった。

では、SIMDレジスタからの読み込みを1回だけにして、その後はSIMDレジスタを使わずに処理するようにしたらどうか?ということで Iterator 用のstructを用意し、 IntoIterator を実装してそちらに pop の処理を任せる実装にしてみた。

pub(crate) struct SquareIterator([u64; 2]);

impl SquareIterator {
    #[inline(always)]
    fn pop_lsb(n: &mut u64) -> u8 {
        let pos = n.trailing_zeros() as u8;
        *n &= n.wrapping_sub(1);
        pos
    }
}

impl Iterator for SquareIterator {
     type Item = Square;

     // #[inline(always)]
     fn next(&mut self) -> Option<Self::Item> {
        if self.0[0] != 0 {
            return Some(unsafe { Square::from_u8_unchecked(Self::pop_lsb(&mut self.0[0]) + 1) });
        }
        if self.0[1] != 0 {
            return Some(unsafe { Square::from_u8_unchecked(Self::pop_lsb(&mut self.0[1]) + 64) });
        }
        None
    }
}

impl IntoIterator for Bitboard {
    type Item = Square;
    type IntoIter = SquareIterator;

    fn into_iter(self) -> Self::IntoIter {
        unsafe {
            let m = std::mem::MaybeUninit::<[u64; 2]>::uninit();
            aarch64::vst1q_u64(m.as_ptr() as *mut _, self.0);
            SquareIterator(m.assume_init())
        }
    }
}

これで、以前と同様に for sq in bb { ... } のようにfor loopですべての Square 列挙の繰り返しができるし、SIMDレジスタからの読み込みは1回で済むから速くなるはず… と思ったがbenchmarkとってみると効果が無い。。

ところが、上述のコードで // #[inline(always)]コメントアウトしている部分、fn next(&mut self) -> Option<Self::Item>inline(always) attribute をつけると劇的に改善された。どうして…!?

一応、手元の環境で様々な実装を比較しながらbenchmarkを回してみた。

https://gist.github.com/sugyan/03c1e73a253b4591dc9f29f9609ad93e

SIMDを使わずに [u64; 2]Iterator専用structを用意するのが、やはり素朴で最も速いようだ。それと同様の実装をしているはずの aarch64_iterator, x86_64_iterator や、他の pop を呼ぶ実装などはすべて遅い。 同じ実装だが next(&mut self)inline(always) をつけただけの差異しか無い aarch64_iterator_inline, x86_64_iterator_inline は最速に近い値になった。これだけでこんなに変わるものなのか…

https://godbolt.org/z/r3qj898rE

inline(always) の有無による違いも見てみたけどイマイチよく分からない… どういうことなんだろう。

ともかく、これをつけるかどうかで3〜4倍も速度差があることが分かった。x86_64では2.5倍くらいの差しか出ていなくて、この影響に気付けなかったようだ。 とはいえIterator専用structを用意する手法の方が良さそうなのでどちらもこれを採用することにした。

WebAssembly

ここまでやったらwasm32の 128-bit packed SIMD にも対応してやるぞ、ということで同様にやってみた。

use std::arch::wasm32;

pub(crate) struct Bitboard(wasm32::v128);

Implement wasm32 simd128 Bitboard by sugyan · Pull Request #23 · sugyan/yasai · GitHub

std::arch::wasm32APIが扱いやすく、それほど苦労することなく実装できた。 同値判定、ゼロ値判定には wasm32::u64x2_all_truewasm32::v128_any_true が使えるので便利。 SIMDレジスタ間の読み書きオーバーヘッドなどは特に考える必要なさそうで、とにかく命令数を少なめにしておけば良い感じ。素朴に読み込んで Leading/Trailing Zeros を使って飛び利きを計算するようにした。

各moduleで unit test を書いてあるがwasm32の場合の実行の仕方がよく分からなかった。とりあえず Wasmer を入れて、 wasm32-wasi でbuildして実行するようにしてみた。

export RUSTFLAGS="-C target-feature=+simd128" 
cargo clean
cargo build --tests --target=wasm32-wasi --features simd
wasmer run target/wasm32-wasi/debug/deps/yasai-*.wasm

また、これを使って wasm-packSIMDを有効にしたアプリケーションを作ろうとすると wasm-opt が古くてSIMDに対応していないらしく、自前で最新版をinstallして使う必要があるようだった。

ハマりどころはそれくらいだっただろうか。

Benchmark

実際どれくらい効果が出るか、と手元のPCで perft 5 を計測してみると。

x86_64

MacBook Pro (13-inch, 2017, Four Thunderbolt 3 Ports) 3.3 GHz Dual-Core Intel Core i5

$ cargo +nightly bench perft::bench_perft_5

...

test perft::bench_perft_5_from_default       ... bench: 357,959,818 ns/iter (+/- 5,399,589)

$ cargo +nightly bench --features simd perft::bench_perft_5

...

test perft::bench_perft_5_from_default       ... bench: 276,508,910 ns/iter (+/- 1,445,515)

NPS約1.3倍。

AArch64

MacBook Pro (14-inch, 2021) Apple M1 Pro

$ cargo +nightly bench perft::bench_perft_5

...

test perft::bench_perft_5_from_default       ... bench: 194,382,054 ns/iter (+/- 2,067,121)

$ cargo +nightly bench --features simd perft::bench_perft_5

...

test perft::bench_perft_5_from_default       ... bench: 163,265,674 ns/iter (+/- 4,929,186)

NPS約1.2倍。

WebAssembly

これはWebで確かめられるので実際にSafari以外のブラウザで確かめてみていただきたい。 (SafariSIMD対応していなかったの知らなかった…)

手元のChromeでは1.5倍ちかく速くなる。

感想

Benchmarkの数値を見てわかる通り、古いPCでSIMDつけてちょっと速くするより 良いPCに替えた方がSIMD無しでも圧倒的に速いので、まぁ結局はマシンパワーが強い方がつよい。

とはいえこういったBitboardの実装を差し替えるだけで効果が出る高速化というのはやる価値はちゃんとあると思うのでやってみて良かった。勉強になることもたくさんあった。

逆にこのBitboard関連ではこれ以上の高速化はそんなに望めないので、合法手生成をもっと速くしようと思ったらあとは生成ロジックの方をちゃんと見直す必要がある、ということになる。今後はここを頑張っていきたい。