tokio_xmpp/connect/
tls_common.rs

1// Copyright (c) 2025 Saarko <saarko@tutanota.com>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7//! Common TLS functionality shared between direct_tls and starttls modules
8
9use core::{error::Error as StdError, fmt};
10#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
11use std::os::fd::AsRawFd;
12use tokio::io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "native-tls")]
15use native_tls::Error as TlsError;
16#[cfg(feature = "rustls-any-backend")]
17use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
18#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
19use tokio_rustls::rustls::Error as TlsError;
20
21#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
22use {
23    alloc::sync::Arc,
24    tokio_rustls::{
25        rustls::pki_types::ServerName,
26        rustls::{ClientConfig, RootCertStore},
27        TlsConnector,
28    },
29};
30
31#[cfg(all(
32    feature = "rustls-any-backend",
33    not(feature = "ktls"),
34    not(feature = "native-tls")
35))]
36pub use tokio_rustls::client::TlsStream;
37
38#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
39/// Tls Stream type based on Ktls
40pub type TlsStream<S> = ktls::KtlsStream<S>;
41
42#[cfg(feature = "native-tls")]
43pub use tokio_native_tls::TlsStream;
44
45#[cfg(feature = "native-tls")]
46use {native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::TlsConnector};
47
48use crate::{connect::ServerConnectorError, error::Error};
49use sasl::common::ChannelBinding;
50
51/// Common TLS error type used by both direct_tls and starttls
52#[derive(Debug)]
53pub enum TlsConnectorError {
54    /// TLS error
55    Tls(TlsError),
56    #[cfg(feature = "rustls-any-backend")]
57    /// DNS name parsing error
58    DnsNameError(InvalidDnsNameError),
59    #[cfg(feature = "ktls")]
60    /// Error while setting up kernel TLS
61    KtlsError(ktls::Error),
62}
63
64impl ServerConnectorError for TlsConnectorError {}
65
66impl fmt::Display for TlsConnectorError {
67    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
68        match self {
69            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
70            #[cfg(feature = "rustls-any-backend")]
71            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
72            #[cfg(feature = "ktls")]
73            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
74        }
75    }
76}
77
78impl StdError for TlsConnectorError {}
79
80impl From<TlsError> for TlsConnectorError {
81    fn from(e: TlsError) -> Self {
82        Self::Tls(e)
83    }
84}
85
86#[cfg(feature = "rustls-any-backend")]
87impl From<InvalidDnsNameError> for TlsConnectorError {
88    fn from(e: InvalidDnsNameError) -> Self {
89        Self::DnsNameError(e)
90    }
91}
92
93/// Establish TLS connection using native-tls
94#[cfg(feature = "native-tls")]
95pub async fn establish_tls_connection<S>(
96    stream: S,
97    domain: &str,
98) -> Result<(TlsStream<S>, ChannelBinding), Error>
99where
100    S: AsyncRead + AsyncWrite + Unpin,
101{
102    let domain = domain.to_owned();
103    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
104        .connect(&domain, stream)
105        .await
106        .map_err(|e| TlsConnectorError::Tls(e))?;
107    log::warn!(
108        "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
109    );
110    Ok((tls_stream, ChannelBinding::None))
111}
112
113/// Establish TLS connection using rustls
114#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
115pub async fn establish_tls_connection<S>(
116    stream: S,
117    domain: &str,
118) -> Result<(TlsStream<S>, ChannelBinding), Error>
119where
120    S: AsyncRead + AsyncWrite + Unpin + AsRawFd,
121{
122    let domain =
123        ServerName::try_from(domain.to_owned()).map_err(TlsConnectorError::DnsNameError)?;
124    let mut root_store = RootCertStore::empty();
125
126    #[cfg(feature = "webpki-roots")]
127    {
128        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
129    }
130
131    #[cfg(feature = "rustls-native-certs")]
132    {
133        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs().certs);
134    }
135
136    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
137    let mut config = ClientConfig::builder()
138        .with_root_certificates(root_store)
139        .with_no_client_auth();
140
141    #[cfg(feature = "ktls")]
142    let stream = {
143        config.enable_secret_extraction = true;
144        ktls::CorkStream::new(stream)
145    };
146
147    let tls_stream = TlsConnector::from(Arc::new(config))
148        .connect(domain, stream)
149        .await
150        .map_err(crate::Error::Io)?;
151
152    // Extract the channel-binding information before we hand the stream over to ktls.
153    let (_, connection) = tls_stream.get_ref();
154    let channel_binding = match connection.protocol_version() {
155        // TODO: Add support for TLS 1.2 and earlier.
156        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
157            let data = vec![0u8; 32];
158            let data = connection
159                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
160                .map_err(TlsConnectorError::Tls)?;
161            ChannelBinding::TlsExporter(data)
162        }
163        _ => ChannelBinding::None,
164    };
165
166    #[cfg(feature = "ktls")]
167    let tls_stream = ktls::config_ktls_client(tls_stream)
168        .await
169        .map_err(TlsConnectorError::KtlsError)?;
170
171    Ok((tls_stream, channel_binding))
172}