1use crate::client::{receiver::ClientReceiver, sender::ClientSender};
8use crate::connect::ServerConnector;
9use crate::error::Error;
10use crate::event::{ensure_stanza_id, Event};
11use crate::stanzastream::{self, StanzaStage, StanzaState, StanzaStream, StanzaToken};
12use crate::xmlstream::Timeouts;
13use crate::Stanza;
14use std::io;
15use std::sync::Arc;
16use tokio::sync::{mpsc, oneshot, Mutex};
17use tokio::task::JoinHandle;
18use xmpp_parsers::{
19 jid::{FullJid, Jid},
20 stream_features::StreamFeatures,
21};
22
23#[cfg(feature = "direct-tls")]
24use crate::connect::DirectTlsServerConnector;
25#[cfg(any(feature = "direct-tls", feature = "starttls", feature = "insecure-tcp"))]
26use crate::connect::DnsConfig;
27#[cfg(feature = "starttls")]
28use crate::connect::StartTlsServerConnector;
29#[cfg(feature = "insecure-tcp")]
30use crate::connect::TcpServerConnector;
31
32mod iq;
33pub(crate) mod login;
34pub(crate) mod receiver;
35pub(crate) mod sender;
36mod stream;
37mod worker;
38
39pub use iq::{IqFailure, IqRequest, IqResponse, IqResponseToken};
40pub use login::auth;
41
42#[derive(Debug)]
50pub struct Client {
51 stanza_rx: mpsc::Receiver<Event>,
53 stream_tx: stanzastream::StanzaSender,
55 shutdown_tx: oneshot::Sender<()>,
57 worker: JoinHandle<stanzastream::StanzaReceiver>,
59 bound_jid: Option<FullJid>,
61 features: Option<StreamFeatures>,
63 iq_response_tracker: iq::IqResponseTracker,
65}
66
67impl Client {
68 pub fn bound_jid(&self) -> Option<&FullJid> {
71 self.bound_jid.as_ref()
72 }
73
74 pub async fn send_stanza(&mut self, mut stanza: Stanza) -> Result<StanzaToken, io::Error> {
91 ensure_stanza_id(&mut stanza);
92 let mut token = self.stream_tx.send(Box::new(stanza)).await;
93
94 match token.wait_for(StanzaStage::Sent).await {
95 Some(StanzaState::Queued) => unreachable!(),
97
98 None | Some(StanzaState::Dropped) => Err(io::Error::new(
99 io::ErrorKind::NotConnected,
100 "stream disconnected fatally before stanza could be sent",
101 )),
102 Some(StanzaState::Failed { error }) => Err(error.into_io_error()),
103 Some(StanzaState::Sent { .. }) | Some(StanzaState::Acked { .. }) => Ok(token),
104 }
105 }
106
107 pub async fn send_iq(&mut self, to: Option<Jid>, req: IqRequest) -> IqResponseToken {
118 let (iq, mut token) = self.iq_response_tracker.allocate_iq_handle(
119 None, to, req,
121 );
122 let stanza_token = self.stream_tx.send(Box::new(iq.into())).await;
123
124 token.set_stanza_token(stanza_token);
125 token
126 }
127
128 pub fn get_stream_features(&self) -> Option<&StreamFeatures> {
135 self.features.as_ref()
136 }
137
138 pub async fn send_end(self) -> Result<(), Error> {
143 self.shutdown_tx.send(()).expect("ClientWorker crashed.");
144
145 let stream_rx = self.worker.await.unwrap();
146 let stream = StanzaStream::reunite(self.stream_tx, stream_rx);
147 stream.close().await;
148
149 Ok(())
150 }
151
152 pub fn split(self) -> (ClientSender, ClientReceiver) {
154 let client = Arc::new(Mutex::new(self));
155
156 let sender = ClientSender(client.clone());
157 let receiver = ClientReceiver(client);
158
159 (sender, receiver)
160 }
161
162 pub fn reunite(sender: ClientSender, receiver: ClientReceiver) -> Self {
169 assert!(
170 Arc::ptr_eq(&sender.0, &receiver.0),
171 "Unrelated ClientSender and ClientReceiver passed to reunite."
172 );
173
174 drop(sender);
175
176 let inner = Arc::try_unwrap(receiver.0).expect("Failed to unwrap ClientReceiver Arc");
177 inner.into_inner()
178 }
179}
180
181#[cfg(feature = "direct-tls")]
182impl Client {
183 pub fn new_direct_tls<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
187 let jid_ref = jid.into();
188 let dns_config = DnsConfig::srv_xmpps(jid_ref.domain().as_ref());
189 Self::new_with_connector(
190 jid_ref,
191 password,
192 DirectTlsServerConnector::from(dns_config),
193 Timeouts::default(),
194 )
195 }
196
197 pub fn new_direct_tls_with_config<J: Into<Jid>, P: Into<String>>(
200 jid: J,
201 password: P,
202 dns_config: DnsConfig,
203 timeouts: Timeouts,
204 ) -> Self {
205 Self::new_with_connector(
206 jid,
207 password,
208 DirectTlsServerConnector::from(dns_config),
209 timeouts,
210 )
211 }
212}
213
214#[cfg(feature = "starttls")]
215impl Client {
216 pub fn new<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
221 let jid = jid.into();
222 let dns_config = DnsConfig::srv_default_client(jid.domain().as_ref());
223 Self::new_starttls(jid, password, dns_config, Timeouts::default())
224 }
225
226 pub fn new_starttls<J: Into<Jid>, P: Into<String>>(
228 jid: J,
229 password: P,
230 dns_config: DnsConfig,
231 timeouts: Timeouts,
232 ) -> Self {
233 Self::new_with_connector(
234 jid,
235 password,
236 StartTlsServerConnector::from(dns_config),
237 timeouts,
238 )
239 }
240}
241
242#[cfg(feature = "insecure-tcp")]
243impl Client {
244 pub fn new_plaintext<J: Into<Jid>, P: Into<String>>(
246 jid: J,
247 password: P,
248 dns_config: DnsConfig,
249 timeouts: Timeouts,
250 ) -> Self {
251 Self::new_with_connector(
252 jid,
253 password,
254 TcpServerConnector::from(dns_config),
255 timeouts,
256 )
257 }
258}
259
260impl Client {
261 pub fn new_with_connector<J: Into<Jid>, P: Into<String>, C: ServerConnector>(
263 jid: J,
264 password: P,
265 connector: C,
266 timeouts: Timeouts,
267 ) -> Self {
268 let stream = StanzaStream::new_c2s(connector, jid.into(), password.into(), timeouts, 16);
269 let (stream_tx, stream_rx) = stream.split();
270
271 let iq_response_tracker = iq::IqResponseTracker::new();
272 let (worker, shutdown_tx, stanza_rx) =
273 worker::ClientWorker::new(stream_rx, iq_response_tracker.clone(), 16);
274
275 let worker = tokio::task::spawn(async move { worker.run().await });
276
277 Self {
278 stream_tx,
279 stanza_rx,
280 worker,
281 shutdown_tx,
282 iq_response_tracker,
283 bound_jid: None,
284 features: None,
285 }
286 }
287}