5
0
mirror of https://github.com/cwinfo/yggdrasil-go.git synced 2024-11-23 05:21:35 +00:00

add a newConn function that returns a pointer to a Conn with atomics properly initialized

This commit is contained in:
Arceliar 2019-04-26 18:07:57 -05:00
parent 75130f7735
commit 0059baf36c
4 changed files with 30 additions and 28 deletions

View File

@ -412,9 +412,9 @@ func (tun *TunAdapter) ifaceReader() error {
// Dial to the remote node // Dial to the remote node
if c, err := tun.dialer.DialByNodeIDandMask(dstNodeID, dstNodeIDMask); err == nil { if c, err := tun.dialer.DialByNodeIDandMask(dstNodeID, dstNodeIDMask); err == nil {
// We've been given a connection so start the connection reader goroutine // We've been given a connection so start the connection reader goroutine
go tun.connReader(&c) go tun.connReader(c)
// Then update our reference to the connection // Then update our reference to the connection
conn, isIn = &c, true conn, isIn = c, true
} else { } else {
// We weren't able to dial for some reason so there's no point in // We weren't able to dial for some reason so there's no point in
// continuing this iteration - skip to the next one // continuing this iteration - skip to the next one

View File

@ -15,12 +15,26 @@ type Conn struct {
core *Core core *Core
nodeID *crypto.NodeID nodeID *crypto.NodeID
nodeMask *crypto.NodeID nodeMask *crypto.NodeID
mutex *sync.RWMutex mutex sync.RWMutex
session *sessionInfo session *sessionInfo
readDeadline atomic.Value // time.Time // TODO timer readDeadline atomic.Value // time.Time // TODO timer
writeDeadline atomic.Value // time.Time // TODO timer writeDeadline atomic.Value // time.Time // TODO timer
searching atomic.Value // bool searching atomic.Value // bool
searchwait chan interface{} searchwait chan struct{}
}
// TODO func NewConn() that initializes atomic and channel fields so things don't crash or block indefinitely
func newConn(core *Core, nodeID *crypto.NodeID, nodeMask *crypto.NodeID, session *sessionInfo) *Conn {
conn := Conn{
core: core,
nodeID: nodeID,
nodeMask: nodeMask,
session: session,
searchwait: make(chan struct{}),
}
conn.SetDeadline(time.Time{})
conn.searching.Store(false)
return &conn
} }
func (c *Conn) String() string { func (c *Conn) String() string {
@ -33,9 +47,9 @@ func (c *Conn) startSearch() {
searchCompleted := func(sinfo *sessionInfo, err error) { searchCompleted := func(sinfo *sessionInfo, err error) {
// Make sure that any blocks on read/write operations are lifted // Make sure that any blocks on read/write operations are lifted
defer func() { defer func() {
defer func() { recover() }() // In case searchwait was closed by another goroutine
c.searching.Store(false) c.searching.Store(false)
close(c.searchwait) close(c.searchwait) // Never reset this to an open channel
c.searchwait = make(chan interface{})
}() }()
// If the search failed for some reason, e.g. it hit a dead end or timed // If the search failed for some reason, e.g. it hit a dead end or timed
// out, then do nothing // out, then do nothing
@ -106,6 +120,8 @@ func (c *Conn) Read(b []byte) (int, error) {
c.mutex.RLock() c.mutex.RLock()
sinfo := c.session sinfo := c.session
c.mutex.RUnlock() c.mutex.RUnlock()
timer := time.NewTimer(0)
util.TimerStop(timer)
// If there is a search in progress then wait for the result // If there is a search in progress then wait for the result
if sinfo == nil { if sinfo == nil {
// Wait for the search to complete // Wait for the search to complete

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/yggdrasil-network/yggdrasil-go/src/crypto" "github.com/yggdrasil-network/yggdrasil-go/src/crypto"
) )
@ -18,7 +17,7 @@ type Dialer struct {
// Dial opens a session to the given node. The first paramter should be "nodeid" // Dial opens a session to the given node. The first paramter should be "nodeid"
// and the second parameter should contain a hexadecimal representation of the // and the second parameter should contain a hexadecimal representation of the
// target node ID. // target node ID.
func (d *Dialer) Dial(network, address string) (Conn, error) { func (d *Dialer) Dial(network, address string) (*Conn, error) {
var nodeID crypto.NodeID var nodeID crypto.NodeID
var nodeMask crypto.NodeID var nodeMask crypto.NodeID
// Process // Process
@ -28,11 +27,11 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
if tokens := strings.Split(address, "/"); len(tokens) == 2 { if tokens := strings.Split(address, "/"); len(tokens) == 2 {
len, err := strconv.Atoi(tokens[1]) len, err := strconv.Atoi(tokens[1])
if err != nil { if err != nil {
return Conn{}, err return nil, err
} }
dest, err := hex.DecodeString(tokens[0]) dest, err := hex.DecodeString(tokens[0])
if err != nil { if err != nil {
return Conn{}, err return nil, err
} }
copy(nodeID[:], dest) copy(nodeID[:], dest)
for idx := 0; idx < len; idx++ { for idx := 0; idx < len; idx++ {
@ -41,7 +40,7 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
} else { } else {
dest, err := hex.DecodeString(tokens[0]) dest, err := hex.DecodeString(tokens[0])
if err != nil { if err != nil {
return Conn{}, err return nil, err
} }
copy(nodeID[:], dest) copy(nodeID[:], dest)
for i := range nodeMask { for i := range nodeMask {
@ -51,19 +50,13 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
return d.DialByNodeIDandMask(&nodeID, &nodeMask) return d.DialByNodeIDandMask(&nodeID, &nodeMask)
default: default:
// An unexpected address type was given, so give up // An unexpected address type was given, so give up
return Conn{}, errors.New("unexpected address type") return nil, errors.New("unexpected address type")
} }
} }
// DialByNodeIDandMask opens a session to the given node based on raw // DialByNodeIDandMask opens a session to the given node based on raw
// NodeID parameters. // NodeID parameters.
func (d *Dialer) DialByNodeIDandMask(nodeID, nodeMask *crypto.NodeID) (Conn, error) { func (d *Dialer) DialByNodeIDandMask(nodeID, nodeMask *crypto.NodeID) (*Conn, error) {
conn := Conn{ conn := newConn(d.core, nodeID, nodeMask, nil)
core: d.core,
mutex: &sync.RWMutex{},
nodeID: nodeID,
nodeMask: nodeMask,
searchwait: make(chan interface{}),
}
return conn, nil return conn, nil
} }

View File

@ -456,14 +456,7 @@ func (ss *sessions) handlePing(ping *sessionPing) {
// Check and see if there's a Listener waiting to accept connections // Check and see if there's a Listener waiting to accept connections
// TODO: this should not block if nothing is accepting // TODO: this should not block if nothing is accepting
if !ping.IsPong && ss.listener != nil { if !ping.IsPong && ss.listener != nil {
conn := &Conn{ conn := newConn(ss.core, crypto.GetNodeID(&sinfo.theirPermPub), &crypto.NodeID{}, sinfo)
core: ss.core,
session: sinfo,
mutex: &sync.RWMutex{},
nodeID: crypto.GetNodeID(&sinfo.theirPermPub),
nodeMask: &crypto.NodeID{},
searchwait: make(chan interface{}),
}
for i := range conn.nodeMask { for i := range conn.nodeMask {
conn.nodeMask[i] = 0xFF conn.nodeMask[i] = 0xFF
} }