From 723097fbf61f447abd6cd027f183bdea30598054 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Sat, 26 Nov 2022 16:18:15 +0000 Subject: [PATCH] Deduplicate some logic --- src/core/link_tcp.go | 43 +++++++++++++++++++++++++++++++++++-------- src/core/link_tls.go | 35 +++++++++-------------------------- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/src/core/link_tcp.go b/src/core/link_tcp.go index c8020fe..60054d4 100644 --- a/src/core/link_tcp.go +++ b/src/core/link_tcp.go @@ -31,19 +31,26 @@ func (l *links) newLinkTCP() *linkTCP { return lt } -func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error { +type tcpDialer struct { + info linkInfo + dialer *net.Dialer + addr *net.TCPAddr +} + +func (l *linkTCP) dialersFor(url *url.URL, options linkOptions, sintf string) ([]*tcpDialer, error) { host, p, err := net.SplitHostPort(url.Host) if err != nil { - return err + return nil, err } port, err := strconv.Atoi(p) if err != nil { - return err + return nil, err } ips, err := net.LookupIP(host) if err != nil { - return err + return nil, err } + dialers := make([]*tcpDialer, 0, len(ips)) for _, ip := range ips { addr := &net.TCPAddr{ IP: ip, @@ -55,10 +62,30 @@ func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error { } info := linkInfoFor("tcp", sintf, tcpIDFor(dialer.LocalAddr, addr)) if l.links.isConnectedTo(info) { - return nil + return nil, nil } - conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String()) + dialers = append(dialers, &tcpDialer{ + info: info, + dialer: dialer, + addr: addr, + }) + } + return dialers, nil +} + +func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error { + dialers, err := l.dialersFor(url, options, sintf) + if err != nil { + return err + } + if len(dialers) == 0 { + return nil + } + for _, d := range dialers { + var conn net.Conn + conn, err = d.dialer.DialContext(l.core.ctx, "tcp", d.addr.String()) if err != nil { + l.core.log.Warnf("Failed to connect to %s: %s", d.addr, err) continue } name := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/") @@ -66,9 +93,9 @@ func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error { url: url, sintf: sintf, } - return l.handler(dial, name, info, conn, options, false, false) + return l.handler(dial, name, d.info, conn, options, false, false) } - return fmt.Errorf("failed to connect via %d addresses", len(ips)) + return fmt.Errorf("failed to connect via %d address(es), last error: %w", len(dialers), err) } func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { diff --git a/src/core/link_tls.go b/src/core/link_tls.go index 33ea4dc..6323a72 100644 --- a/src/core/link_tls.go +++ b/src/core/link_tls.go @@ -13,7 +13,6 @@ import ( "math/big" "net" "net/url" - "strconv" "strings" "time" @@ -48,38 +47,22 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS { } func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) error { - host, p, err := net.SplitHostPort(url.Host) + dialers, err := l.tcp.dialersFor(url, options, sintf) if err != nil { return err } - port, err := strconv.Atoi(p) - if err != nil { - return err + if len(dialers) == 0 { + return nil } - ips, err := net.LookupIP(host) - if err != nil { - return err - } - for _, ip := range ips { - addr := &net.TCPAddr{ - IP: ip, - Port: port, - } - dialer, err := l.tcp.dialerFor(addr, sintf) - if err != nil { - continue - } - info := linkInfoFor("tls", sintf, tcpIDFor(dialer.LocalAddr, addr)) - if l.links.isConnectedTo(info) { - return nil - } + for _, d := range dialers { tlsconfig := l.config.Clone() tlsconfig.ServerName = sni tlsdialer := &tls.Dialer{ - NetDialer: dialer, + NetDialer: d.dialer, Config: tlsconfig, } - conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String()) + var conn net.Conn + conn, err = tlsdialer.DialContext(l.core.ctx, "tcp", d.addr.String()) if err != nil { continue } @@ -88,9 +71,9 @@ func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) err url: url, sintf: sintf, } - return l.handler(dial, name, info, conn, options, false, false) + return l.handler(dial, name, d.info, conn, options, false, false) } - return fmt.Errorf("failed to connect via %d addresses", len(ips)) + return fmt.Errorf("failed to connect via %d address(es), last error: %w", len(dialers), err) } func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) {