tokio_xmpp/connect/
direct_tls.rs

1//! `direct_tls::ServerConfig` provides a `ServerConnector` for direct TLS connections
2
3use alloc::borrow::Cow;
4use core::{error::Error as StdError, fmt};
5#[cfg(feature = "native-tls")]
6use native_tls::Error as TlsError;
7#[cfg(feature = "rustls-any-backend")]
8use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
9// Note: feature = "rustls-any-backend" and feature = "native-tls" are
10// mutually exclusive during normal compiles, but we allow it for rustdoc
11// builds. Thus, we have to make sure that the compilation still succeeds in
12// such a case.
13#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
14use tokio_rustls::rustls::Error as TlsError;
15
16#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
17use {
18    alloc::sync::Arc,
19    tokio_rustls::{
20        rustls::pki_types::ServerName,
21        rustls::{ClientConfig, RootCertStore},
22        TlsConnector,
23    },
24};
25
26#[cfg(all(
27    feature = "rustls-any-backend",
28    not(feature = "ktls"),
29    not(feature = "native-tls")
30))]
31use tokio_rustls::client::TlsStream;
32
33#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
34type TlsStream<S> = ktls::KtlsStream<S>;
35
36#[cfg(feature = "native-tls")]
37use {
38    native_tls::TlsConnector as NativeTlsConnector,
39    tokio_native_tls::{TlsConnector, TlsStream},
40};
41
42use sasl::common::ChannelBinding;
43use tokio::{io::BufStream, net::TcpStream};
44use xmpp_parsers::jid::Jid;
45
46use crate::{
47    connect::{DnsConfig, ServerConnector, ServerConnectorError},
48    error::Error,
49    xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
50};
51
52/// Connect via direct TLS to an XMPP server
53#[derive(Debug, Clone)]
54pub struct DirectTlsServerConnector(pub DnsConfig);
55
56impl From<DnsConfig> for DirectTlsServerConnector {
57    fn from(dns_config: DnsConfig) -> DirectTlsServerConnector {
58        Self(dns_config)
59    }
60}
61
62impl ServerConnector for DirectTlsServerConnector {
63    type Stream = BufStream<TlsStream<TcpStream>>;
64
65    async fn connect(
66        &self,
67        jid: &Jid,
68        ns: &'static str,
69        timeouts: Timeouts,
70    ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
71        let tcp_stream = self.0.resolve().await?;
72
73        // Immediately establish TLS connection
74        let (tls_stream, channel_binding) =
75            establish_tls(tcp_stream, jid.domain().as_str()).await?;
76
77        // Establish XMPP stream over TLS
78        Ok((
79            initiate_stream(
80                tokio::io::BufStream::new(tls_stream),
81                ns,
82                StreamHeader {
83                    to: Some(Cow::Borrowed(jid.domain().as_str())),
84                    from: Some(Cow::Borrowed(jid.to_bare().as_str())), // setting explicitly `from` here, because ejabberd for example does not advertise sasl2/bind2 otherwise
85                    id: None,
86                },
87                timeouts,
88            )
89            .await?,
90            channel_binding,
91        ))
92    }
93}
94
95#[cfg(feature = "native-tls")]
96async fn establish_tls(
97    tcp_stream: TcpStream,
98    domain: &str,
99) -> Result<(TlsStream<TcpStream>, ChannelBinding), Error> {
100    let domain = domain.to_owned();
101    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
102        .connect(&domain, tcp_stream)
103        .await
104        .map_err(|e| DirectTlsError::Tls(e))?;
105    log::warn!(
106        "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
107    );
108    Ok((tls_stream, ChannelBinding::None))
109}
110
111#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
112async fn establish_tls(
113    tcp_stream: TcpStream,
114    domain: &str,
115) -> Result<(TlsStream<TcpStream>, ChannelBinding), Error> {
116    let domain = ServerName::try_from(domain.to_owned()).map_err(DirectTlsError::DnsNameError)?;
117    let mut root_store = RootCertStore::empty();
118    #[cfg(feature = "webpki-roots")]
119    {
120        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
121    }
122    #[cfg(feature = "rustls-native-certs")]
123    {
124        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
125    }
126    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
127    let mut config = ClientConfig::builder()
128        .with_root_certificates(root_store)
129        .with_no_client_auth();
130    #[cfg(feature = "ktls")]
131    let tcp_stream = {
132        config.enable_secret_extraction = true;
133        ktls::CorkStream::new(tcp_stream)
134    };
135    let tls_stream = TlsConnector::from(Arc::new(config))
136        .connect(domain, tcp_stream)
137        .await
138        .map_err(crate::Error::Io)?;
139
140    // Extract the channel-binding information before we hand the stream over to ktls.
141    let (_, connection) = tls_stream.get_ref();
142    let channel_binding = match connection.protocol_version() {
143        // TODO: Add support for TLS 1.2 and earlier.
144        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
145            let data = vec![0u8; 32];
146            let data = connection
147                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
148                .map_err(DirectTlsError::Tls)?;
149            ChannelBinding::TlsExporter(data)
150        }
151        _ => ChannelBinding::None,
152    };
153
154    #[cfg(feature = "ktls")]
155    let tls_stream = ktls::config_ktls_client(tls_stream)
156        .await
157        .map_err(DirectTlsError::KtlsError)?;
158    Ok((tls_stream, channel_binding))
159}
160
161/// Direct TLS ServerConnector Error
162#[derive(Debug)]
163pub enum DirectTlsError {
164    /// TLS error
165    Tls(TlsError),
166    #[cfg(feature = "rustls-any-backend")]
167    /// DNS name parsing error
168    DnsNameError(InvalidDnsNameError),
169    #[cfg(feature = "ktls")]
170    /// Error while setting up kernel TLS
171    KtlsError(ktls::Error),
172}
173
174impl ServerConnectorError for DirectTlsError {}
175
176impl fmt::Display for DirectTlsError {
177    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
178        match self {
179            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
180            #[cfg(feature = "rustls-any-backend")]
181            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
182            #[cfg(feature = "ktls")]
183            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
184        }
185    }
186}
187
188impl StdError for DirectTlsError {}
189
190impl From<TlsError> for DirectTlsError {
191    fn from(e: TlsError) -> Self {
192        Self::Tls(e)
193    }
194}
195
196#[cfg(feature = "rustls-any-backend")]
197impl From<InvalidDnsNameError> for DirectTlsError {
198    fn from(e: InvalidDnsNameError) -> Self {
199        Self::DnsNameError(e)
200    }
201}