40歳から始める関数型言語、OCaml

というわけで、今年に入ってからのここ数ヶ月、OCamlを勉強し始めている。

動機

これまで「関数型言語をちゃんと触ったことがない」ということが若干コンプレックスになっていた。40歳になった今、唐突に始めてみる気になったので、やることにした。

Why OCaml

そもそも関数型にどんな言語があるのか、それぞれがどんな特徴であり、どういったところで使われているのか、という情報すらほとんど知らなかったので、「関数型言語やってみたいんですけど、どれがいいんすかね?」と某所で雑に相談したところ、OCamlを挙げてもらったので、素直にそれに従ってみることにした。

そんな適当で良いのか?とも思うけど、とにかくどれであっても触ってみないことには何も分からないし。実際、OCamlを選んで間違いは無かったと今は思う。

学習方法

Real World OCaml

とにかくOCamlの学習にはこれが良いらしい、という話をきいた。

dev.realworldocaml.org

ので、これを最初の数章は頑張って読んでみて、あとは実際にコードを書いてみつつ必要になったら参照して勉強していく、という感じで。 実際とても詳しく丁寧に書かれていて、とても良い本だと思う。

Github Copilot と ChatGPT

これらはもう現代のプログラマにとって欠かせないものになっていると思う。

最近のメジャーな言語と比較してOCamlに関する学習データがどの程度の量と質で、どの程度優れたコード出力能力であるのかは分からないが、少なくとも初心者から始める人間にとっては十分に助けになる。

特にChatGPTは、環境構築からライブラリの使い方などは勿論、「Rustでこう書いてたのはOCamlではどう書くの」とか「こういうのできないの」「こう書いてみたけど他にも書き方あるのかな」みたいな、独り言で呟くようなことでもChatで気軽に書いて質問できて体験が良かった。 めちゃくちゃ親しくて時間を奪うこと気にせずに何でも質問できる友人、が隣に居ればそういう人に頼れるかもしれないが、そういう友達が居ない人間にとってはこうした心理的なハードル低く幾らでも質問できる存在はとても有り難いと思う。

とはいえやはり正確性は保証されなくてChatGPTも平気で実在しないライブラリ関数をでっちあげて使用例を提示してきたりするし、油断せずにちゃんと公式ドキュメントなど探して確認しながら書く必要はある。 このへんは仕方ないと腹を括って、「俺が学習データを提供してやるぜ」くらいの気概でできるだけ良い(と自分が思う)コードを書いてとにかく公開していく。

オンラインジャッジ (競プロ)

簡単なコード片を書いて正しい入出力を得られたか確認できる、また別の人の書いた同じ目的のコードを気軽に参照できる、という点で競技プログラミングの過去問などは良い学習材料になると思う。とりあえず他人と競うことは考えず、オンラインジャッジシステムとして利用させてもらう。

個人的には LeetCode でやりたかったが、LeetCodeにはOCamlの選択肢が無い…。ので、 AtCoder を使うことにした。 AtCoder Problems の「Boot camp for Beginners」 というのがあることを @naoya_ito さんのTweet で知ったので、真似してこれらを解いてみることにした。

とはいえAtCoderOCamlのランタイムや使えるライブラリが少し古くて(2023/04時点)、手元で Core 最新を使って動くコード書けて意気揚々とSubmitしてみたら CE (Compilation Error) 出てガックリすることも。 ここは今後 言語のアップデート で更新されてより使いやすくなることを期待している。

簡単な問題を解くことで基礎的なデータ構造やアルゴリズムの扱いは練習できるが、難易度が上がると当然 数学的な思考力やより高度なアルゴリズムの知識が主眼になってくる。 そこは言語習得とはまた別の話になってくるのでそこまで頑張る必要はない(本当に競プロを頑張るならC++とかRustとか他の言語でやったほうが良いと思う)、ということでほどほどにして止めるつもり。

Advent of Code

より実践的に「プログラミングによる問題解決」をできるようにする練習として、 Advent of Code は非常に良い題材だと思う。 AtCoderのような競技プログラミング的な要素もあるが、それだけではなく

  • 入力データが必ずしも扱いやすいものばかりではない
    • → 様々な形式に対応したparse処理を書く練習になる
  • part1/part2 で同じ入力から異なる解を出す、など出力も様々な型・形式になり得る
    • → interfaceを考える練習になる
  • アルゴリズムだけで解決しないこともある
    • → 力技での膨大な探索や泥臭い実装などが必要になることも

といった点でまた別の実装力を鍛えることができると考えられる。

2022を完走した記事 にも書いたが、2022年の問題は特にバランス良く様々な題材のものがあって言語習得の練習に十分に適している、と感じた。

というわけで、OCamlAdvent of Code 2022 を解くチャレンジをしている。

github.com

まだ18日目までだが、どうにか最後まで頑張りたい…!

その次?

やはり「問題を解く」だけではなく、最終的には「何らかのアプリケーションを自作する」ことができるくらいにはなりたいので、それを考えていく。

自分の場合は「過去に他の言語で作ったことがあるもの」「自分が使いたいもの」を作るのが最もモチベーションが高くなるので、そういうもので考えると、、、今はコンピュータ将棋関連だろうか。 幸い(?) opam を検索してみた限りではビットボード実装や将棋関連のライブラリなどは存在していないようなので、自分で作ってみるのは面白そうだ。関数型言語での将棋プログラムなんて全然需要は無さそうだけどw とりあえずOCamlでの実装がC++やRustと比較してどの程度の速度になるのかは知りたいので perft できるくらいまではやってみたい。

あとは今の流行でいうと Bluesky で使われている AT Protocol の実装をしてみる、とかだろうか。これもわざわざOCamlでやろうとする人はあんまり居なそう…

所感

とりあえず数週間触ってみての感想。

|> は多用することが分かったので自作キーボード(Claw44)のキーマップにマクロとして追加した。

関数型という概念

結局、現状どれくらい理解できているだろうか…。まだまだ理解が浅いとは思う。

基本的に Base ライブラリを使用しているが、 List に対する様々な処理が再帰で表現できるというのを見て驚かされた。というか List というデータ構造がこんなにも使いやすくて色々できるんだなぁ、と少し親しみを持てるようになった。 とりあえず、 List などを使って |> でパイプを繋げてデータを変換していく、という操作は慣れてきた。

最初はどうしても「手続き型の言語だと普通こう書くので、それを関数型で表現すると…」と変換していく感じにはなりがちだったが、だんだん「この関数とこの関数を組み合わせてこういう関数を作って、そしてこの入力にこういう出力を返す関数を作る、」という感じで考えられるようになってきた気がする。

高階関数の部分適用とか使ってみるとめっちゃ便利だな〜と感心したり。

プログラミングHaskell という本を何年か前に買ったものの序盤ですぐに挫折してしまっていたのだけど、OCamlを通じて多少なりとも理解が深まってきたおかげか、今は再チャレンジして楽しく読めている。

OCamlの書き味

思っていた以上にツールチェインなどが整っていて、環境構築からエディタ設定などほとんど問題なくできて良かった。 VSCodeでは OCaml Platform 入れるだけだし、 ocamlformat を使って自動整形もできる。このあたりの「揃っていて欲しい」ものは大抵ちゃんと揃っている。

コンパイラのエラーメッセージは正直あまり親切な感じはなくて、型が合わないのはどこをどうすれば… みたいなのがなかなか詰まりやすかったりした。とはいえ「ちゃんとコンパイルが通ればだいたいちゃんと動く」という感覚がRustに近いものがある、と感じる。

あとは dune がRustでいう cargo と同じような感覚で使えるので、これも便利。

Rust, Python の経験

Rustを書いたことがあると、色んなところで「似ているな」と感じるところがあった。例えば option とかそれに対するパターンマッチとか。 Rustのiteratorをメソッドチェーンで繋いで処理するのと同じような感覚で |> で繋ぐ処理を書けるな、とか。 Rustがまったくの未経験だったらもっとOCamlに戸惑っていたかもしれない。

Pythonはそんなに似てないかな…? Haskellみたいに内包表記があるともっと違う感想になったかな。 ただOCamlの影響で自分が書くPythonコードがちょっと変わってきたような気がする。関数の使い方をより意識するようになったというか、なんというか。上手く言語化できないけど。

