diff --git a/src/util/cancellation.go b/src/util/cancellation.go index fa98008..2a78c19 100644 --- a/src/util/cancellation.go +++ b/src/util/cancellation.go @@ -75,6 +75,8 @@ func CancellationChild(parent Cancellation) Cancellation { return child } +var CancellationTimeoutError = errors.New("timeout") + func CancellationWithTimeout(parent Cancellation, timeout time.Duration) Cancellation { child := CancellationChild(parent) go func() { @@ -83,7 +85,7 @@ func CancellationWithTimeout(parent Cancellation, timeout time.Duration) Cancell select { case <-child.Finished(): case <-timer.C: - child.Cancel(errors.New("timeout")) + child.Cancel(CancellationTimeoutError) } }() return child diff --git a/src/yggdrasil/conn.go b/src/yggdrasil/conn.go index 1d686f8..bc884fb 100644 --- a/src/yggdrasil/conn.go +++ b/src/yggdrasil/conn.go @@ -46,13 +46,13 @@ func (e *ConnError) Closed() bool { type Conn struct { core *Core - nodeID *crypto.NodeID - nodeMask *crypto.NodeID - mutex sync.RWMutex - close chan bool - session *sessionInfo readDeadline atomic.Value // time.Time // TODO timer writeDeadline atomic.Value // time.Time // TODO timer + cancel util.Cancellation + mutex sync.RWMutex // protects the below + nodeID *crypto.NodeID + nodeMask *crypto.NodeID + session *sessionInfo } // TODO func NewConn() that initializes additional fields as needed @@ -62,12 +62,14 @@ func newConn(core *Core, nodeID *crypto.NodeID, nodeMask *crypto.NodeID, session nodeID: nodeID, nodeMask: nodeMask, session: session, - close: make(chan bool), + cancel: util.NewCancellation(), } return &conn } func (c *Conn) String() string { + c.mutex.RLock() + defer c.mutex.RUnlock() return fmt.Sprintf("conn=%p", c) } @@ -111,28 +113,31 @@ func (c *Conn) search() error { return nil } -func getDeadlineTimer(value *atomic.Value) *time.Timer { - timer := time.NewTimer(24 * 365 * time.Hour) // FIXME for some reason setting this to 0 doesn't always let it stop and drain the channel correctly - util.TimerStop(timer) +func (c *Conn) getDeadlineCancellation(value *atomic.Value) util.Cancellation { if deadline, ok := value.Load().(time.Time); ok { - timer.Reset(time.Until(deadline)) + // A deadline is set, so return a Cancellation that uses it + return util.CancellationWithDeadline(c.cancel, deadline) + } else { + // No cancellation was set, so return a child cancellation with no timeout + return util.CancellationChild(c.cancel) } - return timer } func (c *Conn) Read(b []byte) (int, error) { // Take a copy of the session object sinfo := c.session - timer := getDeadlineTimer(&c.readDeadline) - defer util.TimerStop(timer) + cancel := c.getDeadlineCancellation(&c.readDeadline) + defer cancel.Cancel(nil) var bs []byte for { // Wait for some traffic to come through from the session select { - case <-c.close: - return 0, ConnError{errors.New("session closed"), false, false, true, 0} - case <-timer.C: - return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + } else { + return 0, ConnError{errors.New("session closed"), false, false, true, 0} + } case p, ok := <-sinfo.recv: // If the session is closed then do nothing if !ok { @@ -172,18 +177,22 @@ func (c *Conn) Read(b []byte) (int, error) { // Send to worker select { case sinfo.worker <- workerFunc: - case <-c.close: - return 0, ConnError{errors.New("session closed"), false, false, true, 0} - case <-timer.C: - return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + } else { + return 0, ConnError{errors.New("session closed"), false, false, true, 0} + } } // Wait for the worker to finish select { case <-done: // Wait for the worker to finish, failing this can cause memory errors (util.[Get||Put]Bytes stuff) - case <-c.close: - return 0, ConnError{errors.New("session closed"), false, false, true, 0} - case <-timer.C: - return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + } else { + return 0, ConnError{errors.New("session closed"), false, false, true, 0} + } } // Something went wrong in the session worker so abort if err != nil { @@ -256,8 +265,8 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { } } // Set up a timer so this doesn't block forever - timer := getDeadlineTimer(&c.writeDeadline) - defer util.TimerStop(timer) + cancel := c.getDeadlineCancellation(&c.writeDeadline) + defer cancel.Cancel(nil) // Hand over to the session worker defer func() { if recover() != nil { @@ -267,8 +276,12 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { }() // In case we're racing with a close select { // Send to worker case sinfo.worker <- workerFunc: - case <-timer.C: - return 0, ConnError{errors.New("write timeout"), true, false, false, 0} + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + return 0, ConnError{errors.New("write timeout"), true, false, false, 0} + } else { + return 0, ConnError{errors.New("session closed"), false, false, true, 0} + } } // Wait for the worker to finish, otherwise there are memory errors ([Get||Put]Bytes stuff) <-done @@ -287,13 +300,9 @@ func (c *Conn) Close() (err error) { // Close the session, if it hasn't been closed already c.core.router.doAdmin(c.session.close) } - func() { - defer func() { - recover() - err = ConnError{errors.New("close failed, session already closed"), false, false, true, 0} - }() - close(c.close) // Closes reader/writer goroutines - }() + if e := c.cancel.Cancel(errors.New("connection closed")); e != nil { + err = ConnError{errors.New("close failed, session already closed"), false, false, true, 0} + } return }