mirror of
https://github.com/cwinfo/yggdrasil-go.git
synced 2025-01-11 11:55:41 +00:00
244 lines
6.1 KiB
Go
244 lines
6.1 KiB
Go
package core
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Arceliar/phony"
|
|
)
|
|
|
|
type linkTCP struct {
|
|
phony.Inbox
|
|
*links
|
|
listener *net.ListenConfig
|
|
_listeners map[*Listener]context.CancelFunc
|
|
}
|
|
|
|
func (l *links) newLinkTCP() *linkTCP {
|
|
lt := &linkTCP{
|
|
links: l,
|
|
listener: &net.ListenConfig{
|
|
KeepAlive: -1,
|
|
},
|
|
_listeners: map[*Listener]context.CancelFunc{},
|
|
}
|
|
lt.listener.Control = lt.tcpContext
|
|
return lt
|
|
}
|
|
|
|
type tcpDialer struct {
|
|
info linkInfo
|
|
dialer *net.Dialer
|
|
addr *net.TCPAddr
|
|
}
|
|
|
|
func (l *linkTCP) dialersFor(url *url.URL, options linkOptions, sintf string) ([]*tcpDialer, error) {
|
|
host, p, err := net.SplitHostPort(url.Host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
port, err := strconv.Atoi(p)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ips, err := net.LookupIP(host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dialers := make([]*tcpDialer, 0, len(ips))
|
|
for _, ip := range ips {
|
|
addr := &net.TCPAddr{
|
|
IP: ip,
|
|
Port: port,
|
|
}
|
|
dialer, err := l.dialerFor(addr, sintf)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
info := linkInfoFor("tcp", sintf, tcpIDFor(dialer.LocalAddr, addr))
|
|
if l.links.isConnectedTo(info) {
|
|
return nil, nil
|
|
}
|
|
dialers = append(dialers, &tcpDialer{
|
|
info: info,
|
|
dialer: dialer,
|
|
addr: addr,
|
|
})
|
|
}
|
|
return dialers, nil
|
|
}
|
|
|
|
func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error {
|
|
dialers, err := l.dialersFor(url, options, sintf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(dialers) == 0 {
|
|
return nil
|
|
}
|
|
for _, d := range dialers {
|
|
var conn net.Conn
|
|
conn, err = d.dialer.DialContext(l.core.ctx, "tcp", d.addr.String())
|
|
if err != nil {
|
|
l.core.log.Warnf("Failed to connect to %s: %s", d.addr, err)
|
|
continue
|
|
}
|
|
name := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/")
|
|
dial := &linkDial{
|
|
url: url,
|
|
sintf: sintf,
|
|
}
|
|
return l.handler(dial, name, d.info, conn, options, false, false)
|
|
}
|
|
return fmt.Errorf("failed to connect via %d address(es), last error: %w", len(dialers), err)
|
|
}
|
|
|
|
func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
|
|
ctx, cancel := context.WithCancel(l.core.ctx)
|
|
hostport := url.Host
|
|
if sintf != "" {
|
|
if host, port, err := net.SplitHostPort(hostport); err == nil {
|
|
hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port)
|
|
}
|
|
}
|
|
listener, err := l.listener.Listen(ctx, "tcp", hostport)
|
|
if err != nil {
|
|
cancel()
|
|
return nil, err
|
|
}
|
|
entry := &Listener{
|
|
Listener: listener,
|
|
closed: make(chan struct{}),
|
|
}
|
|
phony.Block(l, func() {
|
|
l._listeners[entry] = cancel
|
|
})
|
|
l.core.log.Printf("TCP listener started on %s", listener.Addr())
|
|
go func() {
|
|
defer phony.Block(l, func() {
|
|
delete(l._listeners, entry)
|
|
})
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
cancel()
|
|
break
|
|
}
|
|
laddr := conn.LocalAddr().(*net.TCPAddr)
|
|
raddr := conn.RemoteAddr().(*net.TCPAddr)
|
|
name := fmt.Sprintf("tcp://%s", raddr)
|
|
info := linkInfoFor("tcp", sintf, tcpIDFor(laddr, raddr))
|
|
if err = l.handler(nil, name, info, conn, linkOptionsForListener(url), true, raddr.IP.IsLinkLocalUnicast()); err != nil {
|
|
l.core.log.Errorln("Failed to create inbound link:", err)
|
|
}
|
|
}
|
|
_ = listener.Close()
|
|
close(entry.closed)
|
|
l.core.log.Printf("TCP listener stopped on %s", listener.Addr())
|
|
}()
|
|
return entry, nil
|
|
}
|
|
|
|
func (l *linkTCP) handler(dial *linkDial, name string, info linkInfo, conn net.Conn, options linkOptions, incoming, force bool) error {
|
|
return l.links.create(
|
|
conn, // connection
|
|
dial, // connection URL
|
|
name, // connection name
|
|
info, // connection info
|
|
incoming, // not incoming
|
|
force, // not forced
|
|
options, // connection options
|
|
)
|
|
}
|
|
|
|
// Returns the address of the listener.
|
|
func (l *linkTCP) getAddr() *net.TCPAddr {
|
|
// TODO: Fix this, because this will currently only give a single address
|
|
// to multicast.go, which obviously is not great, but right now multicast.go
|
|
// doesn't have the ability to send more than one address in a packet either
|
|
var addr *net.TCPAddr
|
|
phony.Block(l, func() {
|
|
for listener := range l._listeners {
|
|
addr = listener.Addr().(*net.TCPAddr)
|
|
}
|
|
})
|
|
return addr
|
|
}
|
|
|
|
func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error) {
|
|
if dst.IP.IsLinkLocalUnicast() {
|
|
if sintf != "" {
|
|
dst.Zone = sintf
|
|
}
|
|
if dst.Zone == "" {
|
|
return nil, fmt.Errorf("link-local address requires a zone")
|
|
}
|
|
}
|
|
dialer := &net.Dialer{
|
|
Timeout: time.Second * 5,
|
|
KeepAlive: -1,
|
|
Control: l.tcpContext,
|
|
}
|
|
if sintf != "" {
|
|
dialer.Control = l.getControl(sintf)
|
|
ief, err := net.InterfaceByName(sintf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("interface %q not found", sintf)
|
|
}
|
|
if ief.Flags&net.FlagUp == 0 {
|
|
return nil, fmt.Errorf("interface %q is not up", sintf)
|
|
}
|
|
addrs, err := ief.Addrs()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("interface %q addresses not available: %w", sintf, err)
|
|
}
|
|
for addrindex, addr := range addrs {
|
|
src, _, err := net.ParseCIDR(addr.String())
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if !src.IsGlobalUnicast() && !src.IsLinkLocalUnicast() {
|
|
continue
|
|
}
|
|
bothglobal := src.IsGlobalUnicast() == dst.IP.IsGlobalUnicast()
|
|
bothlinklocal := src.IsLinkLocalUnicast() == dst.IP.IsLinkLocalUnicast()
|
|
if !bothglobal && !bothlinklocal {
|
|
continue
|
|
}
|
|
if (src.To4() != nil) != (dst.IP.To4() != nil) {
|
|
continue
|
|
}
|
|
if bothglobal || bothlinklocal || addrindex == len(addrs)-1 {
|
|
dialer.LocalAddr = &net.TCPAddr{
|
|
IP: src,
|
|
Port: 0,
|
|
Zone: sintf,
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if dialer.LocalAddr == nil {
|
|
return nil, fmt.Errorf("no suitable source address found on interface %q", sintf)
|
|
}
|
|
}
|
|
return dialer, nil
|
|
}
|
|
|
|
func tcpIDFor(local net.Addr, remoteAddr *net.TCPAddr) string {
|
|
if localAddr, ok := local.(*net.TCPAddr); ok && localAddr.IP.Equal(remoteAddr.IP) {
|
|
// Nodes running on the same host — include both the IP and port.
|
|
return remoteAddr.String()
|
|
}
|
|
if remoteAddr.IP.IsLinkLocalUnicast() {
|
|
// Nodes discovered via multicast — include the IP only.
|
|
return remoteAddr.IP.String()
|
|
}
|
|
// Nodes connected remotely — include both the IP and port.
|
|
return remoteAddr.String()
|
|
}
|