関数型言語をやるとプログラミングが上手くなる」という言説を聞いたことがあるが、こういうところで少しは効果がでるのかな…?気のせいかもしれない。

AIとの親和性

ところで Github Copilot にコードの補完をしてもらうときに、関数のシグネチャの情報はわりと重要だと思っていて。 「こういう名前の関数でこの型のこの名前の引数を受け取って、この型の値を返す」という情報が書いてあると当然補完も正確になりやすいだろう、と。

一方でOCaml型推論が強力なので、型はあるもののtype annotationを書く必要が殆どなく、またちょっとした処理は無名関数を書くことが多いし、そうでなくとも let rec loop acc x = ... とか let f x y = ... とか、簡潔な書き方になりがち。 そうすると、補完しようとするAIとしては情報が足りなすぎて難しいよなぁ、ということを思った。

明瞭な関数名や型注釈をいちいち書くのは面倒で人間としては省略できる方が有り難いが、それはそれでAIに助けてもらいながら書く場合には意図や情報を渡せるという点で良いのかもしれない。 今後新しく作られる言語の設計もそういう観点が考慮されるようになってくるんだろうか、とか。

まとめ

コードを書く仕事はAIに奪われていくかもしれないが、プログラミングは楽しいので好きなように書きたいだけ書いたら良い。

YAPC::Kyoto 2023 に参加した #yapcjapan

yapcjapan.org

おそらく YAPC::Tokyo 2019 以来?4年ぶりにオフラインのイベントに参加しました。地元開催ということで日帰りで行けてありがたい…! 最初だけちょろっと家族で参加。

子ゃーんの相手をしてくださった方々、ありがとうございました。

久しぶりすぎて誰と会ってもキョドってしまうくらいだったけど、とにかく色んな人たちとリアルで再会できて嬉しかった!!

そんなに長居はしてなかったのでトークは多くは訊けなかったけど、

id:ar_tama さんの『あの日ハッカーに憧れた自分が、「ハッカーの呪縛」から解き放たれるまで』がとても良かったです。 このトークを聴いた若手のエンジニアがこれを心に刻んで大きく成長し、また10年後のYAPCで良い発表してくれるといいな、と思いました。

(…と思ったら本当にベストトーク賞だったようで。おめでとうございます!)

とにかく楽しい1日でした。 運営の皆様、今回もありがとうございました!!

Advent of Code 2022 を完走した

毎年12月に開催されている Advent of Code に、2019年から参加している。

過去記事:

2022年のAdvent of Codeにも挑戦していて、年が明けてしまったが先日ようやく25日すべての問題に解答して 50 個のスターを集めることができた。

2022年の問題もどれもとても面白かった。

データ構造を扱う問題、ある程度のアルゴリズム知識が求められる問題、ちょっとした数学的な考え方が必要な問題、ひたすら実装が大変な問題、視覚化が面白い問題、などが揃っているのは毎年だが、特に2022年はバランス良くバラエティに富むものが出題されていて取り組み甲斐があったように思う。

工夫のしどころも多くあって、自分では上手く解いたつもりでも Reddit で他の人の解答を見ると目から鱗のものがたくさん発見できたり。

基本的にはヒント無しで自力で正解まで辿り着いたが、day19のpart2だけはどうにもならなくて他の人の解答を見たりしてようやく、という感じになってしまった。悔しい。

全部Rustで解いていて、あとでPythonでも解こうとしている。

github.com

とても楽しかったし勉強になったので、もうちょっと同じ問題に取り組んだ人達でわちゃわちゃと「ここが難しかった」「これはこんな手法があって」「こんな書き方もあって」など議論できると嬉しい…。 それはRedditで英語で頑張れば良いのだろうけど、ともかくもう少し日本語圏の開発者の間でも流行って仲間が増えてくれるといいな、と思っている。

という想いもあり、2020のときにもやったけど 日本語の解説本を書いてみている。

github.com

まだ半分もいってないけど… 少しでも多くのプログラマの人たちに届くといいな

2023パズル をRustで解いてみる

tkihiraさんの問題が面白そうだったので挑戦してみた。

既に解説記事が出ているので解答はこちらをどうぞ。

nmi.jp

結局自分は自力では解けなくて 他の人の解法や上記の解説記事を読んでようやくできた、のだけど… 自分なりに理解して改めてRustで実装してみた。

RPN(逆ポーランド記法)の backtracking

まずは型定義。

#[derive(Clone, Copy, PartialEq, Eq)]
enum Op {
    Add,
    Sub,
    Mul,
    Div,
}

#[derive(Clone)]
enum ExpressionElement {
    Operand(i32),
    Operator(Op),
}

この ExpressionElement を stack に詰んでいって、末尾が Operator だったらさらに Operand を 2つ pop() して、演算結果を push() していくことで計算が実現できる。

常に Operand の方が多い状態でないと Operator は挿入できないので、各々の出現回数をカウントしながら条件を満たしているときだけ操作するようにする。

様々な実装が試せるように Trait を定義。

trait Rpn {
    fn traverse(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
    ) {
        if i == j + 1 && self.get(i).is_none() {
            if self.evaluate(expr) {
                results.push(expr.clone());
            }
            return;
        }
        if let Some(&n) = self.get(i) {
            self.backtrack(expr, (i + 1, j), results, ExpressionElement::Operand(n));
        }
        if i > j + 1 {
            for op in &[Op::Add, Op::Sub, Op::Mul, Op::Div] {
                self.backtrack(expr, (i, j + 1), results, ExpressionElement::Operator(*op));
            }
        }
    }
    fn evaluate(&self, expr: &[ExpressionElement]) -> bool;
    fn get(&self, i: usize) -> Option<&i32>;
    fn backtrack(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
        e: ExpressionElement,
    ) {
        expr.push(e);
        self.traverse(expr, (i, j), results);
        expr.pop();
    }
}

self.get(i)i番目に使う数字を呼び出す。すべて使い切っていて Operatorも適切な回数使われていたら self.evaluate(expr) で計算結果を確認し、条件に合致していた場合のみ その exprresults に格納する。

exprpush / pop して再帰的に self.traverse() を呼ぶ backtracking の部分は別の実装で上書きできるようにあえて独立のmethodで定義。

何も難しいことを考えずにこれを満たす実装を作ると、以下のようになる。

struct Fraction(i32, i32);

struct DefaultSearcher {
    nums: Vec<i32>,
    target: i32,
}

impl DefaultSearcher {
    pub fn new(nums: Vec<i32>, target: i32) -> Self {
        Self { nums, target }
    }
}

impl Rpn for DefaultSearcher {
    fn evaluate(&self, expr: &[ExpressionElement]) -> bool {
        let mut stack = Vec::new();
        for e in expr {
            match e {
                ExpressionElement::Operand(n) => stack.push(Fraction(*n, 1)),
                ExpressionElement::Operator(op) => {
                    if let (Some(n0), Some(n1)) = (stack.pop(), stack.pop()) {
                        if let Some(n) = op.apply(&n1, &n0) {
                            stack.push(n);
                        } else {
                            return false;
                        }
                    }
                }
            }
        }
        stack
            .last()
            .map(|n| n.1 * self.target == n.0)
            .unwrap_or(false)
    }
    fn get(&self, i: usize) -> Option<&i32> {
        self.nums.get(i)
    }
}

impl Op {
    fn apply(&self, lhs: &Fraction, rhs: &Fraction) -> Option<Fraction> {
        match self {
            Self::Add => Some(Fraction(lhs.0 * rhs.1 + rhs.0 * lhs.1, lhs.1 * rhs.1)),
            Self::Sub => Some(Fraction(lhs.0 * rhs.1 - rhs.0 * lhs.1, lhs.1 * rhs.1)),
            Self::Mul => Some(Fraction(lhs.0 * rhs.0, lhs.1 * rhs.1)),
            Self::Div if rhs.0 != 0 => Some(Fraction(lhs.0 * rhs.1, lhs.1 * rhs.0)),
            _ => None,
        }
    }
}

解説記事では浮動小数点でも十分ということでやっていたけど、自分的には整数値だけで処理したいので Fraction(i32, i32) を定義して分数で計算するようにした。約分のことはとりあえず考えなくてよくて、最終的な値の計算結果の確認は 分子が分母の 2023 倍になっているか否か、で判定できる。 各Opの適用結果を Op::apply で計算。ゼロ除算の可能性があるので返り値は Option<Fraction> に。

