1use alloc::collections::BTreeMap;
8use alloc::sync::{Arc, Weak};
9use core::error::Error;
10use core::fmt;
11use core::future::Future;
12use core::ops::ControlFlow;
13use core::pin::Pin;
14use core::task::{ready, Context, Poll};
15use std::io;
16use std::sync::Mutex;
17
18use futures::Stream;
19use tokio::sync::oneshot;
20
21use xmpp_parsers::{
22 iq::{Iq, IqType},
23 stanza_error::StanzaError,
24};
25
26use crate::{
27 event::make_id,
28 jid::Jid,
29 minidom::Element,
30 stanzastream::{StanzaState, StanzaToken},
31};
32
33#[derive(Debug)]
35pub enum IqRequest {
36 Get(Element),
38
39 Set(Element),
41}
42
43impl From<IqRequest> for IqType {
44 fn from(other: IqRequest) -> IqType {
45 match other {
46 IqRequest::Get(v) => Self::Get(v),
47 IqRequest::Set(v) => Self::Set(v),
48 }
49 }
50}
51
52#[derive(Debug)]
54pub enum IqResponse {
55 Result(Option<Element>),
57
58 Error(StanzaError),
60}
61
62impl From<IqResponse> for IqType {
63 fn from(other: IqResponse) -> IqType {
64 match other {
65 IqResponse::Result(v) => Self::Result(v),
66 IqResponse::Error(v) => Self::Error(v),
67 }
68 }
69}
70
71#[derive(Debug)]
73pub enum IqFailure {
74 LostWorker,
79
80 SendError(io::Error),
82}
83
84impl fmt::Display for IqFailure {
85 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
86 match self {
87 Self::LostWorker => {
88 f.write_str("disconnected from internal connection worker while sending IQ")
89 }
90 Self::SendError(e) => write!(f, "send error: {e}"),
91 }
92 }
93}
94
95impl Error for IqFailure {
96 fn source(&self) -> Option<&(dyn Error + 'static)> {
97 match self {
98 Self::SendError(ref e) => Some(e),
99 Self::LostWorker => None,
100 }
101 }
102}
103
104type IqKey = (Option<Jid>, String);
105type IqMap = BTreeMap<IqKey, IqResponseSink>;
106
107#[derive(Debug)]
108struct IqMapEntryHandle {
109 key: IqKey,
110 map: Weak<Mutex<IqMap>>,
111}
112
113impl Drop for IqMapEntryHandle {
114 fn drop(&mut self) {
115 let Some(map) = self.map.upgrade() else {
116 return;
117 };
118 let Some(mut map) = map.lock().ok() else {
119 return;
120 };
121 map.remove(&self.key);
122 }
123}
124
125pin_project_lite::pin_project! {
126 #[derive(Debug)]
141 pub struct IqResponseToken {
142 entry: Option<IqMapEntryHandle>,
143 #[pin]
144 stanza_token: Option<tokio_stream::wrappers::WatchStream<StanzaState>>,
145 #[pin]
146 inner: oneshot::Receiver<Result<IqResponse, IqFailure>>,
147 }
148}
149
150impl IqResponseToken {
151 pub(crate) fn set_stanza_token(&mut self, token: StanzaToken) {
159 assert!(self.stanza_token.is_none());
160 self.stanza_token = Some(token.into_stream());
161 }
162}
163
164impl Future for IqResponseToken {
165 type Output = Result<IqResponse, IqFailure>;
166
167 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
168 let mut this = self.project();
169 match this.inner.poll(cx) {
170 Poll::Ready(Ok(v)) => {
171 this.entry.take();
173 return Poll::Ready(v);
174 }
175 Poll::Ready(Err(_)) => {
176 log::warn!("IqResponseToken oneshot::Receiver returned receive error!");
177 this.entry.take();
179 return Poll::Ready(Err(IqFailure::LostWorker));
180 }
181 Poll::Pending => (),
182 };
183
184 loop {
185 match this.stanza_token.as_mut().as_pin_mut() {
186 Some(stream) => match ready!(stream.poll_next(cx)) {
188 Some(StanzaState::Queued) => (),
190
191 Some(StanzaState::Dropped) | None => {
192 log::warn!("StanzaToken associated with IqResponseToken signalled that the Stanza was dropped before transmission.");
193 this.entry.take();
195 return Poll::Ready(Err(IqFailure::LostWorker));
197 }
198
199 Some(StanzaState::Failed { error }) => {
200 this.entry.take();
202 return Poll::Ready(Err(IqFailure::SendError(error.into_io_error())));
204 }
205
206 Some(StanzaState::Sent { .. }) | Some(StanzaState::Acked { .. }) => {
207 *this.stanza_token = None;
212 return Poll::Pending;
213 }
214 },
215
216 None => return Poll::Pending,
219 }
220 }
221 }
222}
223
224#[derive(Debug)]
225struct IqResponseSink {
226 inner: oneshot::Sender<Result<IqResponse, IqFailure>>,
227}
228
229impl IqResponseSink {
230 fn complete(self, resp: IqResponse) {
231 let _: Result<_, _> = self.inner.send(Ok(resp));
232 }
233}
234
235#[derive(Debug)]
237pub struct IqResponseTracker {
238 map: Arc<Mutex<IqMap>>,
239}
240
241impl IqResponseTracker {
242 pub fn new() -> Self {
244 Self {
245 map: Arc::new(Mutex::new(IqMap::new())),
246 }
247 }
248
249 pub fn handle_iq(&self, iq: Iq) -> ControlFlow<(), Iq> {
254 let payload = match iq.payload {
255 IqType::Error(error) => IqResponse::Error(error),
256 IqType::Result(result) => IqResponse::Result(result),
257 _ => return ControlFlow::Continue(iq),
258 };
259 let key = (iq.from, iq.id);
260 let mut map = self.map.lock().unwrap();
261 match map.remove(&key) {
262 None => {
263 log::trace!("not handling IQ response from {:?} with id {:?}: no active tracker for this tuple", key.0, key.1);
264 ControlFlow::Continue(Iq {
265 from: key.0,
266 id: key.1,
267 to: iq.to,
268 payload: payload.into(),
269 })
270 }
271 Some(sink) => {
272 sink.complete(payload);
273 ControlFlow::Break(())
274 }
275 }
276 }
277
278 pub fn allocate_iq_handle(
282 &self,
283 from: Option<Jid>,
284 to: Option<Jid>,
285 req: IqRequest,
286 ) -> (Iq, IqResponseToken) {
287 let key = (to, make_id());
288 let mut map = self.map.lock().unwrap();
289 let (tx, rx) = oneshot::channel();
290 let sink = IqResponseSink { inner: tx };
291 assert!(map.get(&key).is_none());
292 let token = IqResponseToken {
293 entry: Some(IqMapEntryHandle {
294 key: key.clone(),
295 map: Arc::downgrade(&self.map),
296 }),
297 stanza_token: None,
298 inner: rx,
299 };
300 map.insert(key.clone(), sink);
301 (
302 Iq {
303 from,
304 to: key.0,
305 id: key.1,
306 payload: req.into(),
307 },
308 token,
309 )
310 }
311}