AT Protocol(Bluesky)のためのRustライブラリを作った

前記事OCamlやってくぞ、と書いたけど結局Rustです。

Bluesky

Twitter代替の候補として噂される(?)分散型SNS Bluesky。 現状ではまだprivate betaということで招待コードが無いと使えないのだけど、先月運良くコードをいただくことができて使い始めてみている。

一通りのSNS的な動きはしているものの、まだまだ鋭意開発中ということで足りていない機能があったり 日々新機能が追加されたりと目紛しく動いている世界のようだ。

AT Protocolとエコシステムの現状 (〜2023/04)

AT Protocolについてはこちら。

atproto.com

日本語では以下の記事が詳しいでしょう。

ざっくりとした理解では、新しい分散型SNSのために「AT Protocol」が開発されていて、いま使っているBlueskyというSNSもこのAT Protocolに従って実装されているSmall-worldの一つである、という感じなのかな。

「AT Protocol」は様々なテクノロジーの総称という感じで、その中の汎用通信レイヤーとして「XRPC」というものがある。これを使ってサーバーと通信する、ということらしい。 BlueskyはWeb UIも提供されているが、DevToolsを覗いてみる限りでは確かにすべてのデータ取得や更新などこのXPRCのリクエストによって行われているようだ。

GitHubでは以下のリポジトリが公開されている。

主要なプロトコル実装としてTypeScriptで書かれた atproto があり、そのGo版ということで indigobluesky-social organizationによって管理されている。

その他の非公式の言語実装やClient, Applicationなどが以下のrepositoryにまとめられている。

言語としてTypeScriptがメインであり それを使ってすべての現存機能が動かせるので、Webフロントエンドで完結する、ということでWeb Clientが多く作られているようだ。

Rust版の実装

上記で紹介されているライブラリでRustのものは、(2023/04時点で)以下の adenosine のみ。

このrepository ownerはBlueskyの中の人ではあるが、公式としてGitHubのorganizationに入れられはしなかった、ようだ。

Lexiconとコード生成

前述したXRPCによるメッセージングは、すべて Lexicon というスキーマシステムによって型が定義されている。 JSON Schemaのように記述されるJSONによって、すべてのリクエストやレスポンス、その中のデータ型やフィールド名などが定義される。

ので、これがAPI通信の仕様のすべて、ということになる。

前述した atprotoindigoAPIライブラリは、このLexicon schemaを示すJSONファイルからコード生成によって作成されているようだ。

この仕組みで作っていない限り、どうしてもclient libraryは最新のAPIに追従できない、ということになる。AT Protocolはまだまだすごい頻度で変更が加えられていて、lexicon schemaも毎日、とはいかないまでも毎週くらいのペースでは何らかの変更がある。

上述の adenosine は、このコード生成によって作られたものではないため、実際既に最近追加されたAPIは使えないようだった。

その他にも reqwest の blocking client に依存している、とか 結局内部のxprc clientには serde_json::Value を渡していて何でも自由にできてしまう、など 微妙だなと思う点もあったので、自分でまったく別のRustライブラリを作ってみよう、ということになった。

(一応他にもそれっぽいlibraryやrepositoryを検索して先行事例を調査してみたりもしたが、ちゃんとそういったコード生成によって作られているものはなさそうだった。)

自作Rust版実装: ATrium

github.com

AT Protocolのための小さなライブラリが集まり青空を望む中庭、ということでATrium、という名前に。最初は atprs とか雑な名前で書き始めていたけど ChatGPT に「いいかんじのプロジェクト名ないですかね」と相談したところ、ATriumという名前を提案してくれたので採用した。最近 YAPC::Kyoto の会場でも聞いた単語だったし丁度いいな、と。

Lexicon schema

まずは必要になるのはLexicon schemaのJSONを読み込むためのライブラリ。 これ自身のJSON Schemaなどが存在するわけではなく、実際のvalidationは zodを使って書かれている ので、これを必死に翻訳する形で。

このJSON自体もかなり柔軟性があり、色んな入力が有り得るので大変…。いちおう type の tag を使って判別可能なようにはなっているので、それを使って enum に振り分けることでどうにか serde で正しく読み込めるようになった。 Option なfieldが多く、一度 deserialize してから元のJSONに同等に戻せるようにすると記述が大変だったので、serde_withskip_serializing_none を使った。

unionで新しい定義を作られるところが、単純に enum に変換すると一段深くなってしまうしtagがずれても困るし…ということで仕方なくいちいちコピーしてしまっているが、もっといい方法あるだろうか。…macroを使えば良いのか?

コード生成

ともかくLexicon schemaを正しく読み込むことができれば、あとはそこからRustコードに落とし込んでいける。 細かいところはまだ実装できていないが、基本的な object などはだいたい書けている、と思う。

Lexiconではdefsの名前は camelCase だが Rustコードでは 型名は PascalCase で field名やmodule名は snake_case にしたい、などの変換が必要だったので、 heck というライブラリを使用した。

正しくnamespaceをマッピングさせていれば ref はそのまま定義済みの型として参照できるが、 union はまぁ厄介で、それ用の enum を追加で定義して Lexicon解析のとき同様に $type tagで判別して振り分けるようにする必要があった。

とりあえず正しいコードとして生成されているかどうかは cargo build で確認できて 通ればまぁ動くだろう、という判断ができて便利だった。

API設計

型の定義だけならまだラクだが、Lexiconには query もしくは procedure のXRPCリクエストの定義が含まれる。 これは関数の実装を提供することになるが、そのためにはHTTP requestを実行するためのclientが必要になる、がAPIライブラリとしてはそういったものは含めたくない、と考えた。

実はRustでHTTP clientをマトモに使ったことが無かったが、 hyper だったり isahc だったり surf だったり、いろいろあるらしい。ライブラリでClientも含めることになると、どれを使うか判断する(もしくはすべてサポートするなどの)必要がでてきてしまう。

ので、それらは回避して HttpClient というTraitを定義して、それを実装するものを外部から注入する形にした。 各XRPCリクエストは「HttpClient をSupertraitとする、デフォルト実装を持つTrait」となり、外部で実装したClientがそれらの実装を宣言すれば各リクエストが使えるようになる、ということになる。

#[async_trait::async_trait]
pub trait HttpClient {
    async fn send(&self, req: http::Request<Vec<u8>>) -> Result<http::Response<Vec<u8>>, Box<dyn Error>>;
}

最初こういうのもライブラリとして存在しているのでは、と探して http-client というのがあったので使おうとしたが、同梱されている HyperClient を使ってみたら依存のtokioのバージョンが古すぎて動かず、またこのTraitで使われている http-types も良いけど http を使ったほうがより汎用的で良いですよ、とChatGPTが教えてくれたので自前で簡単に定義することにした。

ともかく、この形にすることで、ライブラリのユーザーは一手間はかかるけど自分の好きなHTTP clientを使って各種XRPCリクエストを送ることができるようになっている。はず。

とりあえずこうして作った APIライブラリをまず atrium-api という名前でpublishした。

docs.rs https://crates.io/crates/atrium-api

Lex

ちなみに Lexicon schema解析のための型定義だけのもので一つのライブラリとして切り出して作っていて、まだ publishはしていないけどやる予定。 上記の atrium-api の設計が気に入らない、という人はこれを使って自分でコード生成を作って別のAPIライブラリを作ってくれれば良いと思います。

CLI

作成したAPIの動作実験も兼ねて、 mattn/bsky のようなCLIも作ってみている。 これはこれで便利、なのかな? バイナリ配布だけしてみても良いかもしれない。

今後

まだまだ未完成で未実装のところも多いので埋めていく予定。 あとは本家のLexicon更新を検知して自動でコード生成してpushして publishまで…が自動化できても良さそうだけど そこまでは無理か? ちょっと調べてやってみたいところ。

あとは…折角Rustで動かせるようになるならTauriとか使ってGUIアプリまで作れると良いのかな… さすがにそんな余裕は無さそうだけども。

ともかくちゃんと使ってもらえるかもしれない(現時点で22 starsとかだけど)OSSっぽいものを初めて作った気がするので、まずは色々整備していきたいところ。

40歳から始める関数型言語、OCaml

