1#![allow(missing_docs)]
2use std::{
3 net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs},
4 task::{Context, Poll},
5};
6
7use futures::{FutureExt, future::BoxFuture};
8use hyper::client::connect::dns::Name;
9use snafu::ResultExt;
10use tokio::task::spawn_blocking;
11use tower::Service;
12
13pub struct LookupIp(std::vec::IntoIter<SocketAddr>);
14
15#[derive(Debug, Clone, Copy)]
16pub(super) struct Resolver;
17
18impl Resolver {
19 pub(crate) async fn lookup_ip(self, name: String) -> Result<LookupIp, DnsError> {
20 let dummy_port = 9;
26 if name == "localhost" {
28 Ok(LookupIp(
31 vec![SocketAddr::new(Ipv4Addr::LOCALHOST.into(), dummy_port)].into_iter(),
32 ))
33 } else {
34 spawn_blocking(move || {
35 let name_str = name.as_str();
37 let name_ref = name_str
38 .strip_prefix('[')
39 .and_then(|s| s.strip_suffix(']'))
40 .unwrap_or(name_str);
41 (name_ref, dummy_port).to_socket_addrs()
42 })
43 .await
44 .context(JoinSnafu)?
45 .map(LookupIp)
46 .context(UnableLookupSnafu)
47 }
48 }
49}
50
51impl Iterator for LookupIp {
52 type Item = IpAddr;
53
54 fn next(&mut self) -> Option<Self::Item> {
55 self.0.next().map(|address| address.ip())
56 }
57}
58
59impl Service<Name> for Resolver {
60 type Response = LookupIp;
61 type Error = DnsError;
62 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
63
64 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65 Ok(()).into()
66 }
67
68 fn call(&mut self, name: Name) -> Self::Future {
69 self.lookup_ip(name.as_str().to_owned()).boxed()
70 }
71}
72
73#[derive(Debug, snafu::Snafu)]
74pub enum DnsError {
75 #[snafu(display("Unable to resolve name: {}", source))]
76 UnableLookup { source: tokio::io::Error },
77 #[snafu(display("Failed to join with resolving future: {}", source))]
78 JoinError { source: tokio::task::JoinError },
79}
80
81#[cfg(test)]
82mod tests {
83 use super::Resolver;
84
85 async fn resolve(name: &str) -> bool {
86 let resolver = Resolver;
87 resolver.lookup_ip(name.to_owned()).await.is_ok()
88 }
89
90 #[tokio::test]
91 async fn resolve_example() {
92 assert!(resolve("example.com").await);
93 }
94
95 #[tokio::test]
96 async fn resolve_localhost() {
97 assert!(resolve("localhost").await);
98 }
99
100 #[tokio::test]
101 async fn resolve_ipv4() {
102 assert!(resolve("10.0.4.0").await);
103 }
104
105 #[tokio::test]
106 async fn resolve_ipv6() {
107 assert!(resolve("::1").await);
108 }
109}