1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10use core::time::Duration;
11use std::io;
12
13use rand::{thread_rng, Rng};
14
15use futures::{ready, SinkExt, StreamExt};
16
17use tokio::{
18 sync::{mpsc, oneshot},
19 time::Instant,
20};
21
22use xmpp_parsers::{
23 iq,
24 jid::Jid,
25 ping,
26 stream_error::{DefinedCondition, StreamError},
27 stream_features::StreamFeatures,
28};
29
30use crate::connect::AsyncReadAndWrite;
31use crate::xmlstream::{ReadError, XmppStreamElement};
32use crate::Stanza;
33
34use super::connected::{ConnectedEvent, ConnectedState};
35use super::negotiation::NegotiationState;
36use super::queue::{QueueEntry, TransmitQueue};
37use super::stream_management::SmState;
38use super::{Event, StreamEvent};
39
40pub type XmppStream =
43 crate::xmlstream::XmlStream<Box<dyn AsyncReadAndWrite + Send + 'static>, XmppStreamElement>;
44
45pub struct Connection {
47 pub stream: XmppStream,
49
50 pub features: StreamFeatures,
52
53 pub identity: Jid,
60}
61
62pub(super) static LOCAL_SHUTDOWN_TIMEOUT: Duration = Duration::new(10, 0);
65pub(super) static REMOTE_SHUTDOWN_TIMEOUT: Duration = Duration::new(5, 0);
66pub(super) static PING_PROBE_ID_PREFIX: &str = "xmpp-rs-stanzastream-liveness-probe";
67
68pub(super) enum Never {}
69
70pub(super) enum WorkerEvent {
71 Reset {
73 bound_jid: Jid,
74 features: StreamFeatures,
75 },
76
77 Resumed,
79
80 Stanza(Stanza),
82
83 ParseError(xso::error::Error),
85
86 SoftTimeout,
88
89 Disconnected {
91 slot: oneshot::Sender<Connection>,
93
94 error: Option<io::Error>,
96 },
97
98 ReconnectAborted,
100}
101
102enum WorkerStream {
103 Connecting {
105 notify: Option<(oneshot::Sender<Connection>, Option<io::Error>)>,
107
108 slot: oneshot::Receiver<Connection>,
110
111 sm_state: Option<SmState>,
113 },
114
115 Connected {
117 stream: XmppStream,
118 substate: ConnectedState,
119 features: StreamFeatures,
120 identity: Jid,
121 },
122
123 Terminated,
125}
126
127impl WorkerStream {
128 fn disconnect(&mut self, sm_state: Option<SmState>, error: Option<io::Error>) -> WorkerEvent {
129 let (tx, rx) = oneshot::channel();
130 *self = Self::Connecting {
131 notify: None,
132 slot: rx,
133 sm_state,
134 };
135 WorkerEvent::Disconnected { slot: tx, error }
136 }
137
138 fn poll_duplex(
139 self: Pin<&mut Self>,
140 transmit_queue: &mut TransmitQueue<QueueEntry>,
141 cx: &mut Context<'_>,
142 ) -> Poll<Option<WorkerEvent>> {
143 let this = self.get_mut();
144 loop {
145 match this {
146 Self::Terminated => return Poll::Ready(None),
149
150 Self::Connecting {
153 notify,
154 slot,
155 sm_state,
156 } => {
157 if let Some((slot, error)) = notify.take() {
158 return Poll::Ready(Some(WorkerEvent::Disconnected { slot, error }));
159 }
160
161 match ready!(Pin::new(slot).poll(cx)) {
162 Ok(Connection {
163 stream,
164 features,
165 identity,
166 }) => {
167 let substate = ConnectedState::Negotiating {
168 substate: NegotiationState::new(&features, sm_state.take())
173 .expect("Non-negotiable stream"),
174 };
175 *this = Self::Connected {
176 substate,
177 stream,
178 features,
179 identity,
180 };
181 }
182 Err(_) => {
183 *this = Self::Terminated;
185 return Poll::Ready(Some(WorkerEvent::ReconnectAborted));
186 }
187 }
188 }
189
190 Self::Connected {
191 stream,
192 identity,
193 substate,
194 features,
195 } => {
196 match ready!(substate.poll(
197 Pin::new(stream),
198 identity,
199 &features,
200 transmit_queue,
201 cx
202 )) {
203 None => (),
205
206 Some(ConnectedEvent::Worker(v)) => {
208 match v {
209 WorkerEvent::Reset { ref bound_jid, .. } => {
212 *identity = bound_jid.clone();
213 }
214 _ => (),
215 }
216 return Poll::Ready(Some(v));
217 }
218
219 Some(ConnectedEvent::Disconnect { sm_state, error }) => {
221 return Poll::Ready(Some(this.disconnect(sm_state, error)));
222 }
223
224 Some(ConnectedEvent::RemoteShutdown { sm_state }) => {
225 let error = io::Error::new(
226 io::ErrorKind::ConnectionAborted,
227 "peer closed the XML stream",
228 );
229 let (tx, rx) = oneshot::channel();
230 let mut new_state = Self::Connecting {
231 notify: None,
232 slot: rx,
233 sm_state,
234 };
235 core::mem::swap(this, &mut new_state);
236 match new_state {
237 Self::Connected { stream, .. } => {
238 tokio::spawn(shutdown_stream_by_remote_choice(
239 stream,
240 REMOTE_SHUTDOWN_TIMEOUT,
241 ));
242 }
243 _ => unreachable!(),
244 }
245
246 return Poll::Ready(Some(WorkerEvent::Disconnected {
247 slot: tx,
248 error: Some(error),
249 }));
250 }
251
252 Some(ConnectedEvent::LocalShutdownRequested) => {
253 return Poll::Ready(None);
256 }
257 }
258 }
259 }
260 }
261 }
262
263 fn poll_writes(
277 &mut self,
278 transmit_queue: &mut TransmitQueue<QueueEntry>,
279 cx: &mut Context,
280 ) -> Poll<Never> {
281 match self {
282 Self::Terminated | Self::Connecting { .. } => Poll::Pending,
283 Self::Connected {
284 substate, stream, ..
285 } => {
286 ready!(substate.poll_writes(Pin::new(stream), transmit_queue, cx));
287 Poll::Pending
288 }
289 }
290 }
291
292 fn start_send_stream_error(&mut self, error: StreamError) {
293 match self {
294 Self::Terminated | Self::Connecting { .. } => {
297 *self = Self::Terminated;
298 }
299
300 Self::Connected { substate, .. } => substate.start_send_stream_error(error),
301 }
302 }
303
304 fn poll_close(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
305 match self {
306 Self::Terminated => Poll::Ready(Ok(())),
307 Self::Connecting { .. } => {
308 *self = Self::Terminated;
309 Poll::Ready(Ok(()))
310 }
311 Self::Connected {
312 substate, stream, ..
313 } => {
314 let result = ready!(substate.poll_close(Pin::new(stream), cx));
315 *self = Self::Terminated;
316 Poll::Ready(result)
317 }
318 }
319 }
320
321 fn drive_duplex<'a>(
322 &'a mut self,
323 transmit_queue: &'a mut TransmitQueue<QueueEntry>,
324 ) -> DriveDuplex<'a> {
325 DriveDuplex {
326 stream: Pin::new(self),
327 queue: transmit_queue,
328 }
329 }
330
331 fn drive_writes<'a>(
332 &'a mut self,
333 transmit_queue: &'a mut TransmitQueue<QueueEntry>,
334 ) -> DriveWrites<'a> {
335 DriveWrites {
336 stream: Pin::new(self),
337 queue: transmit_queue,
338 }
339 }
340
341 fn close(&mut self) -> Close {
342 Close {
343 stream: Pin::new(self),
344 }
345 }
346
347 fn queue_sm_request(&mut self) -> bool {
355 match self {
356 Self::Terminated | Self::Connecting { .. } => false,
357 Self::Connected { substate, .. } => substate.queue_sm_request(),
358 }
359 }
360}
361
362struct DriveDuplex<'x> {
363 stream: Pin<&'x mut WorkerStream>,
364 queue: &'x mut TransmitQueue<QueueEntry>,
365}
366
367impl<'x> Future for DriveDuplex<'x> {
368 type Output = Option<WorkerEvent>;
369
370 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
371 let this = self.get_mut();
372 this.stream.as_mut().poll_duplex(this.queue, cx)
373 }
374}
375
376struct DriveWrites<'x> {
377 stream: Pin<&'x mut WorkerStream>,
378 queue: &'x mut TransmitQueue<QueueEntry>,
379}
380
381impl<'x> Future for DriveWrites<'x> {
382 type Output = Never;
383
384 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
385 let this = self.get_mut();
386 this.stream.as_mut().poll_writes(this.queue, cx)
387 }
388}
389
390struct Close<'x> {
391 stream: Pin<&'x mut WorkerStream>,
392}
393
394impl<'x> Future for Close<'x> {
395 type Output = io::Result<()>;
396
397 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
398 let this = self.get_mut();
399 this.stream.as_mut().poll_close(cx)
400 }
401}
402
403pub(super) fn parse_error_to_stream_error(e: xso::error::Error) -> StreamError {
404 use xso::error::Error;
405 let condition = match e {
406 Error::XmlError(_) => DefinedCondition::NotWellFormed,
407 Error::TextParseError(_) | Error::Other(_) => DefinedCondition::InvalidXml,
408 Error::TypeMismatch => DefinedCondition::UnsupportedStanzaType,
409 };
410 StreamError {
411 condition,
412 text: Some((None, e.to_string())),
413 application_specific: vec![],
414 }
415}
416
417pub(super) struct StanzaStreamWorker {
419 reconnector: Box<dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static>,
420 frontend_tx: mpsc::Sender<Event>,
421 stream: WorkerStream,
422 transmit_queue: TransmitQueue<QueueEntry>,
423}
424
425macro_rules! send_or_break {
426 ($value:expr => $permit:ident in $ch:expr, $txq:expr => $stream:expr$(,)?) => {
427 if let Some(permit) = $permit.take() {
428 log::trace!("stanza received, passing to frontend via permit");
429 permit.send($value);
430 } else {
431 log::trace!("no permit for received stanza available, blocking on channel send while handling writes");
432 tokio::select! {
433 result = $stream.drive_writes(&mut $txq) => { match result {} },
437 result = $ch.send($value) => match result {
438 Err(_) => break,
439 Ok(()) => (),
440 },
441 }
442 }
443 };
444}
445
446impl StanzaStreamWorker {
447 pub fn spawn(
448 mut reconnector: Box<
449 dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static,
450 >,
451 queue_depth: usize,
452 ) -> (mpsc::Sender<QueueEntry>, mpsc::Receiver<Event>) {
453 let (conn_tx, conn_rx) = oneshot::channel();
454 reconnector(None, conn_tx);
455 let (c2f_tx, c2f_rx) = mpsc::channel(queue_depth);
457 let (f2c_tx, transmit_queue) = TransmitQueue::channel(queue_depth);
459 let mut worker = StanzaStreamWorker {
460 reconnector,
461 frontend_tx: c2f_tx,
462 stream: WorkerStream::Connecting {
463 slot: conn_rx,
464 sm_state: None,
465 notify: None,
466 },
467 transmit_queue,
468 };
469 tokio::spawn(async move { worker.run().await });
470 (f2c_tx, c2f_rx)
471 }
472
473 pub async fn run(&mut self) {
474 let mut ping_probe_ctr: u64 = thread_rng().gen();
480
481 let mut permit = None;
493 loop {
494 tokio::select! {
495 new_permit = self.frontend_tx.reserve(), if permit.is_none() && !self.frontend_tx.is_closed() => match new_permit {
496 Ok(new_permit) => permit = Some(new_permit),
497 Err(_) => break,
500 },
501 ev = self.stream.drive_duplex(&mut self.transmit_queue) => {
502 let Some(ev) = ev else {
503 break;
505 };
506 match ev {
507 WorkerEvent::Reset { bound_jid, features } => send_or_break!(
508 Event::Stream(StreamEvent::Reset { bound_jid, features }) => permit in self.frontend_tx,
509 self.transmit_queue => self.stream,
510 ),
511 WorkerEvent::Disconnected { slot, error } => {
512 send_or_break!(
513 Event::Stream(StreamEvent::Suspended) => permit in self.frontend_tx,
514 self.transmit_queue => self.stream,
515 );
516 if let Some(error) = error {
517 log::debug!("Backend stream got disconnected because of an I/O error: {error}. Attempting reconnect.");
518 } else {
519 log::debug!("Backend stream got disconnected for an unknown reason. Attempting reconnect.");
520 }
521 if self.frontend_tx.is_closed() || self.transmit_queue.is_closed() {
522 log::debug!("Immediately aborting reconnect because the frontend is gone.");
523 break;
524 }
525 (self.reconnector)(None, slot);
526 }
527 WorkerEvent::Resumed => send_or_break!(
528 Event::Stream(StreamEvent::Resumed) => permit in self.frontend_tx,
529 self.transmit_queue => self.stream,
530 ),
531 WorkerEvent::Stanza(stanza) => send_or_break!(
532 Event::Stanza(stanza) => permit in self.frontend_tx,
533 self.transmit_queue => self.stream,
534 ),
535 WorkerEvent::ParseError(e) => {
536 log::error!("Parse error on stream: {e}");
537 self.stream.start_send_stream_error(parse_error_to_stream_error(e));
538 }
541 WorkerEvent::SoftTimeout => {
542 if self.stream.queue_sm_request() {
543 log::debug!("SoftTimeout tripped: enqueued <sm:r/>");
544 } else {
545 log::debug!("SoftTimeout tripped. Stream Management is not enabled, enqueueing ping IQ");
546 ping_probe_ctr = ping_probe_ctr.wrapping_add(1);
547 self.transmit_queue.enqueue(QueueEntry::untracked(Box::new(iq::Iq::from_get(
553 format!("{}-{}", PING_PROBE_ID_PREFIX, ping_probe_ctr),
554 ping::Ping,
555 ).into())));
556 }
557 }
558 WorkerEvent::ReconnectAborted => {
559 panic!("Backend was unable to handle reconnect request.");
560 }
561 }
562 },
563 }
564 }
565 match self.stream.close().await {
566 Ok(()) => log::debug!("Stream closed successfully"),
567 Err(e) => log::debug!("Stream closure failed: {e}"),
568 }
569 }
570}
571
572async fn shutdown_stream_by_remote_choice(mut stream: XmppStream, timeout: Duration) {
573 let deadline = Instant::now() + timeout;
574 match tokio::time::timeout_at(
575 deadline,
576 <XmppStream as SinkExt<&Stanza>>::close(&mut stream),
577 )
578 .await
579 {
580 Ok(_) => (),
582 Err(_) => {
584 log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
585 return;
586 }
587 }
588 let timeout = tokio::time::sleep_until(deadline);
589 tokio::pin!(timeout);
590 loop {
591 tokio::select! {
592 _ = &mut timeout => {
593 log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
594 break;
595 }
596 ev = stream.next() => match ev {
597 None => break,
598 Some(Ok(data)) => {
599 log::debug!("Ignoring data on stream during shutdown: {data:?}");
600 break;
601 }
602 Some(Err(ReadError::HardError(e))) => {
603 log::debug!("Ignoring stream I/O error during shutdown: {e}");
604 break;
605 }
606 Some(Err(ReadError::SoftTimeout)) => (),
607 Some(Err(ReadError::ParseError(_))) => (),
608 Some(Err(ReadError::StreamFooterReceived)) => (),
609 }
610 }
611 }
612}