というわけで、今年に入ってからのここ数ヶ月、OCamlを勉強し始めている。

動機

これまで「関数型言語をちゃんと触ったことがない」ということが若干コンプレックスになっていた。40歳になった今、唐突に始めてみる気になったので、やることにした。

Why OCaml

そもそも関数型にどんな言語があるのか、それぞれがどんな特徴であり、どういったところで使われているのか、という情報すらほとんど知らなかったので、「関数型言語やってみたいんですけど、どれがいいんすかね?」と某所で雑に相談したところ、OCamlを挙げてもらったので、素直にそれに従ってみることにした。

そんな適当で良いのか?とも思うけど、とにかくどれであっても触ってみないことには何も分からないし。実際、OCamlを選んで間違いは無かったと今は思う。

学習方法

Real World OCaml

とにかくOCamlの学習にはこれが良いらしい、という話をきいた。

dev.realworldocaml.org

ので、これを最初の数章は頑張って読んでみて、あとは実際にコードを書いてみつつ必要になったら参照して勉強していく、という感じで。 実際とても詳しく丁寧に書かれていて、とても良い本だと思う。

Github Copilot と ChatGPT

これらはもう現代のプログラマにとって欠かせないものになっていると思う。

最近のメジャーな言語と比較してOCamlに関する学習データがどの程度の量と質で、どの程度優れたコード出力能力であるのかは分からないが、少なくとも初心者から始める人間にとっては十分に助けになる。

特にChatGPTは、環境構築からライブラリの使い方などは勿論、「Rustでこう書いてたのはOCamlではどう書くの」とか「こういうのできないの」「こう書いてみたけど他にも書き方あるのかな」みたいな、独り言で呟くようなことでもChatで気軽に書いて質問できて体験が良かった。 めちゃくちゃ親しくて時間を奪うこと気にせずに何でも質問できる友人、が隣に居ればそういう人に頼れるかもしれないが、そういう友達が居ない人間にとってはこうした心理的なハードル低く幾らでも質問できる存在はとても有り難いと思う。

とはいえやはり正確性は保証されなくてChatGPTも平気で実在しないライブラリ関数をでっちあげて使用例を提示してきたりするし、油断せずにちゃんと公式ドキュメントなど探して確認しながら書く必要はある。 このへんは仕方ないと腹を括って、「俺が学習データを提供してやるぜ」くらいの気概でできるだけ良い(と自分が思う)コードを書いてとにかく公開していく。

オンラインジャッジ (競プロ)

簡単なコード片を書いて正しい入出力を得られたか確認できる、また別の人の書いた同じ目的のコードを気軽に参照できる、という点で競技プログラミングの過去問などは良い学習材料になると思う。とりあえず他人と競うことは考えず、オンラインジャッジシステムとして利用させてもらう。

個人的には LeetCode でやりたかったが、LeetCodeにはOCamlの選択肢が無い…。ので、 AtCoder を使うことにした。 AtCoder Problems の「Boot camp for Beginners」 というのがあることを @naoya_ito さんのTweet で知ったので、真似してこれらを解いてみることにした。

とはいえAtCoderOCamlのランタイムや使えるライブラリが少し古くて(2023/04時点)、手元で Core 最新を使って動くコード書けて意気揚々とSubmitしてみたら CE (Compilation Error) 出てガックリすることも。 ここは今後 言語のアップデート で更新されてより使いやすくなることを期待している。

簡単な問題を解くことで基礎的なデータ構造やアルゴリズムの扱いは練習できるが、難易度が上がると当然 数学的な思考力やより高度なアルゴリズムの知識が主眼になってくる。 そこは言語習得とはまた別の話になってくるのでそこまで頑張る必要はない(本当に競プロを頑張るならC++とかRustとか他の言語でやったほうが良いと思う)、ということでほどほどにして止めるつもり。

Advent of Code

より実践的に「プログラミングによる問題解決」をできるようにする練習として、 Advent of Code は非常に良い題材だと思う。 AtCoderのような競技プログラミング的な要素もあるが、それだけではなく

  • 入力データが必ずしも扱いやすいものばかりではない
    • → 様々な形式に対応したparse処理を書く練習になる
  • part1/part2 で同じ入力から異なる解を出す、など出力も様々な型・形式になり得る
    • → interfaceを考える練習になる
  • アルゴリズムだけで解決しないこともある
    • → 力技での膨大な探索や泥臭い実装などが必要になることも

といった点でまた別の実装力を鍛えることができると考えられる。

2022を完走した記事 にも書いたが、2022年の問題は特にバランス良く様々な題材のものがあって言語習得の練習に十分に適している、と感じた。

というわけで、OCamlAdvent of Code 2022 を解くチャレンジをしている。

github.com

まだ18日目までだが、どうにか最後まで頑張りたい…!

その次?

やはり「問題を解く」だけではなく、最終的には「何らかのアプリケーションを自作する」ことができるくらいにはなりたいので、それを考えていく。

自分の場合は「過去に他の言語で作ったことがあるもの」「自分が使いたいもの」を作るのが最もモチベーションが高くなるので、そういうもので考えると、、、今はコンピュータ将棋関連だろうか。 幸い(?) opam を検索してみた限りではビットボード実装や将棋関連のライブラリなどは存在していないようなので、自分で作ってみるのは面白そうだ。関数型言語での将棋プログラムなんて全然需要は無さそうだけどw とりあえずOCamlでの実装がC++やRustと比較してどの程度の速度になるのかは知りたいので perft できるくらいまではやってみたい。

あとは今の流行でいうと Bluesky で使われている AT Protocol の実装をしてみる、とかだろうか。これもわざわざOCamlでやろうとする人はあんまり居なそう…

所感

とりあえず数週間触ってみての感想。

|> は多用することが分かったので自作キーボード(Claw44)のキーマップにマクロとして追加した。

関数型という概念

結局、現状どれくらい理解できているだろうか…。まだまだ理解が浅いとは思う。

基本的に Base ライブラリを使用しているが、 List に対する様々な処理が再帰で表現できるというのを見て驚かされた。というか List というデータ構造がこんなにも使いやすくて色々できるんだなぁ、と少し親しみを持てるようになった。 とりあえず、 List などを使って |> でパイプを繋げてデータを変換していく、という操作は慣れてきた。

最初はどうしても「手続き型の言語だと普通こう書くので、それを関数型で表現すると…」と変換していく感じにはなりがちだったが、だんだん「この関数とこの関数を組み合わせてこういう関数を作って、そしてこの入力にこういう出力を返す関数を作る、」という感じで考えられるようになってきた気がする。

高階関数の部分適用とか使ってみるとめっちゃ便利だな〜と感心したり。

プログラミングHaskell という本を何年か前に買ったものの序盤ですぐに挫折してしまっていたのだけど、OCamlを通じて多少なりとも理解が深まってきたおかげか、今は再チャレンジして楽しく読めている。

OCamlの書き味

思っていた以上にツールチェインなどが整っていて、環境構築からエディタ設定などほとんど問題なくできて良かった。 VSCodeでは OCaml Platform 入れるだけだし、 ocamlformat を使って自動整形もできる。このあたりの「揃っていて欲しい」ものは大抵ちゃんと揃っている。

コンパイラのエラーメッセージは正直あまり親切な感じはなくて、型が合わないのはどこをどうすれば… みたいなのがなかなか詰まりやすかったりした。とはいえ「ちゃんとコンパイルが通ればだいたいちゃんと動く」という感覚がRustに近いものがある、と感じる。

あとは dune がRustでいう cargo と同じような感覚で使えるので、これも便利。

Rust, Python の経験

Rustを書いたことがあると、色んなところで「似ているな」と感じるところがあった。例えば option とかそれに対するパターンマッチとか。 Rustのiteratorをメソッドチェーンで繋いで処理するのと同じような感覚で |> で繋ぐ処理を書けるな、とか。 Rustがまったくの未経験だったらもっとOCamlに戸惑っていたかもしれない。

Pythonはそんなに似てないかな…? Haskellみたいに内包表記があるともっと違う感想になったかな。 ただOCamlの影響で自分が書くPythonコードがちょっと変わってきたような気がする。関数の使い方をより意識するようになったというか、なんというか。上手く言語化できないけど。

