use futures::stream::StreamExt;
use sasl::client::mechanisms::{Anonymous, Plain, Scram};
use sasl::client::Mechanism;
use sasl::common::scram::{Sha1, Sha256};
use sasl::common::Credentials;
use std::collections::HashSet;
use std::str::FromStr;
use tokio::io::{AsyncRead, AsyncWrite};
use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
use crate::error::{AuthError, Error, ProtocolError};
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::XMPPStream;
pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
mut stream: XMPPStream<S>,
creds: Credentials,
) -> Result<S, Error> {
let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism + Send + Sync> + Send>> = vec![
Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
Box::new(|| Box::new(Anonymous::new())),
];
let remote_mechs: HashSet<String> = stream
.stream_features
.sasl_mechanisms
.mechanisms
.iter()
.cloned()
.collect();
for local_mech in local_mechs {
let mut mechanism = local_mech();
if remote_mechs.contains(mechanism.name()) {
let initial = mechanism.initial();
let mechanism_name =
XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
stream
.send_stanza(Auth {
mechanism: mechanism_name,
data: initial,
})
.await?;
loop {
match stream.next().await {
Some(Ok(Packet::Stanza(stanza))) => {
if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
let response = mechanism
.response(&challenge.data)
.map_err(|e| AuthError::Sasl(e))?;
stream.send_stanza(Response { data: response }).await?;
} else if let Ok(_) = Success::try_from(stanza.clone()) {
return Ok(stream.into_inner());
} else if let Ok(failure) = Failure::try_from(stanza.clone()) {
return Err(Error::Auth(AuthError::Fail(failure.defined_condition)));
} else {
}
}
Some(Ok(_)) => {
}
Some(Err(e)) => return Err(e),
None => return Err(Error::Disconnected),
}
}
}
}
Err(AuthError::NoMechanism.into())
}