diff --git a/src/yggdrasil/dialer.go b/src/yggdrasil/dialer.go index e3f24e9..8ef0582 100644 --- a/src/yggdrasil/dialer.go +++ b/src/yggdrasil/dialer.go @@ -19,15 +19,12 @@ type Dialer struct { // Dial opens a session to the given node. The first paramter should be "nodeid" // and the second parameter should contain a hexadecimal representation of the -// target node ID. Internally, it uses DialContext with a 6-second timeout. +// target node ID. It uses DialContext internally. func (d *Dialer) Dial(network, address string) (net.Conn, error) { - const timeout = 6 * time.Second - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return d.DialContext(ctx, network, address) + return d.DialContext(nil, network, address) } -// DialContext is used internally by Dial, and should only be used with a context that includes a timeout. +// DialContext is used internally by Dial, and should only be used with a context that includes a timeout. It uses DialByNodeIDandMask internally. func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { var nodeID crypto.NodeID var nodeMask crypto.NodeID @@ -66,7 +63,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. } // DialByNodeIDandMask opens a session to the given node based on raw -// NodeID parameters. +// NodeID parameters. If ctx is nil or has no timeout, then a default timeout of 6 seconds will apply, beginning *after* the search finishes. func (d *Dialer) DialByNodeIDandMask(ctx context.Context, nodeID, nodeMask *crypto.NodeID) (net.Conn, error) { conn := newConn(d.core, nodeID, nodeMask, nil) if err := conn.search(); err != nil { @@ -75,10 +72,19 @@ func (d *Dialer) DialByNodeIDandMask(ctx context.Context, nodeID, nodeMask *cryp return nil, err } conn.session.setConn(nil, conn) + var c context.Context + var cancel context.CancelFunc + const timeout = 6 * time.Second + if ctx != nil { + c, cancel = context.WithTimeout(ctx, timeout) + } else { + c, cancel = context.WithTimeout(context.Background(), timeout) + } + defer cancel() select { case <-conn.session.init: return conn, nil - case <-ctx.Done(): + case <-c.Done(): conn.Close() return nil, errors.New("session handshake timeout") }