1use core::{fmt, net::SocketAddr};
2#[cfg(feature = "dns")]
3use futures::{future::select_ok, FutureExt};
4#[cfg(feature = "dns")]
5use hickory_resolver::{
6 config::LookupIpStrategy, proto::rr::IntoName, proto::rr::RData, TokioResolver,
7};
8#[cfg(feature = "dns")]
9use log::debug;
10use tokio::net::TcpStream;
11
12use crate::Error;
13
14#[derive(Clone, Debug)]
16pub enum DnsConfig {
17 #[cfg(feature = "dns")]
19 UseSrv {
20 host: String,
22 srv: String,
24 fallback_port: u16,
26 resolver: Option<TokioResolver>,
28 },
29
30 #[allow(unused)]
32 #[cfg(feature = "dns")]
33 NoSrv {
34 host: String,
36 port: u16,
38 resolver: Option<TokioResolver>,
40 },
41
42 #[allow(unused)]
44 Addr {
45 addr: String,
47 },
48}
49
50impl fmt::Display for DnsConfig {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 match self {
53 #[cfg(feature = "dns")]
54 Self::UseSrv { host, .. } => write!(f, "{}", host),
55 #[cfg(feature = "dns")]
56 Self::NoSrv { host, port, .. } => write!(f, "{}:{}", host, port),
57 Self::Addr { addr } => write!(f, "{}", addr),
58 }
59 }
60}
61
62impl DnsConfig {
63 #[cfg(feature = "dns")]
65 pub fn srv(host: &str, srv: &str, fallback_port: u16) -> Self {
66 Self::UseSrv {
67 host: host.to_string(),
68 srv: srv.to_string(),
69 fallback_port,
70 resolver: None,
71 }
72 }
73
74 #[cfg(feature = "dns")]
76 pub fn srv_default_client(host: &str) -> Self {
77 Self::UseSrv {
78 host: host.to_string(),
79 srv: "_xmpp-client._tcp".to_string(),
80 fallback_port: 5222,
81 resolver: None,
82 }
83 }
84
85 #[cfg(feature = "dns")]
87 pub fn srv_xmpps(host: &str) -> Self {
88 Self::UseSrv {
89 host: host.to_string(),
90 srv: "_xmpps-client._tcp".to_string(),
91 fallback_port: 5223,
92 resolver: None,
93 }
94 }
95
96 #[cfg(feature = "dns")]
98 pub fn no_srv(host: &str, port: u16) -> Self {
99 Self::NoSrv {
100 host: host.to_string(),
101 port,
102 resolver: None,
103 }
104 }
105
106 pub fn addr(addr: &str) -> Self {
108 Self::Addr {
109 addr: addr.to_string(),
110 }
111 }
112
113 #[cfg(feature = "dns")]
115 pub fn with_resolver(&mut self, custom_resolver: TokioResolver) {
116 match self {
117 Self::UseSrv {
118 ref mut resolver, ..
119 } => *resolver = Some(custom_resolver),
120 Self::NoSrv {
121 ref mut resolver, ..
122 } => *resolver = Some(custom_resolver),
123 Self::Addr { .. } => {}
124 }
125 }
126
127 pub async fn resolve(&self) -> Result<TcpStream, Error> {
129 match self {
130 #[cfg(feature = "dns")]
131 Self::UseSrv {
132 host,
133 srv,
134 fallback_port,
135 resolver,
136 } => Self::resolve_srv(host, srv, *fallback_port, resolver).await,
137 #[cfg(feature = "dns")]
138 Self::NoSrv {
139 host,
140 port,
141 resolver,
142 } => Self::resolve_no_srv(host, *port, resolver).await,
143 Self::Addr { addr } => {
144 let addr: SocketAddr = addr.parse()?;
146 return Ok(TcpStream::connect(&SocketAddr::new(addr.ip(), addr.port())).await?);
147 }
148 }
149 }
150
151 #[cfg(feature = "dns")]
152 async fn resolve_srv(
153 host: &str,
154 srv: &str,
155 fallback_port: u16,
156 resolver: &Option<TokioResolver>,
157 ) -> Result<TcpStream, Error> {
158 let ascii_domain = idna::domain_to_ascii(host)?;
159
160 if let Ok(ip) = ascii_domain.parse() {
161 debug!("Attempting connection to {ip}:{fallback_port}");
162 return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
163 }
164
165 let resolver = Self::new_resolver(resolver)?;
166 let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
167 let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
168
169 let resolver_ref = Some(resolver);
170 match srv_records {
171 Some(lookup) => {
172 for record in lookup.answers().iter() {
174 let RData::SRV(ref srv) = record.data else {
175 continue;
176 };
177
178 debug!("Attempting connection to {srv_domain} {srv}");
179 if let Ok(stream) =
180 Self::resolve_no_srv(&srv.target.to_ascii(), srv.port, &resolver_ref).await
181 {
182 return Ok(stream);
183 }
184 }
185 Err(Error::Disconnected)
186 }
187 None => {
188 debug!("Attempting connection to {host}:{fallback_port}");
190 Self::resolve_no_srv(host, fallback_port, &resolver_ref).await
191 }
192 }
193 }
194
195 #[cfg(feature = "dns")]
196 async fn resolve_no_srv(
197 host: &str,
198 port: u16,
199 resolver: &Option<TokioResolver>,
200 ) -> Result<TcpStream, Error> {
201 let ascii_domain = idna::domain_to_ascii(host)?;
202
203 if let Ok(ip) = ascii_domain.parse() {
204 return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
205 }
206
207 let resolver = Self::new_resolver(resolver)?;
208 let ips = resolver.lookup_ip(ascii_domain).await?;
209
210 select_ok(
213 ips.iter()
214 .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
215 )
216 .await
217 .map(|(result, _)| result)
218 .map_err(|_| Error::Disconnected)
219 }
220
221 #[cfg(feature = "dns")]
222 fn new_resolver(resolver: &Option<TokioResolver>) -> Result<TokioResolver, Error> {
223 if let Some(resolver) = resolver {
224 return Ok(resolver.clone());
225 }
226
227 let (_config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
228 options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
229
230 Ok(TokioResolver::builder_tokio()?
231 .with_options(options)
232 .build()?)
233 }
234}