関数型言語をやるとプログラミングが上手くなる」という言説を聞いたことがあるが、こういうところで少しは効果がでるのかな…?気のせいかもしれない。

AIとの親和性

ところで Github Copilot にコードの補完をしてもらうときに、関数のシグネチャの情報はわりと重要だと思っていて。 「こういう名前の関数でこの型のこの名前の引数を受け取って、この型の値を返す」という情報が書いてあると当然補完も正確になりやすいだろう、と。

一方でOCaml型推論が強力なので、型はあるもののtype annotationを書く必要が殆どなく、またちょっとした処理は無名関数を書くことが多いし、そうでなくとも let rec loop acc x = ... とか let f x y = ... とか、簡潔な書き方になりがち。 そうすると、補完しようとするAIとしては情報が足りなすぎて難しいよなぁ、ということを思った。

明瞭な関数名や型注釈をいちいち書くのは面倒で人間としては省略できる方が有り難いが、それはそれでAIに助けてもらいながら書く場合には意図や情報を渡せるという点で良いのかもしれない。 今後新しく作られる言語の設計もそういう観点が考慮されるようになってくるんだろうか、とか。

まとめ

コードを書く仕事はAIに奪われていくかもしれないが、プログラミングは楽しいので好きなように書きたいだけ書いたら良い。

YAPC::Kyoto 2023 に参加した #yapcjapan

yapcjapan.org

おそらく YAPC::Tokyo 2019 以来?4年ぶりにオフラインのイベントに参加しました。地元開催ということで日帰りで行けてありがたい…! 最初だけちょろっと家族で参加。

子ゃーんの相手をしてくださった方々、ありがとうございました。

久しぶりすぎて誰と会ってもキョドってしまうくらいだったけど、とにかく色んな人たちとリアルで再会できて嬉しかった!!

そんなに長居はしてなかったのでトークは多くは訊けなかったけど、

id:ar_tama さんの『あの日ハッカーに憧れた自分が、「ハッカーの呪縛」から解き放たれるまで』がとても良かったです。 このトークを聴いた若手のエンジニアがこれを心に刻んで大きく成長し、また10年後のYAPCで良い発表してくれるといいな、と思いました。

(…と思ったら本当にベストトーク賞だったようで。おめでとうございます!)

とにかく楽しい1日でした。 運営の皆様、今回もありがとうございました!!

Advent of Code 2022 を完走した

毎年12月に開催されている Advent of Code に、2019年から参加している。

過去記事:

2022年のAdvent of Codeにも挑戦していて、年が明けてしまったが先日ようやく25日すべての問題に解答して 50 個のスターを集めることができた。

2022年の問題もどれもとても面白かった。

データ構造を扱う問題、ある程度のアルゴリズム知識が求められる問題、ちょっとした数学的な考え方が必要な問題、ひたすら実装が大変な問題、視覚化が面白い問題、などが揃っているのは毎年だが、特に2022年はバランス良くバラエティに富むものが出題されていて取り組み甲斐があったように思う。

工夫のしどころも多くあって、自分では上手く解いたつもりでも Reddit で他の人の解答を見ると目から鱗のものがたくさん発見できたり。

基本的にはヒント無しで自力で正解まで辿り着いたが、day19のpart2だけはどうにもならなくて他の人の解答を見たりしてようやく、という感じになってしまった。悔しい。

全部Rustで解いていて、あとでPythonでも解こうとしている。

github.com

とても楽しかったし勉強になったので、もうちょっと同じ問題に取り組んだ人達でわちゃわちゃと「ここが難しかった」「これはこんな手法があって」「こんな書き方もあって」など議論できると嬉しい…。 それはRedditで英語で頑張れば良いのだろうけど、ともかくもう少し日本語圏の開発者の間でも流行って仲間が増えてくれるといいな、と思っている。

という想いもあり、2020のときにもやったけど 日本語の解説本を書いてみている。

github.com

まだ半分もいってないけど… 少しでも多くのプログラマの人たちに届くといいな

2023パズル をRustで解いてみる

tkihiraさんの問題が面白そうだったので挑戦してみた。

既に解説記事が出ているので解答はこちらをどうぞ。

nmi.jp

結局自分は自力では解けなくて 他の人の解法や上記の解説記事を読んでようやくできた、のだけど… 自分なりに理解して改めてRustで実装してみた。

RPN(逆ポーランド記法)の backtracking

まずは型定義。

#[derive(Clone, Copy, PartialEq, Eq)]
enum Op {
    Add,
    Sub,
    Mul,
    Div,
}

#[derive(Clone)]
enum ExpressionElement {
    Operand(i32),
    Operator(Op),
}

この ExpressionElement を stack に詰んでいって、末尾が Operator だったらさらに Operand を 2つ pop() して、演算結果を push() していくことで計算が実現できる。

常に Operand の方が多い状態でないと Operator は挿入できないので、各々の出現回数をカウントしながら条件を満たしているときだけ操作するようにする。

様々な実装が試せるように Trait を定義。

trait Rpn {
    fn traverse(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
    ) {
        if i == j + 1 && self.get(i).is_none() {
            if self.evaluate(expr) {
                results.push(expr.clone());
            }
            return;
        }
        if let Some(&n) = self.get(i) {
            self.backtrack(expr, (i + 1, j), results, ExpressionElement::Operand(n));
        }
        if i > j + 1 {
            for op in &[Op::Add, Op::Sub, Op::Mul, Op::Div] {
                self.backtrack(expr, (i, j + 1), results, ExpressionElement::Operator(*op));
            }
        }
    }
    fn evaluate(&self, expr: &[ExpressionElement]) -> bool;
    fn get(&self, i: usize) -> Option<&i32>;
    fn backtrack(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
        e: ExpressionElement,
    ) {
        expr.push(e);
        self.traverse(expr, (i, j), results);
        expr.pop();
    }
}

self.get(i)i番目に使う数字を呼び出す。すべて使い切っていて Operatorも適切な回数使われていたら self.evaluate(expr) で計算結果を確認し、条件に合致していた場合のみ その exprresults に格納する。

exprpush / pop して再帰的に self.traverse() を呼ぶ backtracking の部分は別の実装で上書きできるようにあえて独立のmethodで定義。

何も難しいことを考えずにこれを満たす実装を作ると、以下のようになる。

struct Fraction(i32, i32);

struct DefaultSearcher {
    nums: Vec<i32>,
    target: i32,
}

impl DefaultSearcher {
    pub fn new(nums: Vec<i32>, target: i32) -> Self {
        Self { nums, target }
    }
}

impl Rpn for DefaultSearcher {
    fn evaluate(&self, expr: &[ExpressionElement]) -> bool {
        let mut stack = Vec::new();
        for e in expr {
            match e {
                ExpressionElement::Operand(n) => stack.push(Fraction(*n, 1)),
                ExpressionElement::Operator(op) => {
                    if let (Some(n0), Some(n1)) = (stack.pop(), stack.pop()) {
                        if let Some(n) = op.apply(&n1, &n0) {
                            stack.push(n);
                        } else {
                            return false;
                        }
                    }
                }
            }
        }
        stack
            .last()
            .map(|n| n.1 * self.target == n.0)
            .unwrap_or(false)
    }
    fn get(&self, i: usize) -> Option<&i32> {
        self.nums.get(i)
    }
}

impl Op {
    fn apply(&self, lhs: &Fraction, rhs: &Fraction) -> Option<Fraction> {
        match self {
            Self::Add => Some(Fraction(lhs.0 * rhs.1 + rhs.0 * lhs.1, lhs.1 * rhs.1)),
            Self::Sub => Some(Fraction(lhs.0 * rhs.1 - rhs.0 * lhs.1, lhs.1 * rhs.1)),
            Self::Mul => Some(Fraction(lhs.0 * rhs.0, lhs.1 * rhs.1)),
            Self::Div if rhs.0 != 0 => Some(Fraction(lhs.0 * rhs.1, lhs.1 * rhs.0)),
            _ => None,
        }
    }
}

