tokio_xmpp/connect/
direct_tls.rs1use alloc::borrow::Cow;
4use core::{error::Error as StdError, fmt};
5#[cfg(feature = "native-tls")]
6use native_tls::Error as TlsError;
7#[cfg(feature = "rustls-any-backend")]
8use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
9#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
14use tokio_rustls::rustls::Error as TlsError;
15
16#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
17use {
18 alloc::sync::Arc,
19 tokio_rustls::{
20 rustls::pki_types::ServerName,
21 rustls::{ClientConfig, RootCertStore},
22 TlsConnector,
23 },
24};
25
26#[cfg(all(
27 feature = "rustls-any-backend",
28 not(feature = "ktls"),
29 not(feature = "native-tls")
30))]
31use tokio_rustls::client::TlsStream;
32
33#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
34type TlsStream<S> = ktls::KtlsStream<S>;
35
36#[cfg(feature = "native-tls")]
37use {
38 native_tls::TlsConnector as NativeTlsConnector,
39 tokio_native_tls::{TlsConnector, TlsStream},
40};
41
42use sasl::common::ChannelBinding;
43use tokio::{io::BufStream, net::TcpStream};
44use xmpp_parsers::jid::Jid;
45
46use crate::{
47 connect::{DnsConfig, ServerConnector, ServerConnectorError},
48 error::Error,
49 xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
50};
51
52#[derive(Debug, Clone)]
54pub struct DirectTlsServerConnector(pub DnsConfig);
55
56impl From<DnsConfig> for DirectTlsServerConnector {
57 fn from(dns_config: DnsConfig) -> DirectTlsServerConnector {
58 Self(dns_config)
59 }
60}
61
62impl ServerConnector for DirectTlsServerConnector {
63 type Stream = BufStream<TlsStream<TcpStream>>;
64
65 async fn connect(
66 &self,
67 jid: &Jid,
68 ns: &'static str,
69 timeouts: Timeouts,
70 ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
71 let tcp_stream = self.0.resolve().await?;
72
73 let (tls_stream, channel_binding) =
75 establish_tls(tcp_stream, jid.domain().as_str()).await?;
76
77 Ok((
79 initiate_stream(
80 tokio::io::BufStream::new(tls_stream),
81 ns,
82 StreamHeader {
83 to: Some(Cow::Borrowed(jid.domain().as_str())),
84 from: Some(Cow::Borrowed(jid.to_bare().as_str())), id: None,
86 },
87 timeouts,
88 )
89 .await?,
90 channel_binding,
91 ))
92 }
93}
94
95#[cfg(feature = "native-tls")]
96async fn establish_tls(
97 tcp_stream: TcpStream,
98 domain: &str,
99) -> Result<(TlsStream<TcpStream>, ChannelBinding), Error> {
100 let domain = domain.to_owned();
101 let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
102 .connect(&domain, tcp_stream)
103 .await
104 .map_err(|e| DirectTlsError::Tls(e))?;
105 log::warn!(
106 "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
107 );
108 Ok((tls_stream, ChannelBinding::None))
109}
110
111#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
112async fn establish_tls(
113 tcp_stream: TcpStream,
114 domain: &str,
115) -> Result<(TlsStream<TcpStream>, ChannelBinding), Error> {
116 let domain = ServerName::try_from(domain.to_owned()).map_err(DirectTlsError::DnsNameError)?;
117 let mut root_store = RootCertStore::empty();
118 #[cfg(feature = "webpki-roots")]
119 {
120 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
121 }
122 #[cfg(feature = "rustls-native-certs")]
123 {
124 root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
125 }
126 #[allow(unused_mut, reason = "This config is mutable when using ktls")]
127 let mut config = ClientConfig::builder()
128 .with_root_certificates(root_store)
129 .with_no_client_auth();
130 #[cfg(feature = "ktls")]
131 let tcp_stream = {
132 config.enable_secret_extraction = true;
133 ktls::CorkStream::new(tcp_stream)
134 };
135 let tls_stream = TlsConnector::from(Arc::new(config))
136 .connect(domain, tcp_stream)
137 .await
138 .map_err(crate::Error::Io)?;
139
140 let (_, connection) = tls_stream.get_ref();
142 let channel_binding = match connection.protocol_version() {
143 Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
145 let data = vec![0u8; 32];
146 let data = connection
147 .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
148 .map_err(DirectTlsError::Tls)?;
149 ChannelBinding::TlsExporter(data)
150 }
151 _ => ChannelBinding::None,
152 };
153
154 #[cfg(feature = "ktls")]
155 let tls_stream = ktls::config_ktls_client(tls_stream)
156 .await
157 .map_err(DirectTlsError::KtlsError)?;
158 Ok((tls_stream, channel_binding))
159}
160
161#[derive(Debug)]
163pub enum DirectTlsError {
164 Tls(TlsError),
166 #[cfg(feature = "rustls-any-backend")]
167 DnsNameError(InvalidDnsNameError),
169 #[cfg(feature = "ktls")]
170 KtlsError(ktls::Error),
172}
173
174impl ServerConnectorError for DirectTlsError {}
175
176impl fmt::Display for DirectTlsError {
177 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
178 match self {
179 Self::Tls(e) => write!(fmt, "TLS error: {}", e),
180 #[cfg(feature = "rustls-any-backend")]
181 Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
182 #[cfg(feature = "ktls")]
183 Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
184 }
185 }
186}
187
188impl StdError for DirectTlsError {}
189
190impl From<TlsError> for DirectTlsError {
191 fn from(e: TlsError) -> Self {
192 Self::Tls(e)
193 }
194}
195
196#[cfg(feature = "rustls-any-backend")]
197impl From<InvalidDnsNameError> for DirectTlsError {
198 fn from(e: InvalidDnsNameError) -> Self {
199 Self::DnsNameError(e)
200 }
201}