tokio_xmpp/connect/
dns.rs

1use core::{fmt, net::SocketAddr};
2#[cfg(feature = "dns")]
3use futures::{future::select_ok, FutureExt};
4#[cfg(feature = "dns")]
5use hickory_resolver::{config::LookupIpStrategy, IntoName, TokioResolver};
6#[cfg(feature = "dns")]
7use log::debug;
8use tokio::net::TcpStream;
9
10use crate::Error;
11
12/// XMPP server connection configuration
13#[derive(Clone, Debug)]
14pub enum DnsConfig {
15    /// Use SRV record to find server host
16    #[cfg(feature = "dns")]
17    UseSrv {
18        /// Hostname to resolve
19        host: String,
20        /// TXT field eg. _xmpp-client._tcp
21        srv: String,
22        /// When SRV resolution fails what port to use
23        fallback_port: u16,
24    },
25
26    /// Manually define server host and port
27    #[allow(unused)]
28    #[cfg(feature = "dns")]
29    NoSrv {
30        /// Server host name
31        host: String,
32        /// Server port
33        port: u16,
34    },
35
36    /// Manually define IP: port (TODO: socket)
37    #[allow(unused)]
38    Addr {
39        /// IP:port
40        addr: String,
41    },
42}
43
44impl fmt::Display for DnsConfig {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            #[cfg(feature = "dns")]
48            Self::UseSrv { host, .. } => write!(f, "{}", host),
49            #[cfg(feature = "dns")]
50            Self::NoSrv { host, port } => write!(f, "{}:{}", host, port),
51            Self::Addr { addr } => write!(f, "{}", addr),
52        }
53    }
54}
55
56impl DnsConfig {
57    /// Constructor for DnsConfig::UseSrv variant
58    #[cfg(feature = "dns")]
59    pub fn srv(host: &str, srv: &str, fallback_port: u16) -> Self {
60        Self::UseSrv {
61            host: host.to_string(),
62            srv: srv.to_string(),
63            fallback_port,
64        }
65    }
66
67    /// Constructor for the default SRV resolution strategy for clients (StartTLS)
68    #[cfg(feature = "dns")]
69    pub fn srv_default_client(host: &str) -> Self {
70        Self::UseSrv {
71            host: host.to_string(),
72            srv: "_xmpp-client._tcp".to_string(),
73            fallback_port: 5222,
74        }
75    }
76
77    /// Constructor for direct TLS connections using RFC 7590 _xmpps-client._tcp
78    #[cfg(feature = "dns")]
79    pub fn srv_xmpps(host: &str) -> Self {
80        Self::UseSrv {
81            host: host.to_string(),
82            srv: "_xmpps-client._tcp".to_string(),
83            fallback_port: 5223,
84        }
85    }
86
87    /// Constructor for DnsConfig::NoSrv variant
88    #[cfg(feature = "dns")]
89    pub fn no_srv(host: &str, port: u16) -> Self {
90        Self::NoSrv {
91            host: host.to_string(),
92            port,
93        }
94    }
95
96    /// Constructor for DnsConfig::Addr variant
97    pub fn addr(addr: &str) -> Self {
98        Self::Addr {
99            addr: addr.to_string(),
100        }
101    }
102
103    /// Try resolve the DnsConfig to a TcpStream
104    pub async fn resolve(&self) -> Result<TcpStream, Error> {
105        match self {
106            #[cfg(feature = "dns")]
107            Self::UseSrv {
108                host,
109                srv,
110                fallback_port,
111            } => Self::resolve_srv(host, srv, *fallback_port).await,
112            #[cfg(feature = "dns")]
113            Self::NoSrv { host, port } => Self::resolve_no_srv(host, *port).await,
114            Self::Addr { addr } => {
115                // TODO: Unix domain socket
116                let addr: SocketAddr = addr.parse()?;
117                return Ok(TcpStream::connect(&SocketAddr::new(addr.ip(), addr.port())).await?);
118            }
119        }
120    }
121
122    #[cfg(feature = "dns")]
123    async fn resolve_srv(host: &str, srv: &str, fallback_port: u16) -> Result<TcpStream, Error> {
124        let ascii_domain = idna::domain_to_ascii(host)?;
125
126        if let Ok(ip) = ascii_domain.parse() {
127            debug!("Attempting connection to {ip}:{fallback_port}");
128            return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
129        }
130
131        let (_config, options) = hickory_resolver::system_conf::read_system_conf()?;
132        let resolver = TokioResolver::builder_tokio()?
133            .with_options(options)
134            .build();
135
136        let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
137        let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
138
139        match srv_records {
140            Some(lookup) => {
141                // TODO: sort lookup records by priority/weight
142                for srv in lookup.iter() {
143                    debug!("Attempting connection to {srv_domain} {srv}");
144                    if let Ok(stream) =
145                        Self::resolve_no_srv(&srv.target().to_ascii(), srv.port()).await
146                    {
147                        return Ok(stream);
148                    }
149                }
150                Err(Error::Disconnected)
151            }
152            None => {
153                // SRV lookup error, retry with hostname
154                debug!("Attempting connection to {host}:{fallback_port}");
155                Self::resolve_no_srv(host, fallback_port).await
156            }
157        }
158    }
159
160    #[cfg(feature = "dns")]
161    async fn resolve_no_srv(host: &str, port: u16) -> Result<TcpStream, Error> {
162        let ascii_domain = idna::domain_to_ascii(host)?;
163
164        if let Ok(ip) = ascii_domain.parse() {
165            return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
166        }
167
168        let (_config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
169        options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
170        let resolver = TokioResolver::builder_tokio()?
171            .with_options(options)
172            .build();
173
174        let ips = resolver.lookup_ip(ascii_domain).await?;
175
176        // Happy Eyeballs: connect to all records in parallel, return the
177        // first to succeed
178        select_ok(
179            ips.into_iter()
180                .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
181        )
182        .await
183        .map(|(result, _)| result)
184        .map_err(|_| Error::Disconnected)
185    }
186}