解説記事では浮動小数点でも十分ということでやっていたけど、自分的には整数値だけで処理したいので Fraction(i32, i32) を定義して分数で計算するようにした。約分のことはとりあえず考えなくてよくて、最終的な値の計算結果の確認は 分子が分母の 2023 倍になっているか否か、で判定できる。 各Opの適用結果を Op::apply で計算。ゼロ除算の可能性があるので返り値は Option<Fraction> に。

self.evaluate() 内で stackを用意し、exprを順に見ながら Operand だったら Fraction に変換してpushし、 Operator だったら 2つを pop して演算結果をまた格納。理論上 正しい形式のRPN式であれば最終的にはstackには1要素だけ残るはずなので、その値を確認すれば良い。

これでとりあえず、(98の間の割り算を無視して)計算結果が 2023 になる式の組み合わせを列挙できる。

let nums = (1..=10).rev().collect();
let mut results = Vec::new();
DefaultSearcher::new(nums, 2023).traverse(&mut Vec::new(), (0, 0), &mut results);

手元の環境で実行すると 3370 件の結果が約2分弱で列挙された。

...

10-(9-(((8*(7*6))+(5-4))*(3*(2*1))))
10-(9-(((8*(7*6))+(5-4))*(3*(2/1))))
10-(9-(((8*(7*6))+(5-4))*(3+(2+1))))
10-(9-(8*(((7*(6+(5/4)))*(3+2))-1)))
10-(9-(8*((7*((6+(5/4))*(3+2)))-1)))
Completed 3370 results in 111.933449708s

探索の高速化

計算結果を保持

前述の実装だと、式が完成するたびにstack用意して計算結果を確認して… というのが非常に無駄になる。解説記事によると 1,274,544,128 通り生成されるようなので、ここはもっと計算量を削りたい。

途中までの計算結果は変わらないものなので、それを保持しながら探索した方が効率的になる。 内部に別のstackを持って、 Operator を受け取った時点で演算処理結果を格納しておくようにしておく。

前述の self.evaluate() でやっていた計算を self.backtrack() の中で逐次やっておく、という感じ。

pub struct FastSearcher {
    nums: Vec<i32>,
    target: i32,
    stack: Vec<Fraction>,
}

impl FastSearcher {
    pub fn new(nums: Vec<i32>, target: i32) -> Self {
        let stack = Vec::with_capacity(nums.len());
        Self {
            nums,
            target,
            stack,
        }
    }
}

impl Rpn for FastSearcher {
    fn evaluate(&self, _: &[ExpressionElement]) -> bool {
        self.stack[0].1 * self.target == self.stack[0].0
    }
    fn get(&self, i: usize) -> Option<&i32> {
        self.nums.get(i)
    }
    fn backtrack(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
        e: ExpressionElement,
    ) {
        match e {
            ExpressionElement::Operand(n) => {
                self.stack.push(Fraction(n, 1));
                expr.push(e);
                self.traverse(expr, (i, j), results);
                expr.pop();
                self.stack.pop();
            }
            ExpressionElement::Operator(op) => {
                if let (Some(n0), Some(n1)) = (self.stack.pop(), self.stack.pop()) {
                    if let Some(n) = op.apply(&n1, &n0) {
                        self.stack.push(n);
                        expr.push(e);
                        self.traverse(expr, (i, j), results);
                        expr.pop();
                        self.stack.pop();
                    }
                    self.stack.push(n1);
                    self.stack.push(n0);
                }
            }
        }
    }
}

これによって self.evaluate() に到達した時点で計算結果は出ているので、その値だけを確認すれば良い、ということになる。また、途中でゼロ除算が生じた場合にその先の探索をしなくなるので無駄を省く効果もありそうだ。

これで、前述の 3370 件の列挙が 約30秒に短縮された。約4倍の高速化。

探索結果のキャッシュ

また、この内部stackは同じ状態を経過することも多い。例えば

  • 10*(9-8)+...
  • 10/(9-8)+...

8 までの計算結果はどちらも変わらない。どちらか一方で全探索して1件も見つからなかったなら、もう片方でも1件も見つからないことが分かる。

ので、探索して条件に合致する式が見つかったか否かを返すようにし、それが false だった場合には同様の状態からの探索をスキップするように変更。

use std::collections::HashSet;

type Key = (Vec<Fraction>, (usize, usize));

pub struct FastSearcher {
    nums: Vec<i32>,
    target: i32,
    stack: Vec<Fraction>,
    seen: HashSet<Key>,
}

impl FastSearcher {
    pub fn new(nums: Vec<i32>, target: i32) -> Self {
        let stack = Vec::with_capacity(nums.len());
        Self {
            nums,
            target,
            stack,
            seen: HashSet::new(),
        }
    }
}

impl Rpn for FastSearcher {
    fn traverse(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
    ) -> bool {
        if i == j + 1 && self.get(i).is_none() {
            let found = self.evaluate(expr);
            if found {
                results.push(expr.clone());
            }
            return found;
        }
        let key = (self.stack.clone(), (i, j));
        if self.seen.contains(&key) {
            return false;
        }
        let mut ret = false;
        if let Some(&n) = self.get(i) {
            ret |= self.backtrack(expr, (i + 1, j), results, ExpressionElement::Operand(n));
        }
        if i > j + 1 {
            for op in &Op::ALL {
                ret |= self.backtrack(expr, (i, j + 1), results, ExpressionElement::Operator(*op));
            }
        }
        if !ret {
            self.seen.insert(key);
        }
        ret
    }
}

これで、約12秒くらいに短縮された。さらに2倍以上速くなった。

実際にはこの key に使っている self.stack は約分されていない計算結果なので、同じ状態を経由しているはずなのに 値が異なっているという扱いになってしまうケースがある。 gcd を使って既約分数を格納するようにしてみる。

impl FastSearcher {
    fn gcd(a: i32, b: i32) -> i32 {
        if b == 0 {
            a
        } else {
            Self::gcd(b, a % b)
        }
    }
}

impl Rpn for FastSearcher {
    ...

    fn backtrack(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
        e: ExpressionElement,
    ) -> bool {
        let mut ret = false;
        match e {
            ExpressionElement::Operand(n) => {...}
            ExpressionElement::Operator(op) => {
                if let (Some(n0), Some(n1)) = (self.stack.pop(), self.stack.pop()) {
                    if let Some(n) = op.apply(&n1, &n0) {
                        let gcd = Self::gcd(n.0, n.1);
                        self.stack.push(Fraction(n.0 / gcd, n.1 / gcd));

                        ...
                    }
                    self.stack.push(n1);
                    self.stack.push(n0);
                }
            }
        }
        ret
    }
}

これで、3370 件の列挙が 約7.5秒程度まで短縮された。 このキャッシュ手法は探索目標の値などによって効率が変わってくるので一概には言えないかもしれないが、少なくともこの例においては数倍の高速化が実現できた。

割り算を考慮した条件つき生成

ここまでは計算結果が 2023 になるものを全列挙することだけを考えていたが、件のパズル問題では「98の間には既に÷が入っています」という制約がある。

前述した 3370 件から条件に合致するものを抽出しても良いが、そもそも探索する時点で条件に合致しない式になるものを除外していくほうが効率が良い。

これは自分ではまったく思いつかなくて解説記事を読んで目から鱗だったのだけど、 8の数字が出た後、数字と演算子の数が一致していた場合」の演算子を割り算に固定 することによって実現できる。

解説記事を参考に、専用のSearcherを実装してみる。

struct Div8Searcher {
    nums: Vec<i32>,
    target: i32,
    stack: Vec<Fraction>,
    eight_depth: i32,
}

impl Div8Searcher {
    pub fn new(nums: Vec<i32>, target: i32) -> Self {
        let stack = Vec::with_capacity(nums.len());
        Self {
            nums,
            target,
            stack,
            eight_depth: -1,
        }
    }
}

