1use alloc::collections::BTreeMap;
8use alloc::sync::{Arc, Weak};
9use core::error::Error;
10use core::fmt;
11use core::future::Future;
12use core::ops::ControlFlow;
13use core::pin::Pin;
14use core::task::{ready, Context, Poll};
15use std::io;
16use std::sync::Mutex;
17
18use futures::Stream;
19use tokio::sync::oneshot;
20
21use xmpp_parsers::{iq::Iq, stanza_error::StanzaError};
22
23use crate::{
24 event::make_id,
25 jid::Jid,
26 minidom::Element,
27 stanzastream::{StanzaState, StanzaToken},
28};
29
30#[derive(Debug)]
32pub enum IqRequest {
33 Get(Element),
35
36 Set(Element),
38}
39
40impl IqRequest {
41 fn into_iq(self, from: Option<Jid>, to: Option<Jid>, id: String) -> Iq {
42 match self {
43 Self::Get(payload) => Iq::Get {
44 from,
45 to,
46 id,
47 payload,
48 },
49 Self::Set(payload) => Iq::Set {
50 from,
51 to,
52 id,
53 payload,
54 },
55 }
56 }
57}
58
59#[derive(Debug)]
61pub enum IqResponse {
62 Result(Option<Element>),
64
65 Error(StanzaError),
67}
68
69impl IqResponse {
70 fn into_iq(self, from: Option<Jid>, to: Option<Jid>, id: String) -> Iq {
71 match self {
72 Self::Error(error) => Iq::Error {
73 from,
74 to,
75 id,
76 error,
77 payload: None,
78 },
79 Self::Result(payload) => Iq::Result {
80 from,
81 to,
82 id,
83 payload,
84 },
85 }
86 }
87}
88
89#[derive(Debug)]
91pub enum IqFailure {
92 LostWorker,
97
98 SendError(io::Error),
100}
101
102impl fmt::Display for IqFailure {
103 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104 match self {
105 Self::LostWorker => {
106 f.write_str("disconnected from internal connection worker while sending IQ")
107 }
108 Self::SendError(e) => write!(f, "send error: {e}"),
109 }
110 }
111}
112
113impl Error for IqFailure {
114 fn source(&self) -> Option<&(dyn Error + 'static)> {
115 match self {
116 Self::SendError(ref e) => Some(e),
117 Self::LostWorker => None,
118 }
119 }
120}
121
122type IqKey = (Option<Jid>, String);
123type IqMap = BTreeMap<IqKey, IqResponseSink>;
124
125#[derive(Debug)]
126struct IqMapEntryHandle {
127 key: IqKey,
128 map: Weak<Mutex<IqMap>>,
129}
130
131impl Drop for IqMapEntryHandle {
132 fn drop(&mut self) {
133 let Some(map) = self.map.upgrade() else {
134 return;
135 };
136 let Some(mut map) = map.lock().ok() else {
137 return;
138 };
139 map.remove(&self.key);
140 }
141}
142
143pin_project_lite::pin_project! {
144 #[derive(Debug)]
159 pub struct IqResponseToken {
160 entry: Option<IqMapEntryHandle>,
161 #[pin]
162 stanza_token: Option<tokio_stream::wrappers::WatchStream<StanzaState>>,
163 #[pin]
164 inner: oneshot::Receiver<Result<IqResponse, IqFailure>>,
165 }
166}
167
168impl IqResponseToken {
169 pub(crate) fn set_stanza_token(&mut self, token: StanzaToken) {
177 assert!(self.stanza_token.is_none());
178 self.stanza_token = Some(token.into_stream());
179 }
180}
181
182impl Future for IqResponseToken {
183 type Output = Result<IqResponse, IqFailure>;
184
185 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186 let mut this = self.project();
187 match this.inner.poll(cx) {
188 Poll::Ready(Ok(v)) => {
189 this.entry.take();
191 return Poll::Ready(v);
192 }
193 Poll::Ready(Err(_)) => {
194 log::warn!("IqResponseToken oneshot::Receiver returned receive error!");
195 this.entry.take();
197 return Poll::Ready(Err(IqFailure::LostWorker));
198 }
199 Poll::Pending => (),
200 };
201
202 loop {
203 match this.stanza_token.as_mut().as_pin_mut() {
204 Some(stream) => match ready!(stream.poll_next(cx)) {
206 Some(StanzaState::Queued) => (),
208
209 Some(StanzaState::Dropped) | None => {
210 log::warn!("StanzaToken associated with IqResponseToken signalled that the Stanza was dropped before transmission.");
211 this.entry.take();
213 return Poll::Ready(Err(IqFailure::LostWorker));
215 }
216
217 Some(StanzaState::Failed { error }) => {
218 this.entry.take();
220 return Poll::Ready(Err(IqFailure::SendError(error.into_io_error())));
222 }
223
224 Some(StanzaState::Sent { .. }) | Some(StanzaState::Acked { .. }) => {
225 *this.stanza_token = None;
230 return Poll::Pending;
231 }
232 },
233
234 None => return Poll::Pending,
237 }
238 }
239 }
240}
241
242#[derive(Debug)]
243struct IqResponseSink {
244 inner: oneshot::Sender<Result<IqResponse, IqFailure>>,
245}
246
247impl IqResponseSink {
248 fn complete(self, resp: IqResponse) {
249 let _: Result<_, _> = self.inner.send(Ok(resp));
250 }
251}
252
253#[derive(Debug)]
255pub struct IqResponseTracker {
256 map: Arc<Mutex<IqMap>>,
257}
258
259impl IqResponseTracker {
260 pub fn new() -> Self {
262 Self {
263 map: Arc::new(Mutex::new(IqMap::new())),
264 }
265 }
266
267 pub fn handle_iq(&self, iq: Iq) -> ControlFlow<(), Iq> {
272 let (from, to, id, payload) = match iq {
273 Iq::Error {
274 from,
275 to,
276 id,
277 error,
278 payload: _,
279 } => (from, to, id, IqResponse::Error(error)),
280 Iq::Result {
281 from,
282 to,
283 id,
284 payload,
285 } => (from, to, id, IqResponse::Result(payload)),
286 _ => return ControlFlow::Continue(iq),
287 };
288 let key = (from, id);
289 let mut map = self.map.lock().unwrap();
290 match map.remove(&key) {
291 None => {
292 log::trace!("not handling IQ response from {:?} with id {:?}: no active tracker for this tuple", key.0, key.1);
293 ControlFlow::Continue(payload.into_iq(key.0, to, key.1))
294 }
295 Some(sink) => {
296 sink.complete(payload);
297 ControlFlow::Break(())
298 }
299 }
300 }
301
302 pub fn allocate_iq_handle(
306 &self,
307 from: Option<Jid>,
308 to: Option<Jid>,
309 req: IqRequest,
310 ) -> (Iq, IqResponseToken) {
311 let key = (to, make_id());
312 let mut map = self.map.lock().unwrap();
313 let (tx, rx) = oneshot::channel();
314 let sink = IqResponseSink { inner: tx };
315 assert!(map.get(&key).is_none());
316 let token = IqResponseToken {
317 entry: Some(IqMapEntryHandle {
318 key: key.clone(),
319 map: Arc::downgrade(&self.map),
320 }),
321 stanza_token: None,
322 inner: rx,
323 };
324 map.insert(key.clone(), sink);
325 (req.into_iq(from, key.0, key.1), token)
326 }
327}