self.evaluate() 内で stackを用意し、exprを順に見ながら Operand だったら Fraction に変換してpushし、 Operator だったら 2つを pop して演算結果をまた格納。理論上 正しい形式のRPN式であれば最終的にはstackには1要素だけ残るはずなので、その値を確認すれば良い。

これでとりあえず、(98の間の割り算を無視して)計算結果が 2023 になる式の組み合わせを列挙できる。

let nums = (1..=10).rev().collect();
let mut results = Vec::new();
DefaultSearcher::new(nums, 2023).traverse(&mut Vec::new(), (0, 0), &mut results);

手元の環境で実行すると 3370 件の結果が約2分弱で列挙された。

...

10-(9-(((8*(7*6))+(5-4))*(3*(2*1))))
10-(9-(((8*(7*6))+(5-4))*(3*(2/1))))
10-(9-(((8*(7*6))+(5-4))*(3+(2+1))))
10-(9-(8*(((7*(6+(5/4)))*(3+2))-1)))
10-(9-(8*((7*((6+(5/4))*(3+2)))-1)))
Completed 3370 results in 111.933449708s

探索の高速化

計算結果を保持

前述の実装だと、式が完成するたびにstack用意して計算結果を確認して… というのが非常に無駄になる。解説記事によると 1,274,544,128 通り生成されるようなので、ここはもっと計算量を削りたい。

途中までの計算結果は変わらないものなので、それを保持しながら探索した方が効率的になる。 内部に別のstackを持って、 Operator を受け取った時点で演算処理結果を格納しておくようにしておく。

前述の self.evaluate() でやっていた計算を self.backtrack() の中で逐次やっておく、という感じ。

pub struct FastSearcher {
    nums: Vec<i32>,
    target: i32,
    stack: Vec<Fraction>,
}

impl FastSearcher {
    pub fn new(nums: Vec<i32>, target: i32) -> Self {
        let stack = Vec::with_capacity(nums.len());
        Self {
            nums,
            target,
            stack,
        }
    }
}

impl Rpn for FastSearcher {
    fn evaluate(&self, _: &[ExpressionElement]) -> bool {
        self.stack[0].1 * self.target == self.stack[0].0
    }
    fn get(&self, i: usize) -> Option<&i32> {
        self.nums.get(i)
    }
    fn backtrack(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
        e: ExpressionElement,
    ) {
        match e {
            ExpressionElement::Operand(n) => {
                self.stack.push(Fraction(n, 1));
                expr.push(e);
                self.traverse(expr, (i, j), results);
                expr.pop();
                self.stack.pop();
            }
            ExpressionElement::Operator(op) => {
                if let (Some(n0), Some(n1)) = (self.stack.pop(), self.stack.pop()) {
                    if let Some(n) = op.apply(&n1, &n0) {
                        self.stack.push(n);
                        expr.push(e);
                        self.traverse(expr, (i, j), results);
                        expr.pop();
                        self.stack.pop();
                    }
                    self.stack.push(n1);
                    self.stack.push(n0);
                }
            }
        }
    }
}

これによって self.evaluate() に到達した時点で計算結果は出ているので、その値だけを確認すれば良い、ということになる。また、途中でゼロ除算が生じた場合にその先の探索をしなくなるので無駄を省く効果もありそうだ。

これで、前述の 3370 件の列挙が 約30秒に短縮された。約4倍の高速化。

探索結果のキャッシュ

また、この内部stackは同じ状態を経過することも多い。例えば

  • 10*(9-8)+...
  • 10/(9-8)+...

8 までの計算結果はどちらも変わらない。どちらか一方で全探索して1件も見つからなかったなら、もう片方でも1件も見つからないことが分かる。

ので、探索して条件に合致する式が見つかったか否かを返すようにし、それが false だった場合には同様の状態からの探索をスキップするように変更。

use std::collections::HashSet;

type Key = (Vec<Fraction>, (usize, usize));

pub struct FastSearcher {
    nums: Vec<i32>,
    target: i32,
    stack: Vec<Fraction>,
    seen: HashSet<Key>,
}

impl FastSearcher {
    pub fn new(nums: Vec<i32>, target: i32) -> Self {
        let stack = Vec::with_capacity(nums.len());
        Self {
            nums,
            target,
            stack,
            seen: HashSet::new(),
        }
    }
}

impl Rpn for FastSearcher {
    fn traverse(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
    ) -> bool {
        if i == j + 1 && self.get(i).is_none() {
            let found = self.evaluate(expr);
            if found {
                results.push(expr.clone());
            }
            return found;
        }
        let key = (self.stack.clone(), (i, j));
        if self.seen.contains(&key) {
            return false;
        }
        let mut ret = false;
        if let Some(&n) = self.get(i) {
            ret |= self.backtrack(expr, (i + 1, j), results, ExpressionElement::Operand(n));
        }
        if i > j + 1 {
            for op in &Op::ALL {
                ret |= self.backtrack(expr, (i, j + 1), results, ExpressionElement::Operator(*op));
            }
        }
        if !ret {
            self.seen.insert(key);
        }
        ret
    }
}

これで、約12秒くらいに短縮された。さらに2倍以上速くなった。

実際にはこの key に使っている self.stack は約分されていない計算結果なので、同じ状態を経由しているはずなのに 値が異なっているという扱いになってしまうケースがある。 gcd を使って既約分数を格納するようにしてみる。

impl FastSearcher {
    fn gcd(a: i32, b: i32) -> i32 {
        if b == 0 {
            a
        } else {
            Self::gcd(b, a % b)
        }
    }
}

impl Rpn for FastSearcher {
    ...

    fn backtrack(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
        e: ExpressionElement,
    ) -> bool {
        let mut ret = false;
        match e {
            ExpressionElement::Operand(n) => {...}
            ExpressionElement::Operator(op) => {
                if let (Some(n0), Some(n1)) = (self.stack.pop(), self.stack.pop()) {
                    if let Some(n) = op.apply(&n1, &n0) {
                        let gcd = Self::gcd(n.0, n.1);
                        self.stack.push(Fraction(n.0 / gcd, n.1 / gcd));

                        ...
                    }
                    self.stack.push(n1);
                    self.stack.push(n0);
                }
            }
        }
        ret
    }
}

これで、3370 件の列挙が 約7.5秒程度まで短縮された。 このキャッシュ手法は探索目標の値などによって効率が変わってくるので一概には言えないかもしれないが、少なくともこの例においては数倍の高速化が実現できた。

割り算を考慮した条件つき生成

ここまでは計算結果が 2023 になるものを全列挙することだけを考えていたが、件のパズル問題では「98の間には既に÷が入っています」という制約がある。

前述した 3370 件から条件に合致するものを抽出しても良いが、そもそも探索する時点で条件に合致しない式になるものを除外していくほうが効率が良い。

これは自分ではまったく思いつかなくて解説記事を読んで目から鱗だったのだけど、 8の数字が出た後、数字と演算子の数が一致していた場合」の演算子を割り算に固定 することによって実現できる。

解説記事を参考に、専用のSearcherを実装してみる。

struct Div8Searcher {
    nums: Vec<i32>,
    target: i32,
    stack: Vec<Fraction>,
    eight_depth: i32,
}

impl Div8Searcher {
    pub fn new(nums: Vec<i32>, target: i32) -> Self {
        let stack = Vec::with_capacity(nums.len());
        Self {
            nums,
            target,
            stack,
            eight_depth: -1,
        }
    }
}

impl Rpn for Div8Searcher {
    fn traverse(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
    ) -> bool {

        ...

        if i > j + 1 {
            for op in &Op::ALL {
                if self.eight_depth == 0 && op != &Op::Div {
                    continue;
                }
                ret |= self.backtrack(expr, (i, j + 1), results, ExpressionElement::Operator(*op));
            }
        }
        ret
    }
    fn backtrack(
        &mut self,
        expr: &mut Vec<ExpressionElement>,
        (i, j): (usize, usize),
        results: &mut Vec<Vec<ExpressionElement>>,
        e: ExpressionElement,
    ) -> bool {
        let mut ret = false;
        match e {
            ExpressionElement::Operand(n) => {
                let orig = self.eight_depth;
                if n == 8 {
                    self.eight_depth = 0;
                } else if self.eight_depth >= 0 {
                    self.eight_depth += 1;
                }
                self.stack.push(Fraction(n, 1));
                expr.push(e);
                ret |= self.traverse(expr, (i, j), results);
                expr.pop();
                self.stack.pop();
                self.eight_depth = orig;
            }
            ExpressionElement::Operator(op) => {
                if let (Some(n0), Some(n1)) = (self.stack.pop(), self.stack.pop()) {
                    if let Some(n) = op.apply(&n1, &n0) {
                        self.eight_depth -= 1;
                        self.stack.push(n);
                        expr.push(e);
                        ret |= self.traverse(expr, (i, j), results);
                        expr.pop();
                        self.stack.pop();
                        self.eight_depth += 1;
                    }
                    self.stack.push(n1);
                    self.stack.push(n0);
                }
            }
        }
        ret
    }
}

