tokio_xmpp/stanzastream/
worker.rs

1// Copyright (c) 2019 Emmanuel Gil Peyrot <linkmauve@linkmauve.fr>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10use core::time::Duration;
11use std::io;
12
13use rand::{thread_rng, Rng};
14
15use futures::{ready, SinkExt, StreamExt};
16
17use tokio::{
18    sync::{mpsc, oneshot},
19    time::Instant,
20};
21
22use xmpp_parsers::{
23    iq,
24    jid::Jid,
25    ping,
26    stream_error::{DefinedCondition, StreamError},
27    stream_features::StreamFeatures,
28};
29
30use crate::connect::AsyncReadAndWrite;
31use crate::xmlstream::{ReadError, XmppStreamElement};
32use crate::Stanza;
33
34use super::connected::{ConnectedEvent, ConnectedState};
35use super::negotiation::NegotiationState;
36use super::queue::{QueueEntry, TransmitQueue};
37use super::stream_management::SmState;
38use super::{Event, StreamEvent};
39
40/// Convenience alias for [`XmlStreams`][`crate::xmlstream::XmlStream`] which
41/// may be used with [`StanzaStream`][`super::StanzaStream`].
42pub type XmppStream =
43    crate::xmlstream::XmlStream<Box<dyn AsyncReadAndWrite + Send + 'static>, XmppStreamElement>;
44
45/// Underlying connection for a [`StanzaStream`][`super::StanzaStream`].
46pub struct Connection {
47    /// The stream to use to send and receive XMPP data.
48    pub stream: XmppStream,
49
50    /// The stream features offered by the peer.
51    pub features: StreamFeatures,
52
53    /// The identity to which this stream belongs.
54    ///
55    /// Note that connectors must not return bound streams. However, the Jid
56    /// may still be a full jid in order to request a specific resource at
57    /// bind time. If `identity` is a bare JID, the peer will assign the
58    /// resource.
59    pub identity: Jid,
60}
61
62// Allow for up to 10s for local shutdown.
63// TODO: make this configurable maybe?
64pub(super) static LOCAL_SHUTDOWN_TIMEOUT: Duration = Duration::new(10, 0);
65pub(super) static REMOTE_SHUTDOWN_TIMEOUT: Duration = Duration::new(5, 0);
66pub(super) static PING_PROBE_ID_PREFIX: &str = "xmpp-rs-stanzastream-liveness-probe";
67
68pub(super) enum Never {}
69
70pub(super) enum WorkerEvent {
71    /// The stream was reset and can now be used for rx/tx.
72    Reset {
73        bound_jid: Jid,
74        features: StreamFeatures,
75    },
76
77    /// The stream has been resumed successfully.
78    Resumed,
79
80    /// Data received successfully.
81    Stanza(Stanza),
82
83    /// Failed to parse pieces from the stream.
84    ParseError(xso::error::Error),
85
86    /// Soft timeout noted by the underlying XmppStream.
87    SoftTimeout,
88
89    /// Stream disonnected.
90    Disconnected {
91        /// Slot for a new connection.
92        slot: oneshot::Sender<Connection>,
93
94        /// Set to None if the stream was cleanly closed by the remote side.
95        error: Option<io::Error>,
96    },
97
98    /// The reconnection backend dropped the connection channel.
99    ReconnectAborted,
100}
101
102enum WorkerStream {
103    /// Pending connection.
104    Connecting {
105        /// Optional contents of an [`WorkerEvent::Disconnect`] to emit.
106        notify: Option<(oneshot::Sender<Connection>, Option<io::Error>)>,
107
108        /// Receiver slot for the next connection.
109        slot: oneshot::Receiver<Connection>,
110
111        /// Straem management state from a previous connection.
112        sm_state: Option<SmState>,
113    },
114
115    /// Connection available.
116    Connected {
117        stream: XmppStream,
118        substate: ConnectedState,
119        features: StreamFeatures,
120        identity: Jid,
121    },
122
123    /// Disconnected permanently by local choice.
124    Terminated,
125}
126
127impl WorkerStream {
128    fn disconnect(&mut self, sm_state: Option<SmState>, error: Option<io::Error>) -> WorkerEvent {
129        let (tx, rx) = oneshot::channel();
130        *self = Self::Connecting {
131            notify: None,
132            slot: rx,
133            sm_state,
134        };
135        WorkerEvent::Disconnected { slot: tx, error }
136    }
137
138    fn poll_duplex(
139        self: Pin<&mut Self>,
140        transmit_queue: &mut TransmitQueue<QueueEntry>,
141        cx: &mut Context<'_>,
142    ) -> Poll<Option<WorkerEvent>> {
143        let this = self.get_mut();
144        loop {
145            match this {
146                // Disconnected cleanly (terminal state), signal end of
147                // stream.
148                Self::Terminated => return Poll::Ready(None),
149
150                // In the progress of reconnecting, wait for reconnection to
151                // complete and then switch states.
152                Self::Connecting {
153                    notify,
154                    slot,
155                    sm_state,
156                } => {
157                    if let Some((slot, error)) = notify.take() {
158                        return Poll::Ready(Some(WorkerEvent::Disconnected { slot, error }));
159                    }
160
161                    match ready!(Pin::new(slot).poll(cx)) {
162                        Ok(Connection {
163                            stream,
164                            features,
165                            identity,
166                        }) => {
167                            let substate = ConnectedState::Negotiating {
168                                // We panic here, but that is ok-ish, because
169                                // that will "only" crash the worker and thus
170                                // the stream, and that is kind of exactly
171                                // what we want.
172                                substate: NegotiationState::new(&features, sm_state.take())
173                                    .expect("Non-negotiable stream"),
174                            };
175                            *this = Self::Connected {
176                                substate,
177                                stream,
178                                features,
179                                identity,
180                            };
181                        }
182                        Err(_) => {
183                            // The sender was dropped. This is fatal.
184                            *this = Self::Terminated;
185                            return Poll::Ready(Some(WorkerEvent::ReconnectAborted));
186                        }
187                    }
188                }
189
190                Self::Connected {
191                    stream,
192                    identity,
193                    substate,
194                    features,
195                } => {
196                    match ready!(substate.poll(
197                        Pin::new(stream),
198                        identity,
199                        &features,
200                        transmit_queue,
201                        cx
202                    )) {
203                        // continue looping if the substate did not produce a result.
204                        None => (),
205
206                        // produced an event to emit.
207                        Some(ConnectedEvent::Worker(v)) => {
208                            match v {
209                                // Capture the JID from a stream reset to
210                                // update our state.
211                                WorkerEvent::Reset { ref bound_jid, .. } => {
212                                    *identity = bound_jid.clone();
213                                }
214                                _ => (),
215                            }
216                            return Poll::Ready(Some(v));
217                        }
218
219                        // stream broke or closed somehow.
220                        Some(ConnectedEvent::Disconnect { sm_state, error }) => {
221                            return Poll::Ready(Some(this.disconnect(sm_state, error)));
222                        }
223
224                        Some(ConnectedEvent::RemoteShutdown { sm_state }) => {
225                            let error = io::Error::new(
226                                io::ErrorKind::ConnectionAborted,
227                                "peer closed the XML stream",
228                            );
229                            let (tx, rx) = oneshot::channel();
230                            let mut new_state = Self::Connecting {
231                                notify: None,
232                                slot: rx,
233                                sm_state,
234                            };
235                            core::mem::swap(this, &mut new_state);
236                            match new_state {
237                                Self::Connected { stream, .. } => {
238                                    tokio::spawn(shutdown_stream_by_remote_choice(
239                                        stream,
240                                        REMOTE_SHUTDOWN_TIMEOUT,
241                                    ));
242                                }
243                                _ => unreachable!(),
244                            }
245
246                            return Poll::Ready(Some(WorkerEvent::Disconnected {
247                                slot: tx,
248                                error: Some(error),
249                            }));
250                        }
251
252                        Some(ConnectedEvent::LocalShutdownRequested) => {
253                            // We don't switch to "terminated" here, but we
254                            // return "end of stream" nontheless.
255                            return Poll::Ready(None);
256                        }
257                    }
258                }
259            }
260        }
261    }
262
263    /// Poll the stream write-only.
264    ///
265    /// This never completes, not even if the `transmit_queue` is empty and
266    /// its sender has been dropped, unless a write error occurs.
267    ///
268    /// The use case behind this is to run his in parallel to a blocking
269    /// operation which should only block the receive side, but not the
270    /// transmit side of the stream.
271    ///
272    /// Calling this and `poll_duplex` from different tasks in parallel will
273    /// cause havoc.
274    ///
275    /// Any errors are reported on the next call to `poll_duplex`.
276    fn poll_writes(
277        &mut self,
278        transmit_queue: &mut TransmitQueue<QueueEntry>,
279        cx: &mut Context,
280    ) -> Poll<Never> {
281        match self {
282            Self::Terminated | Self::Connecting { .. } => Poll::Pending,
283            Self::Connected {
284                substate, stream, ..
285            } => {
286                ready!(substate.poll_writes(Pin::new(stream), transmit_queue, cx));
287                Poll::Pending
288            }
289        }
290    }
291
292    fn start_send_stream_error(&mut self, error: StreamError) {
293        match self {
294            // If we are not connected or still connecting, we feign success
295            // and enter the Terminated state.
296            Self::Terminated | Self::Connecting { .. } => {
297                *self = Self::Terminated;
298            }
299
300            Self::Connected { substate, .. } => substate.start_send_stream_error(error),
301        }
302    }
303
304    fn poll_close(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
305        match self {
306            Self::Terminated => Poll::Ready(Ok(())),
307            Self::Connecting { .. } => {
308                *self = Self::Terminated;
309                Poll::Ready(Ok(()))
310            }
311            Self::Connected {
312                substate, stream, ..
313            } => {
314                let result = ready!(substate.poll_close(Pin::new(stream), cx));
315                *self = Self::Terminated;
316                Poll::Ready(result)
317            }
318        }
319    }
320
321    fn drive_duplex<'a>(
322        &'a mut self,
323        transmit_queue: &'a mut TransmitQueue<QueueEntry>,
324    ) -> DriveDuplex<'a> {
325        DriveDuplex {
326            stream: Pin::new(self),
327            queue: transmit_queue,
328        }
329    }
330
331    fn drive_writes<'a>(
332        &'a mut self,
333        transmit_queue: &'a mut TransmitQueue<QueueEntry>,
334    ) -> DriveWrites<'a> {
335        DriveWrites {
336            stream: Pin::new(self),
337            queue: transmit_queue,
338        }
339    }
340
341    fn close(&mut self) -> Close {
342        Close {
343            stream: Pin::new(self),
344        }
345    }
346
347    /// Enqueue a `<sm:r/>`, if stream management is enabled.
348    ///
349    /// Multiple calls to `send_sm_request` may cause only a single `<sm:r/>`
350    /// to be sent.
351    ///
352    /// Returns true if stream management is enabled and a request could be
353    /// queued or deduplicated with a previous request.
354    fn queue_sm_request(&mut self) -> bool {
355        match self {
356            Self::Terminated | Self::Connecting { .. } => false,
357            Self::Connected { substate, .. } => substate.queue_sm_request(),
358        }
359    }
360}
361
362struct DriveDuplex<'x> {
363    stream: Pin<&'x mut WorkerStream>,
364    queue: &'x mut TransmitQueue<QueueEntry>,
365}
366
367impl<'x> Future for DriveDuplex<'x> {
368    type Output = Option<WorkerEvent>;
369
370    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
371        let this = self.get_mut();
372        this.stream.as_mut().poll_duplex(this.queue, cx)
373    }
374}
375
376struct DriveWrites<'x> {
377    stream: Pin<&'x mut WorkerStream>,
378    queue: &'x mut TransmitQueue<QueueEntry>,
379}
380
381impl<'x> Future for DriveWrites<'x> {
382    type Output = Never;
383
384    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
385        let this = self.get_mut();
386        this.stream.as_mut().poll_writes(this.queue, cx)
387    }
388}
389
390struct Close<'x> {
391    stream: Pin<&'x mut WorkerStream>,
392}
393
394impl<'x> Future for Close<'x> {
395    type Output = io::Result<()>;
396
397    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
398        let this = self.get_mut();
399        this.stream.as_mut().poll_close(cx)
400    }
401}
402
403pub(super) fn parse_error_to_stream_error(e: xso::error::Error) -> StreamError {
404    use xso::error::Error;
405    let condition = match e {
406        Error::XmlError(_) => DefinedCondition::NotWellFormed,
407        Error::TextParseError(_) | Error::Other(_) => DefinedCondition::InvalidXml,
408        Error::TypeMismatch => DefinedCondition::UnsupportedStanzaType,
409    };
410    StreamError {
411        condition,
412        text: Some((None, e.to_string())),
413        application_specific: vec![],
414    }
415}
416
417/// Worker system for a [`StanzaStream`].
418pub(super) struct StanzaStreamWorker {
419    reconnector: Box<dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static>,
420    frontend_tx: mpsc::Sender<Event>,
421    stream: WorkerStream,
422    transmit_queue: TransmitQueue<QueueEntry>,
423}
424
425macro_rules! send_or_break {
426    ($value:expr => $permit:ident in $ch:expr, $txq:expr => $stream:expr$(,)?) => {
427        if let Some(permit) = $permit.take() {
428            log::trace!("stanza received, passing to frontend via permit");
429            permit.send($value);
430        } else {
431            log::trace!("no permit for received stanza available, blocking on channel send while handling writes");
432            tokio::select! {
433                // drive_writes never completes: I/O errors are reported on
434                // the next call to drive_duplex(), which makes it ideal for
435                // use in parallel to $ch.send().
436                result = $stream.drive_writes(&mut $txq) => { match result {} },
437                result = $ch.send($value) => match result {
438                    Err(_) => break,
439                    Ok(()) => (),
440                },
441            }
442        }
443    };
444}
445
446impl StanzaStreamWorker {
447    pub fn spawn(
448        mut reconnector: Box<
449            dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static,
450        >,
451        queue_depth: usize,
452    ) -> (mpsc::Sender<QueueEntry>, mpsc::Receiver<Event>) {
453        let (conn_tx, conn_rx) = oneshot::channel();
454        reconnector(None, conn_tx);
455        // c2f = core to frontend
456        let (c2f_tx, c2f_rx) = mpsc::channel(queue_depth);
457        // f2c = frontend to core
458        let (f2c_tx, transmit_queue) = TransmitQueue::channel(queue_depth);
459        let mut worker = StanzaStreamWorker {
460            reconnector,
461            frontend_tx: c2f_tx,
462            stream: WorkerStream::Connecting {
463                slot: conn_rx,
464                sm_state: None,
465                notify: None,
466            },
467            transmit_queue,
468        };
469        tokio::spawn(async move { worker.run().await });
470        (f2c_tx, c2f_rx)
471    }
472
473    pub async fn run(&mut self) {
474        // TODO: consider moving this into SmState somehow, i.e. run a kind
475        // of fake stream management exploiting the sequentiality requirement
476        // from RFC 6120.
477        // NOTE: we use a random starting value here to avoid clashes with
478        // other application code.
479        let mut ping_probe_ctr: u64 = thread_rng().gen();
480
481        // We use mpsc::Sender permits (check the docs on
482        // [`tokio::sync::mpsc::Sender::reserve`]) as a way to avoid blocking
483        // on the `frontend_tx` whenever possible.
484        //
485        // We always try to have a permit available. If we have a permit
486        // available, any event we receive from the stream can be sent to
487        // the frontend tx without blocking. If we do not have a permit
488        // available, the code generated by the send_or_break macro will
489        // use the normal Sender::send coroutine function, but will also
490        // service stream writes in parallel (putting backpressure on the
491        // sender while not blocking writes on our end).
492        let mut permit = None;
493        loop {
494            tokio::select! {
495                new_permit = self.frontend_tx.reserve(), if permit.is_none() && !self.frontend_tx.is_closed() => match new_permit {
496                    Ok(new_permit) => permit = Some(new_permit),
497                    // Receiver side dropped… That is stream closure, so we
498                    // shut everything down and exit.
499                    Err(_) => break,
500                },
501                ev = self.stream.drive_duplex(&mut self.transmit_queue) => {
502                    let Some(ev) = ev else {
503                        // Stream terminated by local choice. Exit.
504                        break;
505                    };
506                    match ev {
507                        WorkerEvent::Reset { bound_jid, features } => send_or_break!(
508                            Event::Stream(StreamEvent::Reset { bound_jid, features }) => permit in self.frontend_tx,
509                            self.transmit_queue => self.stream,
510                        ),
511                        WorkerEvent::Disconnected { slot, error } => {
512                            send_or_break!(
513                                Event::Stream(StreamEvent::Suspended) => permit in self.frontend_tx,
514                                self.transmit_queue => self.stream,
515                            );
516                            if let Some(error) = error {
517                                log::debug!("Backend stream got disconnected because of an I/O error: {error}. Attempting reconnect.");
518                            } else {
519                                log::debug!("Backend stream got disconnected for an unknown reason. Attempting reconnect.");
520                            }
521                            if self.frontend_tx.is_closed() || self.transmit_queue.is_closed() {
522                                log::debug!("Immediately aborting reconnect because the frontend is gone.");
523                                break;
524                            }
525                            (self.reconnector)(None, slot);
526                        }
527                        WorkerEvent::Resumed => send_or_break!(
528                            Event::Stream(StreamEvent::Resumed) => permit in self.frontend_tx,
529                            self.transmit_queue => self.stream,
530                        ),
531                        WorkerEvent::Stanza(stanza) => send_or_break!(
532                            Event::Stanza(stanza) => permit in self.frontend_tx,
533                            self.transmit_queue => self.stream,
534                        ),
535                        WorkerEvent::ParseError(e) => {
536                            log::error!("Parse error on stream: {e}");
537                            self.stream.start_send_stream_error(parse_error_to_stream_error(e));
538                            // We are not break-ing here, because drive_duplex
539                            // is sending the error.
540                        }
541                        WorkerEvent::SoftTimeout => {
542                            if self.stream.queue_sm_request() {
543                                log::debug!("SoftTimeout tripped: enqueued <sm:r/>");
544                            } else {
545                                log::debug!("SoftTimeout tripped. Stream Management is not enabled, enqueueing ping IQ");
546                                ping_probe_ctr = ping_probe_ctr.wrapping_add(1);
547                                // We can leave to/from blank because those
548                                // are not needed to send a ping to the peer.
549                                // (At least that holds true on c2s streams.
550                                // On s2s, things are more complicated anyway
551                                // due to how bidi works.)
552                                self.transmit_queue.enqueue(QueueEntry::untracked(Box::new(iq::Iq::from_get(
553                                    format!("{}-{}", PING_PROBE_ID_PREFIX, ping_probe_ctr),
554                                    ping::Ping,
555                                ).into())));
556                            }
557                        }
558                        WorkerEvent::ReconnectAborted => {
559                            panic!("Backend was unable to handle reconnect request.");
560                        }
561                    }
562                },
563            }
564        }
565        match self.stream.close().await {
566            Ok(()) => log::debug!("Stream closed successfully"),
567            Err(e) => log::debug!("Stream closure failed: {e}"),
568        }
569    }
570}
571
572async fn shutdown_stream_by_remote_choice(mut stream: XmppStream, timeout: Duration) {
573    let deadline = Instant::now() + timeout;
574    match tokio::time::timeout_at(
575        deadline,
576        <XmppStream as SinkExt<&Stanza>>::close(&mut stream),
577    )
578    .await
579    {
580        // We don't really care about success or failure here.
581        Ok(_) => (),
582        // .. but if we run in a timeout, we exit here right away.
583        Err(_) => {
584            log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
585            return;
586        }
587    }
588    let timeout = tokio::time::sleep_until(deadline);
589    tokio::pin!(timeout);
590    loop {
591        tokio::select! {
592            _ = &mut timeout => {
593                log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
594                break;
595            }
596            ev = stream.next() => match ev {
597                None => break,
598                Some(Ok(data)) => {
599                    log::debug!("Ignoring data on stream during shutdown: {data:?}");
600                    break;
601                }
602                Some(Err(ReadError::HardError(e))) => {
603                    log::debug!("Ignoring stream I/O error during shutdown: {e}");
604                    break;
605                }
606                Some(Err(ReadError::SoftTimeout)) => (),
607                Some(Err(ReadError::ParseError(_))) => (),
608                Some(Err(ReadError::StreamFooterReceived)) => (),
609            }
610        }
611    }
612}