tokio_xmpp/connect/
starttls.rs1use alloc::borrow::Cow;
4use std::io;
5
6use futures::{sink::SinkExt, stream::StreamExt};
7use sasl::common::ChannelBinding;
8use tokio::{io::BufStream, net::TcpStream};
9use xmpp_parsers::{
10 jid::Jid,
11 starttls::{self, Request},
12};
13
14use crate::{
15 connect::{
16 tls_common::{establish_tls_connection, TlsAsyncStream, TlsConnectorError, TlsStream},
17 DnsConfig, ServerConnector,
18 },
19 error::{Error, ProtocolError},
20 xmlstream::{
21 initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
22 XmppStreamElement,
23 },
24 Client,
25};
26
27#[deprecated(since = "5.0.0", note = "use tokio_xmpp::Client instead")]
29pub type StartTlsClient = Client;
30
31#[derive(Debug, Clone)]
33pub struct StartTlsServerConnector(pub DnsConfig);
34
35impl From<DnsConfig> for StartTlsServerConnector {
36 fn from(dns_config: DnsConfig) -> StartTlsServerConnector {
37 Self(dns_config)
38 }
39}
40
41impl ServerConnector for StartTlsServerConnector {
42 type Stream = BufStream<TlsStream<TcpStream>>;
43
44 async fn connect(
45 &self,
46 jid: &Jid,
47 ns: &'static str,
48 timeouts: Timeouts,
49 ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
50 let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
51
52 let xmpp_stream = initiate_stream(
54 tcp_stream,
55 ns,
56 StreamHeader {
57 to: Some(Cow::Borrowed(jid.domain().as_str())),
58 from: None,
59 id: None,
60 },
61 timeouts,
62 )
63 .await?;
64 let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
65
66 if features.can_starttls() {
67 let (tls_stream, channel_binding) =
69 starttls(xmpp_stream, jid.domain().as_str()).await?;
70 Ok((
72 initiate_stream(
73 tokio::io::BufStream::new(tls_stream),
74 ns,
75 StreamHeader {
76 to: Some(Cow::Borrowed(jid.domain().as_str())),
77 from: None,
78 id: None,
79 },
80 timeouts,
81 )
82 .await?,
83 channel_binding,
84 ))
85 } else {
86 Err(crate::Error::Protocol(ProtocolError::NoTls))
87 }
88 }
89}
90
91pub async fn starttls<S: TlsAsyncStream>(
94 mut stream: XmppStream<BufStream<S>>,
95 domain: &str,
96) -> Result<(TlsStream<S>, ChannelBinding), Error> {
97 stream
98 .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
99 Request,
100 )))
101 .await?;
102
103 loop {
104 match stream.next().await {
105 Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
106 break;
107 }
108 Some(Ok(_)) => (),
109 Some(Err(ReadError::SoftTimeout)) => (),
110 Some(Err(ReadError::HardError(e))) => return Err(e.into()),
111 Some(Err(ReadError::ParseError(e))) => {
112 return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
113 }
114 None | Some(Err(ReadError::StreamFooterReceived)) => {
115 return Err(crate::Error::Disconnected)
116 }
117 }
118 }
119
120 let inner_stream = stream.into_inner().into_inner();
121 establish_tls_connection(inner_stream, domain).await
122}
123
124pub type StartTlsError = TlsConnectorError;