diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 56 | ||||
-rw-r--r-- | src/main.rs | 42 |
2 files changed, 69 insertions, 29 deletions
diff --git a/src/client.rs b/src/client.rs index b105311..53022ee 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,13 +1,11 @@ use log::{ info, error, warn, trace }; use rust_tdlib::Tdlib; use std::{ + marker::PhantomData, task::{ Waker, Context, Poll }, future::Future, pin::Pin, - sync::{ - Arc, - Mutex - }, + sync::{ Arc, Mutex }, thread, }; use crossbeam::channel::{ @@ -19,29 +17,36 @@ use uuid::Uuid; use std::collections::HashMap; use serde_json::Value as JsonValue; use crate::update::Handler; -use pert_types::types::Update; - +use pert_types::methods::Method; #[derive(Debug)] pub struct RequestData { - req: JsonValue, resp: Option<JsonValue>, waker: Option<Waker>, } +type RequestDataRef = Arc<Mutex<RequestData>>; + +#[derive(Debug)] +struct RequestDataToStream { + data: RequestDataRef, + req: JsonValue, +} + #[derive(Debug, Clone)] -pub struct RequestFuture { - data: Arc<Mutex<RequestData>> +pub struct RequestFuture<M: Method> { + _response_type_holder: PhantomData<M>, + pub data: RequestDataRef, } -impl Future for RequestFuture { - type Output = JsonValue; +impl<M: Method> Future for RequestFuture<M> { + type Output = Result<M::Response, serde_json::error::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { let mut data = self.data.lock().unwrap(); if let Some(resp) = &data.resp { - Poll::Ready(resp.clone()) + Poll::Ready(serde_json::from_value(resp.clone())) } else { data.waker = Some(cx.waker().clone()); Poll::Pending @@ -51,7 +56,7 @@ impl Future for RequestFuture { #[derive(Debug)] enum JoinStreams { - NewRequest(RequestFuture), + NewRequest(RequestDataToStream), NewResponse(String), } @@ -89,18 +94,23 @@ impl Client { } } - pub fn send(&self, req: &JsonValue) -> RequestFuture { + pub fn send<R: Method>(&self, req: R) -> Result<RequestFuture<R>, serde_json::error::Error> { let request = RequestData { - req: req.to_owned(), resp: None, waker: None }; let fut = RequestFuture { + _response_type_holder: PhantomData, data: Arc::new(Mutex::new(request)) }; - self.sender.send(JoinStreams::NewRequest(fut.clone())).unwrap(); - fut + self.sender.send(JoinStreams::NewRequest( + RequestDataToStream { + data: fut.data.clone(), + req: serde_json::to_value(req.tag())?, + } + )).unwrap(); + Ok(fut) } fn listen_tg(tx: Sender<JoinStreams>, api: Arc<Tdlib>, timeout: f64) { @@ -117,7 +127,7 @@ impl Client { #[derive(Debug)] struct OneshotResponder { api: Arc<Tdlib>, - wakers_map: HashMap<Uuid, RequestFuture>, + wakers_map: HashMap<Uuid, RequestDataRef>, rx: Receiver<JoinStreams>, } @@ -133,7 +143,7 @@ impl OneshotResponder { fn run<H: Handler>(&mut self, updater: H, client: Client, rt: tokio::runtime::Handle) { loop { match self.rx.recv() { - Ok(JoinStreams::NewRequest(fut)) => { + Ok(JoinStreams::NewRequest(req_data)) => { let id = loop { let id = Uuid::new_v4(); if self.wakers_map.contains_key(&id) { @@ -142,14 +152,13 @@ impl OneshotResponder { break id; } }; - let data = fut.data.clone(); - let request: &mut JsonValue = &mut data.lock().unwrap().req; + let mut request = req_data.req; if !request["@extra"].is_null() { warn!("overwriting @extra in request"); } request["@extra"] = id.to_hyphenated().to_string().into(); self.api.send(request.to_string().as_ref()); - self.wakers_map.insert(id, fut); + self.wakers_map.insert(id, req_data.data); trace!("new req:\n{:#}", request); }, Ok(JoinStreams::NewResponse(resp)) => { @@ -186,8 +195,7 @@ impl OneshotResponder { if let Ok(id) = Uuid::parse_str(id_str) { let fut_extracted = self.wakers_map .remove(&id) - .unwrap() - .data; + .unwrap(); let mut fut_data = fut_extracted.lock().unwrap(); fut_data.resp = Some(resp); diff --git a/src/main.rs b/src/main.rs index 69f5ecc..49e4397 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use std::env; use tokio; -use log::{ info, error }; +use log::{ info }; mod client; //mod auth; @@ -9,10 +9,44 @@ mod update; struct UpdateHandler; +fn make_lib_params() -> pert_types::types::TdlibParameters { + let cache = env::current_dir().unwrap().join("cache"); + let make_path = |p: &str| cache.join(p).to_str().map(|p| p.to_owned()).unwrap(); + pert_types::types::TdlibParameters { + use_test_dc: true, + database_directory: make_path("database"), + files_directory: make_path("files"), + use_file_database: true, + use_chat_info_database: true, + use_message_database: true, + use_secret_chats: false, + api_id: env::var("API_ID").unwrap().parse().unwrap(), + api_hash: env::var("API_HASH").unwrap(), + system_language_code: "en".to_owned(), + device_model: "mbia v1".to_owned(), + system_version: "15".to_owned(), + application_version: "0.1".to_owned(), + enable_storage_optimizer: false, + ignore_file_names: true, + } +} + impl update::Handler for UpdateHandler { - fn handle(&self, _client: client::Client, req: pert_types::types::Update) -> futures::future::BoxFuture<'static, ()> { + fn handle(&self, client: client::Client, req: pert_types::types::Update) -> futures::future::BoxFuture<'static, ()> { Box::pin(async move { - info!("update: {:#?}", req); + match req { + pert_types::types::Update::UpdateAuthorizationState(state) => { + match state.authorization_state { + pert_types::types::AuthorizationState::AuthorizationStateWaitTdlibParameters(_) => { + client.send(pert_types::methods::SetTdlibParameters { + parameters: make_lib_params(), + }).unwrap().await.unwrap(); + } + _ => info!("auth state unknown: {:#?}", state) + } + } + _ => info!("unknown update: {:#?}", req) + } }) } } @@ -27,7 +61,5 @@ async fn main() { let tg = client::Client::new(tg_log, UpdateHandler{}); - - std::thread::sleep(std::time::Duration::new(200, 0)); } |