半年ほど前から、BlueskyのAT ProtocolのRust版ライブラリを作っている。 memo.sugyan.com
その中で最近実装した機能の話。
API Agent
本家の atproto (TypeScript実装)に AtpAgent
というものがある。
この AtpAgent
は、(Blueskyだけに限らない) AT Protocol のための汎用的なエージェントとして提供されている。その機能の一つとしてtokenの管理機能がある。
AT Protocolの認証
少なくとも 2023/11 時点では HTTP Bearer auth でJWTを送信することで認証を行う方式となっている。
com.atproto.server.createSession
でログイン成功すると accessJwt
と refreshJwt
などが含まれる認証情報が返ってくるので、その accessJwt
をBearer tokenに使って各エンポイントにアクセスする。
また、 com.atproto.server.refreshSession
というエンドポイントがあり、ここを refreshJwt
をtokenに使って叩くことでtokenを更新することができる。
tokenの管理と自動更新機構
で、 @atproto/api
で提供されている AtpAgent
はそれらのtokenの管理を自動で行う機能を持っている。
export class AtpAgent { ... session?: AtpSessionData /** * Internal fetch handler which adds access-token management */ private async _fetch(...): Promise<AtpAgentFetchHandlerResponse> { ... // wait for any active session-refreshes to finish await this._refreshSessionPromise // send the request let res = await AtpAgent.fetch(...) // handle session-refreshes as needed if (isErrorResponse(res, ['ExpiredToken']) && this.session?.refreshJwt) { // attempt refresh await this._refreshSession() // resend the request with the new access token res = await AtpAgent.fetch(...) } return res } }
まず現時点でもっているsession情報を元にリクエスト処理を試み、そのレスポンスが expired
のエラーだったときのみそれを検出してtokenをrefreshして同じ内容のリクエストを再送する。最初のリクエストが成功していた場合はそのままそのレスポンスを返す。
という動き。
ATriumでの実装
で、これと同様の仕組みを持つ AtpAgent
をATriumでも実装しようと考えた。
XrpcClient
trait
ATriumでは、AT ProtocolのXRPCリクエストを送るための XrpcClient
というtraitを定義している。
#[async_trait] pub trait HttpClient { async fn send_http( &self, request: Request<Vec<u8>>, ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>; } pub type XrpcResult<O, E> = Result<OutputDataOrBytes<O>, self::Error<E>>; #[async_trait] pub trait XrpcClient: HttpClient { fn base_uri(&self) -> String; #[allow(unused_variables)] async fn auth(&self, is_refresh: bool) -> Option<String> { None } async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E> where P: Serialize + Send + Sync, I: Serialize + Send + Sync, O: DeserializeOwned + Send + Sync, E: DeserializeOwned + Send + Sync, { ... } }
XRPC は規約に沿ったHTTPリクエスト/レスポンスでしかないので、内部としてはHTTPの処理をすることになる。
が、RustではHTTPリクエスト/レスポンスを処理するための標準ライブラリのようなものはなく、特に非同期の場合は reqwest
や Isahc
、Surf
のような3rd partyのライブラリが多く使われている、と思う。
ATriumではこれらのライブラリをバックエンドとして開発者が選択できるよう、HTTP部分を抽象化する HttpClient
というtraitを定義し、それを継承した XrpcClient
を実装したインスタンスを内部に持つ AtpServiceClient
が各XRPCを処理する形にしている。
XrpcClient::send_xrpc()
はデフォルト実装を持っており、HttpClient::send_http()
さえ実装されていれば、あとはそれを使ってリクエストに使う入力のserializeや返ってきたレスポンスJSONのdeserializeなどの処理を行うようになっている。
session管理するwrapper
認証については XrpcClient
のメソッドとして async fn auth(&self, is_refresh: bool) -> Option<String>
を定義しているだけなので、ここでどのようなtokenを返すかはtrait実装者に任される。
インメモリで管理する場合、内部で Arc<RwLock<Option<Session>>>
のように持っておいて管理することで、マルチスレッドで共有されていても安全に扱えるようになる。
(参考: https://github.com/usagi/rust-memory-container-cs)
なので、内部で XrpcClient
を実装したものを持つWrapperを作り、主な処理はそれに移譲して auth()
だけを実装することで、保持しているsession情報からtokenを返す XrpcClient
実装を作ることができる。
use std::sync::Arc; use tokio::sync::RwLock; struct Wrapper<T> where T: XrpcClient + Send + Sync, { inner: T, session: Arc<RwLock<Option<Session>>>, } #[async_trait] impl<T> HttpClient for Wrapper<T> where T: XrpcClient + Send + Sync, { async fn send_http( &self, request: Request<Vec<u8>>, ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> { self.inner.send_http(request).await } } #[async_trait] impl<T> XrpcClient for Wrapper<T> where T: XrpcClient + Send + Sync, { fn base_uri(&self) -> String { self.inner.base_uri() } async fn auth(&self, is_refresh: bool) -> Option<String> { self.session.read().unwrap().as_ref().map(|session| { if is_refresh { session.refresh_jwt.clone() } else { session.access_jwt.clone() } }) } }
ここでは非同期の RwLock
として tokio::sync
を使っている。
これで、AtpAgent
も同じsessionを共有し、例えばログイン成功時に取得したsession情報を書き込む、といったことをすれば、その後のリクエストでその値が使われるようになる。
struct AtpAgent<T> where T: XrpcClient + Send + Sync, { api: Service<Wrapper<T>>, session: Arc<RwLock<Option<Session>>>, } impl<T> AtpAgent<T> where T: XrpcClient + Send + Sync, { fn new(xrpc: T) -> Self { let session = Arc::new(RwLock::new(None)); let api = Service::new(Wrapper { inner: xrpc, session: Arc::clone(&session) }); Self { api, session } } async fn login(&self, ...) { // login process let session: Session = ...; self.session.write().await.replace(session); } }
tokenの自動更新 (失敗例)
で、これを使ってtokenの自動更新を行うためには、XrpcClient::send_xrpc()
をオーバーライドする形で実装すれば良さそう。
impl Wrapper<T> where T: XrpcClient + Send + Sync, { async fn refresh_session() { // refresh process let session: Session = ...; self.session.write().await.replace(session); } fn is_expired<O, E>(result: &XrpcResult<O, E>) -> bool where O: DeserializeOwned + Send + Sync, E: DeserializeOwned + Send + Sync, { if let Err(Error::XrpcResponse(response)) = &result { if let Some(XrpcErrorKind::Undefined(body)) = &response.error { if let Some("ExpiredToken") = &body.error.as_deref() { return true; } } } false } } impl<T> XrpcClient for Wrapper<T> where T: XrpcClient + Send + Sync, { fn base_uri(&self) -> String { ... } async fn auth(&self, is_refresh: bool) -> Option<String> { ... } async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E> where P: Serialize + Send + Sync, I: Serialize + Send + Sync, O: DeserializeOwned + Send + Sync, E: DeserializeOwned + Send + Sync, { let result = self.inner.send_xrpc(request).await; // handle session-refreshes as needed if Self::is_expired(&result) { self.refresh_session().await; self.inner.send_xrpc(request).await } else { result } } }
self.inner
に send_xrpc()
を移譲し、その結果を判定して Expired
のエラーであった場合のみ self.refresh_session()
を呼び出してtokenを更新し、再度同じ send_xrpc()
を呼び出す、という形。
self.refresh_session()
が成功して内部のsessionが書き変わっていれば、再度同じリクエストを送ったときにはtokenが更新されているので成功する。
…と思ったが、実際にはこれは思った通りには動かない。
Rustのtrait実装はインスタンスのメソッドをオーバーライドしているわけではなく、あくまでtraitのメソッドを実装しているだけなので、 self.inner
に移譲したメソッドは self
とは別のインスタンスとして扱われる。
つまり self.inner.send_xrpc()
のデフォルト実装中で呼ばれる self.auth()
はあくまで self.inner
に実装されている auth()
であり、Wrapper
に実装されている auth()
は呼ばれない。なので、いくら Wrapper
の内部でsessionを更新しても self.inner
には関係がない、ということになる。
2重のwrapperで解決
これをどうやって解決するかしばらく悩んだが、結局wrapperをもう一つ作って2重にsessionを共有することで想定した動きをするようになった。
主な処理を移譲された inner
が同じsession情報を持って使ってくれていれば問題ないので、元のwrapperを RefreshWrapper
として、同じように XrpcClient
を実装する SessionWrapper
を作り、 auth()
で self.session
を参照する機能だけをそちらに持たせるようにする。
struct RefreshWrapper<T> where T: XrpcClient + Send + Sync, { inner: T, session: Arc<RwLock<Option<Session>>>, } #[async_trait] impl<T> HttpClient for RefreshWrapper<T> where T: XrpcClient + Send + Sync, { ... // (use inner) } #[async_trait] impl<T> XrpcClient for RefreshWrapper<T> where T: XrpcClient + Send + Sync, { async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E> where P: Serialize + Send + Sync, I: Serialize + Send + Sync, O: DeserializeOwned + Send + Sync, E: DeserializeOwned + Send + Sync, { let result = self.inner.send_xrpc(request).await; // handle session-refreshes as needed if Self::is_expired(&result) { self.refresh_session().await; self.inner.send_xrpc(request).await } else { result } } } struct SessionWrapper<T> where T: XrpcClient + Send + Sync, { inner: T, session: Arc<RwLock<Option<Session>>>, } #[async_trait] impl<T> HttpClient for SessionWrapper<T> where T: XrpcClient + Send + Sync, { ... // (use inner) } #[async_trait] impl<T> XrpcClient for SessionWrapper<T> where T: XrpcClient + Send + Sync, { ... async fn auth(&self, is_refresh: bool) -> Option<String> { self.session.read().await.as_ref().map(|session| { if is_refresh { session.refresh_jwt.clone() } else { session.access_jwt.clone() } }) } }
そして、 AtpAgent
では XrpcClient
を実装したものとして RefreshWrapper<SessionWrapper<T>>
を使うようにする。両者間で session
を共有するために Arc<RwLock<Option<Session>>>
を渡している。
struct AtpAgent<T> where T: XrpcClient + Send + Sync, { api: Service<RefreshWrapper<SessionWrapper<T>>>, session: Arc<RwLock<Option<Session>>>, } impl<T> AtpAgent<T> where T: XrpcClient + Send + Sync, { fn new(xrpc: T) -> Self { let session = Arc::new(RwLock::new(None)); let api = Service::new(Arc::new(RefreshWrapper { inner: SessionWrapper { inner: xrpc, session: Arc::clone(&session), }, session: Arc::clone(&session), })); Self { api, session } } }
send_xrpc()
内でのexpire検出と更新・再送の処理は RefreshWrapper
で行われるが、実際の送信は内部の SessionWrapper
に移譲されることになる。
SessionWrapper
では共有されている self.session
を auth()
内で参照する動作をするので、 RefreshWrapper
(や、それを使う AtpAgent
本体) で更新されたsession情報がそのまま使われることになる。
これでtokenの自動更新が実現された。
並行処理での同時更新の問題
もう一つ、TypeScript実装の AtpAgent
が持っている機能として、複数の refreshSession()
が同時に来ても1回しか実行されないように制御する、というものがある。
/** * Internal helper to refresh sessions * - Wraps the actual implementation in a promise-guard to ensure only * one refresh is attempted at a time. */ private async _refreshSession() { if (this._refreshSessionPromise) { return this._refreshSessionPromise } this._refreshSessionPromise = this._refreshSessionInner() try { await this._refreshSessionPromise } finally { this._refreshSessionPromise = undefined } }
内部状態として Promise<void> | undefined
を持っており、 _refreshSession()
が呼ばれた時点でそれが undefined
であれば Promise<void>
をセットした上で実際にその処理を await
し、既に Promise<void>
がセットされている場合はそれを返す、という動き。
もしAgentが並行して複数のAPIをほぼ同時に叩いたときに、tokenがexpiredだった場合はその結果がほぼ同時に返ってくることになる。その時点ではtokenがrefreshされていないので、それぞれがほぼ同時に自動refresh処理を行うことになるが、実際には1回だけrefreshされていれば良いので、そのように制御するための仕組み、と考えられる。
そもそも同じrefresh tokenを使って複数回refreshのリクエストすると、既に使用されたrefresh tokenが無効になり2回目以降のrefreshはエラーになる可能性もあり(ここはサーバ実装次第と思われる、現時点でのblueskyでは特に何も起こらず新しいtokenが発行されるだけ、のようだ)、非同期で並行処理され得る環境ではこういった制御は必要になる。
…しかしRustでは同じように実装しようとしても上手くいかなかった。 async
なinnerを呼んだ返り値を保持するとなると Pin<Box<impl Future<Output = Result<...>> + 'static>>
のような形になり、しかしこれを Mutex
とかで保持しようとしてもこれは Send
ではないので Mutex
に入れられない、など… 何度もコンパイルエラーに悩まされて挫折した。
Rustの非同期に詳しい人なら上手く実装できるんだろうか…
Notify
による制御実装
ChatGPTと長々と議論し、結局 tokio::sync
にある Mutex
と Notify
を使うことで同様の動作を実現させた。
use tokio::sync::{Mutex, Notify}; struct RefreshWrapper<T> where T: XrpcClient + Send + Sync, { inner: T, session: Arc<RwLock<Option<Session>>>, is_refreshing: Mutex<bool>, notify: Notify, } impl<T> RefreshWrapper<T> where T: XrpcClient + Send + Sync, { ... async fn refresh_session(&self) { { let mut is_refreshing = self.is_refreshing.lock().await; if *is_refreshing { drop(is_refreshing); return self.notify.notified().await; } *is_refreshing = true; } self.refresh_session_inner().await; *self.is_refreshing.lock().await = false; self.notify.notify_waiters(); } }
まず self.is_refreshing
で、refresh実行中のものがあるか否かを保持する。これは Mutex
で管理されるので、 最初にロックを取得して true
に変更したものが完了してまた false
に戻すまでは他の処理からの読み取り結果は必ず true
になる。
そして、その true
に変更できたものだけが、後続の実際の更新処理である self.refresh_session_inner()
を実行する。
self.is_refreshing
が true
になっている間は、 self.refresh_session_inner()
が終わったあとに呼ばれる self.notify.notify_waiters()
によって完了を知らされるまで待機するだけ、という形になる。
こういった他のスレッドからの通知を待つための仕組みとして tokio::sync::Notify
があるようだ。この場合は完了したことを知りたいだけなので Notify
だが、処理結果も知りたい場合は oneshot
などを使うと良いのかもしれない。
ライブラリとしては特定の非同期ランタイムに依存するようにはしたくないので tokio
などを使うのは避けたかったが、標準ライブラリや futures
などには同様の仕組みがないようだったので仕方なく tokio
を使うことにした。実際のところ sync
featureで使うこれらのものは特にランタイム依存は無いようで、 async-std
など別の非同期ランタイムで実行しても問題なく動作するようではあった。
まとめ
ライブラリの依存は増えてしまったが、どうにか AtpAgent
として実装したい機能は実現できた。Rustむずかしい…。
他にもっと良い方法をご存知の方がいればpull-requestなどいただけると嬉しいです。