1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10use core::time::Duration;
11use std::io;
12
13use rand::{thread_rng, Rng};
14
15use futures::{ready, SinkExt, StreamExt};
16
17use tokio::{
18 sync::{mpsc, oneshot},
19 time::Instant,
20};
21
22use xmpp_parsers::{
23 iq,
24 jid::Jid,
25 ping,
26 stream_error::{DefinedCondition, StreamError},
27 stream_features::StreamFeatures,
28};
29
30use crate::connect::AsyncReadAndWrite;
31use crate::xmlstream::{ReadError, XmppStreamElement};
32use crate::Stanza;
33
34use super::connected::{ConnectedEvent, ConnectedState};
35use super::negotiation::NegotiationState;
36use super::queue::{QueueEntry, TransmitQueue};
37use super::stream_management::SmState;
38use super::{Event, StreamEvent};
39
40pub type XmppStream =
43 crate::xmlstream::XmlStream<Box<dyn AsyncReadAndWrite + Send + 'static>, XmppStreamElement>;
44
45pub struct Connection {
47 pub stream: XmppStream,
49
50 pub features: StreamFeatures,
52
53 pub identity: Jid,
60}
61
62pub(super) static LOCAL_SHUTDOWN_TIMEOUT: Duration = Duration::new(10, 0);
65pub(super) static REMOTE_SHUTDOWN_TIMEOUT: Duration = Duration::new(5, 0);
66pub(super) static PING_PROBE_ID_PREFIX: &str = "xmpp-rs-stanzastream-liveness-probe";
67
68pub(super) enum Never {}
69
70pub(super) enum WorkerEvent {
71 Reset {
73 bound_jid: Jid,
74 features: StreamFeatures,
75 },
76
77 Resumed,
79
80 Stanza(Stanza),
82
83 ParseError(xso::error::Error),
85
86 SoftTimeout,
88
89 Disconnected {
91 slot: oneshot::Sender<Connection>,
93
94 error: Option<io::Error>,
96 },
97
98 ReconnectAborted,
100}
101
102enum WorkerStream {
103 Connecting {
105 notify: Option<(oneshot::Sender<Connection>, Option<io::Error>)>,
107
108 slot: oneshot::Receiver<Connection>,
110
111 sm_state: Option<SmState>,
113 },
114
115 Connected {
117 stream: XmppStream,
118 substate: ConnectedState,
119 features: StreamFeatures,
120 identity: Jid,
121 },
122
123 Terminated,
125}
126
127impl WorkerStream {
128 fn disconnect(&mut self, sm_state: Option<SmState>, error: Option<io::Error>) -> WorkerEvent {
129 let (tx, rx) = oneshot::channel();
130 *self = Self::Connecting {
131 notify: None,
132 slot: rx,
133 sm_state,
134 };
135 WorkerEvent::Disconnected { slot: tx, error }
136 }
137
138 fn poll_duplex(
139 self: Pin<&mut Self>,
140 transmit_queue: &mut TransmitQueue<QueueEntry>,
141 cx: &mut Context<'_>,
142 ) -> Poll<Option<WorkerEvent>> {
143 let this = self.get_mut();
144 loop {
145 match this {
146 Self::Terminated => return Poll::Ready(None),
149
150 Self::Connecting {
153 notify,
154 slot,
155 sm_state,
156 } => {
157 if let Some((slot, error)) = notify.take() {
158 return Poll::Ready(Some(WorkerEvent::Disconnected { slot, error }));
159 }
160
161 match ready!(Pin::new(slot).poll(cx)) {
162 Ok(Connection {
163 stream,
164 features,
165 identity,
166 }) => {
167 let substate = ConnectedState::Negotiating {
168 substate: NegotiationState::new(&features, sm_state.take())
173 .expect("Non-negotiable stream"),
174 };
175 *this = Self::Connected {
176 substate,
177 stream,
178 features,
179 identity,
180 };
181 }
182 Err(_) => {
183 *this = Self::Terminated;
185 return Poll::Ready(Some(WorkerEvent::ReconnectAborted));
186 }
187 }
188 }
189
190 Self::Connected {
191 stream,
192 identity,
193 substate,
194 features,
195 } => {
196 match ready!(substate.poll(
197 Pin::new(stream),
198 identity,
199 features,
200 transmit_queue,
201 cx
202 )) {
203 None => (),
205
206 Some(ConnectedEvent::Worker(v)) => {
208 if let WorkerEvent::Reset { ref bound_jid, .. } = v {
210 *identity = bound_jid.clone();
211 }
212 return Poll::Ready(Some(v));
213 }
214
215 Some(ConnectedEvent::Disconnect { sm_state, error }) => {
217 return Poll::Ready(Some(this.disconnect(sm_state, error)));
218 }
219
220 Some(ConnectedEvent::RemoteShutdown { sm_state }) => {
221 let error = io::Error::new(
222 io::ErrorKind::ConnectionAborted,
223 "peer closed the XML stream",
224 );
225 let (tx, rx) = oneshot::channel();
226 let mut new_state = Self::Connecting {
227 notify: None,
228 slot: rx,
229 sm_state,
230 };
231 core::mem::swap(this, &mut new_state);
232 match new_state {
233 Self::Connected { stream, .. } => {
234 tokio::spawn(shutdown_stream_by_remote_choice(
235 stream,
236 REMOTE_SHUTDOWN_TIMEOUT,
237 ));
238 }
239 _ => unreachable!(),
240 }
241
242 return Poll::Ready(Some(WorkerEvent::Disconnected {
243 slot: tx,
244 error: Some(error),
245 }));
246 }
247
248 Some(ConnectedEvent::LocalShutdownRequested) => {
249 return Poll::Ready(None);
252 }
253 }
254 }
255 }
256 }
257 }
258
259 fn poll_writes(
273 &mut self,
274 transmit_queue: &mut TransmitQueue<QueueEntry>,
275 cx: &mut Context,
276 ) -> Poll<Never> {
277 match self {
278 Self::Terminated | Self::Connecting { .. } => Poll::Pending,
279 Self::Connected {
280 substate, stream, ..
281 } => {
282 ready!(substate.poll_writes(Pin::new(stream), transmit_queue, cx));
283 Poll::Pending
284 }
285 }
286 }
287
288 fn start_send_stream_error(&mut self, error: StreamError) {
289 match self {
290 Self::Terminated | Self::Connecting { .. } => {
293 *self = Self::Terminated;
294 }
295
296 Self::Connected { substate, .. } => substate.start_send_stream_error(error),
297 }
298 }
299
300 fn poll_close(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
301 match self {
302 Self::Terminated => Poll::Ready(Ok(())),
303 Self::Connecting { .. } => {
304 *self = Self::Terminated;
305 Poll::Ready(Ok(()))
306 }
307 Self::Connected {
308 substate, stream, ..
309 } => {
310 let result = ready!(substate.poll_close(Pin::new(stream), cx));
311 *self = Self::Terminated;
312 Poll::Ready(result)
313 }
314 }
315 }
316
317 fn drive_duplex<'a>(
318 &'a mut self,
319 transmit_queue: &'a mut TransmitQueue<QueueEntry>,
320 ) -> DriveDuplex<'a> {
321 DriveDuplex {
322 stream: Pin::new(self),
323 queue: transmit_queue,
324 }
325 }
326
327 fn drive_writes<'a>(
328 &'a mut self,
329 transmit_queue: &'a mut TransmitQueue<QueueEntry>,
330 ) -> DriveWrites<'a> {
331 DriveWrites {
332 stream: Pin::new(self),
333 queue: transmit_queue,
334 }
335 }
336
337 fn close(&mut self) -> Close {
338 Close {
339 stream: Pin::new(self),
340 }
341 }
342
343 fn queue_sm_request(&mut self) -> bool {
351 match self {
352 Self::Terminated | Self::Connecting { .. } => false,
353 Self::Connected { substate, .. } => substate.queue_sm_request(),
354 }
355 }
356}
357
358struct DriveDuplex<'x> {
359 stream: Pin<&'x mut WorkerStream>,
360 queue: &'x mut TransmitQueue<QueueEntry>,
361}
362
363impl Future for DriveDuplex<'_> {
364 type Output = Option<WorkerEvent>;
365
366 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
367 let this = self.get_mut();
368 this.stream.as_mut().poll_duplex(this.queue, cx)
369 }
370}
371
372struct DriveWrites<'x> {
373 stream: Pin<&'x mut WorkerStream>,
374 queue: &'x mut TransmitQueue<QueueEntry>,
375}
376
377impl Future for DriveWrites<'_> {
378 type Output = Never;
379
380 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
381 let this = self.get_mut();
382 this.stream.as_mut().poll_writes(this.queue, cx)
383 }
384}
385
386struct Close<'x> {
387 stream: Pin<&'x mut WorkerStream>,
388}
389
390impl Future for Close<'_> {
391 type Output = io::Result<()>;
392
393 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
394 let this = self.get_mut();
395 this.stream.as_mut().poll_close(cx)
396 }
397}
398
399pub(super) fn parse_error_to_stream_error(e: xso::error::Error) -> StreamError {
400 use xso::error::Error;
401 let condition = match e {
402 Error::XmlError(_) => DefinedCondition::NotWellFormed,
403 Error::TextParseError(_) | Error::Other(_) => DefinedCondition::InvalidXml,
404 Error::TypeMismatch => DefinedCondition::UnsupportedStanzaType,
405 };
406 StreamError {
407 condition,
408 text: Some((None, e.to_string())),
409 application_specific: vec![],
410 }
411}
412
413pub(super) struct StanzaStreamWorker {
415 reconnector: Box<dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static>,
416 frontend_tx: mpsc::Sender<Event>,
417 stream: WorkerStream,
418 transmit_queue: TransmitQueue<QueueEntry>,
419}
420
421macro_rules! send_or_break {
422 ($value:expr => $permit:ident in $ch:expr, $txq:expr => $stream:expr$(,)?) => {
423 if let Some(permit) = $permit.take() {
424 log::trace!("stanza received, passing to frontend via permit");
425 permit.send($value);
426 } else {
427 log::trace!("no permit for received stanza available, blocking on channel send while handling writes");
428 tokio::select! {
429 result = $stream.drive_writes(&mut $txq) => { match result {} },
433 result = $ch.send($value) => match result {
434 Err(_) => break,
435 Ok(()) => (),
436 },
437 }
438 }
439 };
440}
441
442impl StanzaStreamWorker {
443 pub fn spawn(
444 mut reconnector: Box<
445 dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static,
446 >,
447 queue_depth: usize,
448 ) -> (mpsc::Sender<QueueEntry>, mpsc::Receiver<Event>) {
449 let (conn_tx, conn_rx) = oneshot::channel();
450 reconnector(None, conn_tx);
451 let (c2f_tx, c2f_rx) = mpsc::channel(queue_depth);
453 let (f2c_tx, transmit_queue) = TransmitQueue::channel(queue_depth);
455 let mut worker = StanzaStreamWorker {
456 reconnector,
457 frontend_tx: c2f_tx,
458 stream: WorkerStream::Connecting {
459 slot: conn_rx,
460 sm_state: None,
461 notify: None,
462 },
463 transmit_queue,
464 };
465 tokio::spawn(async move { worker.run().await });
466 (f2c_tx, c2f_rx)
467 }
468
469 pub async fn run(&mut self) {
470 let mut ping_probe_ctr: u64 = thread_rng().gen();
476
477 let mut permit = None;
489 loop {
490 tokio::select! {
491 new_permit = self.frontend_tx.reserve(), if permit.is_none() && !self.frontend_tx.is_closed() => match new_permit {
492 Ok(new_permit) => permit = Some(new_permit),
493 Err(_) => break,
496 },
497 ev = self.stream.drive_duplex(&mut self.transmit_queue) => {
498 let Some(ev) = ev else {
499 break;
501 };
502 match ev {
503 WorkerEvent::Reset { bound_jid, features } => send_or_break!(
504 Event::Stream(StreamEvent::Reset { bound_jid, features }) => permit in self.frontend_tx,
505 self.transmit_queue => self.stream,
506 ),
507 WorkerEvent::Disconnected { slot, error } => {
508 send_or_break!(
509 Event::Stream(StreamEvent::Suspended) => permit in self.frontend_tx,
510 self.transmit_queue => self.stream,
511 );
512 if let Some(error) = error {
513 log::debug!("Backend stream got disconnected because of an I/O error: {error}. Attempting reconnect.");
514 } else {
515 log::debug!("Backend stream got disconnected for an unknown reason. Attempting reconnect.");
516 }
517 if self.frontend_tx.is_closed() || self.transmit_queue.is_closed() {
518 log::debug!("Immediately aborting reconnect because the frontend is gone.");
519 break;
520 }
521 (self.reconnector)(None, slot);
522 }
523 WorkerEvent::Resumed => send_or_break!(
524 Event::Stream(StreamEvent::Resumed) => permit in self.frontend_tx,
525 self.transmit_queue => self.stream,
526 ),
527 WorkerEvent::Stanza(stanza) => send_or_break!(
528 Event::Stanza(stanza) => permit in self.frontend_tx,
529 self.transmit_queue => self.stream,
530 ),
531 WorkerEvent::ParseError(e) => {
532 log::error!("Parse error on stream: {e}");
533 self.stream.start_send_stream_error(parse_error_to_stream_error(e));
534 }
537 WorkerEvent::SoftTimeout => {
538 if self.stream.queue_sm_request() {
539 log::debug!("SoftTimeout tripped: enqueued <sm:r/>");
540 } else {
541 log::debug!("SoftTimeout tripped. Stream Management is not enabled, enqueueing ping IQ");
542 ping_probe_ctr = ping_probe_ctr.wrapping_add(1);
543 self.transmit_queue.enqueue(QueueEntry::untracked(Box::new(iq::Iq::from_get(
549 format!("{}-{}", PING_PROBE_ID_PREFIX, ping_probe_ctr),
550 ping::Ping,
551 ).into())));
552 }
553 }
554 WorkerEvent::ReconnectAborted => {
555 panic!("Backend was unable to handle reconnect request.");
556 }
557 }
558 },
559 }
560 }
561 match self.stream.close().await {
562 Ok(()) => log::debug!("Stream closed successfully"),
563 Err(e) => log::debug!("Stream closure failed: {e}"),
564 }
565 }
566}
567
568async fn shutdown_stream_by_remote_choice(mut stream: XmppStream, timeout: Duration) {
569 let deadline = Instant::now() + timeout;
570 match tokio::time::timeout_at(
571 deadline,
572 <XmppStream as SinkExt<&Stanza>>::close(&mut stream),
573 )
574 .await
575 {
576 Ok(_) => (),
578 Err(_) => {
580 log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
581 return;
582 }
583 }
584 let timeout = tokio::time::sleep_until(deadline);
585 tokio::pin!(timeout);
586 loop {
587 tokio::select! {
588 _ = &mut timeout => {
589 log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
590 break;
591 }
592 ev = stream.next() => match ev {
593 None => break,
594 Some(Ok(data)) => {
595 log::debug!("Ignoring data on stream during shutdown: {data:?}");
596 break;
597 }
598 Some(Err(ReadError::HardError(e))) => {
599 log::debug!("Ignoring stream I/O error during shutdown: {e}");
600 break;
601 }
602 Some(Err(ReadError::SoftTimeout)) => (),
603 Some(Err(ReadError::ParseError(_))) => (),
604 Some(Err(ReadError::StreamFooterReceived)) => (),
605 }
606 }
607 }
608}