1use alloc::borrow::Cow;
4use core::{error::Error as StdError, fmt};
5#[cfg(feature = "tls-native")]
6use native_tls::Error as TlsError;
7use std::io;
8use std::os::fd::AsRawFd;
9#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
10use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
11#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
12use tokio_rustls::rustls::Error as TlsError;
13
14use futures::{sink::SinkExt, stream::StreamExt};
15
16#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
17use {
18 alloc::sync::Arc,
19 tokio_rustls::{
20 rustls::pki_types::ServerName,
21 rustls::{ClientConfig, RootCertStore},
22 TlsConnector,
23 },
24};
25
26#[cfg(all(
27 feature = "tls-rust",
28 not(feature = "tls-native"),
29 not(feature = "tls-rust-ktls")
30))]
31use tokio_rustls::client::TlsStream;
32
33#[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
34type TlsStream<S> = ktls::KtlsStream<S>;
35
36#[cfg(feature = "tls-native")]
37use {
38 native_tls::TlsConnector as NativeTlsConnector,
39 tokio_native_tls::{TlsConnector, TlsStream},
40};
41
42use sasl::common::ChannelBinding;
43use tokio::{
44 io::{AsyncRead, AsyncWrite, BufStream},
45 net::TcpStream,
46};
47use xmpp_parsers::{
48 jid::Jid,
49 starttls::{self, Request},
50};
51
52use crate::{
53 connect::{DnsConfig, ServerConnector, ServerConnectorError},
54 error::{Error, ProtocolError},
55 xmlstream::{
56 initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
57 XmppStreamElement,
58 },
59 Client,
60};
61
62#[deprecated(since = "5.0.0", note = "use tokio_xmpp::Client instead")]
64pub type StartTlsClient = Client;
65
66#[derive(Debug, Clone)]
68pub struct StartTlsServerConnector(pub DnsConfig);
69
70impl From<DnsConfig> for StartTlsServerConnector {
71 fn from(dns_config: DnsConfig) -> StartTlsServerConnector {
72 Self(dns_config)
73 }
74}
75
76impl ServerConnector for StartTlsServerConnector {
77 type Stream = BufStream<TlsStream<TcpStream>>;
78
79 async fn connect(
80 &self,
81 jid: &Jid,
82 ns: &'static str,
83 timeouts: Timeouts,
84 ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
85 let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
86
87 let xmpp_stream = initiate_stream(
89 tcp_stream,
90 ns,
91 StreamHeader {
92 to: Some(Cow::Borrowed(jid.domain().as_str())),
93 from: None,
94 id: None,
95 },
96 timeouts,
97 )
98 .await?;
99 let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
100
101 if features.can_starttls() {
102 let (tls_stream, channel_binding) =
104 starttls(xmpp_stream, jid.domain().as_str()).await?;
105 Ok((
107 initiate_stream(
108 tokio::io::BufStream::new(tls_stream),
109 ns,
110 StreamHeader {
111 to: Some(Cow::Borrowed(jid.domain().as_str())),
112 from: None,
113 id: None,
114 },
115 timeouts,
116 )
117 .await?,
118 channel_binding,
119 ))
120 } else {
121 Err(crate::Error::Protocol(ProtocolError::NoTls).into())
122 }
123 }
124}
125
126#[cfg(feature = "tls-native")]
127async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
128 xmpp_stream: XmppStream<BufStream<S>>,
129 domain: &str,
130) -> Result<(TlsStream<S>, ChannelBinding), Error> {
131 let domain = domain.to_owned();
132 let stream = xmpp_stream.into_inner().into_inner();
133 let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
134 .connect(&domain, stream)
135 .await
136 .map_err(|e| StartTlsError::Tls(e))?;
137 log::warn!(
138 "tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"
139 );
140 Ok((tls_stream, ChannelBinding::None))
141}
142
143#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
144async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
145 xmpp_stream: XmppStream<BufStream<S>>,
146 domain: &str,
147) -> Result<(TlsStream<S>, ChannelBinding), Error> {
148 let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
149 let stream = xmpp_stream.into_inner().into_inner();
150 let mut root_store = RootCertStore::empty();
151 #[cfg(feature = "webpki-roots")]
152 {
153 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
154 }
155 #[cfg(feature = "rustls-native-certs")]
156 {
157 root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
158 }
159 #[allow(unused_mut, reason = "This config is mutable when using ktls")]
160 let mut config = ClientConfig::builder()
161 .with_root_certificates(root_store)
162 .with_no_client_auth();
163 #[cfg(feature = "tls-rust-ktls")]
164 let stream = {
165 config.enable_secret_extraction = true;
166 ktls::CorkStream::new(stream)
167 };
168 let tls_stream = TlsConnector::from(Arc::new(config))
169 .connect(domain, stream)
170 .await
171 .map_err(|e| Error::from(crate::Error::Io(e)))?;
172
173 let (_, connection) = tls_stream.get_ref();
175 let channel_binding = match connection.protocol_version() {
176 Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
178 let data = vec![0u8; 32];
179 let data = connection
180 .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
181 .map_err(|e| StartTlsError::Tls(e))?;
182 ChannelBinding::TlsExporter(data)
183 }
184 _ => ChannelBinding::None,
185 };
186
187 #[cfg(feature = "tls-rust-ktls")]
188 let tls_stream = ktls::config_ktls_client(tls_stream)
189 .await
190 .map_err(StartTlsError::KtlsError)?;
191 Ok((tls_stream, channel_binding))
192}
193
194pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
197 mut stream: XmppStream<BufStream<S>>,
198 domain: &str,
199) -> Result<(TlsStream<S>, ChannelBinding), Error> {
200 stream
201 .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
202 Request,
203 )))
204 .await?;
205
206 loop {
207 match stream.next().await {
208 Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
209 break;
210 }
211 Some(Ok(_)) => (),
212 Some(Err(ReadError::SoftTimeout)) => (),
213 Some(Err(ReadError::HardError(e))) => return Err(e.into()),
214 Some(Err(ReadError::ParseError(e))) => {
215 return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
216 }
217 None | Some(Err(ReadError::StreamFooterReceived)) => {
218 return Err(crate::Error::Disconnected)
219 }
220 }
221 }
222
223 get_tls_stream(stream, domain).await
224}
225
226#[derive(Debug)]
228pub enum StartTlsError {
229 Tls(TlsError),
231 #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
232 DnsNameError(InvalidDnsNameError),
234 #[cfg(feature = "tls-rust-ktls")]
235 KtlsError(ktls::Error),
237}
238
239impl ServerConnectorError for StartTlsError {}
240
241impl fmt::Display for StartTlsError {
242 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
243 match self {
244 Self::Tls(e) => write!(fmt, "TLS error: {}", e),
245 #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
246 Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
247 #[cfg(feature = "tls-rust-ktls")]
248 Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
249 }
250 }
251}
252
253impl StdError for StartTlsError {}
254
255impl From<TlsError> for StartTlsError {
256 fn from(e: TlsError) -> Self {
257 Self::Tls(e)
258 }
259}
260
261#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
262impl From<InvalidDnsNameError> for StartTlsError {
263 fn from(e: InvalidDnsNameError) -> Self {
264 Self::DnsNameError(e)
265 }
266}