5
0
mirror of https://github.com/cwinfo/yggdrasil-go.git synced 2024-09-19 21:52:32 +00:00

safer dial timeout handling, in case it was used with a nil context or a context that had no timeout set

This commit is contained in:
Arceliar 2019-10-21 20:47:50 -05:00
parent eccd9a348f
commit 681c8ca6f9

View File

@ -19,15 +19,12 @@ type Dialer struct {
// Dial opens a session to the given node. The first paramter should be "nodeid"
// and the second parameter should contain a hexadecimal representation of the
// target node ID. Internally, it uses DialContext with a 6-second timeout.
// target node ID. It uses DialContext internally.
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
const timeout = 6 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return d.DialContext(ctx, network, address)
return d.DialContext(nil, network, address)
}
// DialContext is used internally by Dial, and should only be used with a context that includes a timeout.
// DialContext is used internally by Dial, and should only be used with a context that includes a timeout. It uses DialByNodeIDandMask internally.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
var nodeID crypto.NodeID
var nodeMask crypto.NodeID
@ -66,7 +63,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
}
// DialByNodeIDandMask opens a session to the given node based on raw
// NodeID parameters.
// NodeID parameters. If ctx is nil or has no timeout, then a default timeout of 6 seconds will apply, beginning *after* the search finishes.
func (d *Dialer) DialByNodeIDandMask(ctx context.Context, nodeID, nodeMask *crypto.NodeID) (net.Conn, error) {
conn := newConn(d.core, nodeID, nodeMask, nil)
if err := conn.search(); err != nil {
@ -75,10 +72,19 @@ func (d *Dialer) DialByNodeIDandMask(ctx context.Context, nodeID, nodeMask *cryp
return nil, err
}
conn.session.setConn(nil, conn)
var c context.Context
var cancel context.CancelFunc
const timeout = 6 * time.Second
if ctx != nil {
c, cancel = context.WithTimeout(ctx, timeout)
} else {
c, cancel = context.WithTimeout(context.Background(), timeout)
}
defer cancel()
select {
case <-conn.session.init:
return conn, nil
case <-ctx.Done():
case <-c.Done():
conn.Close()
return nil, errors.New("session handshake timeout")
}