5
0
mirror of https://github.com/cwinfo/yggdrasil-go.git synced 2024-11-26 08:21:36 +00:00

Add per-session read/write workers, work in progress, they still unfortunately need to take a mutex for safety

This commit is contained in:
Arceliar 2019-08-03 21:46:18 -05:00
parent a2966291b9
commit df0090e32a
4 changed files with 106 additions and 54 deletions

View File

@ -137,7 +137,6 @@ func (c *Conn) Read(b []byte) (int, error) {
sinfo := c.session sinfo := c.session
cancel := c.getDeadlineCancellation(&c.readDeadline) cancel := c.getDeadlineCancellation(&c.readDeadline)
defer cancel.Cancel(nil) defer cancel.Cancel(nil)
var bs []byte
for { for {
// Wait for some traffic to come through from the session // Wait for some traffic to come through from the session
select { select {
@ -147,54 +146,25 @@ func (c *Conn) Read(b []byte) (int, error) {
} else { } else {
return 0, ConnError{errors.New("session closed"), false, false, true, 0} return 0, ConnError{errors.New("session closed"), false, false, true, 0}
} }
case p, ok := <-sinfo.recv: case bs := <-sinfo.recv:
// If the session is closed then do nothing
if !ok {
return 0, ConnError{errors.New("session closed"), false, false, true, 0}
}
var err error var err error
sessionFunc := func() { n := len(bs)
defer util.PutBytes(p.Payload) if len(bs) > len(b) {
// If the nonce is bad then drop the packet and return an error n = len(b)
if !sinfo.nonceIsOK(&p.Nonce) { err = ConnError{errors.New("read buffer too small for entire packet"), false, true, false, 0}
err = ConnError{errors.New("packet dropped due to invalid nonce"), false, true, false, 0}
return
}
// Decrypt the packet
var isOK bool
bs, isOK = crypto.BoxOpen(&sinfo.sharedSesKey, p.Payload, &p.Nonce)
// Check if we were unable to decrypt the packet for some reason and
// return an error if we couldn't
if !isOK {
err = ConnError{errors.New("packet dropped due to decryption failure"), false, true, false, 0}
return
}
// Update the session
sinfo.updateNonce(&p.Nonce)
sinfo.time = time.Now()
sinfo.bytesRecvd += uint64(len(bs))
}
sinfo.doFunc(sessionFunc)
// Something went wrong in the session worker so abort
if err != nil {
if ce, ok := err.(*ConnError); ok && ce.Temporary() {
continue
}
return 0, err
} }
// Copy results to the output slice and clean up // Copy results to the output slice and clean up
copy(b, bs) copy(b, bs)
util.PutBytes(bs) util.PutBytes(bs)
// If we've reached this point then everything went to plan, return the // If we've reached this point then everything went to plan, return the
// number of bytes we populated back into the given slice // number of bytes we populated back into the given slice
return len(bs), nil return n, err
} }
} }
} }
func (c *Conn) Write(b []byte) (bytesWritten int, err error) { func (c *Conn) Write(b []byte) (bytesWritten int, err error) {
sinfo := c.session sinfo := c.session
var packet []byte
written := len(b) written := len(b)
sessionFunc := func() { sessionFunc := func() {
// Does the packet exceed the permitted size for the session? // Does the packet exceed the permitted size for the session?
@ -202,18 +172,6 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) {
written, err = 0, ConnError{errors.New("packet too big"), true, false, false, int(sinfo.getMTU())} written, err = 0, ConnError{errors.New("packet too big"), true, false, false, int(sinfo.getMTU())}
return return
} }
// Encrypt the packet
payload, nonce := crypto.BoxSeal(&sinfo.sharedSesKey, b, &sinfo.myNonce)
defer util.PutBytes(payload)
// Construct the wire packet to send to the router
p := wire_trafficPacket{
Coords: sinfo.coords,
Handle: sinfo.theirHandle,
Nonce: *nonce,
Payload: payload,
}
packet = p.encode()
sinfo.bytesSent += uint64(len(b))
// The rest of this work is session keep-alive traffic // The rest of this work is session keep-alive traffic
doSearch := func() { doSearch := func() {
routerWork := func() { routerWork := func() {
@ -244,11 +202,10 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) {
} }
} }
sinfo.doFunc(sessionFunc) sinfo.doFunc(sessionFunc)
// Give the packet to the router
if written > 0 { if written > 0 {
sinfo.core.router.out(packet) bs := append(util.GetBytes(), b...)
sinfo.send <- bs
} }
// Finally return the number of bytes we wrote
return written, err return written, err
} }

View File