eight_depth の値を保持し、8 が出現したら 0 にセットして OperandOperator かによって値を増減。値が 0 のときには次に push する OpeatorOp::Div だけに制限される。

前述したような探索結果のキャッシュも使いたいところだが、新たな状態としてこの self.eight_depth もキーに含める必要があり、その割にはそれほど効果も無さそうなのでここでは省いた。

これによって 「98の間には必ず/が入る」式だけが探索されるようになる。

条件に合致する 530 件が、約4秒強で列挙されるようになった。全列挙してから抽出するよりも明らかに速いことが分かる。

...

(10*(9/((8-7)/((6*5)/(4/3)))))-(2/1)
(10*(9/((8-7)/(6*((5/4)*3)))))-(2*1)
(10*(9/((8-7)/(6*((5/4)*3)))))-(2/1)
(10*(9/((8-7)/(6*(5/(4/3))))))-(2*1)
(10*(9/((8-7)/(6*(5/(4/3))))))-(2/1)
Completed 530 results in 4.187013041s

中置記法、正規化

ここまでで探索して収集された resultsVec<Vec<ExpressionElement>> であり、これを中置記法に書き換える操作もまたstackを使って実現していくことになる。

impl Display for Op {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
        f.write_char(match self {
            Self::Add => '+',
            Self::Sub => '-',
            Self::Mul => '*',
            Self::Div => '/',
        })
    }
}

fn solve(searcher: &mut impl Rpn) {
    let mut results = Vec::new();
    searcher.traverse(&mut Vec::new(), (0, 0), &mut results);
    for result in results {
        let mut stack = Vec::new();
        for e in &result {
            match e {
                ExpressionElement::Operand(n) => stack.push((n.to_string(), None)),
                ExpressionElement::Operator(op) => {
                    if let (Some((mut s0, o0)), Some((mut s1, o1))) = (stack.pop(), stack.pop()) {
                        if o1.is_some() {
                            s1 = format!("({s1})");
                        }
                        if o0.is_some() {
                            s0 = format!("({s0})");
                        }
                        stack.push((format!("{s1}{op}{s0}"), Some(*op)));
                    }
                }
            }
        }
        println!("{}", stack[0].0);
    }
}

連結していくための String と、どの Operator で得られたものか(もしくは Operand か)、を保持しながら処理していく。簡単にすべて括弧をつけてしまうのなら上述のように書ける。

不要な括弧を除去して出力する場合は、左operandの場合と右operandの場合で処理が変わる(引き算・割り算は右operandに注意する必要がある)。matches! で判定してやってみる。

        ...

        ExpressionElement::Operator(op) => {
            if let (Some((mut s0, o0)), Some((mut s1, o1))) = (stack.pop(), stack.pop()) {
                if matches!((op, o1), (Op::Mul | Op::Div, Some(Op::Add | Op::Sub))) {
                    s1 = format!("({s1})");
                }
                if matches!(
                    (op, o0),
                    (Op::Sub | Op::Mul, Some(Op::Add | Op::Sub)) | (Op::Div, Some(_))
                ) {
                    s0 = format!("({s0})");
                }
                stack.push((format!("{s1}{op}{s0}"), Some(*op)));
            }
        }

これで、得られた結果を HashSet などで重複排除することで、解説記事と同様の 81 件を約4秒強で列挙することができた。

((10+(9/8+7)*6*5)*4-3)*2-1
(10*9/((8-7)/(6*5))/(4/3)-2)*1
(10*9/((8-7)/(6*5))/(4/3)-2)/1
(10*9/((8-7)/(6*5))/4*3-2)*1
(10*9/((8-7)/(6*5))/4*3-2)/1
(10*9/((8-7)/(6*5)*4)*3-2)*1
(10*9/((8-7)/(6*5)*4)*3-2)/1
(10*9/((8-7)/(6*5)*4/3)-2)*1
(10*9/((8-7)/(6*5)*4/3)-2)/1
(10*9/((8-7)/(6*5/(4/3)))-2)*1
(10*9/((8-7)/(6*5/(4/3)))-2)/1
(10*9/((8-7)/(6*5/4))*3-2)*1
(10*9/((8-7)/(6*5/4))*3-2)/1
(10*9/((8-7)/(6*5/4)/3)-2)*1
(10*9/((8-7)/(6*5/4)/3)-2)/1
(10*9/((8-7)/(6*5/4*3))-2)*1
(10*9/((8-7)/(6*5/4*3))-2)/1
(10*9/((8-7)/6)*5/(4/3)-2)*1
(10*9/((8-7)/6)*5/(4/3)-2)/1
(10*9/((8-7)/6)*5/4*3-2)*1
(10*9/((8-7)/6)*5/4*3-2)/1
(10*9/((8-7)/6/(5/(4/3)))-2)*1
(10*9/((8-7)/6/(5/(4/3)))-2)/1
(10*9/((8-7)/6/(5/4))*3-2)*1
(10*9/((8-7)/6/(5/4))*3-2)/1
(10*9/((8-7)/6/(5/4)/3)-2)*1
(10*9/((8-7)/6/(5/4)/3)-2)/1
(10*9/((8-7)/6/(5/4*3))-2)*1
(10*9/((8-7)/6/(5/4*3))-2)/1
(10*9/((8-7)/6/5)/(4/3)-2)*1
(10*9/((8-7)/6/5)/(4/3)-2)/1
(10*9/((8-7)/6/5)/4*3-2)*1
(10*9/((8-7)/6/5)/4*3-2)/1
(10*9/((8-7)/6/5*4)*3-2)*1
(10*9/((8-7)/6/5*4)*3-2)/1
(10*9/((8-7)/6/5*4/3)-2)*1
(10*9/((8-7)/6/5*4/3)-2)/1
(10*9/(8-7)*6*5/(4/3)-2)*1
(10*9/(8-7)*6*5/(4/3)-2)/1
(10*9/(8-7)*6*5/4*3-2)*1
(10*9/(8-7)*6*5/4*3-2)/1
10*9/((8-7)/(6*5))/(4/3)-2*1
10*9/((8-7)/(6*5))/(4/3)-2/1
10*9/((8-7)/(6*5))/4*3-2*1
10*9/((8-7)/(6*5))/4*3-2/1
10*9/((8-7)/(6*5)*4)*3-2*1
10*9/((8-7)/(6*5)*4)*3-2/1
10*9/((8-7)/(6*5)*4/3)-2*1
10*9/((8-7)/(6*5)*4/3)-2/1
10*9/((8-7)/(6*5/(4/3)))-2*1
10*9/((8-7)/(6*5/(4/3)))-2/1
10*9/((8-7)/(6*5/4))*3-2*1
10*9/((8-7)/(6*5/4))*3-2/1
10*9/((8-7)/(6*5/4)/3)-2*1
10*9/((8-7)/(6*5/4)/3)-2/1
10*9/((8-7)/(6*5/4*3))-2*1
10*9/((8-7)/(6*5/4*3))-2/1
10*9/((8-7)/6)*5/(4/3)-2*1
10*9/((8-7)/6)*5/(4/3)-2/1
10*9/((8-7)/6)*5/4*3-2*1
10*9/((8-7)/6)*5/4*3-2/1
10*9/((8-7)/6/(5/(4/3)))-2*1
10*9/((8-7)/6/(5/(4/3)))-2/1
10*9/((8-7)/6/(5/4))*3-2*1
10*9/((8-7)/6/(5/4))*3-2/1
10*9/((8-7)/6/(5/4)/3)-2*1
10*9/((8-7)/6/(5/4)/3)-2/1
10*9/((8-7)/6/(5/4*3))-2*1
10*9/((8-7)/6/(5/4*3))-2/1
10*9/((8-7)/6/5)/(4/3)-2*1
10*9/((8-7)/6/5)/(4/3)-2/1
10*9/((8-7)/6/5)/4*3-2*1
10*9/((8-7)/6/5)/4*3-2/1
10*9/((8-7)/6/5*4)*3-2*1
10*9/((8-7)/6/5*4)*3-2/1
10*9/((8-7)/6/5*4/3)-2*1
10*9/((8-7)/6/5*4/3)-2/1
10*9/(8-7)*6*5/(4/3)-2*1
10*9/(8-7)*6*5/(4/3)-2/1
10*9/(8-7)*6*5/4*3-2*1
10*9/(8-7)*6*5/4*3-2/1
Completed 81 results in 4.229359166s

