1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
use super::error::Error as StartTlsError;
use crate::Error;
use futures::{future::select_ok, FutureExt};
use hickory_resolver::{
    config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
};
use log::debug;
use std::net::SocketAddr;
use tokio::net::TcpStream;

pub async fn connect_to_host(domain: &str, port: u16) -> Result<TcpStream, Error> {
    let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;

    if let Ok(ip) = ascii_domain.parse() {
        return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
    }

    let (config, mut options) =
        hickory_resolver::system_conf::read_system_conf().map_err(StartTlsError::Resolve)?;
    options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
    let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());

    let ips = resolver
        .lookup_ip(ascii_domain)
        .await
        .map_err(StartTlsError::Resolve)?;
    // Happy Eyeballs: connect to all records in parallel, return the
    // first to succeed
    select_ok(
        ips.into_iter()
            .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
    )
    .await
    .map(|(result, _)| result)
    .map_err(|_| crate::Error::Disconnected)
}

pub async fn connect_with_srv(
    domain: &str,
    srv: &str,
    fallback_port: u16,
) -> Result<TcpStream, Error> {
    let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;

    if let Ok(ip) = ascii_domain.parse() {
        debug!("Attempting connection to {ip}:{fallback_port}");
        return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
    }

    let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(StartTlsError::Resolve)?;

    let srv_domain = format!("{}.{}.", srv, ascii_domain)
        .into_name()
        .map_err(StartTlsError::Dns)?;
    let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();

    match srv_records {
        Some(lookup) => {
            // TODO: sort lookup records by priority/weight
            for srv in lookup.iter() {
                debug!("Attempting connection to {srv_domain} {srv}");
                match connect_to_host(&srv.target().to_ascii(), srv.port()).await {
                    Ok(stream) => return Ok(stream),
                    Err(_) => {}
                }
            }
            Err(crate::Error::Disconnected.into())
        }
        None => {
            // SRV lookup error, retry with hostname
            debug!("Attempting connection to {domain}:{fallback_port}");
            connect_to_host(domain, fallback_port).await
        }
    }
}