tokio_xmpp/connect/
starttls.rs

1//! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections
2
3use alloc::borrow::Cow;
4use core::{error::Error as StdError, fmt};
5#[cfg(feature = "native-tls")]
6use native_tls::Error as TlsError;
7use std::io;
8use std::os::fd::AsRawFd;
9#[cfg(feature = "rustls-any-backend")]
10use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
11// Note: feature = "rustls-any-backend" and feature = "native-tls" are
12// mutually exclusive during normal compiles, but we allow it for rustdoc
13// builds. Thus, we have to make sure that the compilation still succeeds in
14// such a case.
15#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
16use tokio_rustls::rustls::Error as TlsError;
17
18use futures::{sink::SinkExt, stream::StreamExt};
19
20#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
21use {
22    alloc::sync::Arc,
23    tokio_rustls::{
24        rustls::pki_types::ServerName,
25        rustls::{ClientConfig, RootCertStore},
26        TlsConnector,
27    },
28};
29
30#[cfg(all(
31    feature = "rustls-any-backend",
32    not(feature = "ktls"),
33    not(feature = "native-tls")
34))]
35use tokio_rustls::client::TlsStream;
36
37#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
38type TlsStream<S> = ktls::KtlsStream<S>;
39
40#[cfg(feature = "native-tls")]
41use {
42    native_tls::TlsConnector as NativeTlsConnector,
43    tokio_native_tls::{TlsConnector, TlsStream},
44};
45
46use sasl::common::ChannelBinding;
47use tokio::{
48    io::{AsyncRead, AsyncWrite, BufStream},
49    net::TcpStream,
50};
51use xmpp_parsers::{
52    jid::Jid,
53    starttls::{self, Request},
54};
55
56use crate::{
57    connect::{DnsConfig, ServerConnector, ServerConnectorError},
58    error::{Error, ProtocolError},
59    xmlstream::{
60        initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
61        XmppStreamElement,
62    },
63    Client,
64};
65
66/// Client that connects over StartTls
67#[deprecated(since = "5.0.0", note = "use tokio_xmpp::Client instead")]
68pub type StartTlsClient = Client;
69
70/// Connect via TCP+StartTLS to an XMPP server
71#[derive(Debug, Clone)]
72pub struct StartTlsServerConnector(pub DnsConfig);
73
74impl From<DnsConfig> for StartTlsServerConnector {
75    fn from(dns_config: DnsConfig) -> StartTlsServerConnector {
76        Self(dns_config)
77    }
78}
79
80impl ServerConnector for StartTlsServerConnector {
81    type Stream = BufStream<TlsStream<TcpStream>>;
82
83    async fn connect(
84        &self,
85        jid: &Jid,
86        ns: &'static str,
87        timeouts: Timeouts,
88    ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
89        let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
90
91        // Unencryped XmppStream
92        let xmpp_stream = initiate_stream(
93            tcp_stream,
94            ns,
95            StreamHeader {
96                to: Some(Cow::Borrowed(jid.domain().as_str())),
97                from: None,
98                id: None,
99            },
100            timeouts,
101        )
102        .await?;
103        let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
104
105        if features.can_starttls() {
106            // TlsStream
107            let (tls_stream, channel_binding) =
108                starttls(xmpp_stream, jid.domain().as_str()).await?;
109            // Encrypted XmppStream
110            Ok((
111                initiate_stream(
112                    tokio::io::BufStream::new(tls_stream),
113                    ns,
114                    StreamHeader {
115                        to: Some(Cow::Borrowed(jid.domain().as_str())),
116                        from: None,
117                        id: None,
118                    },
119                    timeouts,
120                )
121                .await?,
122                channel_binding,
123            ))
124        } else {
125            Err(crate::Error::Protocol(ProtocolError::NoTls))
126        }
127    }
128}
129
130#[cfg(feature = "native-tls")]
131async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
132    xmpp_stream: XmppStream<BufStream<S>>,
133    domain: &str,
134) -> Result<(TlsStream<S>, ChannelBinding), Error> {
135    let domain = domain.to_owned();
136    let stream = xmpp_stream.into_inner().into_inner();
137    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
138        .connect(&domain, stream)
139        .await
140        .map_err(|e| StartTlsError::Tls(e))?;
141    log::warn!(
142        "tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"
143    );
144    Ok((tls_stream, ChannelBinding::None))
145}
146
147#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
148async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
149    xmpp_stream: XmppStream<BufStream<S>>,
150    domain: &str,
151) -> Result<(TlsStream<S>, ChannelBinding), Error> {
152    let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
153    let stream = xmpp_stream.into_inner().into_inner();
154    let mut root_store = RootCertStore::empty();
155    #[cfg(feature = "webpki-roots")]
156    {
157        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
158    }
159    #[cfg(feature = "rustls-native-certs")]
160    {
161        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
162    }
163    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
164    let mut config = ClientConfig::builder()
165        .with_root_certificates(root_store)
166        .with_no_client_auth();
167    #[cfg(feature = "ktls")]
168    let stream = {
169        config.enable_secret_extraction = true;
170        ktls::CorkStream::new(stream)
171    };
172    let tls_stream = TlsConnector::from(Arc::new(config))
173        .connect(domain, stream)
174        .await
175        .map_err(crate::Error::Io)?;
176
177    // Extract the channel-binding information before we hand the stream over to ktls.
178    let (_, connection) = tls_stream.get_ref();
179    let channel_binding = match connection.protocol_version() {
180        // TODO: Add support for TLS 1.2 and earlier.
181        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
182            let data = vec![0u8; 32];
183            let data = connection
184                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
185                .map_err(StartTlsError::Tls)?;
186            ChannelBinding::TlsExporter(data)
187        }
188        _ => ChannelBinding::None,
189    };
190
191    #[cfg(feature = "ktls")]
192    let tls_stream = ktls::config_ktls_client(tls_stream)
193        .await
194        .map_err(StartTlsError::KtlsError)?;
195    Ok((tls_stream, channel_binding))
196}
197
198/// Performs `<starttls/>` on an XmppStream and returns a binary
199/// TlsStream.
200pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
201    mut stream: XmppStream<BufStream<S>>,
202    domain: &str,
203) -> Result<(TlsStream<S>, ChannelBinding), Error> {
204    stream
205        .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
206            Request,
207        )))
208        .await?;
209
210    loop {
211        match stream.next().await {
212            Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
213                break;
214            }
215            Some(Ok(_)) => (),
216            Some(Err(ReadError::SoftTimeout)) => (),
217            Some(Err(ReadError::HardError(e))) => return Err(e.into()),
218            Some(Err(ReadError::ParseError(e))) => {
219                return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
220            }
221            None | Some(Err(ReadError::StreamFooterReceived)) => {
222                return Err(crate::Error::Disconnected)
223            }
224        }
225    }
226
227    get_tls_stream(stream, domain).await
228}
229
230/// StartTLS ServerConnector Error
231#[derive(Debug)]
232pub enum StartTlsError {
233    /// TLS error
234    Tls(TlsError),
235    #[cfg(feature = "rustls-any-backend")]
236    /// DNS name parsing error
237    DnsNameError(InvalidDnsNameError),
238    #[cfg(feature = "ktls")]
239    /// Error while setting up kernel TLS
240    KtlsError(ktls::Error),
241}
242
243impl ServerConnectorError for StartTlsError {}
244
245impl fmt::Display for StartTlsError {
246    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
247        match self {
248            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
249            #[cfg(feature = "rustls-any-backend")]
250            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
251            #[cfg(feature = "ktls")]
252            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
253        }
254    }
255}
256
257impl StdError for StartTlsError {}
258
259impl From<TlsError> for StartTlsError {
260    fn from(e: TlsError) -> Self {
261        Self::Tls(e)
262    }
263}
264
265#[cfg(feature = "rustls-any-backend")]
266impl From<InvalidDnsNameError> for StartTlsError {
267    fn from(e: InvalidDnsNameError) -> Self {
268        Self::DnsNameError(e)
269    }
270}