use futures::{sink::SinkExt, stream::StreamExt};
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use {
std::sync::Arc,
tokio_rustls::{
client::TlsStream,
rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
TlsConnector,
},
};
#[cfg(feature = "tls-native")]
use {
native_tls::TlsConnector as NativeTlsConnector,
tokio_native_tls::{TlsConnector, TlsStream},
};
use minidom::Element;
use sasl::common::ChannelBinding;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
use xmpp_parsers::{jid::Jid, ns};
use crate::{connect::ServerConnector, xmpp_codec::Packet, AsyncClient, SimpleClient};
use crate::{connect::ServerConnectorError, xmpp_stream::XMPPStream};
use self::error::Error;
use self::happy_eyeballs::{connect_to_host, connect_with_srv};
mod client;
pub mod error;
mod happy_eyeballs;
pub type StartTlsAsyncClient = AsyncClient<ServerConfig>;
pub type StartTlsSimpleClient = SimpleClient<ServerConfig>;
#[derive(Clone, Debug)]
pub enum ServerConfig {
UseSrv,
#[allow(unused)]
Manual {
host: String,
port: u16,
},
}
impl ServerConnectorError for Error {}
impl ServerConnector for ServerConfig {
type Stream = TlsStream<TcpStream>;
type Error = Error;
async fn connect(&self, jid: &Jid, ns: &str) -> Result<XMPPStream<Self::Stream>, Error> {
let tcp_stream = match self {
ServerConfig::UseSrv => {
connect_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await?
}
ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
};
let xmpp_stream = XMPPStream::start(tcp_stream, jid.clone(), ns.to_owned()).await?;
if xmpp_stream.stream_features.can_starttls() {
let tls_stream = starttls(xmpp_stream).await?;
Ok(XMPPStream::start(tls_stream, jid.clone(), ns.to_owned()).await?)
} else {
return Err(crate::Error::Protocol(crate::ProtocolError::NoTls).into());
}
}
fn channel_binding(
#[allow(unused_variables)] stream: &Self::Stream,
) -> Result<sasl::common::ChannelBinding, Error> {
#[cfg(feature = "tls-native")]
{
log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
Ok(ChannelBinding::None)
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
{
let (_, connection) = stream.get_ref();
Ok(match connection.protocol_version() {
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
let data = vec![0u8; 32];
let data = connection.export_keying_material(
data,
b"EXPORTER-Channel-Binding",
None,
)?;
ChannelBinding::TlsExporter(data)
}
_ => ChannelBinding::None,
})
}
}
}
#[cfg(feature = "tls-native")]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let domain = xmpp_stream.jid.domain().to_owned();
let stream = xmpp_stream.into_inner();
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
.connect(&domain, stream)
.await?;
Ok(tls_stream)
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let domain = xmpp_stream.jid.domain().to_string();
let domain = ServerName::try_from(domain.as_str())?;
let stream = xmpp_stream.into_inner();
let mut root_store = RootCertStore::empty();
root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let tls_stream = TlsConnector::from(Arc::new(config))
.connect(domain, stream)
.await
.map_err(|e| Error::from(crate::Error::Io(e)))?;
Ok(tls_stream)
}
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
mut xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let nonza = Element::builder("starttls", ns::TLS).build();
let packet = Packet::Stanza(nonza);
xmpp_stream.send(packet).await?;
loop {
match xmpp_stream.next().await {
Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
Some(Ok(Packet::Text(_))) => {}
Some(Err(e)) => return Err(e.into()),
_ => {
return Err(crate::Error::Protocol(crate::ProtocolError::NoTls).into());
}
}
}
get_tls_stream(xmpp_stream).await
}