impl Rpn for Div8Searcher {
    fn traverse(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
    ) -> bool {

        ...

        if i > j + 1 {
            for op in &Op::ALL {
                if self.eight_depth == 0 && op != &Op::Div {
                    continue;
                }
                ret |= self.backtrack(expr, (i, j + 1), results, ExpressionElement::Operator(*op));
            }
        }
        ret
    }
    fn backtrack(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
        e: ExpressionElement,
    ) -> bool {
        let mut ret = false;
        match e {
            ExpressionElement::Operand(n) => {
                let orig = self.eight_depth;
                if n == 8 {
                    self.eight_depth = 0;
                } else if self.eight_depth >= 0 {
                    self.eight_depth += 1;
                }
                self.stack.push(Fraction(n, 1));
                expr.push(e);
                ret |= self.traverse(expr, (i, j), results);
                expr.pop();
                self.stack.pop();
                self.eight_depth = orig;
            }
            ExpressionElement::Operator(op) => {
                if let (Some(n0), Some(n1)) = (self.stack.pop(), self.stack.pop()) {
                    if let Some(n) = op.apply(&n1, &n0) {
                        self.eight_depth -= 1;
                        self.stack.push(n);
                        expr.push(e);
                        ret |= self.traverse(expr, (i, j), results);
                        expr.pop();
                        self.stack.pop();
                        self.eight_depth += 1;
                    }
                    self.stack.push(n1);
                    self.stack.push(n0);
                }
            }
        }
        ret
    }
}

eight_depth の値を保持し、8 が出現したら 0 にセットして OperandOperator かによって値を増減。値が 0 のときには次に push する OpeatorOp::Div だけに制限される。

前述したような探索結果のキャッシュも使いたいところだが、新たな状態としてこの self.eight_depth もキーに含める必要があり、その割にはそれほど効果も無さそうなのでここでは省いた。

これによって 「98の間には必ず/が入る」式だけが探索されるようになる。

条件に合致する 530 件が、約4秒強で列挙されるようになった。全列挙してから抽出するよりも明らかに速いことが分かる。

...

(10*(9/((8-7)/((6*5)/(4/3)))))-(2/1)
(10*(9/((8-7)/(6*((5/4)*3)))))-(2*1)
(10*(9/((8-7)/(6*((5/4)*3)))))-(2/1)
(10*(9/((8-7)/(6*(5/(4/3))))))-(2*1)
(10*(9/((8-7)/(6*(5/(4/3))))))-(2/1)
Completed 530 results in 4.187013041s

中置記法、正規化

ここまでで探索して収集された resultsVec<Vec<ExpressionElement>> であり、これを中置記法に書き換える操作もまたstackを使って実現していくことになる。

impl Display for Op {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
        f.write_char(match self {
            Self::Add => '+',
            Self::Sub => '-',
            Self::Mul => '*',
            Self::Div => '/',
        })
    }
}

fn solve(searcher: &mut impl Rpn) {
    let mut results = Vec::new();
    searcher.traverse(&mut Vec::new(), (0, 0), &mut results);
    for result in results {
        let mut stack = Vec::new();
        for e in &result {
            match e {
                ExpressionElement::Operand(n) => stack.push((n.to_string(), None)),
                ExpressionElement::Operator(op) => {
                    if let (Some((mut s0, o0)), Some((mut s1, o1))) = (stack.pop(), stack.pop()) {
                        if o1.is_some() {
                            s1 = format!("({s1})");
                        }
                        if o0.is_some() {
                            s0 = format!("({s0})");
                        }
                        stack.push((format!("{s1}{op}{s0}"), Some(*op)));
                    }
                }
            }
        }
        println!("{}", stack[0].0);
    }
}

連結していくための String と、どの Operator で得られたものか(もしくは Operand か)、を保持しながら処理していく。簡単にすべて括弧をつけてしまうのなら上述のように書ける。

不要な括弧を除去して出力する場合は、左operandの場合と右operandの場合で処理が変わる(引き算・割り算は右operandに注意する必要がある)。matches! で判定してやってみる。

        ...

        ExpressionElement::Operator(op) => {
            if let (Some((mut s0, o0)), Some((mut s1, o1))) = (stack.pop(), stack.pop()) {
                if matches!((op, o1), (Op::Mul | Op::Div, Some(Op::Add | Op::Sub))) {
                    s1 = format!("({s1})");
                }
                if matches!(
                    (op, o0),
                    (Op::Sub | Op::Mul, Some(Op::Add | Op::Sub)) | (Op::Div, Some(_))
                ) {
                    s0 = format!("({s0})");
                }
                stack.push((format!("{s1}{op}{s0}"), Some(*op)));
            }
        }

これで、得られた結果を HashSet などで重複排除することで、解説記事と同様の 81 件を約4秒強で列挙することができた。

((10+(9/8+7)*6*5)*4-3)*2-1
(10*9/((8-7)/(6*5))/(4/3)-2)*1
(10*9/((8-7)/(6*5))/(4/3)-2)/1
(10*9/((8-7)/(6*5))/4*3-2)*1
(10*9/((8-7)/(6*5))/4*3-2)/1
(10*9/((8-7)/(6*5)*4)*3-2)*1
(10*9/((8-7)/(6*5)*4)*3-2)/1
(10*9/((8-7)/(6*5)*4/3)-2)*1
(10*9/((8-7)/(6*5)*4/3)-2)/1
(10*9/((8-7)/(6*5/(4/3)))-2)*1
(10*9/((8-7)/(6*5/(4/3)))-2)/1
(10*9/((8-7)/(6*5/4))*3-2)*1
(10*9/((8-7)/(6*5/4))*3-2)/1
(10*9/((8-7)/(6*5/4)/3)-2)*1
(10*9/((8-7)/(6*5/4)/3)-2)/1
(10*9/((8-7)/(6*5/4*3))-2)*1
(10*9/((8-7)/(6*5/4*3))-2)/1
(10*9/((8-7)/6)*5/(4/3)-2)*1
(10*9/((8-7)/6)*5/(4/3)-2)/1
(10*9/((8-7)/6)*5/4*3-2)*1
(10*9/((8-7)/6)*5/4*3-2)/1
(10*9/((8-7)/6/(5/(4/3)))-2)*1
(10*9/((8-7)/6/(5/(4/3)))-2)/1
(10*9/((8-7)/6/(5/4))*3-2)*1
(10*9/((8-7)/6/(5/4))*3-2)/1
(10*9/((8-7)/6/(5/4)/3)-2)*1
(10*9/((8-7)/6/(5/4)/3)-2)/1
(10*9/((8-7)/6/(5/4*3))-2)*1
(10*9/((8-7)/6/(5/4*3))-2)/1
(10*9/((8-7)/6/5)/(4/3)-2)*1
(10*9/((8-7)/6/5)/(4/3)-2)/1
(10*9/((8-7)/6/5)/4*3-2)*1
(10*9/((8-7)/6/5)/4*3-2)/1
(10*9/((8-7)/6/5*4)*3-2)*1
(10*9/((8-7)/6/5*4)*3-2)/1
(10*9/((8-7)/6/5*4/3)-2)*1
(10*9/((8-7)/6/5*4/3)-2)/1
(10*9/(8-7)*6*5/(4/3)-2)*1
(10*9/(8-7)*6*5/(4/3)-2)/1
(10*9/(8-7)*6*5/4*3-2)*1
(10*9/(8-7)*6*5/4*3-2)/1
10*9/((8-7)/(6*5))/(4/3)-2*1
10*9/((8-7)/(6*5))/(4/3)-2/1
10*9/((8-7)/(6*5))/4*3-2*1
10*9/((8-7)/(6*5))/4*3-2/1
10*9/((8-7)/(6*5)*4)*3-2*1
10*9/((8-7)/(6*5)*4)*3-2/1
10*9/((8-7)/(6*5)*4/3)-2*1
10*9/((8-7)/(6*5)*4/3)-2/1
10*9/((8-7)/(6*5/(4/3)))-2*1
10*9/((8-7)/(6*5/(4/3)))-2/1
10*9/((8-7)/(6*5/4))*3-2*1
10*9/((8-7)/(6*5/4))*3-2/1
10*9/((8-7)/(6*5/4)/3)-2*1
10*9/((8-7)/(6*5/4)/3)-2/1
10*9/((8-7)/(6*5/4*3))-2*1
10*9/((8-7)/(6*5/4*3))-2/1
10*9/((8-7)/6)*5/(4/3)-2*1
10*9/((8-7)/6)*5/(4/3)-2/1
10*9/((8-7)/6)*5/4*3-2*1
10*9/((8-7)/6)*5/4*3-2/1
10*9/((8-7)/6/(5/(4/3)))-2*1
10*9/((8-7)/6/(5/(4/3)))-2/1
10*9/((8-7)/6/(5/4))*3-2*1
10*9/((8-7)/6/(5/4))*3-2/1
10*9/((8-7)/6/(5/4)/3)-2*1
10*9/((8-7)/6/(5/4)/3)-2/1
10*9/((8-7)/6/(5/4*3))-2*1
10*9/((8-7)/6/(5/4*3))-2/1
10*9/((8-7)/6/5)/(4/3)-2*1
10*9/((8-7)/6/5)/(4/3)-2/1
10*9/((8-7)/6/5)/4*3-2*1
10*9/((8-7)/6/5)/4*3-2/1
10*9/((8-7)/6/5*4)*3-2*1
10*9/((8-7)/6/5*4)*3-2/1
10*9/((8-7)/6/5*4/3)-2*1
10*9/((8-7)/6/5*4/3)-2/1
10*9/(8-7)*6*5/(4/3)-2*1
10*9/(8-7)*6*5/(4/3)-2/1
10*9/(8-7)*6*5/4*3-2*1
10*9/(8-7)*6*5/4*3-2/1
Completed 81 results in 4.229359166s

