1use alloc::collections::BTreeMap;
8use alloc::sync::{Arc, Weak};
9use core::error::Error;
10use core::fmt;
11use core::future::Future;
12use core::ops::ControlFlow;
13use core::pin::Pin;
14use core::task::{ready, Context, Poll};
15use std::io;
16use std::sync::Mutex;
17use xmpp_parsers::jid::BareJid;
18
19use futures::Stream;
20use tokio::sync::oneshot;
21
22use xmpp_parsers::{iq::Iq, stanza_error::StanzaError};
23
24use crate::{
25 event::make_id,
26 jid::Jid,
27 minidom::Element,
28 stanzastream::{StanzaState, StanzaToken},
29};
30
31#[derive(Debug)]
33pub enum IqRequest {
34 Get(Element),
36
37 Set(Element),
39}
40
41impl IqRequest {
42 fn into_iq(self, from: Option<Jid>, to: Option<Jid>, id: String) -> Iq {
43 match self {
44 Self::Get(payload) => Iq::Get {
45 from,
46 to,
47 id,
48 payload,
49 },
50 Self::Set(payload) => Iq::Set {
51 from,
52 to,
53 id,
54 payload,
55 },
56 }
57 }
58}
59
60#[derive(Debug)]
62pub enum IqResponse {
63 Result(Option<Element>),
65
66 Error(StanzaError),
68}
69
70impl IqResponse {
71 fn into_iq(self, from: Option<Jid>, to: Option<Jid>, id: String) -> Iq {
72 match self {
73 Self::Error(error) => Iq::Error {
74 from,
75 to,
76 id,
77 error,
78 payload: None,
79 },
80 Self::Result(payload) => Iq::Result {
81 from,
82 to,
83 id,
84 payload,
85 },
86 }
87 }
88}
89
90#[derive(Debug)]
92pub enum IqFailure {
93 LostWorker,
98
99 SendError(io::Error),
101}
102
103impl fmt::Display for IqFailure {
104 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105 match self {
106 Self::LostWorker => {
107 f.write_str("disconnected from internal connection worker while sending IQ")
108 }
109 Self::SendError(e) => write!(f, "send error: {e}"),
110 }
111 }
112}
113
114impl Error for IqFailure {
115 fn source(&self) -> Option<&(dyn Error + 'static)> {
116 match self {
117 Self::SendError(ref e) => Some(e),
118 Self::LostWorker => None,
119 }
120 }
121}
122
123type IqKey = (Option<Jid>, String);
124type IqMap = BTreeMap<IqKey, IqResponseSink>;
125
126#[derive(Debug)]
127struct IqMapEntryHandle {
128 key: IqKey,
129 map: Weak<Mutex<IqMap>>,
130}
131
132impl Drop for IqMapEntryHandle {
133 fn drop(&mut self) {
134 let Some(map) = self.map.upgrade() else {
135 return;
136 };
137 let Some(mut map) = map.lock().ok() else {
138 return;
139 };
140 map.remove(&self.key);
141 }
142}
143
144pin_project_lite::pin_project! {
145 #[derive(Debug)]
160 pub struct IqResponseToken {
161 entry: Option<IqMapEntryHandle>,
162 #[pin]
163 stanza_token: Option<tokio_stream::wrappers::WatchStream<StanzaState>>,
164 #[pin]
165 inner: oneshot::Receiver<Result<IqResponse, IqFailure>>,
166 }
167}
168
169impl IqResponseToken {
170 pub(crate) fn set_stanza_token(&mut self, token: StanzaToken) {
178 assert!(self.stanza_token.is_none());
179 self.stanza_token = Some(token.into_stream());
180 }
181}
182
183impl Future for IqResponseToken {
184 type Output = Result<IqResponse, IqFailure>;
185
186 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
187 let mut this = self.project();
188 match this.inner.poll(cx) {
189 Poll::Ready(Ok(v)) => {
190 this.entry.take();
192 return Poll::Ready(v);
193 }
194 Poll::Ready(Err(_)) => {
195 log::warn!("IqResponseToken oneshot::Receiver returned receive error!");
196 this.entry.take();
198 return Poll::Ready(Err(IqFailure::LostWorker));
199 }
200 Poll::Pending => (),
201 };
202
203 loop {
204 match this.stanza_token.as_mut().as_pin_mut() {
205 Some(stream) => match ready!(stream.poll_next(cx)) {
207 Some(StanzaState::Queued) => (),
209
210 Some(StanzaState::Dropped) | None => {
211 log::warn!("StanzaToken associated with IqResponseToken signalled that the Stanza was dropped before transmission.");
212 this.entry.take();
214 return Poll::Ready(Err(IqFailure::LostWorker));
216 }
217
218 Some(StanzaState::Failed { error }) => {
219 this.entry.take();
221 return Poll::Ready(Err(IqFailure::SendError(error.into_io_error())));
223 }
224
225 Some(StanzaState::Sent { .. }) | Some(StanzaState::Acked { .. }) => {
226 *this.stanza_token = None;
231 return Poll::Pending;
232 }
233 },
234
235 None => return Poll::Pending,
238 }
239 }
240 }
241}
242
243#[derive(Debug)]
244struct IqResponseSink {
245 inner: oneshot::Sender<Result<IqResponse, IqFailure>>,
246}
247
248impl IqResponseSink {
249 fn complete(self, resp: IqResponse) {
250 let _: Result<_, _> = self.inner.send(Ok(resp));
251 }
252}
253
254#[derive(Clone, Debug)]
256pub struct IqResponseTracker {
257 map: Arc<Mutex<IqMap>>,
258 account_jid: Arc<Mutex<Option<BareJid>>>,
259}
260
261impl IqResponseTracker {
262 pub fn new() -> Self {
264 Self {
265 map: Arc::new(Mutex::new(IqMap::new())),
266 account_jid: Arc::new(Mutex::new(None)),
267 }
268 }
269
270 pub fn set_account_jid(&self, jid: BareJid) {
272 let mut guard = self.account_jid.lock().unwrap();
273 *guard = Some(jid);
274 }
275
276 pub fn handle_iq(&self, iq: Iq) -> ControlFlow<(), Iq> {
281 let (from, to, id, payload) = match iq {
282 Iq::Error {
283 from,
284 to,
285 id,
286 error,
287 payload: _,
288 } => (from, to, id, IqResponse::Error(error)),
289 Iq::Result {
290 from,
291 to,
292 id,
293 payload,
294 } => (from, to, id, IqResponse::Result(payload)),
295 _ => return ControlFlow::Continue(iq),
296 };
297 let key = (from, id);
298 let mut map = self.map.lock().unwrap();
299 match map.remove(&key) {
300 None => {
301 log::debug!("not handling IQ response from {:?} with id {:?}: no active tracker for this tuple", key.0, key.1);
302 log::trace!("active trackers: {map:?}");
303 ControlFlow::Continue(payload.into_iq(key.0, to, key.1))
304 }
305 Some(sink) => {
306 log::trace!("completing IQ {:?}", key.0);
307 sink.complete(payload);
308 ControlFlow::Break(())
309 }
310 }
311 }
312
313 pub fn allocate_iq_handle(
317 &self,
318 from: Option<Jid>,
319 mut to: Option<Jid>,
320 req: IqRequest,
321 ) -> (Iq, IqResponseToken) {
322 if to.is_none() {
323 let account_jid = self.account_jid.lock().unwrap();
326 to = account_jid.clone().map(Jid::from);
327 }
328
329 let key = (to, make_id());
330 let mut map = self.map.lock().unwrap();
331 let (tx, rx) = oneshot::channel();
332 let sink = IqResponseSink { inner: tx };
333 assert!(map.get(&key).is_none());
334 let token = IqResponseToken {
335 entry: Some(IqMapEntryHandle {
336 key: key.clone(),
337 map: Arc::downgrade(&self.map),
338 }),
339 stanza_token: None,
340 inner: rx,
341 };
342 map.insert(key.clone(), sink);
343 (req.into_iq(from, key.0, key.1), token)
344 }
345}