CLIアプリケーション化

ここまで幾つかの実装ができたので、折角なので引数でパラメータを変えながら全列挙できる CLI アプリケーションを作成してみた。

$ ./puzzle2023 --help
Solvers for https://twitter.com/tkihira/status/1609313732034965506

Usage: puzzle2023 [OPTIONS]

Options:
  -t, --target <TARGET>      [default: 2023]
  -m, --max-num <MAX_NUM>    [default: 10]
  -r, --rev <REV>            [default: true] [possible values: true, false]
  -s, --searcher <SEARCHER>  [default: fast] [possible values: default, fast, div8]
  -n, --normalize
  -v, --verbose
  -h, --help                 Print help information
  -V, --version              Print version information

全パターン列挙

$ ./puzzle2023 -v

...

10-(9-(((8*(7*6))+(5-4))*(3*(2*1))))
10-(9-(((8*(7*6))+(5-4))*(3*(2/1))))
10-(9-(((8*(7*6))+(5-4))*(3+(2+1))))
10-(9-(8*(((7*(6+(5/4)))*(3+2))-1)))
10-(9-(8*((7*((6+(5/4))*(3+2)))-1)))
Completed 3370 results in 7.338064208s

9と8の間の割り算に限定

$ ./puzzle2023 -v --searcher div8

...

(10*(9/((8-7)/(6*((5/4)*3)))))-(2/1)
(10*(9/((8-7)/(6*(5/(4/3))))))-(2*1)
(10*(9/((8-7)/(6*(5/(4/3))))))-(2/1)
Completed 530 results in 4.219031208s

正規化

$ ./puzzle2023 -v --searcher div8 --normalize

...

10*9/(8-7)*6*5/(4/3)-2/1
10*9/(8-7)*6*5/4*3-2*1
10*9/(8-7)*6*5/4*3-2/1
Completed 81 results in 4.239637041s

計算結果の確認

$ ./puzzle2023 --searcher div8 --normalize | python3 -c 'for s in open(0): print(f"{eval(s):.1f}")' | sort | uniq -c
  81 2023.0

使うのを 1 から 8 の昇順にして 2024 を作るようにしてみる

$ ./puzzle2023 -m8 -rfalse -t2024 -nv
(((1+2)*3*4+5)*6+7)*8
(1*2*(3+4*5*6)+7)*8
(1+(2+3*4)*(5+6+7))*8
(1+(2+3-(4-5))*6*7)*8
(1+(2+3-4+5)*6*7)*8
(1+2*(3+4)*(5+6+7))*8
(1+2*(3+4+5+6)*7)*8
(1+2*3+4)*(5*6-7)*8
(1+2/(3/((4+5)*6))*7)*8
(1+2/(3/((4+5)*6)/7))*8
(1+2/(3/((4+5)*6*7)))*8
(1+2/(3/(4+5))*6*7)*8
(1+2/(3/(4+5)/(6*7)))*8
(1+2/(3/(4+5)/6)*7)*8
(1+2/(3/(4+5)/6/7))*8
(1+2/3*(4+5)*6*7)*8
(1-(2-3*4))*(5*6-7)*8
(1-2*(3*4-5*6)*7)*8
(1-2*3*(4-5)*6*7)*8
(1-2*3/((4-5)/(6*7)))*8
(1-2*3/((4-5)/6)*7)*8
(1-2*3/((4-5)/6/7))*8
(1-2*3/(4-5)*6*7)*8
(1-2+3*4)*(5*6-7)*8
1*(2*(3+4*5*6)+7)*8
Completed 25 results in 131.605458ms

「8の前を割り算に限定」にしているのも、コマンドラインオプションで違う値で指定できるようにしたかったけど、複雑になりすぎる感じがしたのでそれは断念…

改善?

tkihiraさんの最適化ではstackを push / pop するのではなく固定長配列を用意して見るべきindexだけをズラしていく、といった工夫をしていた。真似して実装してみたが本実装においてはそこはあまり効果なさそうだった。

あとはこうして再帰的に探索した結果を一度 results という Vec に格納してしまっているのが実は非効率なのでは、と思ったが… Iterator として使いたかったが そうするためには探索の途中の状態を保持しておいて next() が呼ばれるたびに続きを探索していく、という処理にする必要があり、そのあたりの実装がうまくできなかった。何か良い方法あるのだろうか…

Repository

github.com

Rubyでバイナリデータに対するrindex検索の挙動でハマったので調べたことメモ

自分の手元の環境でこんなことが起きた。

$ ruby -v
ruby 3.1.2p20 (2022-04-12 revision 4491bb740a) [arm64-darwin21]
$ irb
irb(main):001:0> "\x01\x80\x00\x00".index("\x01")
=> 0
irb(main):002:0> "\x01\x80\x00\x00".rindex("\x01")
=> 1

\x010 番目にしかないのだから、 .index でも .rindex でも 0 が返ってくるはずではないの??

先に結論

バイナリデータを扱うときには必ずEncodingを ASCII-8BIT に指定しておくこと!

きっかけ

roo というgem (記事時点で 2.9.0)を使って、Excelファイルを開こうとした。

require 'roo'

p Roo::Excelx.new("hoge.xlsx")

これは問題ないが、 #initialize の引数は file_or_strem ということでファイルを open したものを渡しても良いはず、と

require 'roo'

File.open("hoge.xlsx") do |f|
  p Roo::Excelx.new(f)
end

ということをすると以下のような謎のエラーを吐いて落ちる。

/Users/sugyan/.rbenv/versions/3.1.2/lib/ruby/gems/3.1.0/gems/rubyzip-2.3.2/lib/zip/entry.rb:365:in `check_c_dir_entry_static_header_length': undefined method `bytesize' for nil:NilClass (NoMethodError)

      return if buf.bytesize == ::Zip::CDIR_ENTRY_STATIC_HEADER_LENGTH
                   ^^^^^^^^^

で困っていたので他の人に聞いてみたところ「自分のところではそんな問題は起きていない」と言われ。 試しにDockerでRuby環境作って実行してみると、確かに問題なく動く…何故? と調べはじめた。

String#rindex の謎挙動

自分の環境と問題なく動く環境とで比較しながら見てみると、roo が使っている rubyzip (記事時点で 2.3.2) の中でファイル内容を読み取った結果の値が違っていた。

module Zip
  class CentralDirectory
    include Enumerable

    END_OF_CDS             = 0x06054b50

    ...

    def get_e_o_c_d(buf) #:nodoc:
      sig_index = buf.rindex([END_OF_CDS].pack('V'))
      ...

buf の中から 0x06054b50 をpackしたもの つまり "PK\x05\x06" のデータを末尾から探すために .rindex() を呼んでいるのだが、手元の環境と 問題なく動く環境ではその値が 1 ズレている。。

何故?と思って色々試しているうちに冒頭の例のようなものに辿り着いた。

もう少し深く追う

少なくとも \x80 などを含まないASCIIだけのものであればおかしな挙動にはならない。

irb(main):001:0> "\x01\x7F\x00\x00".rindex("\x01")
=> 0

また、これが増えるとズレもどんどん大きくなる。

irb(main):001:0> "\x01\x80".rindex("\x01")
=> 0
irb(main):002:0> "\x01\x80\x80".rindex("\x01")
=> 1
irb(main):003:0> "\x01\x80\x80\x80".rindex("\x01")
=> 2
irb(main):004:0> "\x01\x80\x80\x80\x80".rindex("\x01")
=> 3

