tokio_xmpp/client/
login.rs1use alloc::borrow::Cow;
2use core::str::FromStr;
3use futures::{SinkExt, StreamExt};
4use sasl::client::mechanisms::{Anonymous, Plain, Scram};
5use sasl::client::Mechanism;
6use sasl::common::scram::{Sha1, Sha256};
7use sasl::common::Credentials;
8use std::collections::HashSet;
9use std::io;
10use tokio::io::{AsyncBufRead, AsyncWrite};
11use xmpp_parsers::{
12 jid::Jid,
13 ns,
14 sasl::{Auth, Mechanism as XMPPMechanism, Nonza, Response},
15 stream_features::{SaslMechanisms, StreamFeatures},
16};
17
18use crate::{
19 connect::ServerConnector,
20 error::{AuthError, Error, ProtocolError},
21 xmlstream::{
22 xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, Timeouts, XmppStream,
23 },
24};
25
26pub async fn auth<S: AsyncBufRead + AsyncWrite + Unpin>(
32 mut stream: XmppStream<S>,
33 sasl_mechanisms: &SaslMechanisms,
34 creds: Credentials,
35) -> Result<InitiatingStream<S>, Error> {
36 let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism + Send + Sync> + Send>> = vec![
37 Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
38 Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
39 Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
40 Box::new(|| Box::new(Anonymous::new())),
41 ];
42
43 let remote_mechs: HashSet<String> = sasl_mechanisms.mechanisms.iter().cloned().collect();
44
45 for local_mech in local_mechs {
46 let mut mechanism = local_mech();
47 if remote_mechs.contains(mechanism.name()) {
48 let initial = mechanism.initial();
49 let mechanism_name =
50 XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
51
52 stream
53 .send(&XmppStreamElement::Sasl(Nonza::Auth(Auth {
54 mechanism: mechanism_name,
55 data: initial,
56 })))
57 .await?;
58
59 loop {
60 match stream.next().await {
61 Some(Ok(XmppStreamElement::Sasl(sasl))) => match sasl {
62 Nonza::Challenge(challenge) => {
63 let response = mechanism
64 .response(&challenge.data)
65 .map_err(AuthError::Sasl)?;
66
67 stream
69 .send(&XmppStreamElement::Sasl(Nonza::Response(Response {
70 data: response,
71 })))
72 .await?;
73 }
74 Nonza::Success(_) => return Ok(stream.initiate_reset()),
75 Nonza::Failure(failure) => {
76 return Err(Error::Auth(AuthError::Fail(failure.defined_condition)));
77 }
78 _ => {
79 }
81 },
82 Some(Ok(el)) => {
83 return Err(io::Error::new(
84 io::ErrorKind::InvalidData,
85 format!(
86 "unexpected stream element during SASL negotiation: {:?}",
87 el
88 ),
89 )
90 .into())
91 }
92 Some(Err(ReadError::HardError(e))) => return Err(e.into()),
93 Some(Err(ReadError::ParseError(e))) => {
94 return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
95 }
96 Some(Err(ReadError::SoftTimeout)) => {
97 }
99 Some(Err(ReadError::StreamFooterReceived)) | None => {
100 return Err(Error::Disconnected)
101 }
102 }
103 }
104 }
105 }
106
107 Err(AuthError::NoMechanism.into())
108}
109
110pub async fn client_auth<C: ServerConnector>(
112 server: C,
113 jid: Jid,
114 password: String,
115 timeouts: Timeouts,
116) -> Result<(StreamFeatures, XmppStream<C::Stream>), Error> {
117 let username = jid.node().unwrap().as_str();
118
119 let (xmpp_stream, channel_binding) = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
120 let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
121
122 let creds = Credentials::default()
123 .with_username(username)
124 .with_password(password)
125 .with_channel_binding(channel_binding);
126 let stream = auth(xmpp_stream, &features.sasl_mechanisms, creds).await?;
128 let stream = stream
129 .send_header(StreamHeader {
130 to: Some(Cow::Borrowed(jid.domain().as_str())),
131 from: None,
132 id: None,
133 })
134 .await?;
135 Ok(stream.recv_features().await?)
136}