CLIアプリケーション化

ここまで幾つかの実装ができたので、折角なので引数でパラメータを変えながら全列挙できる CLI アプリケーションを作成してみた。

$ ./puzzle2023 --help
Solvers for https://twitter.com/tkihira/status/1609313732034965506

Usage: puzzle2023 [OPTIONS]

Options:
  -t, --target <TARGET>      [default: 2023]
  -m, --max-num <MAX_NUM>    [default: 10]
  -r, --rev <REV>            [default: true] [possible values: true, false]
  -s, --searcher <SEARCHER>  [default: fast] [possible values: default, fast, div8]
  -n, --normalize
  -v, --verbose
  -h, --help                 Print help information
  -V, --version              Print version information

全パターン列挙

$ ./puzzle2023 -v

...

10-(9-(((8*(7*6))+(5-4))*(3*(2*1))))
10-(9-(((8*(7*6))+(5-4))*(3*(2/1))))
10-(9-(((8*(7*6))+(5-4))*(3+(2+1))))
10-(9-(8*(((7*(6+(5/4)))*(3+2))-1)))
10-(9-(8*((7*((6+(5/4))*(3+2)))-1)))
Completed 3370 results in 7.338064208s

9と8の間の割り算に限定

$ ./puzzle2023 -v --searcher div8

...

(10*(9/((8-7)/(6*((5/4)*3)))))-(2/1)
(10*(9/((8-7)/(6*(5/(4/3))))))-(2*1)
(10*(9/((8-7)/(6*(5/(4/3))))))-(2/1)
Completed 530 results in 4.219031208s

正規化

$ ./puzzle2023 -v --searcher div8 --normalize

...

10*9/(8-7)*6*5/(4/3)-2/1
10*9/(8-7)*6*5/4*3-2*1
10*9/(8-7)*6*5/4*3-2/1
Completed 81 results in 4.239637041s

計算結果の確認

$ ./puzzle2023 --searcher div8 --normalize | python3 -c 'for s in open(0): print(f"{eval(s):.1f}")' | sort | uniq -c
  81 2023.0

使うのを 1 から 8 の昇順にして 2024 を作るようにしてみる

$ ./puzzle2023 -m8 -rfalse -t2024 -nv
(((1+2)*3*4+5)*6+7)*8
(1*2*(3+4*5*6)+7)*8
(1+(2+3*4)*(5+6+7))*8
(1+(2+3-(4-5))*6*7)*8
(1+(2+3-4+5)*6*7)*8
(1+2*(3+4)*(5+6+7))*8
(1+2*(3+4+5+6)*7)*8
(1+2*3+4)*(5*6-7)*8
(1+2/(3/((4+5)*6))*7)*8
(1+2/(3/((4+5)*6)/7))*8
(1+2/(3/((4+5)*6*7)))*8
(1+2/(3/(4+5))*6*7)*8
(1+2/(3/(4+5)/(6*7)))*8
(1+2/(3/(4+5)/6)*7)*8
(1+2/(3/(4+5)/6/7))*8
(1+2/3*(4+5)*6*7)*8
(1-(2-3*4))*(5*6-7)*8
(1-2*(3*4-5*6)*7)*8
(1-2*3*(4-5)*6*7)*8
(1-2*3/((4-5)/(6*7)))*8
(1-2*3/((4-5)/6)*7)*8
(1-2*3/((4-5)/6/7))*8
(1-2*3/(4-5)*6*7)*8
(1-2+3*4)*(5*6-7)*8
1*(2*(3+4*5*6)+7)*8
Completed 25 results in 131.605458ms

「8の前を割り算に限定」にしているのも、コマンドラインオプションで違う値で指定できるようにしたかったけど、複雑になりすぎる感じがしたのでそれは断念…

改善?

tkihiraさんの最適化ではstackを push / pop するのではなく固定長配列を用意して見るべきindexだけをズラしていく、といった工夫をしていた。真似して実装してみたが本実装においてはそこはあまり効果なさそうだった。

あとはこうして再帰的に探索した結果を一度 results という Vec に格納してしまっているのが実は非効率なのでは、と思ったが… Iterator として使いたかったが そうするためには探索の途中の状態を保持しておいて next() が呼ばれるたびに続きを探索していく、という処理にする必要があり、そのあたりの実装がうまくできなかった。何か良い方法あるのだろうか…

Repository

github.com

Rubyでバイナリデータに対するrindex検索の挙動でハマったので調べたことメモ

自分の手元の環境でこんなことが起きた。

$ ruby -v
ruby 3.1.2p20 (2022-04-12 revision 4491bb740a) [arm64-darwin21]
$ irb
irb(main):001:0> "\x01\x80\x00\x00".index("\x01")
=> 0
irb(main):002:0> "\x01\x80\x00\x00".rindex("\x01")
=> 1

\x010 番目にしかないのだから、 .index でも .rindex でも 0 が返ってくるはずではないの??

先に結論

バイナリデータを扱うときには必ずEncodingを ASCII-8BIT に指定しておくこと!

きっかけ

roo というgem (記事時点で 2.9.0)を使って、Excelファイルを開こうとした。

require 'roo'

p Roo::Excelx.new("hoge.xlsx")

これは問題ないが、 #initialize の引数は file_or_strem ということでファイルを open したものを渡しても良いはず、と

require 'roo'

File.open("hoge.xlsx") do |f|
  p Roo::Excelx.new(f)
end

ということをすると以下のような謎のエラーを吐いて落ちる。

/Users/sugyan/.rbenv/versions/3.1.2/lib/ruby/gems/3.1.0/gems/rubyzip-2.3.2/lib/zip/entry.rb:365:in `check_c_dir_entry_static_header_length': undefined method `bytesize' for nil:NilClass (NoMethodError)

      return if buf.bytesize == ::Zip::CDIR_ENTRY_STATIC_HEADER_LENGTH
                   ^^^^^^^^^

で困っていたので他の人に聞いてみたところ「自分のところではそんな問題は起きていない」と言われ。 試しにDockerでRuby環境作って実行してみると、確かに問題なく動く…何故? と調べはじめた。

String#rindex の謎挙動

自分の環境と問題なく動く環境とで比較しながら見てみると、roo が使っている rubyzip (記事時点で 2.3.2) の中でファイル内容を読み取った結果の値が違っていた。

module Zip
  class CentralDirectory
    include Enumerable

    END_OF_CDS             = 0x06054b50

    ...

    def get_e_o_c_d(buf) #:nodoc:
      sig_index = buf.rindex([END_OF_CDS].pack('V'))
      ...

buf の中から 0x06054b50 をpackしたもの つまり "PK\x05\x06" のデータを末尾から探すために .rindex() を呼んでいるのだが、手元の環境と 問題なく動く環境ではその値が 1 ズレている。。

何故?と思って色々試しているうちに冒頭の例のようなものに辿り着いた。