逆から走査する際に異常な文字が検出された際に読み飛ばし、その結果の走査数を文字列長から引いたものを返したことにより値がおかしくなる、というのが予想される。

ソースを読んでみる。 rstring 関連はこのへんのようだ。

#ifdef HAVE_MEMRCHR によって実装が分かれている。 memrchr というのがglibcに入っているもので、自分のmacOS環境では使えなくて Docker内のLinux環境などでは使える、これによって環境によって挙動が変わったりする、のかもしれない。

で、使えない環境の方での実装は。 対象が見つかるまで while loop の中で rb_enc_prev_char で前の文字を走査している、という感じのようだ。encodingの影響を受けそう…?

    while (s) {
        if (memcmp(s, t, slen) == 0) {
            return pos;
        }
        if (pos == 0) break;
        pos--;
        s = rb_enc_prev_char(sbeg, s, e, enc);
    }

Encodingと実行環境

Encodingが関係しているようだったので色々調べてみた。

Rubyではバイナリデータ(バイト列)も String で扱う。

バイナリの取扱い

Ruby の String は、文字の列を扱うためだけでなく、バイトの列を扱うためにも使われます。しかし、Ruby M17N には直接にバイナリを表すエンコーディングは存在しません。このため、バイナリを String で扱う際には、ASCII 互換オクテット列を意味する ASCII-8BIT を用います。これにより、ASCII 互換であるこの String は 7bit クリーンな文字列と比較・結合が可能となります。

https://docs.ruby-lang.org/ja/latest/doc/spec=2fm17n.html

ということで、バイナリの場合は本来は ASCII-8BIT のEncodingを持つ文字列として検索しなければならないのに、手元の環境では UTF-8 のままで検索しようとしていた、というところに「使い方の問題」があったようだ。

irb(main):001:0> "\x01\x80\x00\x00".encoding
=> #<Encoding:UTF-8>

Encodingが ASCII-8BIT になっていれば、rindex でも正しい値を得ることができそうだ。 全然知らなかったが String#bASCII-8BIT の複製を得ることもできるらしい。

https://docs.ruby-lang.org/ja/latest/method/String/i/b.html

irb(main):001:0> "\x01\x80\x00\x00".rindex("\x01")
=> 1
irb(main):002:0> "\x01\x80\x00\x00".force_encoding(Encoding::ASCII_8BIT).rindex("\x01")
=> 0
irb(main):003:0> "\x01\x80\x00\x00".b.rindex("\x01")
=> 0

なるほど〜。

ではこの Encoding は何で決まるのか?というのも上記リンクに書いてある。

リテラルエンコーディング

文字列リテラル正規表現リテラルそしてシンボルリテラルから生成されるオブジェクトのエンコーディングスクリプトエンコーディングになります。

またスクリプトエンコーディングが US-ASCII である場合、7bit クリーンではないバックスラッシュ記法で表記されたリテラルエンコーディングは ASCII-8BIT になります。

さらに Unicode エスケープ (\uXXXX) を含む場合、リテラルエンコーディングUTF-8 になります。

複雑。。。

とにかく通常はスクリプトエンコーディングがまず大事のようだ。

スクリプトエンコーディング

スクリプトエンコーディングとは Ruby スクリプトを書くのに使われているエンコーディングです。スクリプトエンコーディングは マジックコメントを用いて指定します。スクリプトエンコーディングには ASCII 互換エンコーディングを用いることができます。 ASCII 非互換のエンコーディングや、ダミーエンコーディングは用いることができません。

現在のスクリプトエンコーディング__ENCODING__ により取得することができます。

さらに、magic commentによってこのスクリプトエンコーディングを決定でき、それが無い場合の挙動も書かれている。

マジックコメントが指定されなかった場合、コマンド引数 -K, RUBYOPT およびファイルの shebang からスクリプトエンコーディングは以下のように決定されます。左が優先です。

magic comment(最優先) > -K > RUBYOPTの-K > shebang

上のどれもが指定されていない場合、通常のスクリプトなら UTF-8、-e や stdin から実行されたものなら locale がスクリプトエンコーディングになります。 -K オプションが複数指定されていた場合は、後のものが優先されます。

通常のスクリプト-e かどうか、でも変わったりするのね…。そして特に指定ない場合は最終的にはlocaleが使われる。 ので、 LC_CTYPE などでも挙動が変わり得るようだ。

$ ruby -e 's="\x01\x80\x00\x00"; p __ENCODING__, s.encoding, s.rindex("\x01")'
#<Encoding:UTF-8>
#<Encoding:UTF-8>
1
$ ruby -Kn -e 's="\x01\x80\x00\x00"; p __ENCODING__, s.encoding, s.rindex("\x01")'
#<Encoding:ASCII-8BIT>
#<Encoding:ASCII-8BIT>
0
$ RUBYOPT="-Kn" ruby -e 's="\x01\x80\x00\x00"; p __ENCODING__, s.encoding, s.rindex("\x01")'
#<Encoding:ASCII-8BIT>
#<Encoding:ASCII-8BIT>
0
$ LC_CTYPE=C ruby -e 's="\x01\x80\x00\x00"; p __ENCODING__, s.encoding, s.rindex("\x01")'
#<Encoding:US-ASCII>
#<Encoding:ASCII-8BIT>
0

という具合に、何もしないと UTF-8 文字列として扱ってしまうが、オプションや環境変数によって リテラルの Encoding を変えることもできる。

こうやって正しく ASCII-8BIT のEncodingを持つ文字列として検索すれば、 String#rindex の値がズレるということはなさそうだ。

つまり再現条件は

手元の環境で String#rindex の値がズレたのは2つの要因があって

  1. memrchr が使えない環境でビルドしたRubyで実行していて
  2. ASCII-8BIT でないEncodingの文字列に対して rindex をかけていた

ということになる。2つ揃っていないと起きないものと思われる。

Rooの問題

前述の、Rooでエラーが起きる件について考える。

そもそもEncodingが不明で渡ってくるものに対して安易に rindex を使うものではない、ということで rubyzip 側に落ち度がありそうではある。が 現在の master branch では リリースされている 2.3.2 とは大きく変わっていて、もう同じことは起きないのかもしれない。詳しくは追っていないので分からないが、今回のと関連するissueがあった。

既にCloseされているが、今回調べた Zip::CentralDirectory は直接使用するものではなく Zip::File を使ってください、ということのようだ。

つまりこの場合、 Roo 側の使い方が悪い、ということになりそう。 で、Rooの方も調べてみると ちゃんとそれに対応しようとしていると思われる修正があった。

これが取り込まれれば手元の環境でも問題なく動くようになるかな…?

もしくは現時点でも、使う側が渡す引数のEncodingを明示的に指定してやることで一応回避できそうではある。

require 'roo'

File.open("hoge.xlsx", "r:ASCII-8BIT") do |f|
  p Roo::Excelx.new(f)
end

Rubyのバグではないの?

しかし Encodingを正しく指定していなかったために起きているとしても、 String#indexString#rindex も検索した対象の開始indexを返すものであるはずなのだから、それがズレるのはおかしいのではないのか…?という気はする。

さらに memrchr が使えるか否かのビルド環境によってだけで結果が変わる(ちゃんと確かめてないけど おそらくそう…)、というのも気持ち悪い。

かなり昔から現在の実装になっているようだし 今からでは挙動を変えづらそうではある。 仕様といってしまえばそれまでだけど…。 このあたりはRuby開発陣の方々の見解も聞いてみたいところではあります。

3.2

ちなみに Ruby 3.2 からは String#byteindexString#byterindex が追加されるそうです。今回のようなバイナリデータに対する検索にピッタリ使えそうですね。

spherical linear interpolation(slerp)によるlatent spaceでのnoise補間

memo.sugyan.com

の記事を書いてから、先行事例の調査が足りていなかったなと反省。 Latent Seed の Gaussian noise 間での morphing はあんまりやっている人いないんじゃないかな、と書いたけど、検索してみると普通に居た。

Stable Diffusion が公開されるよりも前の話だった…。

そしてこの morphing を作るためのコードも gist で公開されている

読んでみると、2 つの noise の間の値を得る方法として slerp という関数が使われている。

        for i, t in enumerate(np.linspace(0, 1, num_steps)):
            init = slerp(float(t), init1, init2)

