tokio_xmpp/connect/
dns.rs

1use core::{fmt, net::SocketAddr};
2#[cfg(feature = "dns")]
3use futures::{future::select_ok, FutureExt};
4#[cfg(feature = "dns")]
5use hickory_resolver::{
6    config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
7};
8#[cfg(feature = "dns")]
9use log::debug;
10use tokio::net::TcpStream;
11
12use crate::Error;
13
14/// XMPP server connection configuration
15#[derive(Clone, Debug)]
16pub enum DnsConfig {
17    /// Use SRV record to find server host
18    #[cfg(feature = "dns")]
19    UseSrv {
20        /// Hostname to resolve
21        host: String,
22        /// TXT field eg. _xmpp-client._tcp
23        srv: String,
24        /// When SRV resolution fails what port to use
25        fallback_port: u16,
26    },
27
28    /// Manually define server host and port
29    #[allow(unused)]
30    #[cfg(feature = "dns")]
31    NoSrv {
32        /// Server host name
33        host: String,
34        /// Server port
35        port: u16,
36    },
37
38    /// Manually define IP: port (TODO: socket)
39    #[allow(unused)]
40    Addr {
41        /// IP:port
42        addr: String,
43    },
44}
45
46impl fmt::Display for DnsConfig {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            #[cfg(feature = "dns")]
50            Self::UseSrv { host, .. } => write!(f, "{}", host),
51            #[cfg(feature = "dns")]
52            Self::NoSrv { host, port } => write!(f, "{}:{}", host, port),
53            Self::Addr { addr } => write!(f, "{}", addr),
54        }
55    }
56}
57
58impl DnsConfig {
59    /// Constructor for DnsConfig::UseSrv variant
60    #[cfg(feature = "dns")]
61    pub fn srv(host: &str, srv: &str, fallback_port: u16) -> Self {
62        Self::UseSrv {
63            host: host.to_string(),
64            srv: srv.to_string(),
65            fallback_port,
66        }
67    }
68
69    /// Constructor for the default SRV resolution strategy for clients (StartTLS)
70    #[cfg(feature = "dns")]
71    pub fn srv_default_client(host: &str) -> Self {
72        Self::UseSrv {
73            host: host.to_string(),
74            srv: "_xmpp-client._tcp".to_string(),
75            fallback_port: 5222,
76        }
77    }
78
79    /// Constructor for direct TLS connections using RFC 7590 _xmpps-client._tcp
80    #[cfg(feature = "dns")]
81    pub fn srv_xmpps(host: &str) -> Self {
82        Self::UseSrv {
83            host: host.to_string(),
84            srv: "_xmpps-client._tcp".to_string(),
85            fallback_port: 5223,
86        }
87    }
88
89    /// Constructor for DnsConfig::NoSrv variant
90    #[cfg(feature = "dns")]
91    pub fn no_srv(host: &str, port: u16) -> Self {
92        Self::NoSrv {
93            host: host.to_string(),
94            port,
95        }
96    }
97
98    /// Constructor for DnsConfig::Addr variant
99    pub fn addr(addr: &str) -> Self {
100        Self::Addr {
101            addr: addr.to_string(),
102        }
103    }
104
105    /// Try resolve the DnsConfig to a TcpStream
106    pub async fn resolve(&self) -> Result<TcpStream, Error> {
107        match self {
108            #[cfg(feature = "dns")]
109            Self::UseSrv {
110                host,
111                srv,
112                fallback_port,
113            } => Self::resolve_srv(host, srv, *fallback_port).await,
114            #[cfg(feature = "dns")]
115            Self::NoSrv { host, port } => Self::resolve_no_srv(host, *port).await,
116            Self::Addr { addr } => {
117                // TODO: Unix domain socket
118                let addr: SocketAddr = addr.parse()?;
119                return Ok(TcpStream::connect(&SocketAddr::new(addr.ip(), addr.port())).await?);
120            }
121        }
122    }
123
124    #[cfg(feature = "dns")]
125    async fn resolve_srv(host: &str, srv: &str, fallback_port: u16) -> Result<TcpStream, Error> {
126        let ascii_domain = idna::domain_to_ascii(host)?;
127
128        if let Ok(ip) = ascii_domain.parse() {
129            debug!("Attempting connection to {ip}:{fallback_port}");
130            return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
131        }
132
133        let resolver = TokioAsyncResolver::tokio_from_system_conf()?;
134
135        let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
136        let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
137
138        match srv_records {
139            Some(lookup) => {
140                // TODO: sort lookup records by priority/weight
141                for srv in lookup.iter() {
142                    debug!("Attempting connection to {srv_domain} {srv}");
143                    if let Ok(stream) =
144                        Self::resolve_no_srv(&srv.target().to_ascii(), srv.port()).await
145                    {
146                        return Ok(stream);
147                    }
148                }
149                Err(Error::Disconnected)
150            }
151            None => {
152                // SRV lookup error, retry with hostname
153                debug!("Attempting connection to {host}:{fallback_port}");
154                Self::resolve_no_srv(host, fallback_port).await
155            }
156        }
157    }
158
159    #[cfg(feature = "dns")]
160    async fn resolve_no_srv(host: &str, port: u16) -> Result<TcpStream, Error> {
161        let ascii_domain = idna::domain_to_ascii(host)?;
162
163        if let Ok(ip) = ascii_domain.parse() {
164            return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
165        }
166
167        let (config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
168        options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
169        let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
170
171        let ips = resolver.lookup_ip(ascii_domain).await?;
172
173        // Happy Eyeballs: connect to all records in parallel, return the
174        // first to succeed
175        select_ok(
176            ips.into_iter()
177                .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
178        )
179        .await
180        .map(|(result, _)| result)
181        .map_err(|_| Error::Disconnected)
182    }
183}