tokio_xmpp/connect/
tls_common.rs1use core::{error::Error as StdError, fmt};
10#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
11use std::os::fd::AsRawFd;
12use tokio::io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "native-tls")]
15use native_tls::Error as TlsError;
16#[cfg(feature = "rustls-any-backend")]
17use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
18#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
19use tokio_rustls::rustls::Error as TlsError;
20
21#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
22use {
23 alloc::sync::Arc,
24 tokio_rustls::{
25 rustls::pki_types::ServerName,
26 rustls::{ClientConfig, RootCertStore},
27 TlsConnector,
28 },
29};
30
31#[cfg(all(
32 feature = "rustls-any-backend",
33 not(feature = "ktls"),
34 not(feature = "native-tls")
35))]
36pub use tokio_rustls::client::TlsStream;
37
38#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
39pub type TlsStream<S> = ktls::KtlsStream<S>;
41
42#[cfg(feature = "native-tls")]
43pub use tokio_native_tls::TlsStream;
44
45#[cfg(feature = "native-tls")]
46use {native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::TlsConnector};
47
48use crate::{connect::ServerConnectorError, error::Error};
49use sasl::common::ChannelBinding;
50
51#[derive(Debug)]
53pub enum TlsConnectorError {
54 Tls(TlsError),
56 #[cfg(feature = "rustls-any-backend")]
57 DnsNameError(InvalidDnsNameError),
59 #[cfg(feature = "ktls")]
60 KtlsError(ktls::Error),
62}
63
64impl ServerConnectorError for TlsConnectorError {}
65
66impl fmt::Display for TlsConnectorError {
67 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
68 match self {
69 Self::Tls(e) => write!(fmt, "TLS error: {}", e),
70 #[cfg(feature = "rustls-any-backend")]
71 Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
72 #[cfg(feature = "ktls")]
73 Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
74 }
75 }
76}
77
78impl StdError for TlsConnectorError {}
79
80impl From<TlsError> for TlsConnectorError {
81 fn from(e: TlsError) -> Self {
82 Self::Tls(e)
83 }
84}
85
86#[cfg(feature = "rustls-any-backend")]
87impl From<InvalidDnsNameError> for TlsConnectorError {
88 fn from(e: InvalidDnsNameError) -> Self {
89 Self::DnsNameError(e)
90 }
91}
92
93#[cfg(feature = "native-tls")]
95pub async fn establish_tls_connection<S>(
96 stream: S,
97 domain: &str,
98) -> Result<(TlsStream<S>, ChannelBinding), Error>
99where
100 S: AsyncRead + AsyncWrite + Unpin,
101{
102 let domain = domain.to_owned();
103 let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
104 .connect(&domain, stream)
105 .await
106 .map_err(|e| TlsConnectorError::Tls(e))?;
107 log::warn!(
108 "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
109 );
110 Ok((tls_stream, ChannelBinding::None))
111}
112
113#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
115pub async fn establish_tls_connection<S>(
116 stream: S,
117 domain: &str,
118) -> Result<(TlsStream<S>, ChannelBinding), Error>
119where
120 S: AsyncRead + AsyncWrite + Unpin + AsRawFd,
121{
122 let domain =
123 ServerName::try_from(domain.to_owned()).map_err(TlsConnectorError::DnsNameError)?;
124 let mut root_store = RootCertStore::empty();
125
126 #[cfg(feature = "webpki-roots")]
127 {
128 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
129 }
130
131 #[cfg(feature = "rustls-native-certs")]
132 {
133 root_store.add_parsable_certificates(rustls_native_certs::load_native_certs().certs);
134 }
135
136 #[allow(unused_mut, reason = "This config is mutable when using ktls")]
137 let mut config = ClientConfig::builder()
138 .with_root_certificates(root_store)
139 .with_no_client_auth();
140
141 #[cfg(feature = "ktls")]
142 let stream = {
143 config.enable_secret_extraction = true;
144 ktls::CorkStream::new(stream)
145 };
146
147 let tls_stream = TlsConnector::from(Arc::new(config))
148 .connect(domain, stream)
149 .await
150 .map_err(crate::Error::Io)?;
151
152 let (_, connection) = tls_stream.get_ref();
154 let channel_binding = match connection.protocol_version() {
155 Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
157 let data = vec![0u8; 32];
158 let data = connection
159 .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
160 .map_err(TlsConnectorError::Tls)?;
161 ChannelBinding::TlsExporter(data)
162 }
163 _ => ChannelBinding::None,
164 };
165
166 #[cfg(feature = "ktls")]
167 let tls_stream = ktls::config_ktls_client(tls_stream)
168 .await
169 .map_err(TlsConnectorError::KtlsError)?;
170
171 Ok((tls_stream, channel_binding))
172}