1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10use core::time::Duration;
11use std::io;
12
13use futures::{ready, SinkExt, StreamExt};
14
15use tokio::{
16 sync::{mpsc, oneshot},
17 time::Instant,
18};
19
20use xmpp_parsers::{
21 iq,
22 jid::Jid,
23 ping,
24 stream_error::{DefinedCondition, StreamError},
25 stream_features::StreamFeatures,
26};
27
28use crate::connect::AsyncReadAndWrite;
29use crate::xmlstream::{ReadError, XmppStreamElement};
30use crate::Stanza;
31
32use super::connected::{ConnectedEvent, ConnectedState};
33use super::negotiation::NegotiationState;
34use super::queue::{QueueEntry, TransmitQueue};
35use super::stream_management::SmState;
36use super::{Event, StreamEvent};
37
38pub type XmppStream =
41 crate::xmlstream::XmlStream<Box<dyn AsyncReadAndWrite + Send + 'static>, XmppStreamElement>;
42
43pub struct Connection {
45 pub stream: XmppStream,
47
48 pub features: StreamFeatures,
50
51 pub identity: Jid,
58}
59
60pub(super) static LOCAL_SHUTDOWN_TIMEOUT: Duration = Duration::new(10, 0);
63pub(super) static REMOTE_SHUTDOWN_TIMEOUT: Duration = Duration::new(5, 0);
64pub(super) static PING_PROBE_ID_PREFIX: &str = "xmpp-rs-stanzastream-liveness-probe";
65
66pub(super) enum Never {}
67
68pub(super) enum WorkerEvent {
69 Reset {
71 bound_jid: Jid,
72 features: StreamFeatures,
73 },
74
75 Resumed,
77
78 Stanza(Stanza),
80
81 ParseError(xso::error::Error),
83
84 SoftTimeout,
86
87 Disconnected {
89 slot: oneshot::Sender<Connection>,
91
92 error: Option<io::Error>,
94 },
95
96 ReconnectAborted,
98}
99
100enum WorkerStream {
101 Connecting {
103 notify: Option<(oneshot::Sender<Connection>, Option<io::Error>)>,
105
106 slot: oneshot::Receiver<Connection>,
108
109 sm_state: Option<SmState>,
111 },
112
113 Connected {
115 stream: XmppStream,
116 substate: ConnectedState,
117 features: StreamFeatures,
118 identity: Jid,
119 },
120
121 Terminated,
123}
124
125impl WorkerStream {
126 fn disconnect(&mut self, sm_state: Option<SmState>, error: Option<io::Error>) -> WorkerEvent {
127 let (tx, rx) = oneshot::channel();
128 *self = Self::Connecting {
129 notify: None,
130 slot: rx,
131 sm_state,
132 };
133 WorkerEvent::Disconnected { slot: tx, error }
134 }
135
136 fn poll_duplex(
137 self: Pin<&mut Self>,
138 transmit_queue: &mut TransmitQueue<QueueEntry>,
139 cx: &mut Context<'_>,
140 ) -> Poll<Option<WorkerEvent>> {
141 let this = self.get_mut();
142 loop {
143 match this {
144 Self::Terminated => return Poll::Ready(None),
147
148 Self::Connecting {
151 notify,
152 slot,
153 sm_state,
154 } => {
155 if let Some((slot, error)) = notify.take() {
156 return Poll::Ready(Some(WorkerEvent::Disconnected { slot, error }));
157 }
158
159 match ready!(Pin::new(slot).poll(cx)) {
160 Ok(Connection {
161 stream,
162 features,
163 identity,
164 }) => {
165 let substate = ConnectedState::Negotiating {
166 substate: NegotiationState::new(&features, sm_state.take())
171 .expect("Non-negotiable stream"),
172 };
173 *this = Self::Connected {
174 substate,
175 stream,
176 features,
177 identity,
178 };
179 }
180 Err(_) => {
181 *this = Self::Terminated;
183 return Poll::Ready(Some(WorkerEvent::ReconnectAborted));
184 }
185 }
186 }
187
188 Self::Connected {
189 stream,
190 identity,
191 substate,
192 features,
193 } => {
194 match ready!(substate.poll(
195 Pin::new(stream),
196 identity,
197 features,
198 transmit_queue,
199 cx
200 )) {
201 None => (),
203
204 Some(ConnectedEvent::Worker(v)) => {
206 if let WorkerEvent::Reset { ref bound_jid, .. } = v {
208 *identity = bound_jid.clone();
209 }
210 return Poll::Ready(Some(v));
211 }
212
213 Some(ConnectedEvent::Disconnect { sm_state, error }) => {
215 return Poll::Ready(Some(this.disconnect(sm_state, error)));
216 }
217
218 Some(ConnectedEvent::RemoteShutdown { sm_state }) => {
219 let error = io::Error::new(
220 io::ErrorKind::ConnectionAborted,
221 "peer closed the XML stream",
222 );
223 let (tx, rx) = oneshot::channel();
224 let mut new_state = Self::Connecting {
225 notify: None,
226 slot: rx,
227 sm_state,
228 };
229 core::mem::swap(this, &mut new_state);
230 match new_state {
231 Self::Connected { stream, .. } => {
232 tokio::spawn(shutdown_stream_by_remote_choice(
233 stream,
234 REMOTE_SHUTDOWN_TIMEOUT,
235 ));
236 }
237 _ => unreachable!(),
238 }
239
240 return Poll::Ready(Some(WorkerEvent::Disconnected {
241 slot: tx,
242 error: Some(error),
243 }));
244 }
245
246 Some(ConnectedEvent::LocalShutdownRequested) => {
247 return Poll::Ready(None);
250 }
251 }
252 }
253 }
254 }
255 }
256
257 fn poll_writes(
271 &mut self,
272 transmit_queue: &mut TransmitQueue<QueueEntry>,
273 cx: &mut Context,
274 ) -> Poll<Never> {
275 match self {
276 Self::Terminated | Self::Connecting { .. } => Poll::Pending,
277 Self::Connected {
278 substate, stream, ..
279 } => {
280 ready!(substate.poll_writes(Pin::new(stream), transmit_queue, cx));
281 Poll::Pending
282 }
283 }
284 }
285
286 fn start_send_stream_error(&mut self, error: StreamError) {
287 match self {
288 Self::Terminated | Self::Connecting { .. } => {
291 *self = Self::Terminated;
292 }
293
294 Self::Connected { substate, .. } => substate.start_send_stream_error(error),
295 }
296 }
297
298 fn poll_close(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
299 match self {
300 Self::Terminated => Poll::Ready(Ok(())),
301 Self::Connecting { .. } => {
302 *self = Self::Terminated;
303 Poll::Ready(Ok(()))
304 }
305 Self::Connected {
306 substate, stream, ..
307 } => {
308 let result = ready!(substate.poll_close(Pin::new(stream), cx));
309 *self = Self::Terminated;
310 Poll::Ready(result)
311 }
312 }
313 }
314
315 fn drive_duplex<'a>(
316 &'a mut self,
317 transmit_queue: &'a mut TransmitQueue<QueueEntry>,
318 ) -> DriveDuplex<'a> {
319 DriveDuplex {
320 stream: Pin::new(self),
321 queue: transmit_queue,
322 }
323 }
324
325 fn drive_writes<'a>(
326 &'a mut self,
327 transmit_queue: &'a mut TransmitQueue<QueueEntry>,
328 ) -> DriveWrites<'a> {
329 DriveWrites {
330 stream: Pin::new(self),
331 queue: transmit_queue,
332 }
333 }
334
335 fn close(&mut self) -> Close<'_> {
336 Close {
337 stream: Pin::new(self),
338 }
339 }
340
341 fn queue_sm_request(&mut self) -> bool {
349 match self {
350 Self::Terminated | Self::Connecting { .. } => false,
351 Self::Connected { substate, .. } => substate.queue_sm_request(),
352 }
353 }
354}
355
356struct DriveDuplex<'x> {
357 stream: Pin<&'x mut WorkerStream>,
358 queue: &'x mut TransmitQueue<QueueEntry>,
359}
360
361impl Future for DriveDuplex<'_> {
362 type Output = Option<WorkerEvent>;
363
364 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
365 let this = self.get_mut();
366 this.stream.as_mut().poll_duplex(this.queue, cx)
367 }
368}
369
370struct DriveWrites<'x> {
371 stream: Pin<&'x mut WorkerStream>,
372 queue: &'x mut TransmitQueue<QueueEntry>,
373}
374
375impl Future for DriveWrites<'_> {
376 type Output = Never;
377
378 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
379 let this = self.get_mut();
380 this.stream.as_mut().poll_writes(this.queue, cx)
381 }
382}
383
384struct Close<'x> {
385 stream: Pin<&'x mut WorkerStream>,
386}
387
388impl Future for Close<'_> {
389 type Output = io::Result<()>;
390
391 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
392 let this = self.get_mut();
393 this.stream.as_mut().poll_close(cx)
394 }
395}
396
397pub(super) fn parse_error_to_stream_error(e: xso::error::Error) -> StreamError {
398 use xso::error::Error;
399 let condition = match e {
400 Error::XmlError(_) => DefinedCondition::NotWellFormed,
401 Error::TextParseError(_) | Error::Other(_) => DefinedCondition::InvalidXml,
402 Error::TypeMismatch => DefinedCondition::UnsupportedStanzaType,
403 };
404 StreamError::new(condition, "en", e.to_string())
405}
406
407pub(super) struct StanzaStreamWorker {
409 reconnector: Box<dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static>,
410 frontend_tx: mpsc::Sender<Event>,
411 stream: WorkerStream,
412 transmit_queue: TransmitQueue<QueueEntry>,
413}
414
415macro_rules! send_or_break {
416 ($value:expr => $permit:ident in $ch:expr, $txq:expr => $stream:expr$(,)?) => {
417 if let Some(permit) = $permit.take() {
418 log::trace!("stanza received, passing to frontend via permit");
419 permit.send($value);
420 } else {
421 log::trace!("no permit for received stanza available, blocking on channel send while handling writes");
422 tokio::select! {
423 result = $stream.drive_writes(&mut $txq) => { match result {} },
427 result = $ch.send($value) => match result {
428 Err(_) => break,
429 Ok(()) => (),
430 },
431 }
432 }
433 };
434}
435
436impl StanzaStreamWorker {
437 pub fn spawn(
438 mut reconnector: Box<
439 dyn FnMut(Option<String>, oneshot::Sender<Connection>) + Send + 'static,
440 >,
441 queue_depth: usize,
442 ) -> (mpsc::Sender<QueueEntry>, mpsc::Receiver<Event>) {
443 let (conn_tx, conn_rx) = oneshot::channel();
444 reconnector(None, conn_tx);
445 let (c2f_tx, c2f_rx) = mpsc::channel(queue_depth);
447 let (f2c_tx, transmit_queue) = TransmitQueue::channel(queue_depth);
449 let mut worker = StanzaStreamWorker {
450 reconnector,
451 frontend_tx: c2f_tx,
452 stream: WorkerStream::Connecting {
453 slot: conn_rx,
454 sm_state: None,
455 notify: None,
456 },
457 transmit_queue,
458 };
459 tokio::spawn(async move { worker.run().await });
460 (f2c_tx, c2f_rx)
461 }
462
463 pub async fn run(&mut self) {
464 let mut ping_probe_ctr: u64 = rand::random();
470
471 let mut permit = None;
483 loop {
484 tokio::select! {
485 new_permit = self.frontend_tx.reserve(), if permit.is_none() && !self.frontend_tx.is_closed() => match new_permit {
486 Ok(new_permit) => permit = Some(new_permit),
487 Err(_) => break,
490 },
491 ev = self.stream.drive_duplex(&mut self.transmit_queue) => {
492 let Some(ev) = ev else {
493 break;
495 };
496 match ev {
497 WorkerEvent::Reset { bound_jid, features } => send_or_break!(
498 Event::Stream(StreamEvent::Reset { bound_jid, features }) => permit in self.frontend_tx,
499 self.transmit_queue => self.stream,
500 ),
501 WorkerEvent::Disconnected { slot, error } => {
502 send_or_break!(
503 Event::Stream(StreamEvent::Suspended) => permit in self.frontend_tx,
504 self.transmit_queue => self.stream,
505 );
506 if let Some(error) = error {
507 log::debug!("Backend stream got disconnected because of an I/O error: {error}. Attempting reconnect.");
508 } else {
509 log::debug!("Backend stream got disconnected for an unknown reason. Attempting reconnect.");
510 }
511 if self.frontend_tx.is_closed() || self.transmit_queue.is_closed() {
512 log::debug!("Immediately aborting reconnect because the frontend is gone.");
513 break;
514 }
515 (self.reconnector)(None, slot);
516 }
517 WorkerEvent::Resumed => send_or_break!(
518 Event::Stream(StreamEvent::Resumed) => permit in self.frontend_tx,
519 self.transmit_queue => self.stream,
520 ),
521 WorkerEvent::Stanza(stanza) => send_or_break!(
522 Event::Stanza(stanza) => permit in self.frontend_tx,
523 self.transmit_queue => self.stream,
524 ),
525 WorkerEvent::ParseError(e) => {
526 log::error!("Parse error on stream: {e}");
527 self.stream.start_send_stream_error(parse_error_to_stream_error(e));
528 }
531 WorkerEvent::SoftTimeout => {
532 if self.stream.queue_sm_request() {
533 log::debug!("SoftTimeout tripped: enqueued <sm:r/>");
534 } else {
535 log::debug!("SoftTimeout tripped. Stream Management is not enabled, enqueueing ping IQ");
536 ping_probe_ctr = ping_probe_ctr.wrapping_add(1);
537 self.transmit_queue.enqueue(QueueEntry::untracked(Box::new(iq::Iq::from_get(
543 format!("{}-{}", PING_PROBE_ID_PREFIX, ping_probe_ctr),
544 ping::Ping,
545 ).into())));
546 }
547 }
548 WorkerEvent::ReconnectAborted => {
549 panic!("Backend was unable to handle reconnect request.");
550 }
551 }
552 },
553 }
554 }
555 match self.stream.close().await {
556 Ok(()) => log::debug!("Stream closed successfully"),
557 Err(e) => log::debug!("Stream closure failed: {e}"),
558 }
559 }
560}
561
562async fn shutdown_stream_by_remote_choice(mut stream: XmppStream, timeout: Duration) {
563 let deadline = Instant::now() + timeout;
564 match tokio::time::timeout_at(
565 deadline,
566 <XmppStream as SinkExt<&Stanza>>::close(&mut stream),
567 )
568 .await
569 {
570 Ok(_) => (),
572 Err(_) => {
574 log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
575 return;
576 }
577 }
578 let timeout = tokio::time::sleep_until(deadline);
579 tokio::pin!(timeout);
580 loop {
581 tokio::select! {
582 _ = &mut timeout => {
583 log::debug!("Giving up on clean stream shutdown after timeout elapsed.");
584 break;
585 }
586 ev = stream.next() => match ev {
587 None => break,
588 Some(Ok(data)) => {
589 log::debug!("Ignoring data on stream during shutdown: {data:?}");
590 break;
591 }
592 Some(Err(ReadError::HardError(e))) => {
593 log::debug!("Ignoring stream I/O error during shutdown: {e}");
594 break;
595 }
596 Some(Err(ReadError::SoftTimeout)) => (),
597 Some(Err(ReadError::ParseError(_))) => (),
598 Some(Err(ReadError::StreamFooterReceived)) => (),
599 }
600 }
601 }
602}