1use alloc::collections::BTreeMap;
8use alloc::sync::{Arc, Weak};
9use core::error::Error;
10use core::fmt;
11use core::future::Future;
12use core::ops::ControlFlow;
13use core::pin::Pin;
14use core::task::{ready, Context, Poll};
15use std::io;
16use std::sync::Mutex;
17
18use futures::Stream;
19use tokio::sync::oneshot;
20
21use xmpp_parsers::{
22 iq::{Iq, IqType},
23 stanza_error::StanzaError,
24};
25
26use crate::{
27 event::make_id,
28 jid::Jid,
29 minidom::Element,
30 stanzastream::{StanzaState, StanzaToken},
31};
32
33pub enum IqRequest {
35 Get(Element),
37
38 Set(Element),
40}
41
42impl From<IqRequest> for IqType {
43 fn from(other: IqRequest) -> IqType {
44 match other {
45 IqRequest::Get(v) => Self::Get(v),
46 IqRequest::Set(v) => Self::Set(v),
47 }
48 }
49}
50
51pub enum IqResponse {
53 Result(Option<Element>),
55
56 Error(StanzaError),
58}
59
60impl From<IqResponse> for IqType {
61 fn from(other: IqResponse) -> IqType {
62 match other {
63 IqResponse::Result(v) => Self::Result(v),
64 IqResponse::Error(v) => Self::Error(v),
65 }
66 }
67}
68
69#[derive(Debug)]
71pub enum IqFailure {
72 LostWorker,
77
78 SendError(io::Error),
80}
81
82impl fmt::Display for IqFailure {
83 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84 match self {
85 Self::LostWorker => {
86 f.write_str("disconnected from internal connection worker while sending IQ")
87 }
88 Self::SendError(e) => write!(f, "send error: {e}"),
89 }
90 }
91}
92
93impl Error for IqFailure {
94 fn source(&self) -> Option<&(dyn Error + 'static)> {
95 match self {
96 Self::SendError(ref e) => Some(e),
97 Self::LostWorker => None,
98 }
99 }
100}
101
102type IqKey = (Option<Jid>, String);
103type IqMap = BTreeMap<IqKey, IqResponseSink>;
104
105struct IqMapEntryHandle {
106 key: IqKey,
107 map: Weak<Mutex<IqMap>>,
108}
109
110impl Drop for IqMapEntryHandle {
111 fn drop(&mut self) {
112 let Some(map) = self.map.upgrade() else {
113 return;
114 };
115 let Some(mut map) = map.lock().ok() else {
116 return;
117 };
118 map.remove(&self.key);
119 }
120}
121
122pin_project_lite::pin_project! {
123 pub struct IqResponseToken {
138 entry: Option<IqMapEntryHandle>,
139 #[pin]
140 stanza_token: Option<tokio_stream::wrappers::WatchStream<StanzaState>>,
141 #[pin]
142 inner: oneshot::Receiver<Result<IqResponse, IqFailure>>,
143 }
144}
145
146impl IqResponseToken {
147 pub(crate) fn set_stanza_token(&mut self, token: StanzaToken) {
155 assert!(self.stanza_token.is_none());
156 self.stanza_token = Some(token.into_stream());
157 }
158}
159
160impl Future for IqResponseToken {
161 type Output = Result<IqResponse, IqFailure>;
162
163 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164 let mut this = self.project();
165 match this.inner.poll(cx) {
166 Poll::Ready(Ok(v)) => {
167 this.entry.take();
169 return Poll::Ready(v);
170 }
171 Poll::Ready(Err(_)) => {
172 log::warn!("IqResponseToken oneshot::Receiver returned receive error!");
173 this.entry.take();
175 return Poll::Ready(Err(IqFailure::LostWorker));
176 }
177 Poll::Pending => (),
178 };
179
180 loop {
181 match this.stanza_token.as_mut().as_pin_mut() {
182 Some(stream) => match ready!(stream.poll_next(cx)) {
184 Some(StanzaState::Queued) => (),
186
187 Some(StanzaState::Dropped) | None => {
188 log::warn!("StanzaToken associated with IqResponseToken signalled that the Stanza was dropped before transmission.");
189 this.entry.take();
191 return Poll::Ready(Err(IqFailure::LostWorker));
193 }
194
195 Some(StanzaState::Failed { error }) => {
196 this.entry.take();
198 return Poll::Ready(Err(IqFailure::SendError(error.into_io_error())));
200 }
201
202 Some(StanzaState::Sent { .. }) | Some(StanzaState::Acked { .. }) => {
203 *this.stanza_token = None;
208 return Poll::Pending;
209 }
210 },
211
212 None => return Poll::Pending,
215 }
216 }
217 }
218}
219
220struct IqResponseSink {
221 inner: oneshot::Sender<Result<IqResponse, IqFailure>>,
222}
223
224impl IqResponseSink {
225 fn complete(self, resp: IqResponse) {
226 let _: Result<_, _> = self.inner.send(Ok(resp));
227 }
228}
229
230pub struct IqResponseTracker {
232 map: Arc<Mutex<IqMap>>,
233}
234
235impl IqResponseTracker {
236 pub fn new() -> Self {
238 Self {
239 map: Arc::new(Mutex::new(IqMap::new())),
240 }
241 }
242
243 pub fn handle_iq(&self, iq: Iq) -> ControlFlow<(), Iq> {
248 let payload = match iq.payload {
249 IqType::Error(error) => IqResponse::Error(error),
250 IqType::Result(result) => IqResponse::Result(result),
251 _ => return ControlFlow::Continue(iq),
252 };
253 let key = (iq.from, iq.id);
254 let mut map = self.map.lock().unwrap();
255 match map.remove(&key) {
256 None => {
257 log::trace!("not handling IQ response from {:?} with id {:?}: no active tracker for this tuple", key.0, key.1);
258 ControlFlow::Continue(Iq {
259 from: key.0,
260 id: key.1,
261 to: iq.to,
262 payload: payload.into(),
263 })
264 }
265 Some(sink) => {
266 sink.complete(payload);
267 ControlFlow::Break(())
268 }
269 }
270 }
271
272 pub fn allocate_iq_handle(
276 &self,
277 from: Option<Jid>,
278 to: Option<Jid>,
279 req: IqRequest,
280 ) -> (Iq, IqResponseToken) {
281 let key = (to, make_id());
282 let mut map = self.map.lock().unwrap();
283 let (tx, rx) = oneshot::channel();
284 let sink = IqResponseSink { inner: tx };
285 assert!(map.get(&key).is_none());
286 let token = IqResponseToken {
287 entry: Some(IqMapEntryHandle {
288 key: key.clone(),
289 map: Arc::downgrade(&self.map),
290 }),
291 stanza_token: None,
292 inner: rx,
293 };
294 map.insert(key.clone(), sink);
295 (
296 Iq {
297 from,
298 to: key.0,
299 id: key.1,
300 payload: req.into(),
301 },
302 token,
303 )
304 }
305}