tokio_xmpp/stanzastream/
worker.rs

1// Copyright (c) 2019 Emmanuel Gil Peyrot <linkmauve@linkmauve.fr>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10use core::time::Duration;
11use std::io;
12
13use futures::{ready, SinkExt, StreamExt};
14
15use tokio::{
16    sync::{mpsc, oneshot},
17    time::Instant,
18};
19
20use xmpp_parsers::{
21    iq,
22    jid::Jid,
23    ping,
24    stream_error::{DefinedCondition, StreamError},
25    stream_features::StreamFeatures,
26};
27
28use crate::connect::AsyncReadAndWrite;
29use crate::xmlstream::{ReadError, XmppStreamElement};
30use crate::Stanza;
31
32use super::connected::{ConnectedEvent, ConnectedState};
33use super::negotiation::NegotiationState;
34use super::queue::{QueueEntry, TransmitQueue};
35use super::stream_management::SmState;
36use super::{Event, StreamEvent};
37
38/// Convenience alias for [`XmlStreams`][`crate::xmlstream::XmlStream`] which
39/// may be used with [`StanzaStream`][`super::StanzaStream`].
40pub type XmppStream =
41    crate::xmlstream::XmlStream<Box<dyn AsyncReadAndWrite + Send + 'static>, XmppStreamElement>;
42
43/// Underlying connection for a [`StanzaStream`][`super::StanzaStream`].
44pub struct Connection {
45    /// The stream to use to send and receive XMPP data.
46    pub stream: XmppStream,
47
48    /// The stream features offered by the peer.
49    pub features: StreamFeatures,
50
51    /// The identity to which this stream belongs.
52    ///
53    /// Note that connectors must not return bound streams. However, the Jid
54    /// may still be a full jid in order to request a specific resource at
55    /// bind time. If `identity` is a bare JID, the peer will assign the
56    /// resource.
57    pub identity: Jid,
58}
59
60// Allow for up to 10s for local shutdown.
61// TODO: make this configurable maybe?
62pub(super) static LOCAL_SHUTDOWN_TIMEOUT: Duration = Duration::new(10, 0);
63pub(super) static REMOTE_SHUTDOWN_TIMEOUT: Duration = Duration::new(5, 0);
64pub(super) static PING_PROBE_ID_PREFIX: &str = "xmpp-rs-stanzastream-liveness-probe";
65
66pub(super) enum Never {}
67
68pub(super) enum WorkerEvent {
69    /// The stream was reset and can now be used for rx/tx.
70    Reset {
71        bound_jid: Jid,
72        features: StreamFeatures,
73    },
74
75    /// The stream has been resumed successfully.
76    Resumed,
77
78    /// Data received successfully.
79    Stanza(Stanza),
80
81    /// Failed to parse pieces from the stream.
82    ParseError(xso::error::Error),
83
84    /// Soft timeout noted by the underlying XmppStream.
85    SoftTimeout,
86
87    /// Stream disonnected.
88    Disconnected {
89        /// Slot for a new connection.
90        slot: oneshot::Sender<Connection>,
91
92        /// Set to None if the stream was cleanly closed by the remote side.
93        error: Option<io::Error>,
94    },
95
96    /// The reconnection backend dropped the connection channel.
97    ReconnectAborted,
98}
99
100enum WorkerStream {
101    /// Pending connection.
102    Connecting {
103        /// Optional contents of an [`WorkerEvent::Disconnect`] to emit.
104        notify: Option<(oneshot::Sender<Connection>, Option<io::Error>)>,
105
106        /// Receiver slot for the next connection.
107        slot: oneshot::Receiver<Connection>,
108
109        /// Straem management state from a previous connection.
110        sm_state: Option<SmState>,
111    },
112
113    /// Connection available.
114    Connected {
115        stream: XmppStream,
116        substate: ConnectedState,
117        features: StreamFeatures,
118        identity: Jid,
119    },
120
121    /// Disconnected permanently by local choice.
122    Terminated,
123}
124
125impl WorkerStream {
126    fn disconnect(&mut self, sm_state: Option<SmState>, error: Option<io::Error>) -> WorkerEvent {
127        let (tx, rx) = oneshot::channel();
128        *self = Self::Connecting {
129            notify: None,
130            slot: rx,
131            sm_state,
132        };
133        WorkerEvent::Disconnected { slot: tx, error }
134    }
135
136    fn poll_duplex(
137        self: Pin<&mut Self>,
138        transmit_queue: &mut TransmitQueue<QueueEntry>,
139        cx: &mut Context<'_>,
140    ) -> Poll<Option<WorkerEvent>> {
141        let this = self.get_mut();
142        loop {
143            match this {
144                // Disconnected cleanly (terminal state), signal end of
145                // stream.
146                Self::Terminated => return Poll::Ready(None),
147
148                // In the progress of reconnecting, wait for reconnection to
149                // complete and then switch states.
150                Self::Connecting {
151                    notify,
152                    slot,
153                    sm_state,
154                } => {
155                    if let Some((slot, error)) = notify.take() {
156                        return Poll::Ready(Some(WorkerEvent::Disconnected { slot, error }));
157                    }
158
159                    match ready!(Pin::new(slot).poll(cx)) {
160                        Ok(Connection {
161                            stream,
162                            features,
163                            identity,
164                        }) => {
165                            let substate = ConnectedState::Negotiating {
166                                // We panic here, but that is ok-ish, because
167                                // that will "only" crash the worker and thus
168                                // the stream, and that is kind of exactly
169                                // what we want.
170                                substate: NegotiationState::new(&features, sm_state.take())
171                                    .expect("Non-negotiable stream"),
172                            };
173                            *this = Self::Connected {
174                                substate,
175                                stream,
176                                features,
177                                identity,
178                            };
179                        }
180                        Err(_) => {
181                            // The sender was dropped. This is fatal.
182                            *this = Self::Terminated;
183                            return Poll::Ready(Some(WorkerEvent::ReconnectAborted));
184                        }
185                    }
186                }
187
188                Self::Connected {
189                    stream,
190                    identity,
191                    substate,
192                    features,
193                } => {
194                    match ready!(substate.poll(
195                        Pin::new(stream),
196                        identity,
197                        features,
198                        transmit_queue,
199                        cx
200                    )) {
201                        // continue looping if the substate did not produce a result.
202                        None => (),
203
204                        // produced an event to emit.
205                        Some(ConnectedEvent::Worker(v)) => {
206                            // Capture the JID from a stream reset to update our state.
207                            if let WorkerEvent::Reset { ref bound_jid, .. } = v {
208                                *identity = bound_jid.clone();
209                            }
210                            return Poll::Ready(Some(v));
211                        }
212
213                        // stream broke or closed somehow.
214                        Some(ConnectedEvent::Disconnect { sm_state, error }) => {
215                            return Poll::Ready(Some(this.disconnect(sm_state, error)));
216                        }
217
218                        Some(ConnectedEvent::RemoteShutdown { sm_state }) => {
219                            let error = io::Error::new(
220                                io::ErrorKind::ConnectionAborted,
221                                "peer closed the XML stream",
222                            );
223                            let (tx, rx) = oneshot::channel();
224                            let mut new_state = Self::Connecting {
225                                notify: None,
226                                slot: rx,
227                                sm_state,
228                            };
229                            core::mem::swap(this, &mut new_state);
230                            match new_state {
231                                Self::Connected { stream, .. } => {
232                                    tokio::spawn(shutdown_stream_by_remote_choice(
233                                        stream,
234                                        REMOTE_SHUTDOWN_TIMEOUT,
235                                    ));
236                                }
237                                _ => unreachable!(),
238                            }
239
240                            return Poll::Ready(Some(WorkerEvent::Disconnected {
241                                slot: tx,
242                                error: Some(error),
243                            }));
244                        }
245
246                        Some(ConnectedEvent::LocalShutdownRequested) => {
247                            // We don't switch to "terminated" here, but we
248                            // return "end of stream" nontheless.
249                            return Poll::Ready(None);
250                        }
251                    }
252                }
253            }
254        }
255    }
256
257    /// Poll the stream write-only.
258    ///
259    /// This never completes, not even if the `transmit_queue` is empty and
260    /// its sender has been dropped, unless a write error occurs.
261    ///
262    /// The use case behind this is to run his in parallel to a blocking
263    /// operation which should only block the receive side, but not the
264    /// transmit side of the stream.
265    ///
266    /// Calling this and `poll_duplex` from different tasks in parallel will
267    /// cause havoc.
268    ///
269    /// Any errors are reported on the next call to `poll_duplex`.
270    fn poll_writes(
271        &mut self,
272        transmit_queue: &mut TransmitQueue<QueueEntry>,
273        cx: &mut Context,
274    ) -> Poll<Never> {
275        match self {
276            Self::Terminated | Self::Connecting { .. } => Poll::Pending,
277            Self::Connected {
278                substate, stream, ..
279            } => {
280                ready!(substate.poll_writes(Pin::new(stream), transmit_queue, cx));
281                Poll::Pending
282            }
283        }
284    }
285
286    fn start_send_stream_error(&mut self, error: StreamError) {
287        match self {
288            // If we are not connected or still connecting, we feign success
289            // and enter the Terminated state.
290            Self::Terminated | Self::Connecting { .. } => {
291                *self = Self::Terminated;
292            }
293
294            Self::Connected { substate, .. } => substate.start_send_stream_error(error),
295        }
296    }
297
298    fn poll_close(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
299        match self {
300            Self::Terminated => Poll::Ready(Ok(())),
301            Self::Connecting { .. } => {
302                *self = Self::Terminated;
303                Poll::Ready(Ok(()))
304            }
305            Self::Connected {
306                substate, stream, ..
307            } => {
308                let result = ready!(substate.poll_close(Pin::new(stream), cx));
309                *self = Self::Terminated;
310                Poll::Ready(result)
311            }
312        }
313    }
314
315    fn drive_duplex<'a>(
316        &'a mut self,
317        transmit_queue: &'a mut TransmitQueue<QueueEntry>,
318    ) -> DriveDuplex<'a> {
319        DriveDuplex {
320            stream: Pin::new(self),
321            queue: transmit_queue,
322        }
323    }
324
325    fn drive_writes<'a>(
326        &'a mut self,
327        transmit_queue: &'a mut TransmitQueue<QueueEntry>,
328    ) -> DriveWrites<'a> {
329        DriveWrites {
330            stream: Pin::new(self),
331            queue: transmit_queue,
332        }
333    }
334
335    fn close(&mut self) -> Close<'_> {
336        Close {
337            stream: Pin::new(self),
338        }
339    }
340
341    /// Enqueue a `<sm:r/>`, if stream management is enabled.
342    ///
343    /// Multiple calls to `send_sm_request` may cause only a single `<sm:r/>`
344    /// to be sent.
345    ///
346    /// Returns true if stream management is enabled and a request could be
347    /// queued or deduplicated with a previous request.
348    fn queue_sm_request(&mut self) -> bool {
349        match self {
350            Self::Terminated | Self::Connecting { .. } => false,
351            Self::Connected { substate, .. } => substate.queue_sm_request(),
352        }
353    }
354}
355
356struct DriveDuplex<'x> {
357    stream: Pin<&'x mut WorkerStream>,
358    queue: &'x mut TransmitQueue<QueueEntry>,
359}
360
361impl Future for DriveDuplex<'_> {
362    type Output = Option<WorkerEvent>;
363
364    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
365        let this = self.get_mut();
366        this.stream.as_mut().poll_duplex(this.queue, cx)
367    }
368}
369
370struct DriveWrites<'x> {
371    stream: Pin<&'x mut WorkerStream>,
372    queue: &'x mut TransmitQueue<QueueEntry>,
373}
374
375impl Future for DriveWrites<'_> {
376    type Output = Never;
377
378    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
379        let this = self.get_mut();
380        this.stream.as_mut().poll_writes(this.queue, cx)
381    }
382}
383
384struct Close<'x> {
385    stream: Pin<&'x mut WorkerStream>,
386}
387
388impl Future for Close<'_> {
389    type Output = io::Result<()>;
390
391    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
392        let this = self.get_mut();
393        this.stream.as_mut().poll_close(cx)
394    }
395}
396
397pub(super) fn parse_error_to_stream_error(e: xso::error::Error) -> StreamError {
398    use xso::error::Error;
399    let condition = match e {
400        Error::XmlError(_) => DefinedCondition::NotWellFormed,
401        Error::TextParseError(_) | Error::Other(_) => DefinedCondition::InvalidXml,
402        Error::TypeMismatch => DefinedCondition::UnsupportedStanzaType,
403    };
404    StreamError::new(condition, "en", e.to_string())
405}
406
407/// Worker system for a [`StanzaStream`].
408pub(super) struct StanzaStreamWorker {
409    reconnector: Box<dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static>,
410    frontend_tx: mpsc::Sender<Event>,
411    stream: WorkerStream,
412    transmit_queue: TransmitQueue<QueueEntry>,
413}
414
415macro_rules! send_or_break {
416    ($value:expr => $permit:ident in $ch:expr, $txq:expr => $stream:expr$(,)?) => {
417        if let Some(permit) = $permit.take() {
418            log::trace!("stanza received, passing to frontend via permit");
419            permit.send($value);
420        } else {
421            log::trace!("no permit for received stanza available, blocking on channel send while handling writes");
422            tokio::select! {
423                // drive_writes never completes: I/O errors are reported on
424                // the next call to drive_duplex(), which makes it ideal for
425                // use in parallel to $ch.send().
426                result = $stream.drive_writes(&mut $txq) => { match result {} },
427                result = $ch.send($value) => match result {
428                    Err(_) => break,
429                    Ok(()) => (),
430                },
431            }
432        }
433    };
434}
435
436impl StanzaStreamWorker {
437    pub fn spawn(
438        mut reconnector: Box<
439            dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static,
440        >,
441        queue_depth: usize,
442    ) -> (mpsc::Sender<QueueEntry>, mpsc::Receiver<Event>) {
443        let (conn_tx, conn_rx) = oneshot::channel();
444        reconnector(None, conn_tx);
445        // c2f = core to frontend
446        let (c2f_tx, c2f_rx) = mpsc::channel(queue_depth);
447        // f2c = frontend to core
448        let (f2c_tx, transmit_queue) = TransmitQueue::channel(queue_depth);
449        let mut worker = StanzaStreamWorker {
450            reconnector,
451            frontend_tx: c2f_tx,
452            stream: WorkerStream::Connecting {
453                slot: conn_rx,
454                sm_state: None,
455                notify: None,
456            },
457            transmit_queue,
458        };
459        tokio::spawn(async move { worker.run().await });
460        (f2c_tx, c2f_rx)
461    }
462
463    pub async fn run(&mut self) {
464        // TODO: consider moving this into SmState somehow, i.e. run a kind
465        // of fake stream management exploiting the sequentiality requirement
466        // from RFC 6120.
467        // NOTE: we use a random starting value here to avoid clashes with
468        // other application code.
469        let mut ping_probe_ctr: u64 = rand::random();
470
471        // We use mpsc::Sender permits (check the docs on
472        // [`tokio::sync::mpsc::Sender::reserve`]) as a way to avoid blocking
473        // on the `frontend_tx` whenever possible.
474        //
475        // We always try to have a permit available. If we have a permit
476        // available, any event we receive from the stream can be sent to
477        // the frontend tx without blocking. If we do not have a permit
478        // available, the code generated by the send_or_break macro will
479        // use the normal Sender::send coroutine function, but will also
480        // service stream writes in parallel (putting backpressure on the
481        // sender while not blocking writes on our end).
482        let mut permit = None;
483        loop {
484            tokio::select! {
485                new_permit = self.frontend_tx.reserve(), if permit.is_none() && !self.frontend_tx.is_closed() => match new_permit {
486                    Ok(new_permit) => permit = Some(new_permit),
487                    // Receiver side dropped… That is stream closure, so we
488                    // shut everything down and exit.
489                    Err(_) => break,
490                },
491                ev = self.stream.drive_duplex(&mut self.transmit_queue) => {
492                    let Some(ev) = ev else {
493                        // Stream terminated by local choice. Exit.
494                        break;
495                    };
496                    match ev {
497                        WorkerEvent::Reset { bound_jid, features } => send_or_break!(
498                            Event::Stream(StreamEvent::Reset { bound_jid, features }) => permit in self.frontend_tx,
499                            self.transmit_queue => self.stream,
500                        ),
501                        WorkerEvent::Disconnected { slot, error } => {
502                            send_or_break!(
503                                Event::Stream(StreamEvent::Suspended) => permit in self.frontend_tx,
504                                self.transmit_queue => self.stream,
505                            );
506                            if let Some(error) = error {
507                                log::debug!("Backend stream got disconnected because of an I/O error: {error}. Attempting reconnect.");
508                            } else {
509                                log::debug!("Backend stream got disconnected for an unknown reason. Attempting reconnect.");
510                            }
511                            if self.frontend_tx.is_closed() || self.transmit_queue.is_closed() {
512                                log::debug!("Immediately aborting reconnect because the frontend is gone.");
513                                break;
514                            }
515                            (self.reconnector)(None, slot);
516                        }
517                        WorkerEvent::Resumed => send_or_break!(
518                            Event::Stream(StreamEvent::Resumed) => permit in self.frontend_tx,
519                            self.transmit_queue => self.stream,
520                        ),
521                        WorkerEvent::Stanza(stanza) => send_or_break!(
522                            Event::Stanza(stanza) => permit in self.frontend_tx,
523                            self.transmit_queue => self.stream,
524                        ),
525                        WorkerEvent::ParseError(e) => {
526                            log::error!("Parse error on stream: {e}");
527                            self.stream.start_send_stream_error(parse_error_to_stream_error(e));
528                            // We are not break-ing here, because drive_duplex
529                            // is sending the error.
530                        }
531                        WorkerEvent::SoftTimeout => {
532                            if self.stream.queue_sm_request() {
533                                log::debug!("SoftTimeout tripped: enqueued <sm:r/>");
534                            } else {
535                                log::debug!("SoftTimeout tripped. Stream Management is not enabled, enqueueing ping IQ");
536                                ping_probe_ctr = ping_probe_ctr.wrapping_add(1);
537                                // We can leave to/from blank because those
538                                // are not needed to send a ping to the peer.
539                                // (At least that holds true on c2s streams.
540                                // On s2s, things are more complicated anyway
541                                // due to how bidi works.)
542                                self.transmit_queue.enqueue(QueueEntry::untracked(Box::new(iq::Iq::from_get(
543                                    format!("{}-{}", PING_PROBE_ID_PREFIX, ping_probe_ctr),
544                                    ping::Ping,
545                                ).into())));
546                            }
547                        }
548                        WorkerEvent::ReconnectAborted => {
549                            panic!("Backend was unable to handle reconnect request.");
550                        }
551                    }
552                },
553            }
554        }
555        match self.stream.close().await {
556            Ok(()) => log::debug!("Stream closed successfully"),
557            Err(e) => log::debug!("Stream closure failed: {e}"),
558        }
559    }
560}
561
562async fn shutdown_stream_by_remote_choice(mut stream: XmppStream, timeout: Duration) {
563    let deadline = Instant::now() + timeout;
564    match tokio::time::timeout_at(
565        deadline,
566        <XmppStream as SinkExt<&Stanza>>::close(&mut stream),
567    )
568    .await
569    {
570        // We don't really care about success or failure here.
571        Ok(_) => (),
572        // .. but if we run in a timeout, we exit here right away.
573        Err(_) => {
574            log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
575            return;
576        }
577    }
578    let timeout = tokio::time::sleep_until(deadline);
579    tokio::pin!(timeout);
580    loop {
581        tokio::select! {
582            _ = &mut timeout => {
583                log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
584                break;
585            }
586            ev = stream.next() => match ev {
587                None => break,
588                Some(Ok(data)) => {
589                    log::debug!("Ignoring data on stream during shutdown: {data:?}");
590                    break;
591                }
592                Some(Err(ReadError::HardError(e))) => {
593                    log::debug!("Ignoring stream I/O error during shutdown: {e}");
594                    break;
595                }
596                Some(Err(ReadError::SoftTimeout)) => (),
597                Some(Err(ReadError::ParseError(_))) => (),
598                Some(Err(ReadError::StreamFooterReceived)) => (),
599            }
600        }
601    }
602}