Skip to main content

tokio_xmpp/connect/
dns.rs

1use core::{fmt, net::SocketAddr};
2#[cfg(feature = "dns")]
3use futures::{future::select_ok, FutureExt};
4#[cfg(feature = "dns")]
5use hickory_resolver::{
6    config::LookupIpStrategy, proto::rr::IntoName, proto::rr::RData, TokioResolver,
7};
8#[cfg(feature = "dns")]
9use log::debug;
10use tokio::net::TcpStream;
11
12use crate::Error;
13
14/// XMPP server connection configuration
15#[derive(Clone, Debug)]
16pub enum DnsConfig {
17    /// Use SRV record to find server host
18    #[cfg(feature = "dns")]
19    UseSrv {
20        /// Hostname to resolve
21        host: String,
22        /// TXT field eg. _xmpp-client._tcp
23        srv: String,
24        /// When SRV resolution fails what port to use
25        fallback_port: u16,
26    },
27
28    /// Manually define server host and port
29    #[allow(unused)]
30    #[cfg(feature = "dns")]
31    NoSrv {
32        /// Server host name
33        host: String,
34        /// Server port
35        port: u16,
36    },
37
38    /// Manually define IP: port (TODO: socket)
39    #[allow(unused)]
40    Addr {
41        /// IP:port
42        addr: String,
43    },
44}
45
46impl fmt::Display for DnsConfig {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            #[cfg(feature = "dns")]
50            Self::UseSrv { host, .. } => write!(f, "{}", host),
51            #[cfg(feature = "dns")]
52            Self::NoSrv { host, port } => write!(f, "{}:{}", host, port),
53            Self::Addr { addr } => write!(f, "{}", addr),
54        }
55    }
56}
57
58impl DnsConfig {
59    /// Constructor for DnsConfig::UseSrv variant
60    #[cfg(feature = "dns")]
61    pub fn srv(host: &str, srv: &str, fallback_port: u16) -> Self {
62        Self::UseSrv {
63            host: host.to_string(),
64            srv: srv.to_string(),
65            fallback_port,
66        }
67    }
68
69    /// Constructor for the default SRV resolution strategy for clients (StartTLS)
70    #[cfg(feature = "dns")]
71    pub fn srv_default_client(host: &str) -> Self {
72        Self::UseSrv {
73            host: host.to_string(),
74            srv: "_xmpp-client._tcp".to_string(),
75            fallback_port: 5222,
76        }
77    }
78
79    /// Constructor for direct TLS connections using RFC 7590 _xmpps-client._tcp
80    #[cfg(feature = "dns")]
81    pub fn srv_xmpps(host: &str) -> Self {
82        Self::UseSrv {
83            host: host.to_string(),
84            srv: "_xmpps-client._tcp".to_string(),
85            fallback_port: 5223,
86        }
87    }
88
89    /// Constructor for DnsConfig::NoSrv variant
90    #[cfg(feature = "dns")]
91    pub fn no_srv(host: &str, port: u16) -> Self {
92        Self::NoSrv {
93            host: host.to_string(),
94            port,
95        }
96    }
97
98    /// Constructor for DnsConfig::Addr variant
99    pub fn addr(addr: &str) -> Self {
100        Self::Addr {
101            addr: addr.to_string(),
102        }
103    }
104
105    /// Try resolve the DnsConfig to a TcpStream
106    pub async fn resolve(&self) -> Result<TcpStream, Error> {
107        match self {
108            #[cfg(feature = "dns")]
109            Self::UseSrv {
110                host,
111                srv,
112                fallback_port,
113            } => Self::resolve_srv(host, srv, *fallback_port).await,
114            #[cfg(feature = "dns")]
115            Self::NoSrv { host, port } => Self::resolve_no_srv(host, *port).await,
116            Self::Addr { addr } => {
117                // TODO: Unix domain socket
118                let addr: SocketAddr = addr.parse()?;
119                return Ok(TcpStream::connect(&SocketAddr::new(addr.ip(), addr.port())).await?);
120            }
121        }
122    }
123
124    #[cfg(feature = "dns")]
125    async fn resolve_srv(host: &str, srv: &str, fallback_port: u16) -> Result<TcpStream, Error> {
126        let ascii_domain = idna::domain_to_ascii(host)?;
127
128        if let Ok(ip) = ascii_domain.parse() {
129            debug!("Attempting connection to {ip}:{fallback_port}");
130            return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
131        }
132
133        let (_config, options) = hickory_resolver::system_conf::read_system_conf()?;
134        let resolver = TokioResolver::builder_tokio()?
135            .with_options(options)
136            .build()?;
137
138        let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
139        let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
140
141        match srv_records {
142            Some(lookup) => {
143                // TODO: sort lookup records by priority/weight
144                for record in lookup.answers().iter() {
145                    let RData::SRV(ref srv) = record.data else {
146                        continue;
147                    };
148
149                    debug!("Attempting connection to {srv_domain} {srv}");
150                    if let Ok(stream) = Self::resolve_no_srv(&srv.target.to_ascii(), srv.port).await
151                    {
152                        return Ok(stream);
153                    }
154                }
155                Err(Error::Disconnected)
156            }
157            None => {
158                // SRV lookup error, retry with hostname
159                debug!("Attempting connection to {host}:{fallback_port}");
160                Self::resolve_no_srv(host, fallback_port).await
161            }
162        }
163    }
164
165    #[cfg(feature = "dns")]
166    async fn resolve_no_srv(host: &str, port: u16) -> Result<TcpStream, Error> {
167        let ascii_domain = idna::domain_to_ascii(host)?;
168
169        if let Ok(ip) = ascii_domain.parse() {
170            return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
171        }
172
173        let (_config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
174        options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
175        let resolver = TokioResolver::builder_tokio()?
176            .with_options(options)
177            .build()?;
178
179        let ips = resolver.lookup_ip(ascii_domain).await?;
180
181        // Happy Eyeballs: connect to all records in parallel, return the
182        // first to succeed
183        select_ok(
184            ips.iter()
185                .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
186        )
187        .await
188        .map(|(result, _)| result)
189        .map_err(|_| Error::Disconnected)
190    }
191}