diff --git a/src/address/address.go b/src/address/address.go index 3960b78..eba6170 100644 --- a/src/address/address.go +++ b/src/address/address.go @@ -2,7 +2,11 @@ // Of particular importance are the functions used to derive addresses or subnets from a NodeID, or to get the NodeID and bitmask of the bits visible from an address, which is needed for DHT searches. package address -import "github.com/yggdrasil-network/yggdrasil-go/src/crypto" +import ( + "fmt" + + "github.com/yggdrasil-network/yggdrasil-go/src/crypto" +) // Address represents an IPv6 address in the yggdrasil address range. type Address [16]byte @@ -128,6 +132,13 @@ func (a *Address) GetNodeIDandMask() (*crypto.NodeID, *crypto.NodeID) { return &nid, &mask } +// GetNodeIDLengthString returns a string representation of the known bits of the NodeID, along with the number of known bits, for use with yggdrasil.Dialer's Dial and DialContext functions. +func (a *Address) GetNodeIDLengthString() string { + nid, mask := a.GetNodeIDandMask() + l := mask.PrefixLength() + return fmt.Sprintf("%s/%d", nid.String(), l) +} + // GetNodeIDandMask returns two *NodeID. // The first is a NodeID with all the bits known from the Subnet set to their correct values. // The second is a bitmask with 1 bit set for each bit that was known from the Subnet. @@ -156,3 +167,10 @@ func (s *Subnet) GetNodeIDandMask() (*crypto.NodeID, *crypto.NodeID) { } return &nid, &mask } + +// GetNodeIDLengthString returns a string representation of the known bits of the NodeID, along with the number of known bits, for use with yggdrasil.Dialer's Dial and DialContext functions. +func (s *Subnet) GetNodeIDLengthString() string { + nid, mask := s.GetNodeIDandMask() + l := mask.PrefixLength() + return fmt.Sprintf("%s/%d", nid.String(), l) +} diff --git a/src/tuntap/iface.go b/src/tuntap/iface.go index 0da9963..3d788b1 100644 --- a/src/tuntap/iface.go +++ b/src/tuntap/iface.go @@ -9,6 +9,7 @@ import ( "github.com/yggdrasil-network/yggdrasil-go/src/address" "github.com/yggdrasil-network/yggdrasil-go/src/crypto" "github.com/yggdrasil-network/yggdrasil-go/src/util" + "github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil" "github.com/Arceliar/phony" ) @@ -225,7 +226,7 @@ func (tun *TunAdapter) _handlePacket(recvd []byte, err error) { return } // Do we have an active connection for this node address? - var dstNodeID, dstNodeIDMask *crypto.NodeID + var dstString string session, isIn := tun.addrToConn[dstAddr] if !isIn || session == nil { session, isIn = tun.subnetToConn[dstSnet] @@ -233,9 +234,9 @@ func (tun *TunAdapter) _handlePacket(recvd []byte, err error) { // Neither an address nor a subnet mapping matched, therefore populate // the node ID and mask to commence a search if dstAddr.IsValid() { - dstNodeID, dstNodeIDMask = dstAddr.GetNodeIDandMask() + dstString = dstAddr.GetNodeIDLengthString() } else { - dstNodeID, dstNodeIDMask = dstSnet.GetNodeIDandMask() + dstString = dstSnet.GetNodeIDLengthString() } } } @@ -243,27 +244,27 @@ func (tun *TunAdapter) _handlePacket(recvd []byte, err error) { if !isIn || session == nil { // Check we haven't been given empty node ID, really this shouldn't ever // happen but just to be sure... - if dstNodeID == nil || dstNodeIDMask == nil { - panic("Given empty dstNodeID and dstNodeIDMask - this shouldn't happen") + if dstString == "" { + panic("Given empty dstString - this shouldn't happen") } - _, known := tun.dials[*dstNodeID] - tun.dials[*dstNodeID] = append(tun.dials[*dstNodeID], bs) - for len(tun.dials[*dstNodeID]) > 32 { - util.PutBytes(tun.dials[*dstNodeID][0]) - tun.dials[*dstNodeID] = tun.dials[*dstNodeID][1:] + _, known := tun.dials[dstString] + tun.dials[dstString] = append(tun.dials[dstString], bs) + for len(tun.dials[dstString]) > 32 { + util.PutBytes(tun.dials[dstString][0]) + tun.dials[dstString] = tun.dials[dstString][1:] } if !known { go func() { - conn, err := tun.dialer.DialByNodeIDandMask(dstNodeID, dstNodeIDMask) + conn, err := tun.dialer.Dial("nodeid", dstString) tun.Act(nil, func() { - packets := tun.dials[*dstNodeID] - delete(tun.dials, *dstNodeID) + packets := tun.dials[dstString] + delete(tun.dials, dstString) if err != nil { return } // We've been given a connection so prepare the session wrapper var tc *tunConn - if tc, err = tun._wrap(conn); err != nil { + if tc, err = tun._wrap(conn.(*yggdrasil.Conn)); err != nil { // Something went wrong when storing the connection, typically that // something already exists for this address or subnet tun.log.Debugln("TUN/TAP iface wrap:", err) diff --git a/src/tuntap/tun.go b/src/tuntap/tun.go index 0bef909..5d77eca 100644 --- a/src/tuntap/tun.go +++ b/src/tuntap/tun.go @@ -52,7 +52,7 @@ type TunAdapter struct { //mutex sync.RWMutex // Protects the below addrToConn map[address.Address]*tunConn subnetToConn map[address.Subnet]*tunConn - dials map[crypto.NodeID][][]byte // Buffer of packets to send after dialing finishes + dials map[string][][]byte // Buffer of packets to send after dialing finishes isOpen bool } @@ -117,7 +117,7 @@ func (tun *TunAdapter) Init(config *config.NodeState, log *log.Logger, listener tun.dialer = dialer tun.addrToConn = make(map[address.Address]*tunConn) tun.subnetToConn = make(map[address.Subnet]*tunConn) - tun.dials = make(map[crypto.NodeID][][]byte) + tun.dials = make(map[string][][]byte) tun.writer.tun = tun tun.reader.tun = tun } diff --git a/src/yggdrasil/dialer.go b/src/yggdrasil/dialer.go index 0441085..e3f24e9 100644 --- a/src/yggdrasil/dialer.go +++ b/src/yggdrasil/dialer.go @@ -1,8 +1,10 @@ package yggdrasil import ( + "context" "encoding/hex" "errors" + "net" "strconv" "strings" "time" @@ -15,12 +17,18 @@ type Dialer struct { core *Core } -// TODO DialContext that allows timeouts/cancellation, Dial should just call this with no timeout set in the context - // Dial opens a session to the given node. The first paramter should be "nodeid" // and the second parameter should contain a hexadecimal representation of the -// target node ID. -func (d *Dialer) Dial(network, address string) (*Conn, error) { +// target node ID. Internally, it uses DialContext with a 6-second timeout. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + const timeout = 6 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return d.DialContext(ctx, network, address) +} + +// DialContext is used internally by Dial, and should only be used with a context that includes a timeout. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { var nodeID crypto.NodeID var nodeMask crypto.NodeID // Process @@ -28,7 +36,7 @@ func (d *Dialer) Dial(network, address string) (*Conn, error) { case "nodeid": // A node ID was provided - we don't need to do anything special with it if tokens := strings.Split(address, "/"); len(tokens) == 2 { - len, err := strconv.Atoi(tokens[1]) + l, err := strconv.Atoi(tokens[1]) if err != nil { return nil, err } @@ -37,7 +45,7 @@ func (d *Dialer) Dial(network, address string) (*Conn, error) { return nil, err } copy(nodeID[:], dest) - for idx := 0; idx < len; idx++ { + for idx := 0; idx < l; idx++ { nodeMask[idx/8] |= 0x80 >> byte(idx%8) } } else { @@ -50,7 +58,7 @@ func (d *Dialer) Dial(network, address string) (*Conn, error) { nodeMask[i] = 0xFF } } - return d.DialByNodeIDandMask(&nodeID, &nodeMask) + return d.DialByNodeIDandMask(ctx, &nodeID, &nodeMask) default: // An unexpected address type was given, so give up return nil, errors.New("unexpected address type") @@ -59,19 +67,18 @@ func (d *Dialer) Dial(network, address string) (*Conn, error) { // DialByNodeIDandMask opens a session to the given node based on raw // NodeID parameters. -func (d *Dialer) DialByNodeIDandMask(nodeID, nodeMask *crypto.NodeID) (*Conn, error) { +func (d *Dialer) DialByNodeIDandMask(ctx context.Context, nodeID, nodeMask *crypto.NodeID) (net.Conn, error) { conn := newConn(d.core, nodeID, nodeMask, nil) if err := conn.search(); err != nil { + // TODO: make searches take a context, so they can be cancelled early conn.Close() return nil, err } conn.session.setConn(nil, conn) - t := time.NewTimer(6 * time.Second) // TODO use a context instead - defer t.Stop() select { case <-conn.session.init: return conn, nil - case <-t.C: + case <-ctx.Done(): conn.Close() return nil, errors.New("session handshake timeout") }