tokio_xmpp/connect/
starttls.rs

1//! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections
2
3use alloc::borrow::Cow;
4use core::{error::Error as StdError, fmt};
5#[cfg(feature = "tls-native")]
6use native_tls::Error as TlsError;
7use std::io;
8use std::os::fd::AsRawFd;
9#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
10use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
11#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
12use tokio_rustls::rustls::Error as TlsError;
13
14use futures::{sink::SinkExt, stream::StreamExt};
15
16#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
17use {
18    alloc::sync::Arc,
19    tokio_rustls::{
20        rustls::pki_types::ServerName,
21        rustls::{ClientConfig, RootCertStore},
22        TlsConnector,
23    },
24};
25
26#[cfg(all(
27    feature = "tls-rust",
28    not(feature = "tls-native"),
29    not(feature = "tls-rust-ktls")
30))]
31use tokio_rustls::client::TlsStream;
32
33#[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
34type TlsStream<S> = ktls::KtlsStream<S>;
35
36#[cfg(feature = "tls-native")]
37use {
38    native_tls::TlsConnector as NativeTlsConnector,
39    tokio_native_tls::{TlsConnector, TlsStream},
40};
41
42use sasl::common::ChannelBinding;
43use tokio::{
44    io::{AsyncRead, AsyncWrite, BufStream},
45    net::TcpStream,
46};
47use xmpp_parsers::{
48    jid::Jid,
49    starttls::{self, Request},
50};
51
52use crate::{
53    connect::{DnsConfig, ServerConnector, ServerConnectorError},
54    error::{Error, ProtocolError},
55    xmlstream::{
56        initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
57        XmppStreamElement,
58    },
59    Client,
60};
61
62/// Client that connects over StartTls
63#[deprecated(since = "5.0.0", note = "use tokio_xmpp::Client instead")]
64pub type StartTlsClient = Client;
65
66/// Connect via TCP+StartTLS to an XMPP server
67#[derive(Debug, Clone)]
68pub struct StartTlsServerConnector(pub DnsConfig);
69
70impl From<DnsConfig> for StartTlsServerConnector {
71    fn from(dns_config: DnsConfig) -> StartTlsServerConnector {
72        Self(dns_config)
73    }
74}
75
76impl ServerConnector for StartTlsServerConnector {
77    type Stream = BufStream<TlsStream<TcpStream>>;
78
79    async fn connect(
80        &self,
81        jid: &Jid,
82        ns: &'static str,
83        timeouts: Timeouts,
84    ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
85        let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
86
87        // Unencryped XmppStream
88        let xmpp_stream = initiate_stream(
89            tcp_stream,
90            ns,
91            StreamHeader {
92                to: Some(Cow::Borrowed(jid.domain().as_str())),
93                from: None,
94                id: None,
95            },
96            timeouts,
97        )
98        .await?;
99        let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
100
101        if features.can_starttls() {
102            // TlsStream
103            let (tls_stream, channel_binding) =
104                starttls(xmpp_stream, jid.domain().as_str()).await?;
105            // Encrypted XmppStream
106            Ok((
107                initiate_stream(
108                    tokio::io::BufStream::new(tls_stream),
109                    ns,
110                    StreamHeader {
111                        to: Some(Cow::Borrowed(jid.domain().as_str())),
112                        from: None,
113                        id: None,
114                    },
115                    timeouts,
116                )
117                .await?,
118                channel_binding,
119            ))
120        } else {
121            Err(crate::Error::Protocol(ProtocolError::NoTls).into())
122        }
123    }
124}
125
126#[cfg(feature = "tls-native")]
127async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
128    xmpp_stream: XmppStream<BufStream<S>>,
129    domain: &str,
130) -> Result<(TlsStream<S>, ChannelBinding), Error> {
131    let domain = domain.to_owned();
132    let stream = xmpp_stream.into_inner().into_inner();
133    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
134        .connect(&domain, stream)
135        .await
136        .map_err(|e| StartTlsError::Tls(e))?;
137    log::warn!(
138        "tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"
139    );
140    Ok((tls_stream, ChannelBinding::None))
141}
142
143#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
144async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
145    xmpp_stream: XmppStream<BufStream<S>>,
146    domain: &str,
147) -> Result<(TlsStream<S>, ChannelBinding), Error> {
148    let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
149    let stream = xmpp_stream.into_inner().into_inner();
150    let mut root_store = RootCertStore::empty();
151    #[cfg(feature = "webpki-roots")]
152    {
153        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
154    }
155    #[cfg(feature = "rustls-native-certs")]
156    {
157        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
158    }
159    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
160    let mut config = ClientConfig::builder()
161        .with_root_certificates(root_store)
162        .with_no_client_auth();
163    #[cfg(feature = "tls-rust-ktls")]
164    let stream = {
165        config.enable_secret_extraction = true;
166        ktls::CorkStream::new(stream)
167    };
168    let tls_stream = TlsConnector::from(Arc::new(config))
169        .connect(domain, stream)
170        .await
171        .map_err(|e| Error::from(crate::Error::Io(e)))?;
172
173    // Extract the channel-binding information before we hand the stream over to ktls.
174    let (_, connection) = tls_stream.get_ref();
175    let channel_binding = match connection.protocol_version() {
176        // TODO: Add support for TLS 1.2 and earlier.
177        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
178            let data = vec![0u8; 32];
179            let data = connection
180                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
181                .map_err(|e| StartTlsError::Tls(e))?;
182            ChannelBinding::TlsExporter(data)
183        }
184        _ => ChannelBinding::None,
185    };
186
187    #[cfg(feature = "tls-rust-ktls")]
188    let tls_stream = ktls::config_ktls_client(tls_stream)
189        .await
190        .map_err(StartTlsError::KtlsError)?;
191    Ok((tls_stream, channel_binding))
192}
193
194/// Performs `<starttls/>` on an XmppStream and returns a binary
195/// TlsStream.
196pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
197    mut stream: XmppStream<BufStream<S>>,
198    domain: &str,
199) -> Result<(TlsStream<S>, ChannelBinding), Error> {
200    stream
201        .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
202            Request,
203        )))
204        .await?;
205
206    loop {
207        match stream.next().await {
208            Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
209                break;
210            }
211            Some(Ok(_)) => (),
212            Some(Err(ReadError::SoftTimeout)) => (),
213            Some(Err(ReadError::HardError(e))) => return Err(e.into()),
214            Some(Err(ReadError::ParseError(e))) => {
215                return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
216            }
217            None | Some(Err(ReadError::StreamFooterReceived)) => {
218                return Err(crate::Error::Disconnected)
219            }
220        }
221    }
222
223    get_tls_stream(stream, domain).await
224}
225
226/// StartTLS ServerConnector Error
227#[derive(Debug)]
228pub enum StartTlsError {
229    /// TLS error
230    Tls(TlsError),
231    #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
232    /// DNS name parsing error
233    DnsNameError(InvalidDnsNameError),
234    #[cfg(feature = "tls-rust-ktls")]
235    /// Error while setting up kernel TLS
236    KtlsError(ktls::Error),
237}
238
239impl ServerConnectorError for StartTlsError {}
240
241impl fmt::Display for StartTlsError {
242    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
243        match self {
244            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
245            #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
246            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
247            #[cfg(feature = "tls-rust-ktls")]
248            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
249        }
250    }
251}
252
253impl StdError for StartTlsError {}
254
255impl From<TlsError> for StartTlsError {
256    fn from(e: TlsError) -> Self {
257        Self::Tls(e)
258    }
259}
260
261#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
262impl From<InvalidDnsNameError> for StartTlsError {
263    fn from(e: InvalidDnsNameError) -> Self {
264        Self::DnsNameError(e)
265    }
266}