1use crate::client::{receiver::ClientReceiver, sender::ClientSender};
8use crate::connect::ServerConnector;
9use crate::error::Error;
10use crate::event::Event;
11use crate::stanzastream::{self, StanzaStage, StanzaState, StanzaStream, StanzaToken};
12use crate::xmlstream::Timeouts;
13use crate::Stanza;
14use std::io;
15use std::sync::Arc;
16use tokio::sync::{mpsc, oneshot, Mutex};
17use tokio::task::JoinHandle;
18use xmpp_parsers::{jid::Jid, stream_features::StreamFeatures};
19
20#[cfg(feature = "direct-tls")]
21use crate::connect::DirectTlsServerConnector;
22#[cfg(any(feature = "direct-tls", feature = "starttls", feature = "insecure-tcp"))]
23use crate::connect::DnsConfig;
24#[cfg(feature = "starttls")]
25use crate::connect::StartTlsServerConnector;
26#[cfg(feature = "insecure-tcp")]
27use crate::connect::TcpServerConnector;
28
29mod iq;
30pub(crate) mod login;
31pub(crate) mod receiver;
32pub(crate) mod sender;
33mod stream;
34mod worker;
35
36pub use iq::{IqFailure, IqRequest, IqResponse, IqResponseToken};
37pub use login::auth;
38
39#[derive(Debug)]
47pub struct Client {
48 stanza_rx: mpsc::Receiver<Event>,
50 stream_tx: stanzastream::StanzaSender,
52 shutdown_tx: oneshot::Sender<()>,
54 worker: JoinHandle<stanzastream::StanzaReceiver>,
56 bound_jid: Option<Jid>,
58 features: Option<StreamFeatures>,
60 iq_response_tracker: iq::IqResponseTracker,
62}
63
64impl Client {
65 pub fn bound_jid(&self) -> Option<&Jid> {
68 self.bound_jid.as_ref()
69 }
70
71 pub async fn send_stanza(&mut self, mut stanza: Stanza) -> Result<StanzaToken, io::Error> {
88 stanza.ensure_id();
89 let mut token = self.stream_tx.send(Box::new(stanza)).await;
90
91 match token.wait_for(StanzaStage::Sent).await {
92 Some(StanzaState::Queued) => unreachable!(),
94
95 None | Some(StanzaState::Dropped) => Err(io::Error::new(
96 io::ErrorKind::NotConnected,
97 "stream disconnected fatally before stanza could be sent",
98 )),
99 Some(StanzaState::Failed { error }) => Err(error.into_io_error()),
100 Some(StanzaState::Sent { .. }) | Some(StanzaState::Acked { .. }) => Ok(token),
101 }
102 }
103
104 pub async fn send_iq(&mut self, to: Option<Jid>, req: IqRequest) -> IqResponseToken {
115 let (iq, mut token) = self.iq_response_tracker.allocate_iq_handle(
116 None, to, req,
118 );
119 let stanza_token = self.stream_tx.send(Box::new(iq.into())).await;
120
121 token.set_stanza_token(stanza_token);
122 token
123 }
124
125 pub fn get_stream_features(&self) -> Option<&StreamFeatures> {
132 self.features.as_ref()
133 }
134
135 pub async fn send_end(self) -> Result<(), Error> {
140 self.shutdown_tx.send(()).expect("ClientWorker crashed.");
141
142 let stream_rx = self.worker.await.unwrap();
143 let stream = StanzaStream::reunite(self.stream_tx, stream_rx);
144 stream.close().await;
145
146 Ok(())
147 }
148
149 pub fn split(self) -> (ClientSender, ClientReceiver) {
151 let client = Arc::new(Mutex::new(self));
152
153 let sender = ClientSender(client.clone());
154 let receiver = ClientReceiver(client);
155
156 (sender, receiver)
157 }
158
159 pub fn reunite(sender: ClientSender, receiver: ClientReceiver) -> Self {
166 assert!(
167 Arc::ptr_eq(&sender.0, &receiver.0),
168 "Unrelated ClientSender and ClientReceiver passed to reunite."
169 );
170
171 drop(sender);
172
173 let inner = Arc::try_unwrap(receiver.0).expect("Failed to unwrap ClientReceiver Arc");
174 inner.into_inner()
175 }
176}
177
178#[cfg(feature = "direct-tls")]
179impl Client {
180 pub fn new_direct_tls<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
184 let jid_ref = jid.into();
185 let dns_config = DnsConfig::srv_xmpps(jid_ref.domain().as_ref());
186 Self::new_with_connector(
187 jid_ref,
188 password,
189 DirectTlsServerConnector::from(dns_config),
190 Timeouts::default(),
191 )
192 }
193
194 pub fn new_direct_tls_with_config<J: Into<Jid>, P: Into<String>>(
197 jid: J,
198 password: P,
199 dns_config: DnsConfig,
200 timeouts: Timeouts,
201 ) -> Self {
202 Self::new_with_connector(
203 jid,
204 password,
205 DirectTlsServerConnector::from(dns_config),
206 timeouts,
207 )
208 }
209}
210
211#[cfg(feature = "starttls")]
212impl Client {
213 pub fn new<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
218 let jid = jid.into();
219 let dns_config = DnsConfig::srv_default_client(jid.domain().as_ref());
220 Self::new_starttls(jid, password, dns_config, Timeouts::default())
221 }
222
223 pub fn new_starttls<J: Into<Jid>, P: Into<String>>(
225 jid: J,
226 password: P,
227 dns_config: DnsConfig,
228 timeouts: Timeouts,
229 ) -> Self {
230 Self::new_with_connector(
231 jid,
232 password,
233 StartTlsServerConnector::from(dns_config),
234 timeouts,
235 )
236 }
237}
238
239#[cfg(feature = "insecure-tcp")]
240impl Client {
241 pub fn new_plaintext<J: Into<Jid>, P: Into<String>>(
243 jid: J,
244 password: P,
245 dns_config: DnsConfig,
246 timeouts: Timeouts,
247 ) -> Self {
248 Self::new_with_connector(
249 jid,
250 password,
251 TcpServerConnector::from(dns_config),
252 timeouts,
253 )
254 }
255}
256
257impl Client {
258 pub fn new_with_connector<J: Into<Jid>, P: Into<String>, C: ServerConnector>(
260 jid: J,
261 password: P,
262 connector: C,
263 timeouts: Timeouts,
264 ) -> Self {
265 let stream = StanzaStream::new_c2s(connector, jid.into(), password.into(), timeouts, 16);
266 let (stream_tx, stream_rx) = stream.split();
267
268 let iq_response_tracker = iq::IqResponseTracker::new();
269 let (worker, shutdown_tx, stanza_rx) =
270 worker::ClientWorker::new(stream_rx, iq_response_tracker.clone(), 16);
271
272 let worker = tokio::task::spawn(async move { worker.run().await });
273
274 Self {
275 stream_tx,
276 stanza_rx,
277 worker,
278 shutdown_tx,
279 iq_response_tracker,
280 bound_jid: None,
281 features: None,
282 }
283 }
284}