diff --git a/src/yggdrasil/conn.go b/src/yggdrasil/conn.go index 08591c4..a019908 100644 --- a/src/yggdrasil/conn.go +++ b/src/yggdrasil/conn.go @@ -11,19 +11,35 @@ import ( "github.com/yggdrasil-network/yggdrasil-go/src/util" ) +// Error implements the net.Error interface +type ConnError struct { + error + timeout bool + temporary bool +} + +func (e *ConnError) Timeout() bool { + return e.timeout +} + +func (e *ConnError) Temporary() bool { + return e.temporary +} + type Conn struct { core *Core nodeID *crypto.NodeID nodeMask *crypto.NodeID mutex sync.RWMutex + closed bool session *sessionInfo - readDeadline atomic.Value // time.Time // TODO timer - writeDeadline atomic.Value // time.Time // TODO timer - searching atomic.Value // bool - searchwait chan struct{} + readDeadline atomic.Value // time.Time // TODO timer + writeDeadline atomic.Value // time.Time // TODO timer + searching atomic.Value // bool + searchwait chan struct{} // Never reset this, it's only used for the initial search } -// TODO func NewConn() that initializes atomic and channel fields so things don't crash or block indefinitely +// TODO func NewConn() that initializes additional fields as needed func newConn(core *Core, nodeID *crypto.NodeID, nodeMask *crypto.NodeID, session *sessionInfo) *Conn { conn := Conn{ core: core, @@ -32,7 +48,6 @@ func newConn(core *Core, nodeID *crypto.NodeID, nodeMask *crypto.NodeID, session session: session, searchwait: make(chan struct{}), } - conn.SetDeadline(time.Time{}) conn.searching.Store(false) return &conn } @@ -45,22 +60,27 @@ func (c *Conn) String() string { func (c *Conn) startSearch() { // The searchCompleted callback is given to the search searchCompleted := func(sinfo *sessionInfo, err error) { - // Make sure that any blocks on read/write operations are lifted - defer func() { - defer func() { recover() }() // In case searchwait was closed by another goroutine - c.searching.Store(false) - close(c.searchwait) // Never reset this to an open channel - }() + defer c.searching.Store(false) // If the search failed for some reason, e.g. it hit a dead end or timed // out, then do nothing if err != nil { c.core.log.Debugln(c.String(), "DHT search failed:", err) + go func() { + time.Sleep(time.Second) + c.mutex.RLock() + closed := c.closed + c.mutex.RUnlock() + if !closed { + // Restart the search, or else Write can stay blocked forever + c.core.router.admin <- c.startSearch + } + }() return } // Take the connection mutex c.mutex.Lock() defer c.mutex.Unlock() - // Were we successfully given a sessionInfo pointeR? + // Were we successfully given a sessionInfo pointer? if sinfo != nil { // Store it, and update the nodeID and nodeMask (which may have been // wildcarded before now) with their complete counterparts @@ -70,11 +90,19 @@ func (c *Conn) startSearch() { for i := range c.nodeMask { c.nodeMask[i] = 0xFF } + // Make sure that any blocks on read/write operations are lifted + defer func() { recover() }() // So duplicate searches don't panic + close(c.searchwait) } else { // No session was returned - this shouldn't really happen because we // should always return an error reason if we don't return a session panic("DHT search didn't return an error or a sessionInfo") } + if c.closed { + // Things were closed before the search returned + // Go ahead and close it again to make sure the session is cleaned up + go c.Close() + } } // doSearch will be called below in response to one or more conditions doSearch := func() { @@ -115,17 +143,30 @@ func (c *Conn) startSearch() { } } +func getDeadlineTimer(value *atomic.Value) *time.Timer { + timer := time.NewTimer(0) + util.TimerStop(timer) + if deadline, ok := value.Load().(time.Time); ok { + timer.Reset(time.Until(deadline)) + } + return timer +} + func (c *Conn) Read(b []byte) (int, error) { // Take a copy of the session object c.mutex.RLock() sinfo := c.session c.mutex.RUnlock() - timer := time.NewTimer(0) - util.TimerStop(timer) + timer := getDeadlineTimer(&c.readDeadline) + defer util.TimerStop(timer) // If there is a search in progress then wait for the result if sinfo == nil { // Wait for the search to complete - <-c.searchwait + select { + case <-c.searchwait: + case <-timer.C: + return 0, ConnError{errors.New("Timeout"), true, false} + } // Retrieve our session info again c.mutex.RLock() sinfo = c.session @@ -146,8 +187,9 @@ func (c *Conn) Read(b []byte) (int, error) { } defer util.PutBytes(p.Payload) var err error - // Hand over to the session worker - sinfo.doWorker(func() { + done := make(chan struct{}) + workerFunc := func() { + defer close(done) // If the nonce is bad then drop the packet and return an error if !sinfo.nonceIsOK(&p.Nonce) { err = errors.New("packet dropped due to invalid nonce") @@ -172,7 +214,18 @@ func (c *Conn) Read(b []byte) (int, error) { sinfo.updateNonce(&p.Nonce) sinfo.time = time.Now() sinfo.bytesRecvd += uint64(len(b)) - }) + } + // Hand over to the session worker + select { // Send to worker + case sinfo.worker <- workerFunc: + case <-timer.C: + return 0, ConnError{errors.New("Timeout"), true, false} + } + select { // Wait for worker to return + case <-done: + case <-timer.C: + return 0, ConnError{errors.New("Timeout"), true, false} + } // Something went wrong in the session worker so abort if err != nil { return 0, err @@ -187,6 +240,8 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { c.mutex.RLock() sinfo := c.session c.mutex.RUnlock() + timer := getDeadlineTimer(&c.writeDeadline) + defer util.TimerStop(timer) // If the session doesn't exist, or isn't initialised (which probably means // that the search didn't complete successfully) then we may need to wait for // the search to complete or start the search again @@ -199,7 +254,11 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { }) } // Wait for the search to complete - <-c.searchwait + select { + case <-c.searchwait: + case <-timer.C: + return 0, ConnError{errors.New("Timeout"), true, false} + } // Retrieve our session info again c.mutex.RLock() sinfo = c.session @@ -213,8 +272,9 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { } // defer util.PutBytes(b) var packet []byte - // Hand over to the session worker - sinfo.doWorker(func() { + done := make(chan struct{}) + workerFunc := func() { + defer close(done) // Encrypt the packet payload, nonce := crypto.BoxSeal(&sinfo.sharedSesKey, b, &sinfo.myNonce) defer util.PutBytes(payload) @@ -227,7 +287,18 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { } packet = p.encode() sinfo.bytesSent += uint64(len(b)) - }) + } + // Hand over to the session worker + select { // Send to worker + case sinfo.worker <- workerFunc: + case <-timer.C: + return 0, ConnError{errors.New("Timeout"), true, false} + } + select { // Wait for worker to return + case <-done: + case <-timer.C: + return 0, ConnError{errors.New("Timeout"), true, false} + } // Give the packet to the router sinfo.core.router.out(packet) // Finally return the number of bytes we wrote @@ -235,10 +306,15 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { } func (c *Conn) Close() error { - // Close the session, if it hasn't been closed already - c.session.close() - c.session = nil + c.mutex.Lock() + defer c.mutex.Unlock() + if c.session != nil { + // Close the session, if it hasn't been closed already + c.session.close() + c.session = nil + } // This can't fail yet - TODO? + c.closed = true return nil }