tokio_xmpp/connect/
tls_common.rs

1// Copyright (c) 2025 Saarko <saarko@tutanota.com>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7//! Common TLS functionality shared between direct_tls and starttls modules
8
9use core::{error::Error as StdError, fmt};
10use std::os::fd::AsRawFd;
11use tokio::io::{AsyncRead, AsyncWrite};
12
13#[cfg(feature = "native-tls")]
14use native_tls::Error as TlsError;
15#[cfg(feature = "rustls-any-backend")]
16use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
17#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
18use tokio_rustls::rustls::Error as TlsError;
19
20#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
21use {
22    alloc::sync::Arc,
23    tokio_rustls::{
24        rustls::pki_types::ServerName,
25        rustls::{ClientConfig, RootCertStore},
26        TlsConnector,
27    },
28};
29
30#[cfg(all(
31    feature = "rustls-any-backend",
32    not(feature = "ktls"),
33    not(feature = "native-tls")
34))]
35pub use tokio_rustls::client::TlsStream;
36
37#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
38pub type TlsStream<S> = ktls::KtlsStream<S>;
39
40#[cfg(feature = "native-tls")]
41pub use tokio_native_tls::TlsStream;
42
43#[cfg(feature = "native-tls")]
44use {native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::TlsConnector};
45
46use crate::{connect::ServerConnectorError, error::Error};
47use sasl::common::ChannelBinding;
48
49/// Common TLS error type used by both direct_tls and starttls
50#[derive(Debug)]
51pub enum TlsConnectorError {
52    /// TLS error
53    Tls(TlsError),
54    #[cfg(feature = "rustls-any-backend")]
55    /// DNS name parsing error
56    DnsNameError(InvalidDnsNameError),
57    #[cfg(feature = "ktls")]
58    /// Error while setting up kernel TLS
59    KtlsError(ktls::Error),
60}
61
62impl ServerConnectorError for TlsConnectorError {}
63
64impl fmt::Display for TlsConnectorError {
65    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
66        match self {
67            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
68            #[cfg(feature = "rustls-any-backend")]
69            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
70            #[cfg(feature = "ktls")]
71            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
72        }
73    }
74}
75
76impl StdError for TlsConnectorError {}
77
78impl From<TlsError> for TlsConnectorError {
79    fn from(e: TlsError) -> Self {
80        Self::Tls(e)
81    }
82}
83
84#[cfg(feature = "rustls-any-backend")]
85impl From<InvalidDnsNameError> for TlsConnectorError {
86    fn from(e: InvalidDnsNameError) -> Self {
87        Self::DnsNameError(e)
88    }
89}
90
91/// Establish TLS connection using native-tls
92#[cfg(feature = "native-tls")]
93pub async fn establish_tls_connection<S>(
94    stream: S,
95    domain: &str,
96) -> Result<(TlsStream<S>, ChannelBinding), Error>
97where
98    S: AsyncRead + AsyncWrite + Unpin,
99{
100    let domain = domain.to_owned();
101    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
102        .connect(&domain, stream)
103        .await
104        .map_err(|e| TlsConnectorError::Tls(e))?;
105    log::warn!(
106        "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
107    );
108    Ok((tls_stream, ChannelBinding::None))
109}
110
111/// Establish TLS connection using rustls
112#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
113pub async fn establish_tls_connection<S>(
114    stream: S,
115    domain: &str,
116) -> Result<(TlsStream<S>, ChannelBinding), Error>
117where
118    S: AsyncRead + AsyncWrite + Unpin + AsRawFd,
119{
120    let domain =
121        ServerName::try_from(domain.to_owned()).map_err(TlsConnectorError::DnsNameError)?;
122    let mut root_store = RootCertStore::empty();
123
124    #[cfg(feature = "webpki-roots")]
125    {
126        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
127    }
128
129    #[cfg(feature = "rustls-native-certs")]
130    {
131        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs().certs);
132    }
133
134    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
135    let mut config = ClientConfig::builder()
136        .with_root_certificates(root_store)
137        .with_no_client_auth();
138
139    #[cfg(feature = "ktls")]
140    let stream = {
141        config.enable_secret_extraction = true;
142        ktls::CorkStream::new(stream)
143    };
144
145    let tls_stream = TlsConnector::from(Arc::new(config))
146        .connect(domain, stream)
147        .await
148        .map_err(crate::Error::Io)?;
149
150    // Extract the channel-binding information before we hand the stream over to ktls.
151    let (_, connection) = tls_stream.get_ref();
152    let channel_binding = match connection.protocol_version() {
153        // TODO: Add support for TLS 1.2 and earlier.
154        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
155            let data = vec![0u8; 32];
156            let data = connection
157                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
158                .map_err(TlsConnectorError::Tls)?;
159            ChannelBinding::TlsExporter(data)
160        }
161        _ => ChannelBinding::None,
162    };
163
164    #[cfg(feature = "ktls")]
165    let tls_stream = ktls::config_ktls_client(tls_stream)
166        .await
167        .map_err(TlsConnectorError::KtlsError)?;
168
169    Ok((tls_stream, channel_binding))
170}