tokio_xmpp/connect/
dns.rs

1use core::{fmt, net::SocketAddr};
2#[cfg(feature = "dns")]
3use futures::{future::select_ok, FutureExt};
4#[cfg(feature = "dns")]
5use hickory_resolver::{
6    config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
7};
8#[cfg(feature = "dns")]
9use log::debug;
10use tokio::net::TcpStream;
11
12use crate::Error;
13
14/// StartTLS XMPP server connection configuration
15#[derive(Clone, Debug)]
16pub enum DnsConfig {
17    /// Use SRV record to find server host
18    #[cfg(feature = "dns")]
19    UseSrv {
20        /// Hostname to resolve
21        host: String,
22        /// TXT field eg. _xmpp-client._tcp
23        srv: String,
24        /// When SRV resolution fails what port to use
25        fallback_port: u16,
26    },
27
28    /// Manually define server host and port
29    #[allow(unused)]
30    #[cfg(feature = "dns")]
31    NoSrv {
32        /// Server host name
33        host: String,
34        /// Server port
35        port: u16,
36    },
37
38    /// Manually define IP: port (TODO: socket)
39    #[allow(unused)]
40    Addr {
41        /// IP:port
42        addr: String,
43    },
44}
45
46impl fmt::Display for DnsConfig {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            #[cfg(feature = "dns")]
50            Self::UseSrv { host, .. } => write!(f, "{}", host),
51            #[cfg(feature = "dns")]
52            Self::NoSrv { host, port } => write!(f, "{}:{}", host, port),
53            Self::Addr { addr } => write!(f, "{}", addr),
54        }
55    }
56}
57
58impl DnsConfig {
59    /// Constructor for DnsConfig::UseSrv variant
60    #[cfg(feature = "dns")]
61    pub fn srv(host: &str, srv: &str, fallback_port: u16) -> Self {
62        Self::UseSrv {
63            host: host.to_string(),
64            srv: srv.to_string(),
65            fallback_port,
66        }
67    }
68
69    /// Constructor for the default SRV resolution strategy for clients
70    #[cfg(feature = "dns")]
71    pub fn srv_default_client(host: &str) -> Self {
72        Self::UseSrv {
73            host: host.to_string(),
74            srv: "_xmpp-client._tcp".to_string(),
75            fallback_port: 5222,
76        }
77    }
78
79    /// Constructor for DnsConfig::NoSrv variant
80    #[cfg(feature = "dns")]
81    pub fn no_srv(host: &str, port: u16) -> Self {
82        Self::NoSrv {
83            host: host.to_string(),
84            port,
85        }
86    }
87
88    /// Constructor for DnsConfig::Addr variant
89    pub fn addr(addr: &str) -> Self {
90        Self::Addr {
91            addr: addr.to_string(),
92        }
93    }
94
95    /// Try resolve the DnsConfig to a TcpStream
96    pub async fn resolve(&self) -> Result<TcpStream, Error> {
97        match self {
98            #[cfg(feature = "dns")]
99            Self::UseSrv {
100                host,
101                srv,
102                fallback_port,
103            } => Self::resolve_srv(host, srv, *fallback_port).await,
104            #[cfg(feature = "dns")]
105            Self::NoSrv { host, port } => Self::resolve_no_srv(host, *port).await,
106            Self::Addr { addr } => {
107                // TODO: Unix domain socket
108                let addr: SocketAddr = addr.parse()?;
109                return Ok(TcpStream::connect(&SocketAddr::new(addr.ip(), addr.port())).await?);
110            }
111        }
112    }
113
114    #[cfg(feature = "dns")]
115    async fn resolve_srv(host: &str, srv: &str, fallback_port: u16) -> Result<TcpStream, Error> {
116        let ascii_domain = idna::domain_to_ascii(&host)?;
117
118        if let Ok(ip) = ascii_domain.parse() {
119            debug!("Attempting connection to {ip}:{fallback_port}");
120            return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
121        }
122
123        let resolver = TokioAsyncResolver::tokio_from_system_conf()?;
124
125        let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
126        let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
127
128        match srv_records {
129            Some(lookup) => {
130                // TODO: sort lookup records by priority/weight
131                for srv in lookup.iter() {
132                    debug!("Attempting connection to {srv_domain} {srv}");
133                    if let Ok(stream) =
134                        Self::resolve_no_srv(&srv.target().to_ascii(), srv.port()).await
135                    {
136                        return Ok(stream);
137                    }
138                }
139                Err(Error::Disconnected)
140            }
141            None => {
142                // SRV lookup error, retry with hostname
143                debug!("Attempting connection to {host}:{fallback_port}");
144                Self::resolve_no_srv(host, fallback_port).await
145            }
146        }
147    }
148
149    #[cfg(feature = "dns")]
150    async fn resolve_no_srv(host: &str, port: u16) -> Result<TcpStream, Error> {
151        let ascii_domain = idna::domain_to_ascii(&host)?;
152
153        if let Ok(ip) = ascii_domain.parse() {
154            return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
155        }
156
157        let (config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
158        options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
159        let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
160
161        let ips = resolver.lookup_ip(ascii_domain).await?;
162
163        // Happy Eyeballs: connect to all records in parallel, return the
164        // first to succeed
165        select_ok(
166            ips.into_iter()
167                .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
168        )
169        .await
170        .map(|(result, _)| result)
171        .map_err(|_| Error::Disconnected)
172    }
173}