tokio_xmpp/connect/
dns.rs1use core::{fmt, net::SocketAddr};
2#[cfg(feature = "dns")]
3use futures::{future::select_ok, FutureExt};
4#[cfg(feature = "dns")]
5use hickory_resolver::{
6 config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
7};
8#[cfg(feature = "dns")]
9use log::debug;
10use tokio::net::TcpStream;
11
12use crate::Error;
13
14#[derive(Clone, Debug)]
16pub enum DnsConfig {
17 #[cfg(feature = "dns")]
19 UseSrv {
20 host: String,
22 srv: String,
24 fallback_port: u16,
26 },
27
28 #[allow(unused)]
30 #[cfg(feature = "dns")]
31 NoSrv {
32 host: String,
34 port: u16,
36 },
37
38 #[allow(unused)]
40 Addr {
41 addr: String,
43 },
44}
45
46impl fmt::Display for DnsConfig {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 match self {
49 #[cfg(feature = "dns")]
50 Self::UseSrv { host, .. } => write!(f, "{}", host),
51 #[cfg(feature = "dns")]
52 Self::NoSrv { host, port } => write!(f, "{}:{}", host, port),
53 Self::Addr { addr } => write!(f, "{}", addr),
54 }
55 }
56}
57
58impl DnsConfig {
59 #[cfg(feature = "dns")]
61 pub fn srv(host: &str, srv: &str, fallback_port: u16) -> Self {
62 Self::UseSrv {
63 host: host.to_string(),
64 srv: srv.to_string(),
65 fallback_port,
66 }
67 }
68
69 #[cfg(feature = "dns")]
71 pub fn srv_default_client(host: &str) -> Self {
72 Self::UseSrv {
73 host: host.to_string(),
74 srv: "_xmpp-client._tcp".to_string(),
75 fallback_port: 5222,
76 }
77 }
78
79 #[cfg(feature = "dns")]
81 pub fn srv_xmpps(host: &str) -> Self {
82 Self::UseSrv {
83 host: host.to_string(),
84 srv: "_xmpps-client._tcp".to_string(),
85 fallback_port: 5223,
86 }
87 }
88
89 #[cfg(feature = "dns")]
91 pub fn no_srv(host: &str, port: u16) -> Self {
92 Self::NoSrv {
93 host: host.to_string(),
94 port,
95 }
96 }
97
98 pub fn addr(addr: &str) -> Self {
100 Self::Addr {
101 addr: addr.to_string(),
102 }
103 }
104
105 pub async fn resolve(&self) -> Result<TcpStream, Error> {
107 match self {
108 #[cfg(feature = "dns")]
109 Self::UseSrv {
110 host,
111 srv,
112 fallback_port,
113 } => Self::resolve_srv(host, srv, *fallback_port).await,
114 #[cfg(feature = "dns")]
115 Self::NoSrv { host, port } => Self::resolve_no_srv(host, *port).await,
116 Self::Addr { addr } => {
117 let addr: SocketAddr = addr.parse()?;
119 return Ok(TcpStream::connect(&SocketAddr::new(addr.ip(), addr.port())).await?);
120 }
121 }
122 }
123
124 #[cfg(feature = "dns")]
125 async fn resolve_srv(host: &str, srv: &str, fallback_port: u16) -> Result<TcpStream, Error> {
126 let ascii_domain = idna::domain_to_ascii(host)?;
127
128 if let Ok(ip) = ascii_domain.parse() {
129 debug!("Attempting connection to {ip}:{fallback_port}");
130 return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
131 }
132
133 let resolver = TokioAsyncResolver::tokio_from_system_conf()?;
134
135 let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
136 let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
137
138 match srv_records {
139 Some(lookup) => {
140 for srv in lookup.iter() {
142 debug!("Attempting connection to {srv_domain} {srv}");
143 if let Ok(stream) =
144 Self::resolve_no_srv(&srv.target().to_ascii(), srv.port()).await
145 {
146 return Ok(stream);
147 }
148 }
149 Err(Error::Disconnected)
150 }
151 None => {
152 debug!("Attempting connection to {host}:{fallback_port}");
154 Self::resolve_no_srv(host, fallback_port).await
155 }
156 }
157 }
158
159 #[cfg(feature = "dns")]
160 async fn resolve_no_srv(host: &str, port: u16) -> Result<TcpStream, Error> {
161 let ascii_domain = idna::domain_to_ascii(host)?;
162
163 if let Ok(ip) = ascii_domain.parse() {
164 return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
165 }
166
167 let (config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
168 options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
169 let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
170
171 let ips = resolver.lookup_ip(ascii_domain).await?;
172
173 select_ok(
176 ips.into_iter()
177 .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
178 )
179 .await
180 .map(|(result, _)| result)
181 .map_err(|_| Error::Disconnected)
182 }
183}