diff --git a/src/tuntap/tun.go b/src/tuntap/tun.go index 07264da..c65c15e 100644 --- a/src/tuntap/tun.go +++ b/src/tuntap/tun.go @@ -28,19 +28,20 @@ const tun_ETHER_HEADER_LENGTH = 14 // you should pass this object to the yggdrasil.SetRouterAdapter() function // before calling yggdrasil.Start(). type TunAdapter struct { - config *config.NodeState - log *log.Logger - reconfigure chan chan error - listener *yggdrasil.Listener - dialer *yggdrasil.Dialer - addr address.Address - subnet address.Subnet - icmpv6 ICMPv6 - mtu int - iface *water.Interface - mutex sync.RWMutex // Protects the below - conns map[crypto.NodeID]*yggdrasil.Conn - isOpen bool + config *config.NodeState + log *log.Logger + reconfigure chan chan error + listener *yggdrasil.Listener + dialer *yggdrasil.Dialer + addr address.Address + subnet address.Subnet + icmpv6 ICMPv6 + mtu int + iface *water.Interface + mutex sync.RWMutex // Protects the below + addrToConn map[address.Address]*yggdrasil.Conn + subnetToConn map[address.Subnet]*yggdrasil.Conn + isOpen bool } // Gets the maximum supported MTU for the platform based on the defaults in @@ -102,7 +103,8 @@ func (tun *TunAdapter) Init(config *config.NodeState, log *log.Logger, listener tun.log = log tun.listener = listener tun.dialer = dialer - tun.conns = make(map[crypto.NodeID]*yggdrasil.Conn) + tun.addrToConn = make(map[address.Address]*yggdrasil.Conn) + tun.subnetToConn = make(map[address.Subnet]*yggdrasil.Conn) } // Start the setup process for the TUN/TAP adapter. If successful, starts the @@ -181,23 +183,40 @@ func (tun *TunAdapter) handler() error { func (tun *TunAdapter) connReader(conn *yggdrasil.Conn) error { remoteNodeID := conn.RemoteAddr() - tun.mutex.Lock() - if _, isIn := tun.conns[remoteNodeID]; isIn { - tun.mutex.Unlock() - return errors.New("duplicate connection") + remoteAddr := address.AddrForNodeID(&remoteNodeID) + remoteSubnet := address.SubnetForNodeID(&remoteNodeID) + err := func() error { + tun.mutex.RLock() + defer tun.mutex.RUnlock() + if _, isIn := tun.addrToConn[*remoteAddr]; isIn { + return errors.New("duplicate connection for address " + net.IP(remoteAddr[:]).String()) + } + if _, isIn := tun.subnetToConn[*remoteSubnet]; isIn { + return errors.New("duplicate connection for subnet " + net.IP(remoteSubnet[:]).String()) + } + return nil + }() + if err != nil { + //return err + panic(err) } - tun.conns[remoteNodeID] = conn + // Store the connection mapped to address and subnet + tun.mutex.Lock() + tun.addrToConn[*remoteAddr] = conn + tun.subnetToConn[*remoteSubnet] = conn tun.mutex.Unlock() + // Make sure to clean those up later when the connection is closed defer func() { tun.mutex.Lock() - delete(tun.conns, remoteNodeID) + delete(tun.addrToConn, *remoteAddr) + delete(tun.subnetToConn, *remoteSubnet) tun.mutex.Unlock() }() b := make([]byte, 65535) for { n, err := conn.Read(b) if err != nil { - tun.log.Errorln(conn.String(), "TUN/TAP conn read error:", err) + //tun.log.Errorln(conn.String(), "TUN/TAP conn read error:", err) continue } if n == 0 { @@ -261,21 +280,28 @@ func (tun *TunAdapter) ifaceReader() error { // For now don't deal with any non-Yggdrasil ranges continue } - dstNodeID, dstNodeIDMask = dstAddr.GetNodeIDandMask() - // Do we have an active connection for this node ID? + // Do we have an active connection for this node address? tun.mutex.RLock() - conn, isIn := tun.conns[*dstNodeID] + conn, isIn := tun.addrToConn[dstAddr] + if !isIn || conn == nil { + conn, isIn = tun.subnetToConn[dstSnet] + if !isIn || conn == nil { + // Neither an address nor a subnet mapping matched, therefore populate + // the node ID and mask to commence a search + dstNodeID, dstNodeIDMask = dstAddr.GetNodeIDandMask() + } + } tun.mutex.RUnlock() // If we don't have a connection then we should open one - if !isIn { + if !isIn || conn == nil { + // Check we haven't been given empty node ID, really this shouldn't ever + // happen but just to be sure... + if dstNodeID == nil || dstNodeIDMask == nil { + panic("Given empty dstNodeID and dstNodeIDMask - this shouldn't happen") + } // Dial to the remote node if c, err := tun.dialer.DialByNodeIDandMask(dstNodeID, dstNodeIDMask); err == nil { - // We've been given a connection, so save it in our connections so we - // can refer to it the next time we send a packet to this destination - tun.mutex.Lock() - tun.conns[*dstNodeID] = &c - tun.mutex.Unlock() - // Start the connection reader goroutine + // We've been given a connection so start the connection reader goroutine go tun.connReader(&c) // Then update our reference to the connection conn, isIn = &c, true @@ -285,9 +311,10 @@ func (tun *TunAdapter) ifaceReader() error { continue } } - // If we have an open connection, either because we already had one or - // because we opened one above, try writing the packet to it - if isIn && conn != nil { + // If we have a connection now, try writing to it + if conn != nil { + // If we have an open connection, either because we already had one or + // because we opened one above, try writing the packet to it w, err := conn.Write(bs[:n]) if err != nil { tun.log.Errorln(conn.String(), "TUN/TAP conn write error:", err) diff --git a/src/yggdrasil/conn.go b/src/yggdrasil/conn.go index 0accf16..903152d 100644 --- a/src/yggdrasil/conn.go +++ b/src/yggdrasil/conn.go @@ -20,7 +20,6 @@ type Conn struct { session *sessionInfo readDeadline atomic.Value // time.Time // TODO timer writeDeadline atomic.Value // time.Time // TODO timer - expired atomic.Value // bool searching atomic.Value // bool } @@ -30,39 +29,58 @@ func (c *Conn) String() string { // This method should only be called from the router goroutine func (c *Conn) startSearch() { + // The searchCompleted callback is given to the search searchCompleted := func(sinfo *sessionInfo, err error) { + // Update the connection with the fact that the search completed, which + // allows another search to be triggered if necessary c.searching.Store(false) - c.mutex.Lock() - defer c.mutex.Unlock() + // If the search failed for some reason, e.g. it hit a dead end or timed + // out, then do nothing if err != nil { c.core.log.Debugln(c.String(), "DHT search failed:", err) - c.expired.Store(true) return } + // Take the connection mutex + c.mutex.Lock() + defer c.mutex.Unlock() + // Were we successfully given a sessionInfo pointeR? if sinfo != nil { + // Store it, and update the nodeID and nodeMask (which may have been + // wildcarded before now) with their complete counterparts c.core.log.Debugln(c.String(), "DHT search completed") c.session = sinfo - c.nodeID, c.nodeMask = sinfo.theirAddr.GetNodeIDandMask() - c.expired.Store(false) + c.nodeID = crypto.GetNodeID(&sinfo.theirPermPub) + for i := range c.nodeMask { + c.nodeMask[i] = 0xFF + } } else { - c.core.log.Debugln(c.String(), "DHT search failed: no session returned") - c.expired.Store(true) - return + // No session was returned - this shouldn't really happen because we + // should always return an error reason if we don't return a session + panic("DHT search didn't return an error or a sessionInfo") } } + // doSearch will be called below in response to one or more conditions doSearch := func() { + // Store the fact that we're searching, so that we don't start additional + // searches until this one has completed c.searching.Store(true) + // Check to see if there is a search already matching the destination sinfo, isIn := c.core.searches.searches[*c.nodeID] if !isIn { + // Nothing was found, so create a new search sinfo = c.core.searches.newIterSearch(c.nodeID, c.nodeMask, searchCompleted) c.core.log.Debugf("%s DHT search started: %p", c.String(), sinfo) } + // Continue the search c.core.searches.continueSearch(sinfo) } + // Take a copy of the session object, in case it changes later c.mutex.RLock() sinfo := c.session c.mutex.RUnlock() if c.session == nil { + // No session object is present so previous searches, if we ran any, have + // not yielded a useful result (dead end, remote host not found) doSearch() } else { sinfo.worker <- func() { @@ -83,10 +101,6 @@ func (c *Conn) startSearch() { } func (c *Conn) Read(b []byte) (int, error) { - // If the session is marked as expired then do nothing at this point - if e, ok := c.expired.Load().(bool); ok && e { - return 0, errors.New("session is closed") - } // Take a copy of the session object c.mutex.RLock() sinfo := c.session @@ -95,17 +109,15 @@ func (c *Conn) Read(b []byte) (int, error) { // in a write, we would trigger a new session, but it doesn't make sense for // us to block forever here if the session will not reopen. // TODO: should this return an error or just a zero-length buffer? - if !sinfo.init { + if sinfo == nil || !sinfo.init { return 0, errors.New("session is closed") } // Wait for some traffic to come through from the session select { // TODO... case p, ok := <-c.recv: - // If the channel was closed then mark the connection as expired, this will - // mean that the next write will start a new search and reopen the session + // If the session is closed then do nothing if !ok { - c.expired.Store(true) return 0, errors.New("session is closed") } defer util.PutBytes(p.Payload) @@ -155,13 +167,9 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { c.mutex.RLock() sinfo := c.session c.mutex.RUnlock() - // Check whether the connection is expired, if it is we can start a new - // search to revive it - expired, eok := c.expired.Load().(bool) // If the session doesn't exist, or isn't initialised (which probably means - // that the session was never set up or it closed by timeout), or the conn - // is marked as expired, then see if we can start a new search - if sinfo == nil || !sinfo.init || (eok && expired) { + // that the search didn't complete successfully) then try to search again + if sinfo == nil || !sinfo.init { // Is a search already taking place? if searching, sok := c.searching.Load().(bool); !sok || (sok && !searching) { // No search was already taking place so start a new one @@ -173,7 +181,7 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { // A search is already taking place so wait for it to finish return 0, errors.New("waiting for search to complete") } - //defer util.PutBytes(b) + // defer util.PutBytes(b) var packet []byte // Hand over to the session worker sinfo.doWorker(func() { @@ -197,11 +205,9 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { } func (c *Conn) Close() error { - // Mark the connection as expired, so that a future read attempt will fail - // and a future write attempt will start a new search - c.expired.Store(true) // Close the session, if it hasn't been closed already c.session.close() + c.session = nil // This can't fail yet - TODO? return nil } diff --git a/src/yggdrasil/session.go b/src/yggdrasil/session.go index e356f61..967d5f6 100644 --- a/src/yggdrasil/session.go +++ b/src/yggdrasil/session.go @@ -107,15 +107,6 @@ func (s *sessionInfo) update(p *sessionPing) bool { return true } -// Returns true if the session has been idle for longer than the allowed timeout. -func (s *sessionInfo) timedout() bool { - var timedout bool - s.doWorker(func() { - timedout = time.Since(s.time) > time.Minute - }) - return timedout -} - // Struct of all active sessions. // Sessions are indexed by handle. // Additionally, stores maps of address/subnet onto keys, and keys onto handles. @@ -233,10 +224,6 @@ func (ss *sessions) isSessionAllowed(pubkey *crypto.BoxPubKey, initiator bool) b // Gets the session corresponding to a given handle. func (ss *sessions) getSessionForHandle(handle *crypto.Handle) (*sessionInfo, bool) { sinfo, isIn := ss.sinfos[*handle] - if isIn && sinfo.timedout() { - // We have a session, but it has timed out - return nil, false - } return sinfo, isIn } @@ -280,8 +267,9 @@ func (ss *sessions) getByTheirSubnet(snet *address.Subnet) (*sessionInfo, bool) return sinfo, isIn } -// Creates a new session and lazily cleans up old/timedout existing sessions. -// This includse initializing session info to sane defaults (e.g. lowest supported MTU). +// Creates a new session and lazily cleans up old existing sessions. This +// includse initializing session info to sane defaults (e.g. lowest supported +// MTU). func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo { if !ss.isSessionAllowed(theirPermKey, true) { return nil @@ -341,11 +329,6 @@ func (ss *sessions) cleanup() { if time.Since(ss.lastCleanup) < time.Minute { return } - for _, s := range ss.sinfos { - if s.timedout() { - s.close() - } - } permShared := make(map[crypto.BoxPubKey]*crypto.BoxSharedKey, len(ss.permShared)) for k, v := range ss.permShared { permShared[k] = v @@ -387,7 +370,6 @@ func (sinfo *sessionInfo) close() { delete(sinfo.core.sessions.addrToPerm, sinfo.theirAddr) delete(sinfo.core.sessions.subnetToPerm, sinfo.theirSubnet) close(sinfo.worker) - sinfo.init = false } // Returns a session ping appropriate for the given session info. @@ -465,17 +447,16 @@ func (ss *sessions) handlePing(ping *sessionPing) { return } } - if !isIn || sinfo.timedout() { - if isIn { - sinfo.close() - } + if !isIn { ss.createSession(&ping.SendPermPub) sinfo, isIn = ss.getByTheirPerm(&ping.SendPermPub) if !isIn { panic("This should not happen") } ss.listenerMutex.Lock() - if ss.listener != nil { + // Check and see if there's a Listener waiting to accept connections + // TODO: this should not block if nothing is accepting + if !ping.IsPong && ss.listener != nil { conn := &Conn{ core: ss.core, session: sinfo, @@ -488,8 +469,6 @@ func (ss *sessions) handlePing(ping *sessionPing) { conn.nodeMask[i] = 0xFF } ss.listener.conn <- conn - } else { - ss.core.log.Debugln("Received new session but there is no listener, ignoring") } ss.listenerMutex.Unlock() }