tokio_xmpp/connect/
tls_common.rs1use core::{error::Error as StdError, fmt};
10use std::os::fd::AsRawFd;
11use tokio::io::{AsyncRead, AsyncWrite};
12
13#[cfg(feature = "native-tls")]
14use native_tls::Error as TlsError;
15#[cfg(feature = "rustls-any-backend")]
16use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
17#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
18use tokio_rustls::rustls::Error as TlsError;
19
20#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
21use {
22 alloc::sync::Arc,
23 tokio_rustls::{
24 rustls::pki_types::ServerName,
25 rustls::{ClientConfig, RootCertStore},
26 TlsConnector,
27 },
28};
29
30#[cfg(all(
31 feature = "rustls-any-backend",
32 not(feature = "ktls"),
33 not(feature = "native-tls")
34))]
35pub use tokio_rustls::client::TlsStream;
36
37#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
38pub type TlsStream<S> = ktls::KtlsStream<S>;
39
40#[cfg(feature = "native-tls")]
41pub use tokio_native_tls::TlsStream;
42
43#[cfg(feature = "native-tls")]
44use {native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::TlsConnector};
45
46use crate::{connect::ServerConnectorError, error::Error};
47use sasl::common::ChannelBinding;
48
49#[derive(Debug)]
51pub enum TlsConnectorError {
52 Tls(TlsError),
54 #[cfg(feature = "rustls-any-backend")]
55 DnsNameError(InvalidDnsNameError),
57 #[cfg(feature = "ktls")]
58 KtlsError(ktls::Error),
60}
61
62impl ServerConnectorError for TlsConnectorError {}
63
64impl fmt::Display for TlsConnectorError {
65 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
66 match self {
67 Self::Tls(e) => write!(fmt, "TLS error: {}", e),
68 #[cfg(feature = "rustls-any-backend")]
69 Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
70 #[cfg(feature = "ktls")]
71 Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
72 }
73 }
74}
75
76impl StdError for TlsConnectorError {}
77
78impl From<TlsError> for TlsConnectorError {
79 fn from(e: TlsError) -> Self {
80 Self::Tls(e)
81 }
82}
83
84#[cfg(feature = "rustls-any-backend")]
85impl From<InvalidDnsNameError> for TlsConnectorError {
86 fn from(e: InvalidDnsNameError) -> Self {
87 Self::DnsNameError(e)
88 }
89}
90
91#[cfg(feature = "native-tls")]
93pub async fn establish_tls_connection<S>(
94 stream: S,
95 domain: &str,
96) -> Result<(TlsStream<S>, ChannelBinding), Error>
97where
98 S: AsyncRead + AsyncWrite + Unpin,
99{
100 let domain = domain.to_owned();
101 let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
102 .connect(&domain, stream)
103 .await
104 .map_err(|e| TlsConnectorError::Tls(e))?;
105 log::warn!(
106 "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
107 );
108 Ok((tls_stream, ChannelBinding::None))
109}
110
111#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
113pub async fn establish_tls_connection<S>(
114 stream: S,
115 domain: &str,
116) -> Result<(TlsStream<S>, ChannelBinding), Error>
117where
118 S: AsyncRead + AsyncWrite + Unpin + AsRawFd,
119{
120 let domain =
121 ServerName::try_from(domain.to_owned()).map_err(TlsConnectorError::DnsNameError)?;
122 let mut root_store = RootCertStore::empty();
123
124 #[cfg(feature = "webpki-roots")]
125 {
126 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
127 }
128
129 #[cfg(feature = "rustls-native-certs")]
130 {
131 root_store.add_parsable_certificates(rustls_native_certs::load_native_certs().certs);
132 }
133
134 #[allow(unused_mut, reason = "This config is mutable when using ktls")]
135 let mut config = ClientConfig::builder()
136 .with_root_certificates(root_store)
137 .with_no_client_auth();
138
139 #[cfg(feature = "ktls")]
140 let stream = {
141 config.enable_secret_extraction = true;
142 ktls::CorkStream::new(stream)
143 };
144
145 let tls_stream = TlsConnector::from(Arc::new(config))
146 .connect(domain, stream)
147 .await
148 .map_err(crate::Error::Io)?;
149
150 let (_, connection) = tls_stream.get_ref();
152 let channel_binding = match connection.protocol_version() {
153 Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
155 let data = vec![0u8; 32];
156 let data = connection
157 .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
158 .map_err(TlsConnectorError::Tls)?;
159 ChannelBinding::TlsExporter(data)
160 }
161 _ => ChannelBinding::None,
162 };
163
164 #[cfg(feature = "ktls")]
165 let tls_stream = ktls::config_ktls_client(tls_stream)
166 .await
167 .map_err(TlsConnectorError::KtlsError)?;
168
169 Ok((tls_stream, channel_binding))
170}