1use alloc::borrow::Cow;
4use core::{error::Error as StdError, fmt};
5#[cfg(feature = "native-tls")]
6use native_tls::Error as TlsError;
7use std::io;
8use std::os::fd::AsRawFd;
9#[cfg(feature = "rustls-any-backend")]
10use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
11#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
16use tokio_rustls::rustls::Error as TlsError;
17
18use futures::{sink::SinkExt, stream::StreamExt};
19
20#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
21use {
22 alloc::sync::Arc,
23 tokio_rustls::{
24 rustls::pki_types::ServerName,
25 rustls::{ClientConfig, RootCertStore},
26 TlsConnector,
27 },
28};
29
30#[cfg(all(
31 feature = "rustls-any-backend",
32 not(feature = "ktls"),
33 not(feature = "native-tls")
34))]
35use tokio_rustls::client::TlsStream;
36
37#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
38type TlsStream<S> = ktls::KtlsStream<S>;
39
40#[cfg(feature = "native-tls")]
41use {
42 native_tls::TlsConnector as NativeTlsConnector,
43 tokio_native_tls::{TlsConnector, TlsStream},
44};
45
46use sasl::common::ChannelBinding;
47use tokio::{
48 io::{AsyncRead, AsyncWrite, BufStream},
49 net::TcpStream,
50};
51use xmpp_parsers::{
52 jid::Jid,
53 starttls::{self, Request},
54};
55
56use crate::{
57 connect::{DnsConfig, ServerConnector, ServerConnectorError},
58 error::{Error, ProtocolError},
59 xmlstream::{
60 initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
61 XmppStreamElement,
62 },
63 Client,
64};
65
66#[deprecated(since = "5.0.0", note = "use tokio_xmpp::Client instead")]
68pub type StartTlsClient = Client;
69
70#[derive(Debug, Clone)]
72pub struct StartTlsServerConnector(pub DnsConfig);
73
74impl From<DnsConfig> for StartTlsServerConnector {
75 fn from(dns_config: DnsConfig) -> StartTlsServerConnector {
76 Self(dns_config)
77 }
78}
79
80impl ServerConnector for StartTlsServerConnector {
81 type Stream = BufStream<TlsStream<TcpStream>>;
82
83 async fn connect(
84 &self,
85 jid: &Jid,
86 ns: &'static str,
87 timeouts: Timeouts,
88 ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
89 let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
90
91 let xmpp_stream = initiate_stream(
93 tcp_stream,
94 ns,
95 StreamHeader {
96 to: Some(Cow::Borrowed(jid.domain().as_str())),
97 from: None,
98 id: None,
99 },
100 timeouts,
101 )
102 .await?;
103 let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
104
105 if features.can_starttls() {
106 let (tls_stream, channel_binding) =
108 starttls(xmpp_stream, jid.domain().as_str()).await?;
109 Ok((
111 initiate_stream(
112 tokio::io::BufStream::new(tls_stream),
113 ns,
114 StreamHeader {
115 to: Some(Cow::Borrowed(jid.domain().as_str())),
116 from: None,
117 id: None,
118 },
119 timeouts,
120 )
121 .await?,
122 channel_binding,
123 ))
124 } else {
125 Err(crate::Error::Protocol(ProtocolError::NoTls))
126 }
127 }
128}
129
130#[cfg(feature = "native-tls")]
131async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
132 xmpp_stream: XmppStream<BufStream<S>>,
133 domain: &str,
134) -> Result<(TlsStream<S>, ChannelBinding), Error> {
135 let domain = domain.to_owned();
136 let stream = xmpp_stream.into_inner().into_inner();
137 let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
138 .connect(&domain, stream)
139 .await
140 .map_err(|e| StartTlsError::Tls(e))?;
141 log::warn!(
142 "tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"
143 );
144 Ok((tls_stream, ChannelBinding::None))
145}
146
147#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
148async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
149 xmpp_stream: XmppStream<BufStream<S>>,
150 domain: &str,
151) -> Result<(TlsStream<S>, ChannelBinding), Error> {
152 let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
153 let stream = xmpp_stream.into_inner().into_inner();
154 let mut root_store = RootCertStore::empty();
155 #[cfg(feature = "webpki-roots")]
156 {
157 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
158 }
159 #[cfg(feature = "rustls-native-certs")]
160 {
161 root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
162 }
163 #[allow(unused_mut, reason = "This config is mutable when using ktls")]
164 let mut config = ClientConfig::builder()
165 .with_root_certificates(root_store)
166 .with_no_client_auth();
167 #[cfg(feature = "ktls")]
168 let stream = {
169 config.enable_secret_extraction = true;
170 ktls::CorkStream::new(stream)
171 };
172 let tls_stream = TlsConnector::from(Arc::new(config))
173 .connect(domain, stream)
174 .await
175 .map_err(crate::Error::Io)?;
176
177 let (_, connection) = tls_stream.get_ref();
179 let channel_binding = match connection.protocol_version() {
180 Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
182 let data = vec![0u8; 32];
183 let data = connection
184 .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
185 .map_err(StartTlsError::Tls)?;
186 ChannelBinding::TlsExporter(data)
187 }
188 _ => ChannelBinding::None,
189 };
190
191 #[cfg(feature = "ktls")]
192 let tls_stream = ktls::config_ktls_client(tls_stream)
193 .await
194 .map_err(StartTlsError::KtlsError)?;
195 Ok((tls_stream, channel_binding))
196}
197
198pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
201 mut stream: XmppStream<BufStream<S>>,
202 domain: &str,
203) -> Result<(TlsStream<S>, ChannelBinding), Error> {
204 stream
205 .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
206 Request,
207 )))
208 .await?;
209
210 loop {
211 match stream.next().await {
212 Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
213 break;
214 }
215 Some(Ok(_)) => (),
216 Some(Err(ReadError::SoftTimeout)) => (),
217 Some(Err(ReadError::HardError(e))) => return Err(e.into()),
218 Some(Err(ReadError::ParseError(e))) => {
219 return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
220 }
221 None | Some(Err(ReadError::StreamFooterReceived)) => {
222 return Err(crate::Error::Disconnected)
223 }
224 }
225 }
226
227 get_tls_stream(stream, domain).await
228}
229
230#[derive(Debug)]
232pub enum StartTlsError {
233 Tls(TlsError),
235 #[cfg(feature = "rustls-any-backend")]
236 DnsNameError(InvalidDnsNameError),
238 #[cfg(feature = "ktls")]
239 KtlsError(ktls::Error),
241}
242
243impl ServerConnectorError for StartTlsError {}
244
245impl fmt::Display for StartTlsError {
246 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
247 match self {
248 Self::Tls(e) => write!(fmt, "TLS error: {}", e),
249 #[cfg(feature = "rustls-any-backend")]
250 Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
251 #[cfg(feature = "ktls")]
252 Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
253 }
254 }
255}
256
257impl StdError for StartTlsError {}
258
259impl From<TlsError> for StartTlsError {
260 fn from(e: TlsError) -> Self {
261 Self::Tls(e)
262 }
263}
264
265#[cfg(feature = "rustls-any-backend")]
266impl From<InvalidDnsNameError> for StartTlsError {
267 fn from(e: InvalidDnsNameError) -> Self {
268 Self::DnsNameError(e)
269 }
270}