https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355#file-stablediffusionwalk-py-L179-L180

この定義は以下のようになっている。

def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
    """ helper function to spherically interpolate two arrays v1 v2 """

    if not isinstance(v0, np.ndarray):
        inputs_are_torch = True
        input_device = v0.device
        v0 = v0.cpu().numpy()
        v1 = v1.cpu().numpy()

    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
    if np.abs(dot) > DOT_THRESHOLD:
        v2 = (1 - t) * v0 + t * v1
    else:
        theta_0 = np.arccos(dot)
        sin_theta_0 = np.sin(theta_0)
        theta_t = theta_0 * t
        sin_theta_t = np.sin(theta_t)
        s0 = np.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1

    if inputs_are_torch:
        v2 = torch.from_numpy(v2).to(input_device)

    return v2

https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355#file-stablediffusionwalk-py-L101-L125

不勉強で知らなかったが、これは "spherical linear interpolation"、日本語では「球面線形補間」と呼ばれるのかな、という手法で、3DCGの世界などではクォータニオン(Quaternion)の補間としてよく使われていたりする、ようだ。

en.wikipedia.org

書かれている通り、元々はクォータニオンの補間の文脈で紹介されていたが、次元数に関係なく適用することができるということらしい。 空間上での原点から2つの点へのベクトルを考え、その2つがなす角  \Omega をドット積とノルムを使って求めることができる。 始点から終点まで、 0 から  \Omega へと角度を線形に変化させながら2点を結ぶ円弧上を移動させていく、という感じだろうか。

 \displaystyle Slerp(p_0,p_1;t) = \frac{\sin{[(1-t)\Omega]}}{\sin{\Omega}}p_0 + \frac{\sin{[t\Omega]}}{\sin{\Omega}}p_1

画像生成モデルでの latent space における補間としては 2016年くらいには提案されていたようだ。 そういえば GAN で遊んでいたときにちょっとそういう話題を聞いたことがあったような気もする…。

arxiv.org

しかしこれによって Gaussian noise として分散を保ったまま変化させられる、ことになるのか…? 数学的なことはちょっとよく分からない。 まぁ球面上を角度だけ変えて動いていると考えればノルムが変化しないから大丈夫なのかな??くらいの感覚でしかない…。

実際にこれを使って補間していったときにどのような変化になるか、前回の記事で書いていた sqrt を使うものと上述の slerp を使うもので ランダムな2点を繋ぐ軌跡を見てみた。2次元、3次元 空間上でそれぞれプロットすると以下のような違いがあった。

間隔が異なるだけでなく、明らかに軌道も異なるところを通るようになるようだ。

実装

引用したコードでも良いが、自分でもPyTorch前提で Slerp を書いてみる。 NumPy への変換はしなくても計算できそうだった。

def myslerp(
    t: float, v0: torch.Tensor, v1: torch.Tensor, DOT_THRESHOLD: float = 0.9995
) -> torch.Tensor:
    u0 = v0 / v0.norm()
    u1 = v1 / v1.norm()
    dot = (u0 * u1).sum()
    if dot.abs() > DOT_THRESHOLD:
        return (1 - t) * v0 + t * v1
    omega = dot.acos()
    return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()

前述の数式の定義通りに実装するだけではあるが、やはり特別な場合に注意する必要はある。2点の原点からの方向がまったく同じか、もしくは正反対の方向のとき、理論上では単位ベクトルのドット積は 1-1 になる。 が、float値の多次元配列で計算してみると 多少の誤差が生じることがある。

for _ in range(10):
    v0 = np.random.normal([1, 4, 64, 64])
    v1 = 1.0 * v0
    print((v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))).sum())
0.9999999999999997
1.0
1.0
0.9999999999999998
0.9999999999999999
1.0
1.0
0.9999999999999998
1.0
1.0000000000000002
for _ in range(10):
    v0 = np.random.normal([1, 4, 64, 64])
    v1 = -1.0 * v0
    print((v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))).sum())
-1.0000000000000004
-1.0
-0.9999999999999998
-1.0
-1.0
-1.0
-0.9999999999999999
-1.0
-1.0000000000000002
-1.0

slerp では このドット積の値から arccos() で角度を求める必要があるが、例えば np.arccos() は 引数が -1.01.0 の間に収まっていない場合に nan を返すことになってしまう。 また、後でこの角度の sin() で割る処理があるので、ドット積が 1.0 で返せていたとしてもその後の計算で今度は inf が出てきたりする。 ので、こういうケースでは球面ではなく普通に線形で繋いでしまった方が良いだろうね、ということで DOT_THRESHOLD というものが使われて処理を分岐させているようだ。


※追記

指摘いただいたが、まったく同じ向きのときはともかく 正反対の方向の場合は線形補間が正しいとは限らない。 例えば2次元において (1, 0)(-1, 0) を補間する場合は原点を通らず (0, 1) もしくは (0, -1) を通る円軌道になった方が良かったり 3次元の球でいうと北極と南極を繋ぐのは赤道上を通る球面軌道であって欲しい、など。 これは2点のどちらかを適度にブレさせれば算出できる場合もあるかもしれないし、上記のように明らかに通るべき点があればまず両点からそこにまず補間してやるなど、色々な考え方がありそう。これもまた用途や場合によると思われる。 今回の Stable Diffusion で使っている乱数においては正反対向きはまず有り得ないものと思ってよさそうだ。


この小数点誤差がどれくらいになるのかは、要素数だったり 計算の順番だったり numpy で計算するか torch で計算するか などによっても変わってくるようだ。 どれくらいの値になったら線形と見做すか、の閾値(上述では 0.9995)は場合によると思われる。 少なくとも Stable Diffusion で使う [4, 64, 64] の shape で torch.randn() で生成される乱数同士の場合は、通常はドット積は 0.1 にすら届くことがまずないので、極端な話ではそれくらい低くても問題なさそうではある。 morphing の2点間で同一の noise を使う可能性がない、という前提であれば分岐なくしても問題ないくらい。

比較

で、実際に slerp を使うように変更すると morphing はどう変化するのか。 以前に載せた anime girl での morphing の slerp 版を作ってみた。

並べて比較

ちょっと変化のタイミングが変わったかな…? という程度で 大きな差は感じられない。

考察

結局のところ、Stable Diffusion においては Gaussian noise である限り何らかの画像を生成できるようになっているし、指定の2点の間をどう遷移しようと違いはないのかもしれない。 ただ slerp は線形に角度変化していくという観点でも、その遷移の移動量としてはより安定した変化が期待できるかな、という気はする。 むしろ前回の記事の手法が適当に思い付きでやった割にはそこそこ近いものが出来ていたのすごかったのでは…? というくらい。

また、今回は Gaussian noise 間でということを考えていたが slerp の手法自体は別にどこでも使えるものだとは思うので知っておいて損はないはず。また StyleGAN とかで遊ぶことがあったらそこでも使ってみても良さそう。

あと、prompt の embedding 結果に対する morphing でも slerp を使うことはできるはず、だが どうなんだろう。実験してみていないけど、A から B を遷移するのに まったく無関係な C を通過することになってしまったり 余計な変化が増えてしまうだけのような気もするので、こちらは線形に最短距離で遷移して刻み幅だけ調整するくらいが良いのではないかな、と思っている。実際にはやはりどちらでも大して変わらないかもしれないし prompt 次第かもしれない。試行錯誤するときの選択肢として持っておくと良さそうではある。

Stable Diffusionでmorphing

ということで少し触って遊んでみたのでメモ。

Stable Diffusion をザックリ理解

先月公開された Stable Diffusion。

stability.ai

高精度で美しい画像を出力できる高性能なモデルながら、Google Colab などでも手軽に動かせるし、 Apple silicon でもそれなりに動かせる、というのが魅力だ。

中身については 以下の記事の "How does Stable Diffusion work?" 以降のところが分かりやすい。

huggingface.co

図をそのまま引用させていただくと

という仕組みになっていて、受け取る入力は "User Prompt" と "Latent Seed" のみ。 前者が「どのような画像を生成するか」を決めて、後者で「どんなバリエーションでその画像を生成するか」を決めるような感じ。

