tokio_xmpp/connect/
tls_common.rs1use core::{error::Error as StdError, fmt};
10#[cfg(feature = "ktls")]
11use std::os::fd::AsRawFd;
12use tokio::io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "ktls")]
17pub trait TlsAsyncStream: AsyncRead + AsyncWrite + Unpin + AsRawFd {}
18#[cfg(feature = "ktls")]
19impl<T: AsyncRead + AsyncWrite + Unpin + AsRawFd> TlsAsyncStream for T {}
20
21#[cfg(not(feature = "ktls"))]
23pub trait TlsAsyncStream: AsyncRead + AsyncWrite + Unpin {}
24#[cfg(not(feature = "ktls"))]
25impl<T: AsyncRead + AsyncWrite + Unpin> TlsAsyncStream for T {}
26
27#[cfg(feature = "native-tls")]
28use native_tls::Error as TlsError;
29#[cfg(feature = "rustls-any-backend")]
30use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
31#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
32use tokio_rustls::rustls::Error as TlsError;
33
34#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
35use {
36 alloc::sync::Arc,
37 tokio_rustls::{
38 rustls::pki_types::ServerName,
39 rustls::{ClientConfig, RootCertStore},
40 TlsConnector,
41 },
42};
43
44#[cfg(all(
45 feature = "rustls-any-backend",
46 not(feature = "ktls"),
47 not(feature = "native-tls")
48))]
49pub use tokio_rustls::client::TlsStream;
50
51#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
52pub type TlsStream<S> = ktls::KtlsStream<S>;
54
55#[cfg(feature = "native-tls")]
56pub use tokio_native_tls::TlsStream;
57
58#[cfg(feature = "native-tls")]
59use {native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::TlsConnector};
60
61use crate::{connect::ServerConnectorError, error::Error};
62use sasl::common::ChannelBinding;
63
64#[derive(Debug)]
66pub enum TlsConnectorError {
67 Tls(TlsError),
69 #[cfg(feature = "rustls-any-backend")]
70 DnsNameError(InvalidDnsNameError),
72 #[cfg(feature = "ktls")]
73 KtlsError(ktls::Error),
75}
76
77impl ServerConnectorError for TlsConnectorError {}
78
79impl fmt::Display for TlsConnectorError {
80 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
81 match self {
82 Self::Tls(e) => write!(fmt, "TLS error: {}", e),
83 #[cfg(feature = "rustls-any-backend")]
84 Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
85 #[cfg(feature = "ktls")]
86 Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
87 }
88 }
89}
90
91impl StdError for TlsConnectorError {}
92
93impl From<TlsError> for TlsConnectorError {
94 fn from(e: TlsError) -> Self {
95 Self::Tls(e)
96 }
97}
98
99#[cfg(feature = "rustls-any-backend")]
100impl From<InvalidDnsNameError> for TlsConnectorError {
101 fn from(e: InvalidDnsNameError) -> Self {
102 Self::DnsNameError(e)
103 }
104}
105
106#[cfg(feature = "native-tls")]
108pub async fn establish_tls_connection<S: TlsAsyncStream>(
109 stream: S,
110 domain: &str,
111) -> Result<(TlsStream<S>, ChannelBinding), Error> {
112 let domain = domain.to_owned();
113 let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
114 .connect(&domain, stream)
115 .await
116 .map_err(|e| TlsConnectorError::Tls(e))?;
117 log::warn!(
118 "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
119 );
120 Ok((tls_stream, ChannelBinding::None))
121}
122
123#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
125pub async fn establish_tls_connection<S: TlsAsyncStream>(
126 stream: S,
127 domain: &str,
128) -> Result<(TlsStream<S>, ChannelBinding), Error> {
129 let domain =
130 ServerName::try_from(domain.to_owned()).map_err(TlsConnectorError::DnsNameError)?;
131 let mut root_store = RootCertStore::empty();
132
133 #[cfg(feature = "webpki-roots")]
134 {
135 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
136 }
137
138 #[cfg(feature = "rustls-native-certs")]
139 {
140 root_store.add_parsable_certificates(rustls_native_certs::load_native_certs().certs);
141 }
142
143 #[allow(unused_mut, reason = "This config is mutable when using ktls")]
144 let mut config = ClientConfig::builder()
145 .with_root_certificates(root_store)
146 .with_no_client_auth();
147
148 #[cfg(feature = "ktls")]
149 let stream = {
150 config.enable_secret_extraction = true;
151 ktls::CorkStream::new(stream)
152 };
153
154 let tls_stream = TlsConnector::from(Arc::new(config))
155 .connect(domain, stream)
156 .await
157 .map_err(crate::Error::Io)?;
158
159 let (_, connection) = tls_stream.get_ref();
161 let channel_binding = match connection.protocol_version() {
162 Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
164 let data = vec![0u8; 32];
165 let data = connection
166 .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
167 .map_err(TlsConnectorError::Tls)?;
168 ChannelBinding::TlsExporter(data)
169 }
170 _ => ChannelBinding::None,
171 };
172
173 #[cfg(feature = "ktls")]
174 let tls_stream = ktls::config_ktls_client(tls_stream)
175 .await
176 .map_err(TlsConnectorError::KtlsError)?;
177
178 Ok((tls_stream, channel_binding))
179}