もう少し深く追う

少なくとも \x80 などを含まないASCIIだけのものであればおかしな挙動にはならない。

irb(main):001:0> "\x01\x7F\x00\x00".rindex("\x01")
=> 0

また、これが増えるとズレもどんどん大きくなる。

irb(main):001:0> "\x01\x80".rindex("\x01")
=> 0
irb(main):002:0> "\x01\x80\x80".rindex("\x01")
=> 1
irb(main):003:0> "\x01\x80\x80\x80".rindex("\x01")
=> 2
irb(main):004:0> "\x01\x80\x80\x80\x80".rindex("\x01")
=> 3

逆から走査する際に異常な文字が検出された際に読み飛ばし、その結果の走査数を文字列長から引いたものを返したことにより値がおかしくなる、というのが予想される。

ソースを読んでみる。 rstring 関連はこのへんのようだ。

#ifdef HAVE_MEMRCHR によって実装が分かれている。 memrchr というのがglibcに入っているもので、自分のmacOS環境では使えなくて Docker内のLinux環境などでは使える、これによって環境によって挙動が変わったりする、のかもしれない。

で、使えない環境の方での実装は。 対象が見つかるまで while loop の中で rb_enc_prev_char で前の文字を走査している、という感じのようだ。encodingの影響を受けそう…?

    while (s) {
        if (memcmp(s, t, slen) == 0) {
            return pos;
        }
        if (pos == 0) break;
        pos--;
        s = rb_enc_prev_char(sbeg, s, e, enc);
    }

Encodingと実行環境

Encodingが関係しているようだったので色々調べてみた。

Rubyではバイナリデータ(バイト列)も String で扱う。

バイナリの取扱い

Ruby の String は、文字の列を扱うためだけでなく、バイトの列を扱うためにも使われます。しかし、Ruby M17N には直接にバイナリを表すエンコーディングは存在しません。このため、バイナリを String で扱う際には、ASCII 互換オクテット列を意味する ASCII-8BIT を用います。これにより、ASCII 互換であるこの String は 7bit クリーンな文字列と比較・結合が可能となります。

https://docs.ruby-lang.org/ja/latest/doc/spec=2fm17n.html

ということで、バイナリの場合は本来は ASCII-8BIT のEncodingを持つ文字列として検索しなければならないのに、手元の環境では UTF-8 のままで検索しようとしていた、というところに「使い方の問題」があったようだ。

irb(main):001:0> "\x01\x80\x00\x00".encoding
=> #<Encoding:UTF-8>

Encodingが ASCII-8BIT になっていれば、rindex でも正しい値を得ることができそうだ。 全然知らなかったが String#bASCII-8BIT の複製を得ることもできるらしい。

https://docs.ruby-lang.org/ja/latest/method/String/i/b.html

irb(main):001:0> "\x01\x80\x00\x00".rindex("\x01")
=> 1
irb(main):002:0> "\x01\x80\x00\x00".force_encoding(Encoding::ASCII_8BIT).rindex("\x01")
=> 0
irb(main):003:0> "\x01\x80\x00\x00".b.rindex("\x01")
=> 0

なるほど〜。

ではこの Encoding は何で決まるのか?というのも上記リンクに書いてある。

リテラルエンコーディング

文字列リテラル正規表現リテラルそしてシンボルリテラルから生成されるオブジェクトのエンコーディングスクリプトエンコーディングになります。

またスクリプトエンコーディングが US-ASCII である場合、7bit クリーンではないバックスラッシュ記法で表記されたリテラルエンコーディングは ASCII-8BIT になります。

さらに Unicode エスケープ (\uXXXX) を含む場合、リテラルエンコーディングUTF-8 になります。

複雑。。。

とにかく通常はスクリプトエンコーディングがまず大事のようだ。

スクリプトエンコーディング

スクリプトエンコーディングとは Ruby スクリプトを書くのに使われているエンコーディングです。スクリプトエンコーディングは マジックコメントを用いて指定します。スクリプトエンコーディングには ASCII 互換エンコーディングを用いることができます。 ASCII 非互換のエンコーディングや、ダミーエンコーディングは用いることができません。

現在のスクリプトエンコーディング__ENCODING__ により取得することができます。

さらに、magic commentによってこのスクリプトエンコーディングを決定でき、それが無い場合の挙動も書かれている。

マジックコメントが指定されなかった場合、コマンド引数 -K, RUBYOPT およびファイルの shebang からスクリプトエンコーディングは以下のように決定されます。左が優先です。

magic comment(最優先) > -K > RUBYOPTの-K > shebang

上のどれもが指定されていない場合、通常のスクリプトなら UTF-8、-e や stdin から実行されたものなら locale がスクリプトエンコーディングになります。 -K オプションが複数指定されていた場合は、後のものが優先されます。

通常のスクリプト-e かどうか、でも変わったりするのね…。そして特に指定ない場合は最終的にはlocaleが使われる。 ので、 LC_CTYPE などでも挙動が変わり得るようだ。

$ ruby -e 's="\x01\x80\x00\x00"; p __ENCODING__, s.encoding, s.rindex("\x01")'
#<Encoding:UTF-8>
#<Encoding:UTF-8>
1
$ ruby -Kn -e 's="\x01\x80\x00\x00"; p __ENCODING__, s.encoding, s.rindex("\x01")'
#<Encoding:ASCII-8BIT>
#<Encoding:ASCII-8BIT>
0
$ RUBYOPT="-Kn" ruby -e 's="\x01\x80\x00\x00"; p __ENCODING__, s.encoding, s.rindex("\x01")'
#<Encoding:ASCII-8BIT>
#<Encoding:ASCII-8BIT>
0
$ LC_CTYPE=C ruby -e 's="\x01\x80\x00\x00"; p __ENCODING__, s.encoding, s.rindex("\x01")'
#<Encoding:US-ASCII>
#<Encoding:ASCII-8BIT>
0

という具合に、何もしないと UTF-8 文字列として扱ってしまうが、オプションや環境変数によって リテラルの Encoding を変えることもできる。

こうやって正しく ASCII-8BIT のEncodingを持つ文字列として検索すれば、 String#rindex の値がズレるということはなさそうだ。

つまり再現条件は

手元の環境で String#rindex の値がズレたのは2つの要因があって

  1. memrchr が使えない環境でビルドしたRubyで実行していて
  2. ASCII-8BIT でないEncodingの文字列に対して rindex をかけていた

ということになる。2つ揃っていないと起きないものと思われる。

Rooの問題

前述の、Rooでエラーが起きる件について考える。

そもそもEncodingが不明で渡ってくるものに対して安易に rindex を使うものではない、ということで rubyzip 側に落ち度がありそうではある。が 現在の master branch では リリースされている 2.3.2 とは大きく変わっていて、もう同じことは起きないのかもしれない。詳しくは追っていないので分からないが、今回のと関連するissueがあった。

既にCloseされているが、今回調べた Zip::CentralDirectory は直接使用するものではなく Zip::File を使ってください、ということのようだ。

つまりこの場合、 Roo 側の使い方が悪い、ということになりそう。 で、Rooの方も調べてみると ちゃんとそれに対応しようとしていると思われる修正があった。

これが取り込まれれば手元の環境でも問題なく動くようになるかな…?

もしくは現時点でも、使う側が渡す引数のEncodingを明示的に指定してやることで一応回避できそうではある。

require 'roo'

File.open("hoge.xlsx", "r:ASCII-8BIT") do |f|
  p Roo::Excelx.new(f)
end

Rubyのバグではないの?

しかし Encodingを正しく指定していなかったために起きているとしても、 String#indexString#rindex も検索した対象の開始indexを返すものであるはずなのだから、それがズレるのはおかしいのではないのか…?という気はする。

さらに memrchr が使えるか否かのビルド環境によってだけで結果が変わる(ちゃんと確かめてないけど おそらくそう…)、というのも気持ち悪い。

かなり昔から現在の実装になっているようだし 今からでは挙動を変えづらそうではある。 仕様といってしまえばそれまでだけど…。 このあたりはRuby開発陣の方々の見解も聞いてみたいところではあります。

3.2

