Skip to main content

tokio_xmpp/client/
login.rs

1use alloc::borrow::Cow;
2use core::str::FromStr;
3use futures::{SinkExt, StreamExt};
4use sasl::client::mechanisms::{Anonymous, Plain, Scram};
5use sasl::client::Mechanism;
6use sasl::common::scram::{Sha1, Sha256};
7use sasl::common::Credentials;
8use std::collections::HashSet;
9use std::io;
10use tokio::io::{AsyncBufRead, AsyncWrite};
11use xmpp_parsers::{
12    jid::Jid,
13    ns,
14    sasl::{Auth, Mechanism as XMPPMechanism, Nonza, Response},
15    stream_features::{SaslMechanisms, StreamFeatures},
16};
17
18use crate::{
19    connect::ServerConnector,
20    error::{AuthError, Error, ProtocolError},
21    xmlstream::{
22        xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, Timeouts, XmppStream,
23    },
24};
25
26/// Run the authentication handshake on a given stream.
27///
28/// Uses the given `sasl_mechanisms` and `creds` to perform the full
29/// authentication handshake. As authentication ends with a stream reset,
30/// this returns the `stream` as [`InitiatingStream`] on success.
31pub async fn auth<S: AsyncBufRead + AsyncWrite + Unpin>(
32    mut stream: XmppStream<S>,
33    sasl_mechanisms: &SaslMechanisms,
34    creds: Credentials,
35) -> Result<InitiatingStream<S>, Error> {
36    let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism + Send + Sync> + Send>> = vec![
37        Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
38        Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
39        Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
40        Box::new(|| Box::new(Anonymous::new())),
41    ];
42
43    let remote_mechs: HashSet<String> = sasl_mechanisms.mechanisms.iter().cloned().collect();
44
45    for local_mech in local_mechs {
46        let mut mechanism = local_mech();
47        if remote_mechs.contains(mechanism.name()) {
48            let initial = mechanism.initial();
49            let mechanism_name =
50                XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
51
52            stream
53                .send(&XmppStreamElement::Sasl(Nonza::Auth(Auth {
54                    mechanism: mechanism_name,
55                    data: initial,
56                })))
57                .await?;
58
59            loop {
60                match stream.next().await {
61                    Some(Ok(XmppStreamElement::Sasl(sasl))) => match sasl {
62                        Nonza::Challenge(challenge) => {
63                            let response = mechanism
64                                .response(&challenge.data)
65                                .map_err(AuthError::Sasl)?;
66
67                            // Send response and loop
68                            stream
69                                .send(&XmppStreamElement::Sasl(Nonza::Response(Response {
70                                    data: response,
71                                })))
72                                .await?;
73                        }
74                        Nonza::Success(_) => return Ok(stream.initiate_reset()),
75                        Nonza::Failure(failure) => {
76                            return Err(Error::Auth(AuthError::Fail(failure.defined_condition)));
77                        }
78                        _ => {
79                            // Ignore?!
80                        }
81                    },
82                    Some(Ok(el)) => {
83                        return Err(io::Error::new(
84                            io::ErrorKind::InvalidData,
85                            format!(
86                                "unexpected stream element during SASL negotiation: {:?}",
87                                el
88                            ),
89                        )
90                        .into())
91                    }
92                    Some(Err(ReadError::HardError(e))) => return Err(e.into()),
93                    Some(Err(ReadError::ParseError(e))) => {
94                        return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
95                    }
96                    Some(Err(ReadError::SoftTimeout)) => {
97                        // We cannot do anything about soft timeouts here...
98                    }
99                    Some(Err(ReadError::StreamFooterReceived)) | None => {
100                        return Err(Error::Disconnected)
101                    }
102                }
103            }
104        }
105    }
106
107    Err(AuthError::NoMechanism.into())
108}
109
110/// Authenticate to an XMPP server, but do not bind a resource.
111pub async fn client_auth<C: ServerConnector>(
112    server: C,
113    jid: Jid,
114    password: String,
115    timeouts: Timeouts,
116) -> Result<(StreamFeatures, XmppStream<C::Stream>), Error> {
117    let username = jid.node().unwrap().as_str();
118
119    let (xmpp_stream, channel_binding) = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
120    let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
121
122    let creds = Credentials::default()
123        .with_username(username)
124        .with_password(password)
125        .with_channel_binding(channel_binding);
126    // Authenticated (unspecified) stream
127    let stream = auth(xmpp_stream, &features.sasl_mechanisms, creds).await?;
128    let stream = stream
129        .send_header(StreamHeader {
130            to: Some(Cow::Borrowed(jid.domain().as_str())),
131            from: None,
132            id: None,
133        })
134        .await?;
135    Ok(stream.recv_features().await?)
136}