diff options
Diffstat (limited to 'src/client.rs')
-rw-r--r-- | src/client.rs | 56 |
1 files changed, 32 insertions, 24 deletions
diff --git a/src/client.rs b/src/client.rs index b105311..53022ee 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,13 +1,11 @@ use log::{ info, error, warn, trace }; use rust_tdlib::Tdlib; use std::{ + marker::PhantomData, task::{ Waker, Context, Poll }, future::Future, pin::Pin, - sync::{ - Arc, - Mutex - }, + sync::{ Arc, Mutex }, thread, }; use crossbeam::channel::{ @@ -19,29 +17,36 @@ use uuid::Uuid; use std::collections::HashMap; use serde_json::Value as JsonValue; use crate::update::Handler; -use pert_types::types::Update; - +use pert_types::methods::Method; #[derive(Debug)] pub struct RequestData { - req: JsonValue, resp: Option<JsonValue>, waker: Option<Waker>, } +type RequestDataRef = Arc<Mutex<RequestData>>; + +#[derive(Debug)] +struct RequestDataToStream { + data: RequestDataRef, + req: JsonValue, +} + #[derive(Debug, Clone)] -pub struct RequestFuture { - data: Arc<Mutex<RequestData>> +pub struct RequestFuture<M: Method> { + _response_type_holder: PhantomData<M>, + pub data: RequestDataRef, } -impl Future for RequestFuture { - type Output = JsonValue; +impl<M: Method> Future for RequestFuture<M> { + type Output = Result<M::Response, serde_json::error::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { let mut data = self.data.lock().unwrap(); if let Some(resp) = &data.resp { - Poll::Ready(resp.clone()) + Poll::Ready(serde_json::from_value(resp.clone())) } else { data.waker = Some(cx.waker().clone()); Poll::Pending @@ -51,7 +56,7 @@ impl Future for RequestFuture { #[derive(Debug)] enum JoinStreams { - NewRequest(RequestFuture), + NewRequest(RequestDataToStream), NewResponse(String), } @@ -89,18 +94,23 @@ impl Client { } } - pub fn send(&self, req: &JsonValue) -> RequestFuture { + pub fn send<R: Method>(&self, req: R) -> Result<RequestFuture<R>, serde_json::error::Error> { let request = RequestData { - req: req.to_owned(), resp: None, waker: None }; let fut = RequestFuture { + _response_type_holder: PhantomData, data: Arc::new(Mutex::new(request)) }; - self.sender.send(JoinStreams::NewRequest(fut.clone())).unwrap(); - fut + self.sender.send(JoinStreams::NewRequest( + RequestDataToStream { + data: fut.data.clone(), + req: serde_json::to_value(req.tag())?, + } + )).unwrap(); + Ok(fut) } fn listen_tg(tx: Sender<JoinStreams>, api: Arc<Tdlib>, timeout: f64) { @@ -117,7 +127,7 @@ impl Client { #[derive(Debug)] struct OneshotResponder { api: Arc<Tdlib>, - wakers_map: HashMap<Uuid, RequestFuture>, + wakers_map: HashMap<Uuid, RequestDataRef>, rx: Receiver<JoinStreams>, } @@ -133,7 +143,7 @@ impl OneshotResponder { fn run<H: Handler>(&mut self, updater: H, client: Client, rt: tokio::runtime::Handle) { loop { match self.rx.recv() { - Ok(JoinStreams::NewRequest(fut)) => { + Ok(JoinStreams::NewRequest(req_data)) => { let id = loop { let id = Uuid::new_v4(); if self.wakers_map.contains_key(&id) { @@ -142,14 +152,13 @@ impl OneshotResponder { break id; } }; - let data = fut.data.clone(); - let request: &mut JsonValue = &mut data.lock().unwrap().req; + let mut request = req_data.req; if !request["@extra"].is_null() { warn!("overwriting @extra in request"); } request["@extra"] = id.to_hyphenated().to_string().into(); self.api.send(request.to_string().as_ref()); - self.wakers_map.insert(id, fut); + self.wakers_map.insert(id, req_data.data); trace!("new req:\n{:#}", request); }, Ok(JoinStreams::NewResponse(resp)) => { @@ -186,8 +195,7 @@ impl OneshotResponder { if let Ok(id) = Uuid::parse_str(id_str) { let fut_extracted = self.wakers_map .remove(&id) - .unwrap() - .data; + .unwrap(); let mut fut_data = fut_extracted.lock().unwrap(); fut_data.resp = Some(resp); |