Go Concurrency Patterns: Pipelines and cancellation - The Go Blog を読んでいて、なかなか理解するまで苦しんだので復習がてら自分でもコードを書いて確かめてみた。
お題
複数の入力データそれぞれに対して重い処理を行い、結果として返ってくる値をまとめて取得する。
途中でエラーが発生したら直ちに処理を中止して終了する。
コード
いちばん簡単な例
エラーを考慮しない場合。
package main import ( "fmt" "log" "math/rand" "time" ) func init() { log.SetFlags(log.Lmicroseconds) rand.Seed(time.Now().UnixNano()) } func doSomething(id int) string { wait := rand.Intn(1000) time.Sleep(time.Millisecond * time.Duration(wait)) // something heavy return fmt.Sprintf("%02d-%03d", id, wait) } func getAllData() (results []string) { for i := 0; i < 100; i++ { value := doSomething(i) log.Println("got", value) results = append(results, value) } return results } func main() { data := getAllData() log.Println("Finished.", data) }
こんなかんじ。100ループで毎回数百ミリ秒かかる処理(本当はCPUぶん回すような処理だったり)をして、その結果をひとつずつ繋げていって結果が格納されたsliceを返す。
$ go run example.go 22:57:33.931461 got 00-955 22:57:34.706984 got 01-774 22:57:35.204441 got 02-497 ... 22:58:23.020169 got 97-116 22:58:23.528240 got 98-507 22:58:24.178353 got 99-649 22:58:24.178596 Finished. [00-955 01-774 02-497 ...
当然ながら順番に1個ずつ処理していくのでとても時間かかる。
エラー処理を加える
doSomething
の中で、もしくはその前に繰り返し処理の内部でエラーが起こりうる、とする。適当に100分の1くらいの確率で起こることにして それぞれの関数をerror
も返すよう変更
var errUnfortunate1 = errors.New("unfortunate error 1") var errUnfortunate2 = errors.New("unfortunate error 2") func doSomething(id int) (string, error) { wait := rand.Intn(1000) time.Sleep(time.Millisecond * time.Duration(wait)) // something heavy if rand.Intn(100) == 0 { return "", errUnfortunate1 } return fmt.Sprintf("%02d-%03d", id, wait), nil } func getAllData() (results []string, err error) { for i := 0; i < 100; i++ { if rand.Intn(100) == 0 { return nil, errUnfortunate2 } value, err := doSomething(i) if err != nil { return nil, err } log.Println("got", value) results = append(results, value) } return results, nil } func main() { data, err := getAllData() if err != nil { log.Println("Failed!", err) return } log.Println("Finished.", date) }
関数の返り値からエラーチェックして 何かあればすぐにgetAllData
を抜けてmain
内で出力して終了するようになっている。
並行化その1 channel化
処理を並行で行うための準備として、goroutineとchannelを使った形に変えていく。
まずは入力を送ってくれるchannelを作って返す関数を作り、そこからrange
で読み取るようにしてみる。エラー処理を無視すると
func getDataChannel() <-chan string { c := make(chan string) go func() { for i := 0; i < 100; i++ { value, _ := doSomething(i) log.Println("got", value) c <- value } close(c) }() return c } func getAllData() (results []string, err error) { c := getDataChannel() for value := range c { results = append(results, value) } return results, nil }
こんなかんじ。
並行化その2 重い処理を並行に
引き続きエラー処理を無視したままだけど、doSomething
部分をgoroutineに。
単純に即時関数で囲んで並行化するだけだと処理が終わる前にc
が閉じてしまったりmainまで終了してしまったりするので、sync.WaitGroup
を使って全部おわるまで待つ。
import ( ... "sync" ) func getDataChannel() <-chan string { c := make(chan string) go func() { var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func(id int) { value, _ := doSomething(id) log.Println("got", value) c <- value wg.Done() }(i) } wg.Wait() close(c) }() return c }
早く終わったものから順にどんどんデータが送られて、全部おわるまで待ってからc
がcloseされる。
エラーが何もなければこれで良いのだけど、、
並行化その3 エラー処理1
まずはdoSomething
で返ってくるerrUnfortunate1を捕捉。
これはgoroutine内で起こり得るので関数の返り値としては使いづらい。ので、返ってくるvalueとともにstructに含めてchannelに送るようにする
type result struct { value string err error } func getDataChannel() <-chan result { c := make(chan result) go func() { var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func(id int) { value, err := doSomething(id) log.Println("got", value, err) c <- result{value: value, err: err} wg.Done() }(i) } wg.Wait() close(c) }() return c } func getAllData() (results []string, err error) { c := getDataChannel() for r := range c { if r.err != nil { return nil, r.err } results = append(results, r.value) } return results, nil }
受け取る側のrange
ループ内でresult.err
をチェックして、エラーを検出したらそこで終了。
これだけではまだまだ問題あるのだけど とりあえずここではこれで捕捉できたことにする
並行化その4 エラー処理2
次に、繰り返し処理の内部で起こり得るerrUnfortunate2を捕捉する。
goroutineでは返り値をとれないので、ループする部分をfunc() error {}()
な即時関数で囲むことで取得する。得たerrorを送る手段としてerror用のchannelを用意し、resultを送るchannelと一緒に返して使ってもらうようにする
func getDataChannel() (<-chan result, <-chan error) { c := make(chan result) errc := make(chan error) go func() { var wg sync.WaitGroup err := func() error { for i := 0; i < 100; i++ { if rand.Intn(100) == 0 { return errUnfortunate2 } wg.Add(1) go func(id int) { value, err := doSomething(id) log.Println("got", value, err) c <- result{value: value, err: err} wg.Done() }(i) } return nil }() wg.Wait() close(c) errc <- err }() return c, errc } func getAllData() (results []string, err error) { c, errc := getDataChannel() for r := range c { results = append(results, r.value) if r.err != nil { return nil, r.err } } err = <-errc if err != nil { return } return results, nil }
errorが起きようと起きまいと即時関数が終了した後にc
はcloseされるのでrange c
ループが終了し、その後にerrc
から即時関数の返り値として得たerrorを取得してチェックすることができる。
これまた問題があるけど一応捕捉はできた。
並行化その5 中断されたことを知らせる
ここまでだと、errUnfortunate1が起きたときにはc
がcloseすることもなく走ってる処理が続くし、errUnfortunate2のときにも走ってるもの待ってからcloseすることになってしまったり、まだ正しく中断できているとは言えない。
並行化して走っている処理たちに中断されたことを知らせるために、もう一つchannelを用意してそれを使って判定するようにする。
func getDataChannel(done <-chan struct{}) (<-chan result, <-chan error) { c := make(chan result) errc := make(chan error) go func() { var wg sync.WaitGroup err := func(walkFunc func(int) error) (err error) { for i := 0; i < 100; i++ { time.Sleep(time.Millisecond * 50) if rand.Intn(100) == 0 { return errUnfortunate2 } err = walkFunc(i) if err != nil { return } } return nil }(func(id int) error { wg.Add(1) go func() { value, err := doSomething(id) log.Println("got", value, err) select { case c <- result{value: value, err: err}: log.Println("sent.") case <-done: log.Println("not sent.") } wg.Done() }() select { case <-done: return errors.New("canceled") default: return nil } }) wg.Wait() close(c) errc <- err }() return c, errc } func getAllData() (results []string, err error) { done := make(chan struct{}) defer close(done) c, errc := getDataChannel(done) for r := range c { results = append(results, r.value) if r.err != nil { return nil, r.err } } err = <-errc if err != nil { return } return results, nil }
getAllData
側で用意したdone
channelは、deferによって関数を抜けるときにcloseする。これをgetDataChannel
に渡しておいて、そちらではselect
を使って処理を分岐させることができる。doneが閉じていればそちら側が実行されるのでdoSomething
から値が返ってきてもc
には送信されないし、ループを実行するwalkFunc
は"canceled"なエラーを受け取りループを中断するようになる。
並行化その6 完成形?
中断したときにsync.WaitGroup
で全部終わるまでWaitするのはブロックする必要ないのでgoroutineにする(deferでも良いかも?)。でもrangeでc
が閉じるまで待っていては結局errc
からすぐには受け取れないのでこちらもselectを使う。
あと、errc
は送る前に受け取り側が終了してしまっていると書き込みがブロックされる可能性があるのでバッファリングしておく必要がある、のでmake
の第2引数で1以上を指定しておく。
func getDataChannel(done <-chan struct{}) (<-chan result, <-chan error) { c := make(chan result) errc := make(chan error, 1) go func() { var wg sync.WaitGroup err := func(walkFunc func(int) error) (err error) { for i := 0; i < 100; i++ { time.Sleep(time.Millisecond * 50) if rand.Intn(100) == 0 { return errUnfortunate2 } err = walkFunc(i) if err != nil { return } } return nil }(func(id int) error { wg.Add(1) go func() { log.Println("start", id) value, err := doSomething(id) log.Println("got", value, err) select { case c <- result{value: value, err: err}: case <-done: } wg.Done() }() select { case <-done: return errors.New("canceled") default: return nil } }) go func() { wg.Wait() close(c) }() errc <- err }() return c, errc } func getAllData() (results []string, err error) { done := make(chan struct{}) defer close(done) c, errc := getDataChannel(done) Loop: for { select { case r, ok := <-c: if !ok { break Loop } results = append(results, r.value) if r.err != nil { return nil, r.err } case err = <-errc: if err != nil { return } } } return results, nil }
これで、並行かつ エラー時には即座に処理が中断されて余計なデータ送受信などもなく後始末もできるようになった。
goroutineの起動数を制限
とはいえ上記の方法だと入力受け取るたびにどんどんgoroutineを起動することになりメモリ使用量などマズいことになり得る。
ので、並行に走らせる数を制限させる別のパターンを用意する。
まず、処理の結果を送るchannelを返していたgetDataChannel
を、"入力"を送るchannelを返すだけのものに変更する。
func getInputChannel(done <-chan struct{}) (<-chan int, <-chan error) { ids := make(chan int) errc := make(chan error, 1) go func() { defer close(ids) err := func(walkFunc func(int) error) (err error) { for i := 0; i < 100; i++ { time.Sleep(time.Millisecond * 50) if rand.Intn(100) == 0 { return errUnfortunate2 } err = walkFunc(i) if err != nil { return } } return nil }(func(id int) error { select { case <-done: return errors.New("canceled") case ids <- id: } return nil }) errc <- err }() return ids, errc }
こんなかんじ、doneが閉じてない限りは入力データとなるidを送りつづける。
で、その入力channelを受け取って出力に結果を流すworker的なものを別に作る。
func worker(ids <-chan int, c chan<- result, done <-chan struct{}) { for id := range ids { value, err := doSomething(id) log.Println("got", value, err) select { case c <- result{value: value, err: err}: case <-done: return } } }
単純に入力が流れてくる限りdoSomething
な処理をして、doneが閉じていない限りはc
にresultを送りつづける。役割がハッキリしている。
んで、あとはこれをgoroutineで起動させて受け取るだけ。ただし終了するのを待ってからc
をcloseしてやる必要はある。
func getAllData() (results []string, err error) { done := make(chan struct{}) defer close(done) ids, errc := getInputChannel(done) var wg sync.WaitGroup c := make(chan result) wg.Add(1) go func() { worker(ids, c, done) wg.Done() }() go func() { wg.Wait() close(c) }() Loop: for { select { case r, ok := <-c: if !ok { break Loop } results = append(results, r.value) if r.err != nil { return nil, r.err } case err = <-errc: if err != nil { return } } } return results, nil }
この形で呼び出されるworkerは任意の数のgoroutineで並行起動してもそれぞれが「入力を受け取り出力を送る」という役目をこなすだけなので上手く動作してくれる。
var wg sync.WaitGroup c := make(chan result) for i := 0; i < 10; i++ { wg.Add(1) go func() { worker(ids, c, done) wg.Done() }() } go func() { wg.Wait() close(c) }()
最終形
というわけで最終的に出来上がったのが以下。ちゃんと納得できるかたちで http://blog.golang.org/pipelines/bounded.go と同じような形にできたので大丈夫だと思う。
package main import ( "errors" "fmt" "log" "math/rand" "runtime" "sync" "time" ) func init() { log.SetFlags(log.Lmicroseconds) rand.Seed(time.Now().UnixNano()) } var errUnfortunate1 = errors.New("unfortunate error 1") var errUnfortunate2 = errors.New("unfortunate error 2") type result struct { value string err error } func doSomething(id int) (string, error) { wait := rand.Intn(1000) time.Sleep(time.Millisecond * time.Duration(wait)) // something heavy if rand.Intn(100) == 0 { return "", errUnfortunate1 } return fmt.Sprintf("%02d-%03d", id, wait), nil } func getInputChannel(done <-chan struct{}) (<-chan int, <-chan error) { ids := make(chan int) errc := make(chan error, 1) go func() { defer close(ids) err := func(walkFunc func(int) error) (err error) { for i := 0; i < 100; i++ { time.Sleep(time.Millisecond * 50) if rand.Intn(100) == 0 { return errUnfortunate2 } err = walkFunc(i) if err != nil { return } } return nil }(func(id int) error { select { case <-done: return errors.New("canceled") case ids <- id: } return nil }) errc <- err }() return ids, errc } func worker(ids <-chan int, c chan<- result, done <-chan struct{}) { for id := range ids { value, err := doSomething(id) log.Println("got", value, err) select { case c <- result{value: value, err: err}: log.Println("sent") case <-done: log.Println("not sent") return } } } func getAllData() (results []string, err error) { done := make(chan struct{}) defer close(done) ids, errc := getInputChannel(done) var wg sync.WaitGroup c := make(chan result) for i := 0; i < 10; i++ { wg.Add(1) go func() { worker(ids, c, done) wg.Done() }() } go func() { wg.Wait() close(c) }() Loop: for { select { case r, ok := <-c: if !ok { break Loop } results = append(results, r.value) if r.err != nil { return nil, r.err } case err = <-errc: if err != nil { return } } } return results, nil } func main() { defer func() { // 異常に大きな数のgoroutineが起動しっぱなしでないか確かめる time.Sleep(time.Millisecond * 2000) log.Println(runtime.NumGoroutine()) }() data, err := getAllData() if err != nil { log.Println("Failed!", err) return } log.Println("Finished.", data) }
まとめ
なかなか処理の流れが複雑なかんじがして「なんでこんな書き方するの」「ここがエラーになったらどうなるの」とか悩んだけど、書きながら読んでるうちにようやく「あー、だからこうするのか」「確かに、こうしようと思ったらこういう形になるよねー」って納得できた。
とはいえスラスラとこういうのが書ける気はまだしないけど…。