User Prompt は [77, 768] の空間にエンコードされて、これを使って [4, 64, 64] の Gaussian noise を scheduler によって繰り返し denoise していくことで目標の画像のための "latent image representations" を生成していく。最後はこれを VAE(Variational Auto-Encoder) で拡大していくことで最終的な [512, 512, 3] の画像を得る。 この途中の scheduler による sampling がアルゴリズムによって結果の質が変わってきたりするし、回数が少なすぎると denoise が足りなくて汚い出力になったりする、というわけだ。

ともかく、重要なのはこの 2 つの入力だけで出力が決まるということ、そして prompt の入力は [77, 768] に embedding されたものが使われる、ということ。 prompt の文字列を工夫していくのも良いが、そこから embedding されて渡すものを直接指定してしまっても良いわけだ。 また、 Latent Seed の noise の方も少しずつ変えていくことで少しだけ違う感じの出力を得たりすることができそうだ。

自分は以前に StyeleGAN で latent space を変化させて生成画像の morphing などをやっていたので、それと同じようなことをやってみることにした。

Prompt 間での interpolation, morphing

まずは 2 つの異なる prompt から生成される画像の間を補間して繋いでみる。

(当然ながら、京都と東京の町並みを補間するからといって愛知や静岡の風景が生成されたりはしない。)

scripts/txt2img.py の中にある、prompt から画像を生成する部分のメインはここにある。

github.com

    c = model.get_learned_conditioning(prompts)
    shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
    samples_ddim, _ = sampler.sample(
        S=opt.ddim_steps,
        conditioning=c,
        batch_size=opt.n_samples,
        shape=shape,
        verbose=False,
        unconditional_guidance_scale=opt.scale,
        unconditional_conditioning=uc,
        eta=opt.ddim_eta,
        x_T=start_code
    )
    x_samples_ddim = model.decode_first_stage(samples_ddim)

model.get_learned_conditioning() で与えられた prompt 文字列から [N, 77, 768]Tensor に embedding されたものが得られる。 これを sampler (ここではデフォルトで使われる DDIMSampler を使用している) に与えて ddim_steps 回数の sampling を実行して [N, 4, 64, 64] の denoise された結果が得られる。これを model.decode_first_stage() に与えることで最終的な画像に使われる値が得られるようだ。

sampler.sample() には色々なパラメータがあるが、とにかく重要なのは conditioning (c) と x_T (start_code) だけ。これを変化させることで生成画像をコントロールしていく。

start_code の方はここでは固定した値を使うようにすることで、「何を描くか」だけを徐々に変化させていく様子を作れる。一度だけ乱数を生成してそれを繰り返し使うようにすると良い。

で、 c の方は 2 つの異なる prompt からそれぞれ embedding された値を取り出して、線形に変化させていく。

指定した cstart_code から生成画像だけを得るような関数を書いておくとやりやすい。 model のロード方法などについては割愛。

from contextlib import nullcontext

import numpy as np
import torch
from PIL import Image

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ddpm import LatentDiffusion


def get_device() -> torch.device:
    ...


def load_model() -> LatentDiffusion:
    ...


model = load_model(...)


def generate(
    c: torch.Tensor, start_code: torch.Tensor, ddim_steps: int = 50
) -> Image:
    batch_size = 1
    device = get_device()
    precision_scope = torch.autocast if device.type == "cuda" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                uc = model.get_learned_conditioning(batch_size * [""])
                shape = [4, 64, 64]
                samples_ddim, _ = DDIMSampler(model).sample(
                    S=ddim_steps,
                    conditioning=c,
                    batch_size=batch_size,
                    shape=shape,
                    verbose=False,
                    unconditional_guidance_scale=7.5,
                    unconditional_conditioning=uc,
                    eta=0.0,
                    x_T=start_code,
                )

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp(
                    (x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0
                )
                image = (
                    255.0 * x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()[0]
                ).astype(np.uint8)
                return Image.fromarray(image)

ちなみに、 sampler.sample() された結果の方を線形に繋いで変化させていくという手法もあるのだけど、これはもはやどんな画像が生成されるかほぼ決定された後の値なので、 morphing しても単なる画像合成のような感じにしかならなくて面白くはない。

指定した cstart_code で画像を生成する準備ができたら、あとはその入力を作っていくだけ。

def morph_prompts(prompts: Tuple[str, str], steps: int) -> None:
    start_code = torch.randn([1, 4, 64, 64], device=get_device())
    c0 = model.get_learned_conditioning(prompts[0])
    c1 = model.get_learned_conditioning(prompts[1])
    for i in range(steps + 1):
        x = i / steps
        c = c0 * (1.0 - x) + c1 * x
        img = generate(c, start_code)
        img.save(f"morphing_{i:03d}.png")

これらを繋げてアニメーションさせれば、2 つの異なる prompt 間の morphing が出来上がる。

ただ、似ているものならまだあまり違和感ないが あまりに異なる 2 つを morphing させようとすると、急激に変化してしまって面白くない。

embedding された空間がどんなものかは未知だが、ともかく A と B の 2 点間には必ず「denoise された結果 A になるもの」と「denoise された結果 B になるもの」が分断される地点がどこかに存在してしまう。 それは A B の中心かもしれないし、少しズレたところかもしれないが、そのあたりで急激な変化が起こり得る。 ので、中点に近い位置は出来るだけ細かい step で刻んだ方がよりシームレスな morphing になりやすいように感じた。 ので、単純な線形に繋ぐのではなく双曲線関数で刻み幅を微妙に変えながら作ってみることにした。

    a = np.arccosh(5.0)
    for i in range(steps + 1):
        t = i / steps
        x = sinh(a * (t * 2.0 - 1.0)) / sinh(a) / 2.0 + 0.5
        c = c0 * (1.0 - x) + c1 * x
        ...

それでもやっぱり急激な変化は捉えられないことが多々あるけれども…。

Seed 間での interpolation, morphing

今度は、同一の prompt で異なる Latent Seed を使用した 2つの画像間での morphing。

prompt の方は固定して、 torch.randn() で生成していた Gaussian noise の方を徐々に変えていく。生成する画像の「お題」は一緒だが、違うバリエーションのものになっていく、という morphing。

prompt のときと同じように変化させていけば良いだけ、と思ったが そうはいかない。実際やってみると中間点あたりはボヤけた画像になってしまうようだ。 最初 何故だろう…?と思ったが どうやらこの noise は "Gaussian noise" であることが重要で、標準正規分布として  {\mu} = 0, {\sigma}^2 = 1 になっていなければならない、ということらしい。 単純に v0 * (1.0 - x) + v1 * x のように単純な線形結合で変化させていくと、中心に近づくにつれてその標準偏差は小さくなってしまう。

それを防ぐために、足し合わせる前にそれぞれの倍率の sqrt をとるようにすると、合成された noise は標準偏差を保持したまま遷移することができそうだ。

すると今度は 0 付近と 1 付近で急激な変化が起こりやすそうなので、 prompt morphing のときのように刻み幅を調節する。

def morph_noises(prompt: str, steps: int) -> None:
    c = model.get_learned_conditioning([prompt])
    n0 = torch.randn([1, 4, 64, 64], device=get_device())
    n1 = torch.randn([1, 4, 64, 64], device=get_device())
    for i in range(steps + 1):
        t = i / steps
        x = 2.0 * t**2 if t < 0.5 else 1.0 - 2.0 * (1.0 - t) ** 2
        start_code = n0 * math.sqrt(1.0 - x) + n1 * math.sqrt(x)
        img = generate(c, start_code)
        img.save(f"morphing_{i:03d}.png")

これで、同じお題(prompt)に対して複数のバリエーションで描かれたものを連続的に変化させていくことができる。 好みの画像を出力する seed を幾つかピックアップして繋いでみたりするとより好みのものが見つかるかもしれないし、意外とブレンドされたものは好みではないものになるかもしれない。


※追記

memo.sugyan.com


まとめ

以上の2つができれば、その応用として prompt と noise を同時に変化させていったり、交互に変化させていったり、数回変化させた後にまた元の画像に戻ってきたり、といったものも作っていける。

雑に試行錯誤しながら実行できるように Google Colab でscriptを書いていたけど、ちょっと整理して公開する予定(需要あるかどうか分からないけど)。 prompt 間の morphing はやっている人結構いるけど、noise 間のものはまだあんまり見かけないような気はする?

※追記: GitHub - sugyan/stable-diffusion-morphing でとりあえず公開しておきました。多分動くはず?