tokio_xmpp/stanzastream/
worker.rs

1// Copyright (c) 2019 Emmanuel Gil Peyrot <linkmauve@linkmauve.fr>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10use core::time::Duration;
11use std::io;
12
13use rand::{thread_rng, Rng};
14
15use futures::{ready, SinkExt, StreamExt};
16
17use tokio::{
18    sync::{mpsc, oneshot},
19    time::Instant,
20};
21
22use xmpp_parsers::{
23    iq,
24    jid::Jid,
25    ping,
26    stream_error::{DefinedCondition, StreamError},
27    stream_features::StreamFeatures,
28};
29
30use crate::connect::AsyncReadAndWrite;
31use crate::xmlstream::{ReadError, XmppStreamElement};
32use crate::Stanza;
33
34use super::connected::{ConnectedEvent, ConnectedState};
35use super::negotiation::NegotiationState;
36use super::queue::{QueueEntry, TransmitQueue};
37use super::stream_management::SmState;
38use super::{Event, StreamEvent};
39
40/// Convenience alias for [`XmlStreams`][`crate::xmlstream::XmlStream`] which
41/// may be used with [`StanzaStream`][`super::StanzaStream`].
42pub type XmppStream =
43    crate::xmlstream::XmlStream<Box<dyn AsyncReadAndWrite + Send + 'static>, XmppStreamElement>;
44
45/// Underlying connection for a [`StanzaStream`][`super::StanzaStream`].
46pub struct Connection {
47    /// The stream to use to send and receive XMPP data.
48    pub stream: XmppStream,
49
50    /// The stream features offered by the peer.
51    pub features: StreamFeatures,
52
53    /// The identity to which this stream belongs.
54    ///
55    /// Note that connectors must not return bound streams. However, the Jid
56    /// may still be a full jid in order to request a specific resource at
57    /// bind time. If `identity` is a bare JID, the peer will assign the
58    /// resource.
59    pub identity: Jid,
60}
61
62// Allow for up to 10s for local shutdown.
63// TODO: make this configurable maybe?
64pub(super) static LOCAL_SHUTDOWN_TIMEOUT: Duration = Duration::new(10, 0);
65pub(super) static REMOTE_SHUTDOWN_TIMEOUT: Duration = Duration::new(5, 0);
66pub(super) static PING_PROBE_ID_PREFIX: &str = "xmpp-rs-stanzastream-liveness-probe";
67
68pub(super) enum Never {}
69
70pub(super) enum WorkerEvent {
71    /// The stream was reset and can now be used for rx/tx.
72    Reset {
73        bound_jid: Jid,
74        features: StreamFeatures,
75    },
76
77    /// The stream has been resumed successfully.
78    Resumed,
79
80    /// Data received successfully.
81    Stanza(Stanza),
82
83    /// Failed to parse pieces from the stream.
84    ParseError(xso::error::Error),
85
86    /// Soft timeout noted by the underlying XmppStream.
87    SoftTimeout,
88
89    /// Stream disonnected.
90    Disconnected {
91        /// Slot for a new connection.
92        slot: oneshot::Sender<Connection>,
93
94        /// Set to None if the stream was cleanly closed by the remote side.
95        error: Option<io::Error>,
96    },
97
98    /// The reconnection backend dropped the connection channel.
99    ReconnectAborted,
100}
101
102enum WorkerStream {
103    /// Pending connection.
104    Connecting {
105        /// Optional contents of an [`WorkerEvent::Disconnect`] to emit.
106        notify: Option<(oneshot::Sender<Connection>, Option<io::Error>)>,
107
108        /// Receiver slot for the next connection.
109        slot: oneshot::Receiver<Connection>,
110
111        /// Straem management state from a previous connection.
112        sm_state: Option<SmState>,
113    },
114
115    /// Connection available.
116    Connected {
117        stream: XmppStream,
118        substate: ConnectedState,
119        features: StreamFeatures,
120        identity: Jid,
121    },
122
123    /// Disconnected permanently by local choice.
124    Terminated,
125}
126
127impl WorkerStream {
128    fn disconnect(&mut self, sm_state: Option<SmState>, error: Option<io::Error>) -> WorkerEvent {
129        let (tx, rx) = oneshot::channel();
130        *self = Self::Connecting {
131            notify: None,
132            slot: rx,
133            sm_state,
134        };
135        WorkerEvent::Disconnected { slot: tx, error }
136    }
137
138    fn poll_duplex(
139        self: Pin<&mut Self>,
140        transmit_queue: &mut TransmitQueue<QueueEntry>,
141        cx: &mut Context<'_>,
142    ) -> Poll<Option<WorkerEvent>> {
143        let this = self.get_mut();
144        loop {
145            match this {
146                // Disconnected cleanly (terminal state), signal end of
147                // stream.
148                Self::Terminated => return Poll::Ready(None),
149
150                // In the progress of reconnecting, wait for reconnection to
151                // complete and then switch states.
152                Self::Connecting {
153                    notify,
154                    slot,
155                    sm_state,
156                } => {
157                    if let Some((slot, error)) = notify.take() {
158                        return Poll::Ready(Some(WorkerEvent::Disconnected { slot, error }));
159                    }
160
161                    match ready!(Pin::new(slot).poll(cx)) {
162                        Ok(Connection {
163                            stream,
164                            features,
165                            identity,
166                        }) => {
167                            let substate = ConnectedState::Negotiating {
168                                // We panic here, but that is ok-ish, because
169                                // that will "only" crash the worker and thus
170                                // the stream, and that is kind of exactly
171                                // what we want.
172                                substate: NegotiationState::new(&features, sm_state.take())
173                                    .expect("Non-negotiable stream"),
174                            };
175                            *this = Self::Connected {
176                                substate,
177                                stream,
178                                features,
179                                identity,
180                            };
181                        }
182                        Err(_) => {
183                            // The sender was dropped. This is fatal.
184                            *this = Self::Terminated;
185                            return Poll::Ready(Some(WorkerEvent::ReconnectAborted));
186                        }
187                    }
188                }
189
190                Self::Connected {
191                    stream,
192                    identity,
193                    substate,
194                    features,
195                } => {
196                    match ready!(substate.poll(
197                        Pin::new(stream),
198                        identity,
199                        features,
200                        transmit_queue,
201                        cx
202                    )) {
203                        // continue looping if the substate did not produce a result.
204                        None => (),
205
206                        // produced an event to emit.
207                        Some(ConnectedEvent::Worker(v)) => {
208                            // Capture the JID from a stream reset to update our state.
209                            if let WorkerEvent::Reset { ref bound_jid, .. } = v {
210                                *identity = bound_jid.clone();
211                            }
212                            return Poll::Ready(Some(v));
213                        }
214
215                        // stream broke or closed somehow.
216                        Some(ConnectedEvent::Disconnect { sm_state, error }) => {
217                            return Poll::Ready(Some(this.disconnect(sm_state, error)));
218                        }
219
220                        Some(ConnectedEvent::RemoteShutdown { sm_state }) => {
221                            let error = io::Error::new(
222                                io::ErrorKind::ConnectionAborted,
223                                "peer closed the XML stream",
224                            );
225                            let (tx, rx) = oneshot::channel();
226                            let mut new_state = Self::Connecting {
227                                notify: None,
228                                slot: rx,
229                                sm_state,
230                            };
231                            core::mem::swap(this, &mut new_state);
232                            match new_state {
233                                Self::Connected { stream, .. } => {
234                                    tokio::spawn(shutdown_stream_by_remote_choice(
235                                        stream,
236                                        REMOTE_SHUTDOWN_TIMEOUT,
237                                    ));
238                                }
239                                _ => unreachable!(),
240                            }
241
242                            return Poll::Ready(Some(WorkerEvent::Disconnected {
243                                slot: tx,
244                                error: Some(error),
245                            }));
246                        }
247
248                        Some(ConnectedEvent::LocalShutdownRequested) => {
249                            // We don't switch to "terminated" here, but we
250                            // return "end of stream" nontheless.
251                            return Poll::Ready(None);
252                        }
253                    }
254                }
255            }
256        }
257    }
258
259    /// Poll the stream write-only.
260    ///
261    /// This never completes, not even if the `transmit_queue` is empty and
262    /// its sender has been dropped, unless a write error occurs.
263    ///
264    /// The use case behind this is to run his in parallel to a blocking
265    /// operation which should only block the receive side, but not the
266    /// transmit side of the stream.
267    ///
268    /// Calling this and `poll_duplex` from different tasks in parallel will
269    /// cause havoc.
270    ///
271    /// Any errors are reported on the next call to `poll_duplex`.
272    fn poll_writes(
273        &mut self,
274        transmit_queue: &mut TransmitQueue<QueueEntry>,
275        cx: &mut Context,
276    ) -> Poll<Never> {
277        match self {
278            Self::Terminated | Self::Connecting { .. } => Poll::Pending,
279            Self::Connected {
280                substate, stream, ..
281            } => {
282                ready!(substate.poll_writes(Pin::new(stream), transmit_queue, cx));
283                Poll::Pending
284            }
285        }
286    }
287
288    fn start_send_stream_error(&mut self, error: StreamError) {
289        match self {
290            // If we are not connected or still connecting, we feign success
291            // and enter the Terminated state.
292            Self::Terminated | Self::Connecting { .. } => {
293                *self = Self::Terminated;
294            }
295
296            Self::Connected { substate, .. } => substate.start_send_stream_error(error),
297        }
298    }
299
300    fn poll_close(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
301        match self {
302            Self::Terminated => Poll::Ready(Ok(())),
303            Self::Connecting { .. } => {
304                *self = Self::Terminated;
305                Poll::Ready(Ok(()))
306            }
307            Self::Connected {
308                substate, stream, ..
309            } => {
310                let result = ready!(substate.poll_close(Pin::new(stream), cx));
311                *self = Self::Terminated;
312                Poll::Ready(result)
313            }
314        }
315    }
316
317    fn drive_duplex<'a>(
318        &'a mut self,
319        transmit_queue: &'a mut TransmitQueue<QueueEntry>,
320    ) -> DriveDuplex<'a> {
321        DriveDuplex {
322            stream: Pin::new(self),
323            queue: transmit_queue,
324        }
325    }
326
327    fn drive_writes<'a>(
328        &'a mut self,
329        transmit_queue: &'a mut TransmitQueue<QueueEntry>,
330    ) -> DriveWrites<'a> {
331        DriveWrites {
332            stream: Pin::new(self),
333            queue: transmit_queue,
334        }
335    }
336
337    fn close(&mut self) -> Close {
338        Close {
339            stream: Pin::new(self),
340        }
341    }
342
343    /// Enqueue a `<sm:r/>`, if stream management is enabled.
344    ///
345    /// Multiple calls to `send_sm_request` may cause only a single `<sm:r/>`
346    /// to be sent.
347    ///
348    /// Returns true if stream management is enabled and a request could be
349    /// queued or deduplicated with a previous request.
350    fn queue_sm_request(&mut self) -> bool {
351        match self {
352            Self::Terminated | Self::Connecting { .. } => false,
353            Self::Connected { substate, .. } => substate.queue_sm_request(),
354        }
355    }
356}
357
358struct DriveDuplex<'x> {
359    stream: Pin<&'x mut WorkerStream>,
360    queue: &'x mut TransmitQueue<QueueEntry>,
361}
362
363impl Future for DriveDuplex<'_> {
364    type Output = Option<WorkerEvent>;
365
366    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
367        let this = self.get_mut();
368        this.stream.as_mut().poll_duplex(this.queue, cx)
369    }
370}
371
372struct DriveWrites<'x> {
373    stream: Pin<&'x mut WorkerStream>,
374    queue: &'x mut TransmitQueue<QueueEntry>,
375}
376
377impl Future for DriveWrites<'_> {
378    type Output = Never;
379
380    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
381        let this = self.get_mut();
382        this.stream.as_mut().poll_writes(this.queue, cx)
383    }
384}
385
386struct Close<'x> {
387    stream: Pin<&'x mut WorkerStream>,
388}
389
390impl Future for Close<'_> {
391    type Output = io::Result<()>;
392
393    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
394        let this = self.get_mut();
395        this.stream.as_mut().poll_close(cx)
396    }
397}
398
399pub(super) fn parse_error_to_stream_error(e: xso::error::Error) -> StreamError {
400    use xso::error::Error;
401    let condition = match e {
402        Error::XmlError(_) => DefinedCondition::NotWellFormed,
403        Error::TextParseError(_) | Error::Other(_) => DefinedCondition::InvalidXml,
404        Error::TypeMismatch => DefinedCondition::UnsupportedStanzaType,
405    };
406    StreamError {
407        condition,
408        text: Some((None, e.to_string())),
409        application_specific: vec![],
410    }
411}
412
413/// Worker system for a [`StanzaStream`].
414pub(super) struct StanzaStreamWorker {
415    reconnector: Box<dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static>,
416    frontend_tx: mpsc::Sender<Event>,
417    stream: WorkerStream,
418    transmit_queue: TransmitQueue<QueueEntry>,
419}
420
421macro_rules! send_or_break {
422    ($value:expr => $permit:ident in $ch:expr, $txq:expr => $stream:expr$(,)?) => {
423        if let Some(permit) = $permit.take() {
424            log::trace!("stanza received, passing to frontend via permit");
425            permit.send($value);
426        } else {
427            log::trace!("no permit for received stanza available, blocking on channel send while handling writes");
428            tokio::select! {
429                // drive_writes never completes: I/O errors are reported on
430                // the next call to drive_duplex(), which makes it ideal for
431                // use in parallel to $ch.send().
432                result = $stream.drive_writes(&mut $txq) => { match result {} },
433                result = $ch.send($value) => match result {
434                    Err(_) => break,
435                    Ok(()) => (),
436                },
437            }
438        }
439    };
440}
441
442impl StanzaStreamWorker {
443    pub fn spawn(
444        mut reconnector: Box<
445            dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static,
446        >,
447        queue_depth: usize,
448    ) -> (mpsc::Sender<QueueEntry>, mpsc::Receiver<Event>) {
449        let (conn_tx, conn_rx) = oneshot::channel();
450        reconnector(None, conn_tx);
451        // c2f = core to frontend
452        let (c2f_tx, c2f_rx) = mpsc::channel(queue_depth);
453        // f2c = frontend to core
454        let (f2c_tx, transmit_queue) = TransmitQueue::channel(queue_depth);
455        let mut worker = StanzaStreamWorker {
456            reconnector,
457            frontend_tx: c2f_tx,
458            stream: WorkerStream::Connecting {
459                slot: conn_rx,
460                sm_state: None,
461                notify: None,
462            },
463            transmit_queue,
464        };
465        tokio::spawn(async move { worker.run().await });
466        (f2c_tx, c2f_rx)
467    }
468
469    pub async fn run(&mut self) {
470        // TODO: consider moving this into SmState somehow, i.e. run a kind
471        // of fake stream management exploiting the sequentiality requirement
472        // from RFC 6120.
473        // NOTE: we use a random starting value here to avoid clashes with
474        // other application code.
475        let mut ping_probe_ctr: u64 = thread_rng().gen();
476
477        // We use mpsc::Sender permits (check the docs on
478        // [`tokio::sync::mpsc::Sender::reserve`]) as a way to avoid blocking
479        // on the `frontend_tx` whenever possible.
480        //
481        // We always try to have a permit available. If we have a permit
482        // available, any event we receive from the stream can be sent to
483        // the frontend tx without blocking. If we do not have a permit
484        // available, the code generated by the send_or_break macro will
485        // use the normal Sender::send coroutine function, but will also
486        // service stream writes in parallel (putting backpressure on the
487        // sender while not blocking writes on our end).
488        let mut permit = None;
489        loop {
490            tokio::select! {
491                new_permit = self.frontend_tx.reserve(), if permit.is_none() && !self.frontend_tx.is_closed() => match new_permit {
492                    Ok(new_permit) => permit = Some(new_permit),
493                    // Receiver side dropped… That is stream closure, so we
494                    // shut everything down and exit.
495                    Err(_) => break,
496                },
497                ev = self.stream.drive_duplex(&mut self.transmit_queue) => {
498                    let Some(ev) = ev else {
499                        // Stream terminated by local choice. Exit.
500                        break;
501                    };
502                    match ev {
503                        WorkerEvent::Reset { bound_jid, features } => send_or_break!(
504                            Event::Stream(StreamEvent::Reset { bound_jid, features }) => permit in self.frontend_tx,
505                            self.transmit_queue => self.stream,
506                        ),
507                        WorkerEvent::Disconnected { slot, error } => {
508                            send_or_break!(
509                                Event::Stream(StreamEvent::Suspended) => permit in self.frontend_tx,
510                                self.transmit_queue => self.stream,
511                            );
512                            if let Some(error) = error {
513                                log::debug!("Backend stream got disconnected because of an I/O error: {error}. Attempting reconnect.");
514                            } else {
515                                log::debug!("Backend stream got disconnected for an unknown reason. Attempting reconnect.");
516                            }
517                            if self.frontend_tx.is_closed() || self.transmit_queue.is_closed() {
518                                log::debug!("Immediately aborting reconnect because the frontend is gone.");
519                                break;
520                            }
521                            (self.reconnector)(None, slot);
522                        }
523                        WorkerEvent::Resumed => send_or_break!(
524                            Event::Stream(StreamEvent::Resumed) => permit in self.frontend_tx,
525                            self.transmit_queue => self.stream,
526                        ),
527                        WorkerEvent::Stanza(stanza) => send_or_break!(
528                            Event::Stanza(stanza) => permit in self.frontend_tx,
529                            self.transmit_queue => self.stream,
530                        ),
531                        WorkerEvent::ParseError(e) => {
532                            log::error!("Parse error on stream: {e}");
533                            self.stream.start_send_stream_error(parse_error_to_stream_error(e));
534                            // We are not break-ing here, because drive_duplex
535                            // is sending the error.
536                        }
537                        WorkerEvent::SoftTimeout => {
538                            if self.stream.queue_sm_request() {
539                                log::debug!("SoftTimeout tripped: enqueued <sm:r/>");
540                            } else {
541                                log::debug!("SoftTimeout tripped. Stream Management is not enabled, enqueueing ping IQ");
542                                ping_probe_ctr = ping_probe_ctr.wrapping_add(1);
543                                // We can leave to/from blank because those
544                                // are not needed to send a ping to the peer.
545                                // (At least that holds true on c2s streams.
546                                // On s2s, things are more complicated anyway
547                                // due to how bidi works.)
548                                self.transmit_queue.enqueue(QueueEntry::untracked(Box::new(iq::Iq::from_get(
549                                    format!("{}-{}", PING_PROBE_ID_PREFIX, ping_probe_ctr),
550                                    ping::Ping,
551                                ).into())));
552                            }
553                        }
554                        WorkerEvent::ReconnectAborted => {
555                            panic!("Backend was unable to handle reconnect request.");
556                        }
557                    }
558                },
559            }
560        }
561        match self.stream.close().await {
562            Ok(()) => log::debug!("Stream closed successfully"),
563            Err(e) => log::debug!("Stream closure failed: {e}"),
564        }
565    }
566}
567
568async fn shutdown_stream_by_remote_choice(mut stream: XmppStream, timeout: Duration) {
569    let deadline = Instant::now() + timeout;
570    match tokio::time::timeout_at(
571        deadline,
572        <XmppStream as SinkExt<&Stanza>>::close(&mut stream),
573    )
574    .await
575    {
576        // We don't really care about success or failure here.
577        Ok(_) => (),
578        // .. but if we run in a timeout, we exit here right away.
579        Err(_) => {
580            log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
581            return;
582        }
583    }
584    let timeout = tokio::time::sleep_until(deadline);
585    tokio::pin!(timeout);
586    loop {
587        tokio::select! {
588            _ = &mut timeout => {
589                log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
590                break;
591            }
592            ev = stream.next() => match ev {
593                None => break,
594                Some(Ok(data)) => {
595                    log::debug!("Ignoring data on stream during shutdown: {data:?}");
596                    break;
597                }
598                Some(Err(ReadError::HardError(e))) => {
599                    log::debug!("Ignoring stream I/O error during shutdown: {e}");
600                    break;
601                }
602                Some(Err(ReadError::SoftTimeout)) => (),
603                Some(Err(ReadError::ParseError(_))) => (),
604                Some(Err(ReadError::StreamFooterReceived)) => (),
605            }
606        }
607    }
608}