Skip to main content

tokio_xmpp/connect/
dns.rs

1use core::{fmt, net::SocketAddr};
2#[cfg(feature = "dns")]
3use futures::{future::select_ok, FutureExt};
4#[cfg(feature = "dns")]
5use hickory_resolver::{
6    config::LookupIpStrategy, proto::rr::IntoName, proto::rr::RData, TokioResolver,
7};
8#[cfg(feature = "dns")]
9use log::debug;
10use tokio::net::TcpStream;
11
12use crate::Error;
13
14/// XMPP server connection configuration
15#[derive(Clone, Debug)]
16pub enum DnsConfig {
17    /// Use SRV record to find server host
18    #[cfg(feature = "dns")]
19    UseSrv {
20        /// Hostname to resolve
21        host: String,
22        /// TXT field eg. _xmpp-client._tcp
23        srv: String,
24        /// When SRV resolution fails what port to use
25        fallback_port: u16,
26        /// Pre-configured DNS resolver
27        resolver: Option<TokioResolver>,
28    },
29
30    /// Manually define server host and port
31    #[allow(unused)]
32    #[cfg(feature = "dns")]
33    NoSrv {
34        /// Server host name
35        host: String,
36        /// Server port
37        port: u16,
38        /// Pre-configured DNS resolver
39        resolver: Option<TokioResolver>,
40    },
41
42    /// Manually define IP: port (TODO: socket)
43    #[allow(unused)]
44    Addr {
45        /// IP:port
46        addr: String,
47    },
48}
49
50impl fmt::Display for DnsConfig {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        match self {
53            #[cfg(feature = "dns")]
54            Self::UseSrv { host, .. } => write!(f, "{}", host),
55            #[cfg(feature = "dns")]
56            Self::NoSrv { host, port, .. } => write!(f, "{}:{}", host, port),
57            Self::Addr { addr } => write!(f, "{}", addr),
58        }
59    }
60}
61
62impl DnsConfig {
63    /// Constructor for DnsConfig::UseSrv variant
64    #[cfg(feature = "dns")]
65    pub fn srv(host: &str, srv: &str, fallback_port: u16) -> Self {
66        Self::UseSrv {
67            host: host.to_string(),
68            srv: srv.to_string(),
69            fallback_port,
70            resolver: None,
71        }
72    }
73
74    /// Constructor for the default SRV resolution strategy for clients (StartTLS)
75    #[cfg(feature = "dns")]
76    pub fn srv_default_client(host: &str) -> Self {
77        Self::UseSrv {
78            host: host.to_string(),
79            srv: "_xmpp-client._tcp".to_string(),
80            fallback_port: 5222,
81            resolver: None,
82        }
83    }
84
85    /// Constructor for direct TLS connections using RFC 7590 _xmpps-client._tcp
86    #[cfg(feature = "dns")]
87    pub fn srv_xmpps(host: &str) -> Self {
88        Self::UseSrv {
89            host: host.to_string(),
90            srv: "_xmpps-client._tcp".to_string(),
91            fallback_port: 5223,
92            resolver: None,
93        }
94    }
95
96    /// Constructor for DnsConfig::NoSrv variant
97    #[cfg(feature = "dns")]
98    pub fn no_srv(host: &str, port: u16) -> Self {
99        Self::NoSrv {
100            host: host.to_string(),
101            port,
102            resolver: None,
103        }
104    }
105
106    /// Constructor for DnsConfig::Addr variant
107    pub fn addr(addr: &str) -> Self {
108        Self::Addr {
109            addr: addr.to_string(),
110        }
111    }
112
113    /// Set pre-configured DNS resolver
114    #[cfg(feature = "dns")]
115    pub fn with_resolver(&mut self, custom_resolver: TokioResolver) {
116        match self {
117            Self::UseSrv {
118                ref mut resolver, ..
119            } => *resolver = Some(custom_resolver),
120            Self::NoSrv {
121                ref mut resolver, ..
122            } => *resolver = Some(custom_resolver),
123            Self::Addr { .. } => {}
124        }
125    }
126
127    /// Try resolve the DnsConfig to a TcpStream
128    pub async fn resolve(&self) -> Result<TcpStream, Error> {
129        match self {
130            #[cfg(feature = "dns")]
131            Self::UseSrv {
132                host,
133                srv,
134                fallback_port,
135                resolver,
136            } => Self::resolve_srv(host, srv, *fallback_port, resolver).await,
137            #[cfg(feature = "dns")]
138            Self::NoSrv {
139                host,
140                port,
141                resolver,
142            } => Self::resolve_no_srv(host, *port, resolver).await,
143            Self::Addr { addr } => {
144                // TODO: Unix domain socket
145                let addr: SocketAddr = addr.parse()?;
146                return Ok(TcpStream::connect(&SocketAddr::new(addr.ip(), addr.port())).await?);
147            }
148        }
149    }
150
151    #[cfg(feature = "dns")]
152    async fn resolve_srv(
153        host: &str,
154        srv: &str,
155        fallback_port: u16,
156        resolver: &Option<TokioResolver>,
157    ) -> Result<TcpStream, Error> {
158        let ascii_domain = idna::domain_to_ascii(host)?;
159
160        if let Ok(ip) = ascii_domain.parse() {
161            debug!("Attempting connection to {ip}:{fallback_port}");
162            return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
163        }
164
165        let resolver = Self::new_resolver(resolver)?;
166        let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
167        let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
168
169        let resolver_ref = Some(resolver);
170        match srv_records {
171            Some(lookup) => {
172                // TODO: sort lookup records by priority/weight
173                for record in lookup.answers().iter() {
174                    let RData::SRV(ref srv) = record.data else {
175                        continue;
176                    };
177
178                    debug!("Attempting connection to {srv_domain} {srv}");
179                    if let Ok(stream) =
180                        Self::resolve_no_srv(&srv.target.to_ascii(), srv.port, &resolver_ref).await
181                    {
182                        return Ok(stream);
183                    }
184                }
185                Err(Error::Disconnected)
186            }
187            None => {
188                // SRV lookup error, retry with hostname
189                debug!("Attempting connection to {host}:{fallback_port}");
190                Self::resolve_no_srv(host, fallback_port, &resolver_ref).await
191            }
192        }
193    }
194
195    #[cfg(feature = "dns")]
196    async fn resolve_no_srv(
197        host: &str,
198        port: u16,
199        resolver: &Option<TokioResolver>,
200    ) -> Result<TcpStream, Error> {
201        let ascii_domain = idna::domain_to_ascii(host)?;
202
203        if let Ok(ip) = ascii_domain.parse() {
204            return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
205        }
206
207        let resolver = Self::new_resolver(resolver)?;
208        let ips = resolver.lookup_ip(ascii_domain).await?;
209
210        // Happy Eyeballs: connect to all records in parallel, return the
211        // first to succeed
212        select_ok(
213            ips.iter()
214                .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
215        )
216        .await
217        .map(|(result, _)| result)
218        .map_err(|_| Error::Disconnected)
219    }
220
221    #[cfg(feature = "dns")]
222    fn new_resolver(resolver: &Option<TokioResolver>) -> Result<TokioResolver, Error> {
223        if let Some(resolver) = resolver {
224            return Ok(resolver.clone());
225        }
226
227        let (_config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
228        options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
229
230        Ok(TokioResolver::builder_tokio()?
231            .with_options(options)
232            .build()?)
233    }
234}