5
0
mirror of https://github.com/cwinfo/yggdrasil-go.git synced 2024-11-23 06:31:35 +00:00

Break deadlock by creating session recv queue when session is created instead of repointing at search completion, also make expired atomic

This commit is contained in:
Neil Alexander 2019-04-22 11:20:35 +01:00
parent 5a02e2ff44
commit 47eb2fc47f
No known key found for this signature in database
GPG Key ID: A02A2019A2BB0944
3 changed files with 32 additions and 44 deletions

View File

@ -193,12 +193,13 @@ func (tun *TunAdapter) connReader(conn *yggdrasil.Conn) error {
delete(tun.conns, remoteNodeID) delete(tun.conns, remoteNodeID)
tun.mutex.Unlock() tun.mutex.Unlock()
}() }()
tun.log.Debugln("Start connection reader for", conn.String())
b := make([]byte, 65535) b := make([]byte, 65535)
for { for {
n, err := conn.Read(b) n, err := conn.Read(b)
if err != nil { if err != nil {
tun.log.Errorln("TUN/TAP conn read error:", err) tun.log.Errorln("TUN/TAP conn read error:", err)
return err continue
} }
if n == 0 { if n == 0 {
continue continue
@ -209,7 +210,7 @@ func (tun *TunAdapter) connReader(conn *yggdrasil.Conn) error {
continue continue
} }
if w != n { if w != n {
tun.log.Errorln("TUN/TAP iface write len didn't match conn read len") tun.log.Errorln("TUN/TAP iface write mismatch:", w, "bytes written vs", n, "bytes given")
continue continue
} }
} }
@ -220,7 +221,7 @@ func (tun *TunAdapter) ifaceReader() error {
for { for {
n, err := tun.iface.Read(bs) n, err := tun.iface.Read(bs)
if err != nil { if err != nil {
tun.log.Errorln("TUN/TAP iface read error:", err) continue
} }
// Look up if the dstination address is somewhere we already have an // Look up if the dstination address is somewhere we already have an
// open connection to // open connection to
@ -253,6 +254,10 @@ func (tun *TunAdapter) ifaceReader() error {
// Unknown address length or protocol // Unknown address length or protocol
continue continue
} }
if !dstAddr.IsValid() && !dstSnet.IsValid() {
// For now don't deal with any non-Yggdrasil ranges
continue
}
dstNodeID, dstNodeIDMask = dstAddr.GetNodeIDandMask() dstNodeID, dstNodeIDMask = dstAddr.GetNodeIDandMask()
// Do we have an active connection for this node ID? // Do we have an active connection for this node ID?
tun.mutex.Lock() tun.mutex.Lock()
@ -260,10 +265,11 @@ func (tun *TunAdapter) ifaceReader() error {
tun.mutex.Unlock() tun.mutex.Unlock()
w, err := conn.Write(bs) w, err := conn.Write(bs)
if err != nil { if err != nil {
tun.log.Println("TUN/TAP conn write error:", err) tun.log.Errorln("TUN/TAP conn write error:", err)
continue continue
} }
if w != n { if w != n {
tun.log.Errorln("TUN/TAP conn write mismatch:", w, "bytes written vs", n, "bytes given")
continue continue
} }
} else { } else {
@ -273,7 +279,7 @@ func (tun *TunAdapter) ifaceReader() error {
go tun.connReader(&conn) go tun.connReader(&conn)
} else { } else {
tun.mutex.Unlock() tun.mutex.Unlock()
tun.log.Println("TUN/TAP dial error:", err) tun.log.Errorln("TUN/TAP dial error:", err)
} }
} }

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/yggdrasil-network/yggdrasil-go/src/crypto" "github.com/yggdrasil-network/yggdrasil-go/src/crypto"
@ -17,9 +18,9 @@ type Conn struct {
recv chan *wire_trafficPacket // Eventually gets attached to session.recv recv chan *wire_trafficPacket // Eventually gets attached to session.recv
mutex *sync.RWMutex mutex *sync.RWMutex
session *sessionInfo session *sessionInfo
readDeadline time.Time // TODO timer readDeadline atomic.Value // time.Time // TODO timer
writeDeadline time.Time // TODO timer writeDeadline atomic.Value // time.Time // TODO timer
expired bool expired atomic.Value // bool
} }
func (c *Conn) String() string { func (c *Conn) String() string {
@ -32,14 +33,12 @@ func (c *Conn) startSearch() {
if err != nil { if err != nil {
c.core.log.Debugln("DHT search failed:", err) c.core.log.Debugln("DHT search failed:", err)
c.mutex.Lock() c.mutex.Lock()
c.expired = true c.expired.Store(true)
c.mutex.Unlock()
return return
} }
if sinfo != nil { if sinfo != nil {
c.mutex.Lock() c.mutex.Lock()
c.session = sinfo c.session = sinfo
c.session.recv = c.recv
c.nodeID, c.nodeMask = sinfo.theirAddr.GetNodeIDandMask() c.nodeID, c.nodeMask = sinfo.theirAddr.GetNodeIDandMask()
c.mutex.Unlock() c.mutex.Unlock()
} }
@ -75,30 +74,20 @@ func (c *Conn) startSearch() {
} }
func (c *Conn) Read(b []byte) (int, error) { func (c *Conn) Read(b []byte) (int, error) {
err := func() error { if e, ok := c.expired.Load().(bool); ok && e {
return 0, errors.New("session is closed")
}
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() sinfo := c.session
if c.expired { c.mutex.RUnlock()
return errors.New("session is closed")
}
return nil
}()
if err != nil {
return 0, err
}
select { select {
// TODO... // TODO...
case p, ok := <-c.recv: case p, ok := <-c.recv:
if !ok { if !ok {
c.mutex.Lock() c.expired.Store(true)
c.expired = true
c.mutex.Unlock()
return 0, errors.New("session is closed") return 0, errors.New("session is closed")
} }
defer util.PutBytes(p.Payload) defer util.PutBytes(p.Payload)
c.mutex.RLock()
sinfo := c.session
c.mutex.RUnlock()
var err error var err error
sinfo.doWorker(func() { sinfo.doWorker(func() {
if !sinfo.nonceIsOK(&p.Nonce) { if !sinfo.nonceIsOK(&p.Nonce) {
@ -131,19 +120,12 @@ func (c *Conn) Read(b []byte) (int, error) {
} }
func (c *Conn) Write(b []byte) (bytesWritten int, err error) { func (c *Conn) Write(b []byte) (bytesWritten int, err error) {
var sinfo *sessionInfo if e, ok := c.expired.Load().(bool); ok && e {
err = func() error { return 0, errors.New("session is closed")
}
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() sinfo := c.session
if c.expired { c.mutex.RUnlock()
return errors.New("session is closed")
}
sinfo = c.session
return nil
}()
if err != nil {
return 0, err
}
if sinfo == nil { if sinfo == nil {
c.core.router.doAdmin(func() { c.core.router.doAdmin(func() {
c.startSearch() c.startSearch()
@ -173,7 +155,7 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) {
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.expired = true c.expired.Store(true)
c.session.close() c.session.close()
return nil return nil
} }
@ -195,11 +177,11 @@ func (c *Conn) SetDeadline(t time.Time) error {
} }
func (c *Conn) SetReadDeadline(t time.Time) error { func (c *Conn) SetReadDeadline(t time.Time) error {
c.readDeadline = t c.readDeadline.Store(t)
return nil return nil
} }
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline = t c.writeDeadline.Store(t)
return nil return nil
} }

View File

@ -321,6 +321,7 @@ func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo {
sinfo.theirAddr = *address.AddrForNodeID(crypto.GetNodeID(&sinfo.theirPermPub)) sinfo.theirAddr = *address.AddrForNodeID(crypto.GetNodeID(&sinfo.theirPermPub))
sinfo.theirSubnet = *address.SubnetForNodeID(crypto.GetNodeID(&sinfo.theirPermPub)) sinfo.theirSubnet = *address.SubnetForNodeID(crypto.GetNodeID(&sinfo.theirPermPub))
sinfo.worker = make(chan func(), 1) sinfo.worker = make(chan func(), 1)
sinfo.recv = make(chan *wire_trafficPacket, 32)
ss.sinfos[sinfo.myHandle] = &sinfo ss.sinfos[sinfo.myHandle] = &sinfo
ss.byMySes[sinfo.mySesPub] = &sinfo.myHandle ss.byMySes[sinfo.mySesPub] = &sinfo.myHandle
ss.byTheirPerm[sinfo.theirPermPub] = &sinfo.myHandle ss.byTheirPerm[sinfo.theirPermPub] = &sinfo.myHandle
@ -480,12 +481,11 @@ func (ss *sessions) handlePing(ping *sessionPing) {
mutex: &sync.RWMutex{}, mutex: &sync.RWMutex{},
nodeID: crypto.GetNodeID(&sinfo.theirPermPub), nodeID: crypto.GetNodeID(&sinfo.theirPermPub),
nodeMask: &crypto.NodeID{}, nodeMask: &crypto.NodeID{},
recv: make(chan *wire_trafficPacket, 32), recv: sinfo.recv,
} }
for i := range conn.nodeMask { for i := range conn.nodeMask {
conn.nodeMask[i] = 0xFF conn.nodeMask[i] = 0xFF
} }
sinfo.recv = conn.recv
ss.listener.conn <- conn ss.listener.conn <- conn
} else { } else {
ss.core.log.Debugln("Received new session but there is no listener, ignoring") ss.core.log.Debugln("Received new session but there is no listener, ignoring")