@ -69,6 +69,7 @@ func (d *Dialer) DialByNodeIDandMask(nodeID, nodeMask *crypto.NodeID) (*Conn, er
defer t.Stop() defer t.Stop()
select { select {
case <-conn.session.init: case <-conn.session.init:
conn.session.startWorkers(conn.cancel)
return conn, nil return conn, nil
case <-t.C: case <-t.C:
conn.Close() conn.Close()

View File

@ -166,7 +166,7 @@ func (r *router) handleTraffic(packet []byte) {
return return
} }
select { select {
case sinfo.recv <- &p: // FIXME ideally this should be front drop case sinfo.fromRouter <- &p: // FIXME ideally this should be front drop
default: default:
util.PutBytes(p.Payload) util.PutBytes(p.Payload)
} }

View File

@ -6,11 +6,13 @@ package yggdrasil
import ( import (
"bytes" "bytes"
"errors"
"sync" "sync"
"time" "time"
"github.com/yggdrasil-network/yggdrasil-go/src/address" "github.com/yggdrasil-network/yggdrasil-go/src/address"
"github.com/yggdrasil-network/yggdrasil-go/src/crypto" "github.com/yggdrasil-network/yggdrasil-go/src/crypto"
"github.com/yggdrasil-network/yggdrasil-go/src/util"
) )
// All the information we know about an active session. // All the information we know about an active session.
@ -44,8 +46,11 @@ type sessionInfo struct {
tstamp int64 // ATOMIC - tstamp from their last session ping, replay attack mitigation tstamp int64 // ATOMIC - tstamp from their last session ping, replay attack mitigation
bytesSent uint64 // Bytes of real traffic sent in this session bytesSent uint64 // Bytes of real traffic sent in this session
bytesRecvd uint64 // Bytes of real traffic received in this session bytesRecvd uint64 // Bytes of real traffic received in this session
recv chan *wire_trafficPacket // Received packets go here, picked up by the associated Conn fromRouter chan *wire_trafficPacket // Received packets go here, picked up by the associated Conn
init chan struct{} // Closed when the first session pong arrives, used to signal that the session is ready for initial use init chan struct{} // Closed when the first session pong arrives, used to signal that the session is ready for initial use
cancel util.Cancellation // Used to terminate workers
recv chan []byte
send chan []byte
} }
func (sinfo *sessionInfo) doFunc(f func()) { func (sinfo *sessionInfo) doFunc(f func()) {
@ -222,7 +227,9 @@ func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo {
sinfo.myHandle = *crypto.NewHandle() sinfo.myHandle = *crypto.NewHandle()
sinfo.theirAddr = *address.AddrForNodeID(crypto.GetNodeID(&sinfo.theirPermPub)) sinfo.theirAddr = *address.AddrForNodeID(crypto.GetNodeID(&sinfo.theirPermPub))
sinfo.theirSubnet = *address.SubnetForNodeID(crypto.GetNodeID(&sinfo.theirPermPub)) sinfo.theirSubnet = *address.SubnetForNodeID(crypto.GetNodeID(&sinfo.theirPermPub))
sinfo.recv = make(chan *wire_trafficPacket, 32) sinfo.fromRouter = make(chan *wire_trafficPacket, 32)
sinfo.recv = make(chan []byte, 32)
sinfo.send = make(chan []byte, 32)
ss.sinfos[sinfo.myHandle] = &sinfo ss.sinfos[sinfo.myHandle] = &sinfo
ss.byTheirPerm[sinfo.theirPermPub] = &sinfo.myHandle ss.byTheirPerm[sinfo.theirPermPub] = &sinfo.myHandle
return &sinfo return &sinfo
@ -355,6 +362,7 @@ func (ss *sessions) handlePing(ping *sessionPing) {
for i := range conn.nodeMask { for i := range conn.nodeMask {
conn.nodeMask[i] = 0xFF conn.nodeMask[i] = 0xFF
} }
conn.session.startWorkers(conn.cancel)
ss.listener.conn <- conn ss.listener.conn <- conn
} }
ss.listenerMutex.Unlock() ss.listenerMutex.Unlock()
@ -418,3 +426,89 @@ func (ss *sessions) reset() {
}) })
} }
} }
////////////////////////////////////////////////////////////////////////////////
//////////////////////////// Worker Functions Below ////////////////////////////
////////////////////////////////////////////////////////////////////////////////
func (sinfo *sessionInfo) startWorkers(cancel util.Cancellation) {
sinfo.cancel = cancel
go sinfo.recvWorker()
go sinfo.sendWorker()
}
func (sinfo *sessionInfo) recvWorker() {
// TODO move theirNonce etc into a struct that gets stored here, passed in over a channel
// Since there's no reason for anywhere else in the session code to need to *read* it...
// Only needs to be updated from the outside if a ping resets it...
// That would get rid of the need to take a mutex for the sessionFunc
for {
select {
case <-sinfo.cancel.Finished():
return
case p := <-sinfo.fromRouter:
var bs []byte
var err error
sessionFunc := func() {
defer util.PutBytes(p.Payload)
// If the nonce is bad then drop the packet and return an error
if !sinfo.nonceIsOK(&p.Nonce) {
err = ConnError{errors.New("packet dropped due to invalid nonce"), false, true, false, 0}
return
}
// Decrypt the packet
var isOK bool
bs, isOK = crypto.BoxOpen(&sinfo.sharedSesKey, p.Payload, &p.Nonce)
// Check if we were unable to decrypt the packet for some reason and
// return an error if we couldn't
if !isOK {
err = ConnError{errors.New("packet dropped due to decryption failure"), false, true, false, 0}
return
}
// Update the session
sinfo.updateNonce(&p.Nonce)
sinfo.time = time.Now()
sinfo.bytesRecvd += uint64(len(bs))
}
sinfo.doFunc(sessionFunc)
if len(bs) > 0 {
if err != nil {
// Bad packet, drop it
util.PutBytes(bs)
} else {
// Pass the packet to the buffer for Conn.Read
sinfo.recv <- bs
}
}
}
}
}
func (sinfo *sessionInfo) sendWorker() {
// TODO move info that this worker needs here, send updates via a channel
// Otherwise we need to take a mutex to avoid races with update()
for {
select {
case <-sinfo.cancel.Finished():
return
case bs := <-sinfo.send:
// TODO
var packet []byte
sessionFunc := func() {
sinfo.bytesSent += uint64(len(bs))
payload, nonce := crypto.BoxSeal(&sinfo.sharedSesKey, bs, &sinfo.myNonce)
defer util.PutBytes(payload)
// Construct the wire packet to send to the router
p := wire_trafficPacket{
Coords: sinfo.coords,
Handle: sinfo.theirHandle,
Nonce: *nonce,
Payload: payload,
}
packet = p.encode()
}
sinfo.doFunc(sessionFunc)
sinfo.core.router.out(packet)
}
}
}