tokio_xmpp/connect/
starttls.rs

1//! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections
2
3use alloc::borrow::Cow;
4use std::io;
5
6use futures::{sink::SinkExt, stream::StreamExt};
7use sasl::common::ChannelBinding;
8use tokio::{io::BufStream, net::TcpStream};
9use xmpp_parsers::{
10    jid::Jid,
11    starttls::{self, Request},
12};
13
14use crate::{
15    connect::{
16        tls_common::{establish_tls_connection, TlsAsyncStream, TlsConnectorError, TlsStream},
17        DnsConfig, ServerConnector,
18    },
19    error::{Error, ProtocolError},
20    xmlstream::{
21        initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
22        XmppStreamElement,
23    },
24    Client,
25};
26
27/// Client that connects over StartTls
28#[deprecated(since = "5.0.0", note = "use tokio_xmpp::Client instead")]
29pub type StartTlsClient = Client;
30
31/// Connect via TCP+StartTLS to an XMPP server
32#[derive(Debug, Clone)]
33pub struct StartTlsServerConnector(pub DnsConfig);
34
35impl From<DnsConfig> for StartTlsServerConnector {
36    fn from(dns_config: DnsConfig) -> StartTlsServerConnector {
37        Self(dns_config)
38    }
39}
40
41impl ServerConnector for StartTlsServerConnector {
42    type Stream = BufStream<TlsStream<TcpStream>>;
43
44    async fn connect(
45        &self,
46        jid: &Jid,
47        ns: &'static str,
48        timeouts: Timeouts,
49    ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
50        let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
51
52        // Unencryped XmppStream
53        let xmpp_stream = initiate_stream(
54            tcp_stream,
55            ns,
56            StreamHeader {
57                to: Some(Cow::Borrowed(jid.domain().as_str())),
58                from: None,
59                id: None,
60            },
61            timeouts,
62        )
63        .await?;
64        let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
65
66        if features.can_starttls() {
67            // TlsStream
68            let (tls_stream, channel_binding) =
69                starttls(xmpp_stream, jid.domain().as_str()).await?;
70            // Encrypted XmppStream
71            Ok((
72                initiate_stream(
73                    tokio::io::BufStream::new(tls_stream),
74                    ns,
75                    StreamHeader {
76                        to: Some(Cow::Borrowed(jid.domain().as_str())),
77                        from: None,
78                        id: None,
79                    },
80                    timeouts,
81                )
82                .await?,
83                channel_binding,
84            ))
85        } else {
86            Err(crate::Error::Protocol(ProtocolError::NoTls))
87        }
88    }
89}
90
91/// Performs `<starttls/>` on an XmppStream and returns a binary
92/// TlsStream.
93pub async fn starttls<S: TlsAsyncStream>(
94    mut stream: XmppStream<BufStream<S>>,
95    domain: &str,
96) -> Result<(TlsStream<S>, ChannelBinding), Error> {
97    stream
98        .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
99            Request,
100        )))
101        .await?;
102
103    loop {
104        match stream.next().await {
105            Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
106                break;
107            }
108            Some(Ok(_)) => (),
109            Some(Err(ReadError::SoftTimeout)) => (),
110            Some(Err(ReadError::HardError(e))) => return Err(e.into()),
111            Some(Err(ReadError::ParseError(e))) => {
112                return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
113            }
114            None | Some(Err(ReadError::StreamFooterReceived)) => {
115                return Err(crate::Error::Disconnected)
116            }
117        }
118    }
119
120    let inner_stream = stream.into_inner().into_inner();
121    establish_tls_connection(inner_stream, domain).await
122}
123
124/// StartTLS ServerConnector Error - now just an alias to the common error type
125pub type StartTlsError = TlsConnectorError;