From 110613b234025078990143749cd30f0b8a964e92 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 8 Nov 2022 21:59:13 +0000 Subject: [PATCH] Try all addresses when connecting to a DNS name Fixes #980 --- src/core/link_tcp.go | 56 +++++++++++++++++++++++--------------------- src/core/link_tls.go | 55 +++++++++++++++++++++++++++---------------- 2 files changed, 64 insertions(+), 47 deletions(-) diff --git a/src/core/link_tcp.go b/src/core/link_tcp.go index 9c3c329..4392fb7 100644 --- a/src/core/link_tcp.go +++ b/src/core/link_tcp.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/url" + "strconv" "strings" "time" @@ -31,24 +32,39 @@ func (l *links) newLinkTCP() *linkTCP { } func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error { - addr, err := net.ResolveTCPAddr("tcp", url.Host) - if err != nil { - return err - } - dialer, err := l.dialerFor(addr, sintf) - if err != nil { - return err - } - info := linkInfoFor("tcp", sintf, tcpIDFor(dialer.LocalAddr, addr)) + info := linkInfoFor("tcp", sintf, url.Host) if l.links.isConnectedTo(info) { return nil } - conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String()) + host, p, err := net.SplitHostPort(url.Host) if err != nil { return err } - uri := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/") - return l.handler(uri, info, conn, options, false, false) + port, err := strconv.Atoi(p) + if err != nil { + return err + } + ips, err := net.LookupIP(host) + if err != nil { + return err + } + for _, ip := range ips { + addr := &net.TCPAddr{ + IP: ip, + Port: port, + } + dialer, err := l.dialerFor(addr, sintf) + if err != nil { + continue + } + conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String()) + if err != nil { + continue + } + uri := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/") + return l.handler(uri, info, conn, options, false, false) + } + return fmt.Errorf("failed to connect via %d addresses", len(ips)) } func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { @@ -82,10 +98,9 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { cancel() break } - laddr := conn.LocalAddr().(*net.TCPAddr) raddr := conn.RemoteAddr().(*net.TCPAddr) name := fmt.Sprintf("tcp://%s", raddr) - info := linkInfoFor("tcp", sintf, tcpIDFor(laddr, raddr)) + info := linkInfoFor("tcp", sintf, raddr.String()) if err = l.handler(name, info, conn, linkOptionsForListener(url), true, raddr.IP.IsLinkLocalUnicast()); err != nil { l.core.log.Errorln("Failed to create inbound link:", err) } @@ -180,16 +195,3 @@ func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error) } return dialer, nil } - -func tcpIDFor(local net.Addr, remoteAddr *net.TCPAddr) string { - if localAddr, ok := local.(*net.TCPAddr); ok && localAddr.IP.Equal(remoteAddr.IP) { - // Nodes running on the same host — include both the IP and port. - return remoteAddr.String() - } - if remoteAddr.IP.IsLinkLocalUnicast() { - // Nodes discovered via multicast — include the IP only. - return remoteAddr.IP.String() - } - // Nodes connected remotely — include both the IP and port. - return remoteAddr.String() -} diff --git a/src/core/link_tls.go b/src/core/link_tls.go index 4eeb871..8e7f870 100644 --- a/src/core/link_tls.go +++ b/src/core/link_tls.go @@ -13,6 +13,7 @@ import ( "math/big" "net" "net/url" + "strconv" "strings" "time" @@ -47,30 +48,45 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS { } func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) error { - addr, err := net.ResolveTCPAddr("tcp", url.Host) - if err != nil { - return err - } - dialer, err := l.tcp.dialerFor(addr, sintf) - if err != nil { - return err - } - info := linkInfoFor("tls", sintf, tcpIDFor(dialer.LocalAddr, addr)) + info := linkInfoFor("tls", sintf, url.Host) if l.links.isConnectedTo(info) { return nil } - tlsconfig := l.config.Clone() - tlsconfig.ServerName = sni - tlsdialer := &tls.Dialer{ - NetDialer: dialer, - Config: tlsconfig, - } - conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String()) + host, p, err := net.SplitHostPort(url.Host) if err != nil { return err } - uri := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/") - return l.handler(uri, info, conn, options, false, false) + port, err := strconv.Atoi(p) + if err != nil { + return err + } + ips, err := net.LookupIP(host) + if err != nil { + return err + } + for _, ip := range ips { + addr := &net.TCPAddr{ + IP: ip, + Port: port, + } + dialer, err := l.tcp.dialerFor(addr, sintf) + if err != nil { + continue + } + tlsconfig := l.config.Clone() + tlsconfig.ServerName = sni + tlsdialer := &tls.Dialer{ + NetDialer: dialer, + Config: tlsconfig, + } + conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String()) + if err != nil { + continue + } + uri := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/") + return l.handler(uri, info, conn, options, false, false) + } + return fmt.Errorf("failed to connect via %d addresses", len(ips)) } func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { @@ -105,10 +121,9 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { cancel() break } - laddr := conn.LocalAddr().(*net.TCPAddr) raddr := conn.RemoteAddr().(*net.TCPAddr) name := fmt.Sprintf("tls://%s", raddr) - info := linkInfoFor("tls", sintf, tcpIDFor(laddr, raddr)) + info := linkInfoFor("tls", sintf, raddr.String()) if err = l.handler(name, info, conn, linkOptionsForListener(url), true, raddr.IP.IsLinkLocalUnicast()); err != nil { l.core.log.Errorln("Failed to create inbound link:", err) }