ちなみに Ruby 3.2 からは String#byteindexString#byterindex が追加されるそうです。今回のようなバイナリデータに対する検索にピッタリ使えそうですね。

spherical linear interpolation(slerp)によるlatent spaceでのnoise補間

memo.sugyan.com

の記事を書いてから、先行事例の調査が足りていなかったなと反省。 Latent Seed の Gaussian noise 間での morphing はあんまりやっている人いないんじゃないかな、と書いたけど、検索してみると普通に居た。

Stable Diffusion が公開されるよりも前の話だった…。

そしてこの morphing を作るためのコードも gist で公開されている

読んでみると、2 つの noise の間の値を得る方法として slerp という関数が使われている。

        for i, t in enumerate(np.linspace(0, 1, num_steps)):
            init = slerp(float(t), init1, init2)

https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355#file-stablediffusionwalk-py-L179-L180

この定義は以下のようになっている。

def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
    """ helper function to spherically interpolate two arrays v1 v2 """

    if not isinstance(v0, np.ndarray):
        inputs_are_torch = True
        input_device = v0.device
        v0 = v0.cpu().numpy()
        v1 = v1.cpu().numpy()

    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
    if np.abs(dot) > DOT_THRESHOLD:
        v2 = (1 - t) * v0 + t * v1
    else:
        theta_0 = np.arccos(dot)
        sin_theta_0 = np.sin(theta_0)
        theta_t = theta_0 * t
        sin_theta_t = np.sin(theta_t)
        s0 = np.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1

    if inputs_are_torch:
        v2 = torch.from_numpy(v2).to(input_device)

    return v2

https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355#file-stablediffusionwalk-py-L101-L125

不勉強で知らなかったが、これは "spherical linear interpolation"、日本語では「球面線形補間」と呼ばれるのかな、という手法で、3DCGの世界などではクォータニオン(Quaternion)の補間としてよく使われていたりする、ようだ。

en.wikipedia.org

書かれている通り、元々はクォータニオンの補間の文脈で紹介されていたが、次元数に関係なく適用することができるということらしい。 空間上での原点から2つの点へのベクトルを考え、その2つがなす角  \Omega をドット積とノルムを使って求めることができる。 始点から終点まで、 0 から  \Omega へと角度を線形に変化させながら2点を結ぶ円弧上を移動させていく、という感じだろうか。

 \displaystyle Slerp(p_0,p_1;t) = \frac{\sin{[(1-t)\Omega]}}{\sin{\Omega}}p_0 + \frac{\sin{[t\Omega]}}{\sin{\Omega}}p_1

画像生成モデルでの latent space における補間としては 2016年くらいには提案されていたようだ。 そういえば GAN で遊んでいたときにちょっとそういう話題を聞いたことがあったような気もする…。

arxiv.org

しかしこれによって Gaussian noise として分散を保ったまま変化させられる、ことになるのか…? 数学的なことはちょっとよく分からない。 まぁ球面上を角度だけ変えて動いていると考えればノルムが変化しないから大丈夫なのかな??くらいの感覚でしかない…。

実際にこれを使って補間していったときにどのような変化になるか、前回の記事で書いていた sqrt を使うものと上述の slerp を使うもので ランダムな2点を繋ぐ軌跡を見てみた。2次元、3次元 空間上でそれぞれプロットすると以下のような違いがあった。

間隔が異なるだけでなく、明らかに軌道も異なるところを通るようになるようだ。

実装

引用したコードでも良いが、自分でもPyTorch前提で Slerp を書いてみる。 NumPy への変換はしなくても計算できそうだった。

def myslerp(
    t: float, v0: torch.Tensor, v1: torch.Tensor, DOT_THRESHOLD: float = 0.9995
) -> torch.Tensor:
    u0 = v0 / v0.norm()
    u1 = v1 / v1.norm()
    dot = (u0 * u1).sum()
    if dot.abs() > DOT_THRESHOLD:
        return (1 - t) * v0 + t * v1
    omega = dot.acos()
    return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()

前述の数式の定義通りに実装するだけではあるが、やはり特別な場合に注意する必要はある。2点の原点からの方向がまったく同じか、もしくは正反対の方向のとき、理論上では単位ベクトルのドット積は 1-1 になる。 が、float値の多次元配列で計算してみると 多少の誤差が生じることがある。

for _ in range(10):
    v0 = np.random.normal([1, 4, 64, 64])
    v1 = 1.0 * v0
    print((v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))).sum())
0.9999999999999997
1.0
1.0
0.9999999999999998
0.9999999999999999
1.0
1.0
0.9999999999999998
1.0
1.0000000000000002
for _ in range(10):
    v0 = np.random.normal([1, 4, 64, 64])
    v1 = -1.0 * v0
    print((v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))).sum())
-1.0000000000000004
-1.0
-0.9999999999999998
-1.0
-1.0
-1.0
-0.9999999999999999
-1.0
-1.0000000000000002
-1.0

slerp では このドット積の値から arccos() で角度を求める必要があるが、例えば np.arccos() は 引数が -1.01.0 の間に収まっていない場合に nan を返すことになってしまう。 また、後でこの角度の sin() で割る処理があるので、ドット積が 1.0 で返せていたとしてもその後の計算で今度は inf が出てきたりする。 ので、こういうケースでは球面ではなく普通に線形で繋いでしまった方が良いだろうね、ということで DOT_THRESHOLD というものが使われて処理を分岐させているようだ。


※追記

指摘いただいたが、まったく同じ向きのときはともかく 正反対の方向の場合は線形補間が正しいとは限らない。 例えば2次元において (1, 0)(-1, 0) を補間する場合は原点を通らず (0, 1) もしくは (0, -1) を通る円軌道になった方が良かったり 3次元の球でいうと北極と南極を繋ぐのは赤道上を通る球面軌道であって欲しい、など。 これは2点のどちらかを適度にブレさせれば算出できる場合もあるかもしれないし、上記のように明らかに通るべき点があればまず両点からそこにまず補間してやるなど、色々な考え方がありそう。これもまた用途や場合によると思われる。 今回の Stable Diffusion で使っている乱数においては正反対向きはまず有り得ないものと思ってよさそうだ。


この小数点誤差がどれくらいになるのかは、要素数だったり 計算の順番だったり numpy で計算するか torch で計算するか などによっても変わってくるようだ。 どれくらいの値になったら線形と見做すか、の閾値(上述では 0.9995)は場合によると思われる。 少なくとも Stable Diffusion で使う [4, 64, 64] の shape で torch.randn() で生成される乱数同士の場合は、通常はドット積は 0.1 にすら届くことがまずないので、極端な話ではそれくらい低くても問題なさそうではある。 morphing の2点間で同一の noise を使う可能性がない、という前提であれば分岐なくしても問題ないくらい。

比較

で、実際に slerp を使うように変更すると morphing はどう変化するのか。 以前に載せた anime girl での morphing の slerp 版を作ってみた。

並べて比較

ちょっと変化のタイミングが変わったかな…? という程度で 大きな差は感じられない。

考察

結局のところ、Stable Diffusion においては Gaussian noise である限り何らかの画像を生成できるようになっているし、指定の2点の間をどう遷移しようと違いはないのかもしれない。 ただ slerp は線形に角度変化していくという観点でも、その遷移の移動量としてはより安定した変化が期待できるかな、という気はする。 むしろ前回の記事の手法が適当に思い付きでやった割にはそこそこ近いものが出来ていたのすごかったのでは…? というくらい。

また、今回は Gaussian noise 間でということを考えていたが slerp の手法自体は別にどこでも使えるものだとは思うので知っておいて損はないはず。また StyleGAN とかで遊ぶことがあったらそこでも使ってみても良さそう。

あと、prompt の embedding 結果に対する morphing でも slerp を使うことはできるはず、だが どうなんだろう。実験してみていないけど、A から B を遷移するのに まったく無関係な C を通過することになってしまったり 余計な変化が増えてしまうだけのような気もするので、こちらは線形に最短距離で遷移して刻み幅だけ調整するくらいが良いのではないかな、と思っている。実際にはやはりどちらでも大して変わらないかもしれないし prompt 次第かもしれない。試行錯誤するときの選択肢として持っておくと良さそうではある。