tokio_xmpp/connect/
starttls.rs

1//! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections
2
3use alloc::borrow::Cow;
4use std::io;
5use std::os::fd::AsRawFd;
6
7use futures::{sink::SinkExt, stream::StreamExt};
8use sasl::common::ChannelBinding;
9use tokio::{
10    io::{AsyncRead, AsyncWrite, BufStream},
11    net::TcpStream,
12};
13use xmpp_parsers::{
14    jid::Jid,
15    starttls::{self, Request},
16};
17
18use crate::{
19    connect::{
20        tls_common::{establish_tls_connection, TlsConnectorError, TlsStream},
21        DnsConfig, ServerConnector,
22    },
23    error::{Error, ProtocolError},
24    xmlstream::{
25        initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
26        XmppStreamElement,
27    },
28    Client,
29};
30
31/// Client that connects over StartTls
32#[deprecated(since = "5.0.0", note = "use tokio_xmpp::Client instead")]
33pub type StartTlsClient = Client;
34
35/// Connect via TCP+StartTLS to an XMPP server
36#[derive(Debug, Clone)]
37pub struct StartTlsServerConnector(pub DnsConfig);
38
39impl From<DnsConfig> for StartTlsServerConnector {
40    fn from(dns_config: DnsConfig) -> StartTlsServerConnector {
41        Self(dns_config)
42    }
43}
44
45impl ServerConnector for StartTlsServerConnector {
46    type Stream = BufStream<TlsStream<TcpStream>>;
47
48    async fn connect(
49        &self,
50        jid: &Jid,
51        ns: &'static str,
52        timeouts: Timeouts,
53    ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
54        let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
55
56        // Unencryped XmppStream
57        let xmpp_stream = initiate_stream(
58            tcp_stream,
59            ns,
60            StreamHeader {
61                to: Some(Cow::Borrowed(jid.domain().as_str())),
62                from: None,
63                id: None,
64            },
65            timeouts,
66        )
67        .await?;
68        let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
69
70        if features.can_starttls() {
71            // TlsStream
72            let (tls_stream, channel_binding) =
73                starttls(xmpp_stream, jid.domain().as_str()).await?;
74            // Encrypted XmppStream
75            Ok((
76                initiate_stream(
77                    tokio::io::BufStream::new(tls_stream),
78                    ns,
79                    StreamHeader {
80                        to: Some(Cow::Borrowed(jid.domain().as_str())),
81                        from: None,
82                        id: None,
83                    },
84                    timeouts,
85                )
86                .await?,
87                channel_binding,
88            ))
89        } else {
90            Err(crate::Error::Protocol(ProtocolError::NoTls))
91        }
92    }
93}
94
95/// Performs `<starttls/>` on an XmppStream and returns a binary
96/// TlsStream.
97pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
98    mut stream: XmppStream<BufStream<S>>,
99    domain: &str,
100) -> Result<(TlsStream<S>, ChannelBinding), Error> {
101    stream
102        .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
103            Request,
104        )))
105        .await?;
106
107    loop {
108        match stream.next().await {
109            Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
110                break;
111            }
112            Some(Ok(_)) => (),
113            Some(Err(ReadError::SoftTimeout)) => (),
114            Some(Err(ReadError::HardError(e))) => return Err(e.into()),
115            Some(Err(ReadError::ParseError(e))) => {
116                return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
117            }
118            None | Some(Err(ReadError::StreamFooterReceived)) => {
119                return Err(crate::Error::Disconnected)
120            }
121        }
122    }
123
124    let inner_stream = stream.into_inner().into_inner();
125    establish_tls_connection(inner_stream, domain).await
126}
127
128/// StartTLS ServerConnector Error - now just an alias to the common error type
129pub type StartTlsError = TlsConnectorError;