diff --git a/src/yggdrasil/link.go b/src/yggdrasil/link.go index df73cc4..a4a41e7 100644 --- a/src/yggdrasil/link.go +++ b/src/yggdrasil/link.go @@ -25,6 +25,7 @@ type link struct { mutex sync.RWMutex // protects interfaces below interfaces map[linkInfo]*linkInterface tcp tcp // TCP interface support + stopped chan struct{} // TODO timeout (to remove from switch), read from config.ReadTimeout } @@ -70,6 +71,7 @@ func (l *link) init(c *Core) error { l.mutex.Lock() l.interfaces = make(map[linkInfo]*linkInterface) l.mutex.Unlock() + l.stopped = make(chan struct{}) if err := l.tcp.init(l); err != nil { c.log.Errorln("Failed to start TCP interface") @@ -135,6 +137,7 @@ func (l *link) create(msgIO linkInterfaceMsgIO, name, linkType, local, remote st } func (l *link) stop() error { + close(l.stopped) if err := l.tcp.stop(); err != nil { return err } @@ -231,7 +234,18 @@ func (intf *linkInterface) handler() error { go intf.peer.start() intf.reader.Act(nil, intf.reader._read) // Wait for the reader to finish + // TODO find a way to do this without keeping live goroutines around + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-intf.link.stopped: + intf.msgIO.close() + case <-done: + } + }() err = <-intf.reader.err + // TODO don't report an error if it's just a 'use of closed network connection' if err != nil { intf.link.core.log.Infof("Disconnected %s: %s, source %s; error: %s", strings.ToUpper(intf.info.linkType), themString, intf.info.local, err) diff --git a/src/yggdrasil/tcp.go b/src/yggdrasil/tcp.go index 8389ecc..36d8058 100644 --- a/src/yggdrasil/tcp.go +++ b/src/yggdrasil/tcp.go @@ -34,6 +34,7 @@ const tcp_ping_interval = (default_timeout * 2 / 3) // The TCP listener and information about active TCP connections, to avoid duplication. type tcp struct { link *link + waitgroup sync.WaitGroup mutex sync.Mutex // Protecting the below listeners map[string]*TcpListener calls map[string]struct{} @@ -97,9 +98,12 @@ func (t *tcp) init(l *link) error { } func (t *tcp) stop() error { + t.mutex.Lock() for _, listener := range t.listeners { close(listener.Stop) } + t.mutex.Unlock() + t.waitgroup.Wait() return nil } @@ -150,6 +154,7 @@ func (t *tcp) listen(listenaddr string) (*TcpListener, error) { Listener: listener, Stop: make(chan bool), } + t.waitgroup.Add(1) go t.listener(&l, listenaddr) return &l, nil } @@ -159,6 +164,7 @@ func (t *tcp) listen(listenaddr string) (*TcpListener, error) { // Runs the listener, which spawns off goroutines for incoming connections. func (t *tcp) listener(l *TcpListener, listenaddr string) { + defer t.waitgroup.Done() if l == nil { return } @@ -199,8 +205,10 @@ func (t *tcp) listener(l *TcpListener, listenaddr string) { t.link.core.log.Errorln("Failed to accept connection:", err) return } + t.waitgroup.Add(1) go t.handler(sock, true, nil) case <-l.Stop: + // FIXME this races with the goroutine that Accepts a TCP connection, may leak connections when a listener is removed return } } @@ -257,6 +265,7 @@ func (t *tcp) call(saddr string, options interface{}, sintf string) { if err != nil { return } + t.waitgroup.Add(1) t.handler(conn, false, saddr) } else { dst, err := net.ResolveTCPAddr("tcp", saddr) @@ -321,12 +330,14 @@ func (t *tcp) call(saddr string, options interface{}, sintf string) { t.link.core.log.Debugln("Failed to dial TCP:", err) return } + t.waitgroup.Add(1) t.handler(conn, false, nil) } }() } func (t *tcp) handler(sock net.Conn, incoming bool, options interface{}) { + defer t.waitgroup.Done() // Happens after sock.close defer sock.Close() t.setExtraOptions(sock) stream := stream{}