Skip to main content

tokio_xmpp/connect/
tls_common.rs

1// Copyright (c) 2025 Saarko <saarko@tutanota.com>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7//! Common TLS functionality shared between direct_tls and starttls modules
8
9use core::{error::Error as StdError, fmt};
10#[cfg(feature = "ktls")]
11use std::os::fd::AsRawFd;
12use tokio::io::{AsyncRead, AsyncWrite};
13
14/// Trait alias for async streams that can be used with TLS.
15// When the `ktls` feature is enabled, this additionally requires `AsRawFd`.
16#[cfg(feature = "ktls")]
17pub trait TlsAsyncStream: AsyncRead + AsyncWrite + Unpin + AsRawFd {}
18#[cfg(feature = "ktls")]
19impl<T: AsyncRead + AsyncWrite + Unpin + AsRawFd> TlsAsyncStream for T {}
20
21/// Trait alias for async streams that can be used with TLS.
22#[cfg(not(feature = "ktls"))]
23pub trait TlsAsyncStream: AsyncRead + AsyncWrite + Unpin {}
24#[cfg(not(feature = "ktls"))]
25impl<T: AsyncRead + AsyncWrite + Unpin> TlsAsyncStream for T {}
26
27#[cfg(feature = "native-tls")]
28use native_tls::Error as TlsError;
29#[cfg(feature = "rustls-any-backend")]
30use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
31#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
32use tokio_rustls::rustls::Error as TlsError;
33
34#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
35use {
36    alloc::sync::Arc,
37    tokio_rustls::{
38        rustls::pki_types::ServerName,
39        rustls::{ClientConfig, RootCertStore},
40        TlsConnector,
41    },
42};
43
44#[cfg(all(
45    feature = "rustls-any-backend",
46    not(feature = "ktls"),
47    not(feature = "native-tls")
48))]
49pub use tokio_rustls::client::TlsStream;
50
51#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
52/// Tls Stream type based on Ktls
53pub type TlsStream<S> = ktls::KtlsStream<S>;
54
55#[cfg(feature = "native-tls")]
56pub use tokio_native_tls::TlsStream;
57
58#[cfg(feature = "native-tls")]
59use {native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::TlsConnector};
60
61use crate::{connect::ServerConnectorError, error::Error};
62use sasl::common::ChannelBinding;
63
64/// Common TLS error type used by both direct_tls and starttls
65#[derive(Debug)]
66pub enum TlsConnectorError {
67    /// TLS error
68    Tls(TlsError),
69    #[cfg(feature = "rustls-any-backend")]
70    /// DNS name parsing error
71    DnsNameError(InvalidDnsNameError),
72    #[cfg(feature = "ktls")]
73    /// Error while setting up kernel TLS
74    KtlsError(ktls::Error),
75}
76
77impl ServerConnectorError for TlsConnectorError {}
78
79impl fmt::Display for TlsConnectorError {
80    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
81        match self {
82            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
83            #[cfg(feature = "rustls-any-backend")]
84            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
85            #[cfg(feature = "ktls")]
86            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
87        }
88    }
89}
90
91impl StdError for TlsConnectorError {}
92
93impl From<TlsError> for TlsConnectorError {
94    fn from(e: TlsError) -> Self {
95        Self::Tls(e)
96    }
97}
98
99#[cfg(feature = "rustls-any-backend")]
100impl From<InvalidDnsNameError> for TlsConnectorError {
101    fn from(e: InvalidDnsNameError) -> Self {
102        Self::DnsNameError(e)
103    }
104}
105
106/// Establish TLS connection using native-tls
107#[cfg(feature = "native-tls")]
108pub async fn establish_tls_connection<S: TlsAsyncStream>(
109    stream: S,
110    domain: &str,
111) -> Result<(TlsStream<S>, ChannelBinding), Error> {
112    let domain = domain.to_owned();
113    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
114        .connect(&domain, stream)
115        .await
116        .map_err(|e| TlsConnectorError::Tls(e))?;
117    log::warn!(
118        "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
119    );
120    Ok((tls_stream, ChannelBinding::None))
121}
122
123/// Establish TLS connection using rustls
124#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
125pub async fn establish_tls_connection<S: TlsAsyncStream>(
126    stream: S,
127    domain: &str,
128) -> Result<(TlsStream<S>, ChannelBinding), Error> {
129    let domain =
130        ServerName::try_from(domain.to_owned()).map_err(TlsConnectorError::DnsNameError)?;
131    let mut root_store = RootCertStore::empty();
132
133    #[cfg(feature = "webpki-roots")]
134    {
135        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
136    }
137
138    #[cfg(feature = "rustls-native-certs")]
139    {
140        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs().certs);
141    }
142
143    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
144    let mut config = ClientConfig::builder()
145        .with_root_certificates(root_store)
146        .with_no_client_auth();
147
148    #[cfg(feature = "ktls")]
149    let stream = {
150        config.enable_secret_extraction = true;
151        ktls::CorkStream::new(stream)
152    };
153
154    let tls_stream = TlsConnector::from(Arc::new(config))
155        .connect(domain, stream)
156        .await
157        .map_err(crate::Error::Io)?;
158
159    // Extract the channel-binding information before we hand the stream over to ktls.
160    let (_, connection) = tls_stream.get_ref();
161    let channel_binding = match connection.protocol_version() {
162        // TODO: Add support for TLS 1.2 and earlier.
163        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
164            let data = vec![0u8; 32];
165            let data = connection
166                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
167                .map_err(TlsConnectorError::Tls)?;
168            ChannelBinding::TlsExporter(data)
169        }
170        _ => ChannelBinding::None,
171    };
172
173    #[cfg(feature = "ktls")]
174    let tls_stream = ktls::config_ktls_client(tls_stream)
175        .await
176        .map_err(TlsConnectorError::KtlsError)?;
177
178    Ok((tls_stream, channel_binding))
179}