tokio_xmpp/connect/
starttls.rs1use alloc::borrow::Cow;
4use std::io;
5use std::os::fd::AsRawFd;
6
7use futures::{sink::SinkExt, stream::StreamExt};
8use sasl::common::ChannelBinding;
9use tokio::{
10 io::{AsyncRead, AsyncWrite, BufStream},
11 net::TcpStream,
12};
13use xmpp_parsers::{
14 jid::Jid,
15 starttls::{self, Request},
16};
17
18use crate::{
19 connect::{
20 tls_common::{establish_tls_connection, TlsConnectorError, TlsStream},
21 DnsConfig, ServerConnector,
22 },
23 error::{Error, ProtocolError},
24 xmlstream::{
25 initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
26 XmppStreamElement,
27 },
28 Client,
29};
30
31#[deprecated(since = "5.0.0", note = "use tokio_xmpp::Client instead")]
33pub type StartTlsClient = Client;
34
35#[derive(Debug, Clone)]
37pub struct StartTlsServerConnector(pub DnsConfig);
38
39impl From<DnsConfig> for StartTlsServerConnector {
40 fn from(dns_config: DnsConfig) -> StartTlsServerConnector {
41 Self(dns_config)
42 }
43}
44
45impl ServerConnector for StartTlsServerConnector {
46 type Stream = BufStream<TlsStream<TcpStream>>;
47
48 async fn connect(
49 &self,
50 jid: &Jid,
51 ns: &'static str,
52 timeouts: Timeouts,
53 ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
54 let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
55
56 let xmpp_stream = initiate_stream(
58 tcp_stream,
59 ns,
60 StreamHeader {
61 to: Some(Cow::Borrowed(jid.domain().as_str())),
62 from: None,
63 id: None,
64 },
65 timeouts,
66 )
67 .await?;
68 let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
69
70 if features.can_starttls() {
71 let (tls_stream, channel_binding) =
73 starttls(xmpp_stream, jid.domain().as_str()).await?;
74 Ok((
76 initiate_stream(
77 tokio::io::BufStream::new(tls_stream),
78 ns,
79 StreamHeader {
80 to: Some(Cow::Borrowed(jid.domain().as_str())),
81 from: None,
82 id: None,
83 },
84 timeouts,
85 )
86 .await?,
87 channel_binding,
88 ))
89 } else {
90 Err(crate::Error::Protocol(ProtocolError::NoTls))
91 }
92 }
93}
94
95pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
98 mut stream: XmppStream<BufStream<S>>,
99 domain: &str,
100) -> Result<(TlsStream<S>, ChannelBinding), Error> {
101 stream
102 .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
103 Request,
104 )))
105 .await?;
106
107 loop {
108 match stream.next().await {
109 Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
110 break;
111 }
112 Some(Ok(_)) => (),
113 Some(Err(ReadError::SoftTimeout)) => (),
114 Some(Err(ReadError::HardError(e))) => return Err(e.into()),
115 Some(Err(ReadError::ParseError(e))) => {
116 return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
117 }
118 None | Some(Err(ReadError::StreamFooterReceived)) => {
119 return Err(crate::Error::Disconnected)
120 }
121 }
122 }
123
124 let inner_stream = stream.into_inner().into_inner();
125 establish_tls_connection(inner_stream, domain).await
126}
127
128pub type StartTlsError = TlsConnectorError;