1use crate::client::{receiver::ClientReceiver, sender::ClientSender};
8use crate::connect::ServerConnector;
9use crate::error::Error;
10use crate::event::Event;
11use crate::stanzastream::{self, StanzaStage, StanzaState, StanzaStream, StanzaToken};
12use crate::xmlstream::Timeouts;
13use crate::Stanza;
14use std::io;
15use std::sync::Arc;
16use tokio::sync::{mpsc, oneshot, Mutex};
17use tokio::task::JoinHandle;
18use xmpp_parsers::{jid::Jid, stream_features::StreamFeatures};
19
20#[cfg(feature = "direct-tls")]
21use crate::connect::DirectTlsServerConnector;
22#[cfg(any(feature = "direct-tls", feature = "starttls", feature = "insecure-tcp"))]
23use crate::connect::DnsConfig;
24#[cfg(feature = "starttls")]
25use crate::connect::StartTlsServerConnector;
26#[cfg(feature = "insecure-tcp")]
27use crate::connect::TcpServerConnector;
28
29mod iq;
30pub(crate) mod login;
31pub(crate) mod receiver;
32pub(crate) mod sender;
33mod stream;
34mod worker;
35
36pub use iq::{IqFailure, IqRequest, IqResponse, IqResponseToken};
37
38#[derive(Debug)]
46pub struct Client {
47 stanza_rx: mpsc::Receiver<Event>,
49 stream_tx: stanzastream::StanzaSender,
51 shutdown_tx: oneshot::Sender<()>,
53 worker: JoinHandle<stanzastream::StanzaReceiver>,
55 bound_jid: Option<Jid>,
57 features: Option<StreamFeatures>,
59 iq_response_tracker: iq::IqResponseTracker,
61}
62
63impl Client {
64 pub fn bound_jid(&self) -> Option<&Jid> {
67 self.bound_jid.as_ref()
68 }
69
70 pub async fn send_stanza(&mut self, mut stanza: Stanza) -> Result<StanzaToken, io::Error> {
87 stanza.ensure_id();
88 let mut token = self.stream_tx.send(Box::new(stanza)).await;
89
90 match token.wait_for(StanzaStage::Sent).await {
91 Some(StanzaState::Queued) => unreachable!(),
93
94 None | Some(StanzaState::Dropped) => Err(io::Error::new(
95 io::ErrorKind::NotConnected,
96 "stream disconnected fatally before stanza could be sent",
97 )),
98 Some(StanzaState::Failed { error }) => Err(error.into_io_error()),
99 Some(StanzaState::Sent { .. }) | Some(StanzaState::Acked { .. }) => Ok(token),
100 }
101 }
102
103 pub async fn send_iq(&mut self, to: Option<Jid>, req: IqRequest) -> IqResponseToken {
114 let (iq, mut token) = self.iq_response_tracker.allocate_iq_handle(
115 None, to, req,
117 );
118 let stanza_token = self.stream_tx.send(Box::new(iq.into())).await;
119
120 token.set_stanza_token(stanza_token);
121 token
122 }
123
124 pub fn get_stream_features(&self) -> Option<&StreamFeatures> {
131 self.features.as_ref()
132 }
133
134 pub async fn send_end(self) -> Result<(), Error> {
139 self.shutdown_tx.send(()).expect("ClientWorker crashed.");
140
141 let stream_rx = self.worker.await.unwrap();
142 let stream = StanzaStream::reunite(self.stream_tx, stream_rx);
143 stream.close().await;
144
145 Ok(())
146 }
147
148 pub fn split(self) -> (ClientSender, ClientReceiver) {
150 let client = Arc::new(Mutex::new(self));
151
152 let sender = ClientSender(client.clone());
153 let receiver = ClientReceiver(client);
154
155 (sender, receiver)
156 }
157
158 pub fn reunite(sender: ClientSender, receiver: ClientReceiver) -> Self {
165 assert!(
166 Arc::ptr_eq(&sender.0, &receiver.0),
167 "Unrelated ClientSender and ClientReceiver passed to reunite."
168 );
169
170 drop(sender);
171
172 let inner = Arc::try_unwrap(receiver.0).expect("Failed to unwrap ClientReceiver Arc");
173 inner.into_inner()
174 }
175}
176
177#[cfg(feature = "direct-tls")]
178impl Client {
179 pub fn new_direct_tls<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
183 let jid_ref = jid.into();
184 let dns_config = DnsConfig::srv_xmpps(jid_ref.domain().as_ref());
185 Self::new_with_connector(
186 jid_ref,
187 password,
188 DirectTlsServerConnector::from(dns_config),
189 Timeouts::default(),
190 )
191 }
192
193 pub fn new_direct_tls_with_config<J: Into<Jid>, P: Into<String>>(
196 jid: J,
197 password: P,
198 dns_config: DnsConfig,
199 timeouts: Timeouts,
200 ) -> Self {
201 Self::new_with_connector(
202 jid,
203 password,
204 DirectTlsServerConnector::from(dns_config),
205 timeouts,
206 )
207 }
208}
209
210#[cfg(feature = "starttls")]
211impl Client {
212 pub fn new<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
217 let jid = jid.into();
218 let dns_config = DnsConfig::srv_default_client(jid.domain().as_ref());
219 Self::new_starttls(jid, password, dns_config, Timeouts::default())
220 }
221
222 pub fn new_starttls<J: Into<Jid>, P: Into<String>>(
224 jid: J,
225 password: P,
226 dns_config: DnsConfig,
227 timeouts: Timeouts,
228 ) -> Self {
229 Self::new_with_connector(
230 jid,
231 password,
232 StartTlsServerConnector::from(dns_config),
233 timeouts,
234 )
235 }
236}
237
238#[cfg(feature = "insecure-tcp")]
239impl Client {
240 pub fn new_plaintext<J: Into<Jid>, P: Into<String>>(
242 jid: J,
243 password: P,
244 dns_config: DnsConfig,
245 timeouts: Timeouts,
246 ) -> Self {
247 Self::new_with_connector(
248 jid,
249 password,
250 TcpServerConnector::from(dns_config),
251 timeouts,
252 )
253 }
254}
255
256impl Client {
257 pub fn new_with_connector<J: Into<Jid>, P: Into<String>, C: ServerConnector>(
259 jid: J,
260 password: P,
261 connector: C,
262 timeouts: Timeouts,
263 ) -> Self {
264 let stream = StanzaStream::new_c2s(connector, jid.into(), password.into(), timeouts, 16);
265 let (stream_tx, stream_rx) = stream.split();
266
267 let iq_response_tracker = iq::IqResponseTracker::new();
268 let (worker, shutdown_tx, stanza_rx) =
269 worker::ClientWorker::new(stream_rx, iq_response_tracker.clone(), 16);
270
271 let worker = tokio::task::spawn(async move { worker.run().await });
272
273 Self {
274 stream_tx,
275 stanza_rx,
276 worker,
277 shutdown_tx,
278 iq_response_tracker,
279 bound_jid: None,
280 features: None,
281 }
282 }
283}