mirror of
https://github.com/cwinfo/yggdrasil-go.git
synced 2024-11-26 10:41:40 +00:00
commit
bd3b42022b
@ -53,15 +53,14 @@ func (s *tunConn) reader() (err error) {
|
||||
}
|
||||
s.tun.log.Debugln("Starting conn reader for", s.conn.String())
|
||||
defer s.tun.log.Debugln("Stopping conn reader for", s.conn.String())
|
||||
var n int
|
||||
b := make([]byte, 65535)
|
||||
for {
|
||||
select {
|
||||
case <-s.stop:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
if n, err = s.conn.Read(b); err != nil {
|
||||
var bs []byte
|
||||
if bs, err = s.conn.ReadNoCopy(); err != nil {
|
||||
if e, eok := err.(yggdrasil.ConnError); eok && !e.Temporary() {
|
||||
if e.Closed() {
|
||||
s.tun.log.Debugln(s.conn.String(), "TUN/TAP conn read debug:", err)
|
||||
@ -70,14 +69,11 @@ func (s *tunConn) reader() (err error) {
|
||||
}
|
||||
return e
|
||||
}
|
||||
} else if n > 0 {
|
||||
bs := append(util.GetBytes(), b[:n]...)
|
||||
select {
|
||||
case s.tun.send <- bs:
|
||||
default:
|
||||
util.PutBytes(bs)
|
||||
}
|
||||
} else if len(bs) > 0 {
|
||||
s.tun.send <- bs
|
||||
s.stillAlive()
|
||||
} else {
|
||||
util.PutBytes(bs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -96,12 +92,12 @@ func (s *tunConn) writer() error {
|
||||
select {
|
||||
case <-s.stop:
|
||||
return nil
|
||||
case b, ok := <-s.send:
|
||||
case bs, ok := <-s.send:
|
||||
if !ok {
|
||||
return errors.New("send closed")
|
||||
}
|
||||
// TODO write timeout and close
|
||||
if _, err := s.conn.Write(b); err != nil {
|
||||
if err := s.conn.WriteNoCopy(bs); err != nil {
|
||||
if e, eok := err.(yggdrasil.ConnError); !eok {
|
||||
if e.Closed() {
|
||||
s.tun.log.Debugln(s.conn.String(), "TUN/TAP generic write debug:", err)
|
||||
@ -112,9 +108,9 @@ func (s *tunConn) writer() error {
|
||||
// TODO: This currently isn't aware of IPv4 for CKR
|
||||
ptb := &icmp.PacketTooBig{
|
||||
MTU: int(e.PacketMaximumSize()),
|
||||
Data: b[:900],
|
||||
Data: bs[:900],
|
||||
}
|
||||
if packet, err := CreateICMPv6(b[8:24], b[24:40], ipv6.ICMPTypePacketTooBig, 0, ptb); err == nil {
|
||||
if packet, err := CreateICMPv6(bs[8:24], bs[24:40], ipv6.ICMPTypePacketTooBig, 0, ptb); err == nil {
|
||||
s.tun.send <- packet
|
||||
}
|
||||
} else {
|
||||
@ -127,7 +123,6 @@ func (s *tunConn) writer() error {
|
||||
} else {
|
||||
s.stillAlive()
|
||||
}
|
||||
util.PutBytes(b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -139,8 +139,10 @@ func (tun *TunAdapter) readerPacketHandler(ch chan []byte) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Shift forward to avoid leaking bytes off the front of the slide when we eventually store it
|
||||
bs = append(recvd[:0], bs...)
|
||||
if offset != 0 {
|
||||
// Shift forward to avoid leaking bytes off the front of the slice when we eventually store it
|
||||
bs = append(recvd[:0], bs...)
|
||||
}
|
||||
// From the IP header, work out what our source and destination addresses
|
||||
// and node IDs are. We will need these in order to work out where to send
|
||||
// the packet
|
||||
@ -260,11 +262,8 @@ func (tun *TunAdapter) readerPacketHandler(ch chan []byte) {
|
||||
tun.mutex.Unlock()
|
||||
if tc != nil {
|
||||
for _, packet := range packets {
|
||||
select {
|
||||
case tc.send <- packet:
|
||||
default:
|
||||
util.PutBytes(packet)
|
||||
}
|
||||
p := packet // Possibly required because of how range
|
||||
tc.send <- p
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -274,21 +273,18 @@ func (tun *TunAdapter) readerPacketHandler(ch chan []byte) {
|
||||
}
|
||||
// If we have a connection now, try writing to it
|
||||
if isIn && session != nil {
|
||||
select {
|
||||
case session.send <- bs:
|
||||
default:
|
||||
util.PutBytes(bs)
|
||||
}
|
||||
session.send <- bs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *TunAdapter) reader() error {
|
||||
recvd := make([]byte, 65535+tun_ETHER_HEADER_LENGTH)
|
||||
toWorker := make(chan []byte, 32)
|
||||
defer close(toWorker)
|
||||
go tun.readerPacketHandler(toWorker)
|
||||
for {
|
||||
// Get a slice to store the packet in
|
||||
recvd := util.ResizeBytes(util.GetBytes(), 65535+tun_ETHER_HEADER_LENGTH)
|
||||
// Wait for a packet to be delivered to us through the TUN/TAP adapter
|
||||
n, err := tun.iface.Read(recvd)
|
||||
if err != nil {
|
||||
@ -298,9 +294,10 @@ func (tun *TunAdapter) reader() error {
|
||||
panic(err)
|
||||
}
|
||||
if n == 0 {
|
||||
util.PutBytes(recvd)
|
||||
continue
|
||||
}
|
||||
bs := append(util.GetBytes(), recvd[:n]...)
|
||||
toWorker <- bs
|
||||
// Send the packet to the worker
|
||||
toWorker <- recvd[:n]
|
||||
}
|
||||
}
|
||||
|
@ -26,27 +26,25 @@ func UnlockThread() {
|
||||
}
|
||||
|
||||
// This is used to buffer recently used slices of bytes, to prevent allocations in the hot loops.
|
||||
var byteStoreMutex sync.Mutex
|
||||
var byteStore [][]byte
|
||||
var byteStore = sync.Pool{New: func() interface{} { return []byte(nil) }}
|
||||
|
||||
// Gets an empty slice from the byte store.
|
||||
func GetBytes() []byte {
|
||||
byteStoreMutex.Lock()
|
||||
defer byteStoreMutex.Unlock()
|
||||
if len(byteStore) > 0 {
|
||||
var bs []byte
|
||||
bs, byteStore = byteStore[len(byteStore)-1][:0], byteStore[:len(byteStore)-1]
|
||||
return bs
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
return byteStore.Get().([]byte)[:0]
|
||||
}
|
||||
|
||||
// Puts a slice in the store.
|
||||
func PutBytes(bs []byte) {
|
||||
byteStoreMutex.Lock()
|
||||
defer byteStoreMutex.Unlock()
|
||||
byteStore = append(byteStore, bs)
|
||||
byteStore.Put(bs)
|
||||
}
|
||||
|
||||
// Gets a slice of the appropriate length, reusing existing slice capacity when possible
|
||||
func ResizeBytes(bs []byte, length int) []byte {
|
||||
if cap(bs) >= length {
|
||||
return bs[:length]
|
||||
} else {
|
||||
return make([]byte, length)
|
||||
}
|
||||
}
|
||||
|
||||
// This is a workaround to go's broken timer implementation
|
||||
|
29
src/util/workerpool.go
Normal file
29
src/util/workerpool.go
Normal file
@ -0,0 +1,29 @@
|
||||
package util
|
||||
|
||||
import "runtime"
|
||||
|
||||
var workerPool chan func()
|
||||
|
||||
func init() {
|
||||
maxProcs := runtime.GOMAXPROCS(0)
|
||||
if maxProcs < 1 {
|
||||
maxProcs = 1
|
||||
}
|
||||
workerPool = make(chan func(), maxProcs)
|
||||
for idx := 0; idx < maxProcs; idx++ {
|
||||
go func() {
|
||||
for f := range workerPool {
|
||||
f()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// WorkerGo submits a job to a pool of GOMAXPROCS worker goroutines.
|
||||
// This is meant for short non-blocking functions f() where you could just go f(),
|
||||
// but you want some kind of backpressure to prevent spawning endless goroutines.
|
||||
// WorkerGo returns as soon as the function is queued to run, not when it finishes.
|
||||
// In Yggdrasil, these workers are used for certain cryptographic operations.
|
||||
func WorkerGo(f func()) {
|
||||
workerPool <- f
|
||||
}
|
@ -82,7 +82,7 @@ func (c *Conn) String() string {
|
||||
return fmt.Sprintf("conn=%p", c)
|
||||
}
|
||||
|
||||
// This should never be called from the router goroutine
|
||||
// This should never be called from the router goroutine, used in the dial functions
|
||||
func (c *Conn) search() error {
|
||||
var sinfo *searchInfo
|
||||
var isIn bool
|
||||
@ -122,6 +122,23 @@ func (c *Conn) search() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Used in session keep-alive traffic in Conn.Write
|
||||
func (c *Conn) doSearch() {
|
||||
routerWork := func() {
|
||||
// Check to see if there is a search already matching the destination
|
||||
sinfo, isIn := c.core.searches.searches[*c.nodeID]
|
||||
if !isIn {
|
||||
// Nothing was found, so create a new search
|
||||
searchCompleted := func(sinfo *sessionInfo, e error) {}
|
||||
sinfo = c.core.searches.newIterSearch(c.nodeID, c.nodeMask, searchCompleted)
|
||||
c.core.log.Debugf("%s DHT search started: %p", c.String(), sinfo)
|
||||
}
|
||||
// Continue the search
|
||||
sinfo.continueSearch()
|
||||
}
|
||||
go func() { c.core.router.admin <- routerWork }()
|
||||
}
|
||||
|
||||
func (c *Conn) getDeadlineCancellation(value *atomic.Value) util.Cancellation {
|
||||
if deadline, ok := value.Load().(time.Time); ok {
|
||||
// A deadline is set, so return a Cancellation that uses it
|
||||
@ -132,123 +149,90 @@ func (c *Conn) getDeadlineCancellation(value *atomic.Value) util.Cancellation {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
// Take a copy of the session object
|
||||
sinfo := c.session
|
||||
// Used internally by Read, the caller is responsible for util.PutBytes when they're done.
|
||||
func (c *Conn) ReadNoCopy() ([]byte, error) {
|
||||
cancel := c.getDeadlineCancellation(&c.readDeadline)
|
||||
defer cancel.Cancel(nil)
|
||||
var bs []byte
|
||||
for {
|
||||
// Wait for some traffic to come through from the session
|
||||
select {
|
||||
case <-cancel.Finished():
|
||||
if cancel.Error() == util.CancellationTimeoutError {
|
||||
return 0, ConnError{errors.New("read timeout"), true, false, false, 0}
|
||||
} else {
|
||||
return 0, ConnError{errors.New("session closed"), false, false, true, 0}
|
||||
}
|
||||
case p, ok := <-sinfo.recv:
|
||||
// If the session is closed then do nothing
|
||||
if !ok {
|
||||
return 0, ConnError{errors.New("session closed"), false, false, true, 0}
|
||||
}
|
||||
var err error
|
||||
sessionFunc := func() {
|
||||
defer util.PutBytes(p.Payload)
|
||||
// If the nonce is bad then drop the packet and return an error
|
||||
if !sinfo.nonceIsOK(&p.Nonce) {
|
||||
err = ConnError{errors.New("packet dropped due to invalid nonce"), false, true, false, 0}
|
||||
return
|
||||
}
|
||||
// Decrypt the packet
|
||||
var isOK bool
|
||||
bs, isOK = crypto.BoxOpen(&sinfo.sharedSesKey, p.Payload, &p.Nonce)
|
||||
// Check if we were unable to decrypt the packet for some reason and
|
||||
// return an error if we couldn't
|
||||
if !isOK {
|
||||
err = ConnError{errors.New("packet dropped due to decryption failure"), false, true, false, 0}
|
||||
return
|
||||
}
|
||||
// Update the session
|
||||
sinfo.updateNonce(&p.Nonce)
|
||||
sinfo.time = time.Now()
|
||||
sinfo.bytesRecvd += uint64(len(bs))
|
||||
}
|
||||
sinfo.doFunc(sessionFunc)
|
||||
// Something went wrong in the session worker so abort
|
||||
if err != nil {
|
||||
if ce, ok := err.(*ConnError); ok && ce.Temporary() {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
// Copy results to the output slice and clean up
|
||||
copy(b, bs)
|
||||
util.PutBytes(bs)
|
||||
// If we've reached this point then everything went to plan, return the
|
||||
// number of bytes we populated back into the given slice
|
||||
return len(bs), nil
|
||||
// Wait for some traffic to come through from the session
|
||||
select {
|
||||
case <-cancel.Finished():
|
||||
if cancel.Error() == util.CancellationTimeoutError {
|
||||
return nil, ConnError{errors.New("read timeout"), true, false, false, 0}
|
||||
} else {
|
||||
return nil, ConnError{errors.New("session closed"), false, false, true, 0}
|
||||
}
|
||||
case bs := <-c.session.recv:
|
||||
return bs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (bytesWritten int, err error) {
|
||||
sinfo := c.session
|
||||
var packet []byte
|
||||
written := len(b)
|
||||
// Implements net.Conn.Read
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
bs, err := c.ReadNoCopy()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := len(bs)
|
||||
if len(bs) > len(b) {
|
||||
n = len(b)
|
||||
err = ConnError{errors.New("read buffer too small for entire packet"), false, true, false, 0}
|
||||
}
|
||||
// Copy results to the output slice and clean up
|
||||
copy(b, bs)
|
||||
util.PutBytes(bs)
|
||||
// Return the number of bytes copied to the slice, along with any error
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Used internally by Write, the caller must not reuse the argument bytes when no error occurs
|
||||
func (c *Conn) WriteNoCopy(bs []byte) error {
|
||||
var err error
|
||||
sessionFunc := func() {
|
||||
// Does the packet exceed the permitted size for the session?
|
||||
if uint16(len(b)) > sinfo.getMTU() {
|
||||
written, err = 0, ConnError{errors.New("packet too big"), true, false, false, int(sinfo.getMTU())}
|
||||
if uint16(len(bs)) > c.session.getMTU() {
|
||||
err = ConnError{errors.New("packet too big"), true, false, false, int(c.session.getMTU())}
|
||||
return
|
||||
}
|
||||
// Encrypt the packet
|
||||
payload, nonce := crypto.BoxSeal(&sinfo.sharedSesKey, b, &sinfo.myNonce)
|
||||
defer util.PutBytes(payload)
|
||||
// Construct the wire packet to send to the router
|
||||
p := wire_trafficPacket{
|
||||
Coords: sinfo.coords,
|
||||
Handle: sinfo.theirHandle,
|
||||
Nonce: *nonce,
|
||||
Payload: payload,
|
||||
}
|
||||
packet = p.encode()
|
||||
sinfo.bytesSent += uint64(len(b))
|
||||
// The rest of this work is session keep-alive traffic
|
||||
doSearch := func() {
|
||||
routerWork := func() {
|
||||
// Check to see if there is a search already matching the destination
|
||||
sinfo, isIn := c.core.searches.searches[*c.nodeID]
|
||||
if !isIn {
|
||||
// Nothing was found, so create a new search
|
||||
searchCompleted := func(sinfo *sessionInfo, e error) {}
|
||||
sinfo = c.core.searches.newIterSearch(c.nodeID, c.nodeMask, searchCompleted)
|
||||
c.core.log.Debugf("%s DHT search started: %p", c.String(), sinfo)
|
||||
}
|
||||
// Continue the search
|
||||
sinfo.continueSearch()
|
||||
}
|
||||
go func() { c.core.router.admin <- routerWork }()
|
||||
}
|
||||
switch {
|
||||
case time.Since(sinfo.time) > 6*time.Second:
|
||||
if sinfo.time.Before(sinfo.pingTime) && time.Since(sinfo.pingTime) > 6*time.Second {
|
||||
case time.Since(c.session.time) > 6*time.Second:
|
||||
if c.session.time.Before(c.session.pingTime) && time.Since(c.session.pingTime) > 6*time.Second {
|
||||
// TODO double check that the above condition is correct
|
||||
doSearch()
|
||||
c.doSearch()
|
||||
} else {
|
||||
sinfo.core.sessions.ping(sinfo)
|
||||
c.core.sessions.ping(c.session)
|
||||
}
|
||||
case sinfo.reset && sinfo.pingTime.Before(sinfo.time):
|
||||
sinfo.core.sessions.ping(sinfo)
|
||||
case c.session.reset && c.session.pingTime.Before(c.session.time):
|
||||
c.core.sessions.ping(c.session)
|
||||
default: // Don't do anything, to keep traffic throttled
|
||||
}
|
||||
}
|
||||
sinfo.doFunc(sessionFunc)
|
||||
// Give the packet to the router
|
||||
if written > 0 {
|
||||
sinfo.core.router.out(packet)
|
||||
c.session.doFunc(sessionFunc)
|
||||
if err == nil {
|
||||
cancel := c.getDeadlineCancellation(&c.writeDeadline)
|
||||
defer cancel.Cancel(nil)
|
||||
select {
|
||||
case <-cancel.Finished():
|
||||
if cancel.Error() == util.CancellationTimeoutError {
|
||||
err = ConnError{errors.New("write timeout"), true, false, false, 0}
|
||||
} else {
|
||||
err = ConnError{errors.New("session closed"), false, false, true, 0}
|
||||
}
|
||||
case c.session.send <- bs:
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Implements net.Conn.Write
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
written := len(b)
|
||||
bs := append(util.GetBytes(), b...)
|
||||
err := c.WriteNoCopy(bs)
|
||||
if err != nil {
|
||||
util.PutBytes(bs)
|
||||
written = 0
|
||||
}
|
||||
// Finally return the number of bytes we wrote
|
||||
return written, err
|
||||
}
|
||||
|
||||
|
@ -69,6 +69,7 @@ func (d *Dialer) DialByNodeIDandMask(nodeID, nodeMask *crypto.NodeID) (*Conn, er
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-conn.session.init:
|
||||
conn.session.startWorkers(conn.cancel)
|
||||
return conn, nil
|
||||
case <-t.C:
|
||||
conn.Close()
|
||||
|
@ -127,7 +127,6 @@ func (r *router) mainLoop() {
|
||||
r.core.switchTable.doMaintenance()
|
||||
r.core.dht.doMaintenance()
|
||||
r.core.sessions.cleanup()
|
||||
util.GetBytes() // To slowly drain things
|
||||
}
|
||||
case f := <-r.admin:
|
||||
f()
|
||||
@ -166,8 +165,8 @@ func (r *router) handleTraffic(packet []byte) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case sinfo.recv <- &p: // FIXME ideally this should be front drop
|
||||
default:
|
||||
case sinfo.fromRouter <- &p:
|
||||
case <-sinfo.cancel.Finished():
|
||||
util.PutBytes(p.Payload)
|
||||
}
|
||||
}
|
||||
|
@ -6,11 +6,13 @@ package yggdrasil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yggdrasil-network/yggdrasil-go/src/address"
|
||||
"github.com/yggdrasil-network/yggdrasil-go/src/crypto"
|
||||
"github.com/yggdrasil-network/yggdrasil-go/src/util"
|
||||
)
|
||||
|
||||
// All the information we know about an active session.
|
||||
@ -44,8 +46,11 @@ type sessionInfo struct {
|
||||
tstamp int64 // ATOMIC - tstamp from their last session ping, replay attack mitigation
|
||||
bytesSent uint64 // Bytes of real traffic sent in this session
|
||||
bytesRecvd uint64 // Bytes of real traffic received in this session
|
||||
recv chan *wire_trafficPacket // Received packets go here, picked up by the associated Conn
|
||||
fromRouter chan *wire_trafficPacket // Received packets go here, picked up by the associated Conn
|
||||
init chan struct{} // Closed when the first session pong arrives, used to signal that the session is ready for initial use
|
||||
cancel util.Cancellation // Used to terminate workers
|
||||
recv chan []byte
|
||||
send chan []byte
|
||||
}
|
||||
|
||||
func (sinfo *sessionInfo) doFunc(f func()) {
|
||||
@ -222,7 +227,9 @@ func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo {
|
||||
sinfo.myHandle = *crypto.NewHandle()
|
||||
sinfo.theirAddr = *address.AddrForNodeID(crypto.GetNodeID(&sinfo.theirPermPub))
|
||||
sinfo.theirSubnet = *address.SubnetForNodeID(crypto.GetNodeID(&sinfo.theirPermPub))
|
||||
sinfo.recv = make(chan *wire_trafficPacket, 32)
|
||||
sinfo.fromRouter = make(chan *wire_trafficPacket, 1)
|
||||
sinfo.recv = make(chan []byte, 32)
|
||||
sinfo.send = make(chan []byte, 32)
|
||||
ss.sinfos[sinfo.myHandle] = &sinfo
|
||||
ss.byTheirPerm[sinfo.theirPermPub] = &sinfo.myHandle
|
||||
return &sinfo
|
||||
@ -355,6 +362,7 @@ func (ss *sessions) handlePing(ping *sessionPing) {
|
||||
for i := range conn.nodeMask {
|
||||
conn.nodeMask[i] = 0xFF
|
||||
}
|
||||
conn.session.startWorkers(conn.cancel)
|
||||
ss.listener.conn <- conn
|
||||
}
|
||||
ss.listenerMutex.Unlock()
|
||||
@ -418,3 +426,150 @@ func (ss *sessions) reset() {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////////// Worker Functions Below ////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (sinfo *sessionInfo) startWorkers(cancel util.Cancellation) {
|
||||
sinfo.cancel = cancel
|
||||
go sinfo.recvWorker()
|
||||
go sinfo.sendWorker()
|
||||
}
|
||||
|
||||
func (sinfo *sessionInfo) recvWorker() {
|
||||
// TODO move theirNonce etc into a struct that gets stored here, passed in over a channel
|
||||
// Since there's no reason for anywhere else in the session code to need to *read* it...
|
||||
// Only needs to be updated from the outside if a ping resets it...
|
||||
// That would get rid of the need to take a mutex for the sessionFunc
|
||||
var callbacks []chan func()
|
||||
doRecv := func(p *wire_trafficPacket) {
|
||||
var bs []byte
|
||||
var err error
|
||||
var k crypto.BoxSharedKey
|
||||
sessionFunc := func() {
|
||||
if !sinfo.nonceIsOK(&p.Nonce) {
|
||||
err = ConnError{errors.New("packet dropped due to invalid nonce"), false, true, false, 0}
|
||||
return
|
||||
}
|
||||
k = sinfo.sharedSesKey
|
||||
}
|
||||
sinfo.doFunc(sessionFunc)
|
||||
if err != nil {
|
||||
util.PutBytes(p.Payload)
|
||||
return
|
||||
}
|
||||
var isOK bool
|
||||
ch := make(chan func(), 1)
|
||||
poolFunc := func() {
|
||||
bs, isOK = crypto.BoxOpen(&k, p.Payload, &p.Nonce)
|
||||
callback := func() {
|
||||
util.PutBytes(p.Payload)
|
||||
if !isOK {
|
||||
util.PutBytes(bs)
|
||||
return
|
||||
}
|
||||
sessionFunc = func() {
|
||||
if k != sinfo.sharedSesKey || !sinfo.nonceIsOK(&p.Nonce) {
|
||||
// The session updated in the mean time, so return an error
|
||||
err = ConnError{errors.New("session updated during crypto operation"), false, true, false, 0}
|
||||
return
|
||||
}
|
||||
sinfo.updateNonce(&p.Nonce)
|
||||
sinfo.time = time.Now()
|
||||
sinfo.bytesRecvd += uint64(len(bs))
|
||||
}
|
||||
sinfo.doFunc(sessionFunc)
|
||||
if err != nil {
|
||||
// Not sure what else to do with this packet, I guess just drop it
|
||||
util.PutBytes(bs)
|
||||
} else {
|
||||
// Pass the packet to the buffer for Conn.Read
|
||||
sinfo.recv <- bs
|
||||
}
|
||||
}
|
||||
ch <- callback
|
||||
}
|
||||
// Send to the worker and wait for it to finish
|
||||
util.WorkerGo(poolFunc)
|
||||
callbacks = append(callbacks, ch)
|
||||
}
|
||||
for {
|
||||
for len(callbacks) > 0 {
|
||||
select {
|
||||
case f := <-callbacks[0]:
|
||||
callbacks = callbacks[1:]
|
||||
f()
|
||||
case <-sinfo.cancel.Finished():
|
||||
return
|
||||
case p := <-sinfo.fromRouter:
|
||||
doRecv(p)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-sinfo.cancel.Finished():
|
||||
return
|
||||
case p := <-sinfo.fromRouter:
|
||||
doRecv(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sinfo *sessionInfo) sendWorker() {
|
||||
// TODO move info that this worker needs here, send updates via a channel
|
||||
// Otherwise we need to take a mutex to avoid races with update()
|
||||
var callbacks []chan func()
|
||||
doSend := func(bs []byte) {
|
||||
var p wire_trafficPacket
|
||||
var k crypto.BoxSharedKey
|
||||
sessionFunc := func() {
|
||||
sinfo.bytesSent += uint64(len(bs))
|
||||
p = wire_trafficPacket{
|
||||
Coords: append([]byte(nil), sinfo.coords...),
|
||||
Handle: sinfo.theirHandle,
|
||||
Nonce: sinfo.myNonce,
|
||||
}
|
||||
sinfo.myNonce.Increment()
|
||||
k = sinfo.sharedSesKey
|
||||
}
|
||||
// Get the mutex-protected info needed to encrypt the packet
|
||||
sinfo.doFunc(sessionFunc)
|
||||
ch := make(chan func(), 1)
|
||||
poolFunc := func() {
|
||||
// Encrypt the packet
|
||||
p.Payload, _ = crypto.BoxSeal(&k, bs, &p.Nonce)
|
||||
packet := p.encode()
|
||||
// The callback will send the packet
|
||||
callback := func() {
|
||||
// Cleanup
|
||||
util.PutBytes(bs)
|
||||
util.PutBytes(p.Payload)
|
||||
// Send the packet
|
||||
sinfo.core.router.out(packet)
|
||||
}
|
||||
ch <- callback
|
||||
}
|
||||
// Send to the worker and wait for it to finish
|
||||
util.WorkerGo(poolFunc)
|
||||
callbacks = append(callbacks, ch)
|
||||
}
|
||||
for {
|
||||
for len(callbacks) > 0 {
|
||||
select {
|
||||
case f := <-callbacks[0]:
|
||||
callbacks = callbacks[1:]
|
||||
f()
|
||||
case <-sinfo.cancel.Finished():
|
||||
return
|
||||
case bs := <-sinfo.send:
|
||||
doSend(bs)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-sinfo.cancel.Finished():
|
||||
return
|
||||
case bs := <-sinfo.send:
|
||||
doSend(bs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
package yggdrasil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/yggdrasil-network/yggdrasil-go/src/util"
|
||||
)
|
||||
@ -13,9 +15,8 @@ var _ = linkInterfaceMsgIO(&stream{})
|
||||
|
||||
type stream struct {
|
||||
rwc io.ReadWriteCloser
|
||||
inputBuffer []byte // Incoming packet stream
|
||||
frag [2 * streamMsgSize]byte // Temporary data read off the underlying rwc, on its way to the inputBuffer
|
||||
outputBuffer [2 * streamMsgSize]byte // Temporary data about to be written to the rwc
|
||||
inputBuffer *bufio.Reader
|
||||
outputBuffer net.Buffers
|
||||
}
|
||||
|
||||
func (s *stream) close() error {
|
||||
@ -30,19 +31,23 @@ func (s *stream) init(rwc io.ReadWriteCloser) {
|
||||
// TODO have this also do the metadata handshake and create the peer struct
|
||||
s.rwc = rwc
|
||||
// TODO call something to do the metadata exchange
|
||||
s.inputBuffer = bufio.NewReaderSize(s.rwc, 2*streamMsgSize)
|
||||
}
|
||||
|
||||
// writeMsg writes a message with stream padding, and is *not* thread safe.
|
||||
func (s *stream) writeMsg(bs []byte) (int, error) {
|
||||
buf := s.outputBuffer[:0]
|
||||
buf = append(buf, streamMsg[:]...)
|
||||
buf = wire_put_uint64(uint64(len(bs)), buf)
|
||||
padLen := len(buf)
|
||||
buf = append(buf, bs...)
|
||||
buf = append(buf, streamMsg[:])
|
||||
l := wire_put_uint64(uint64(len(bs)), util.GetBytes())
|
||||
defer util.PutBytes(l)
|
||||
buf = append(buf, l)
|
||||
padLen := len(buf[0]) + len(buf[1])
|
||||
buf = append(buf, bs)
|
||||
totalLen := padLen + len(bs)
|
||||
var bn int
|
||||
for bn < len(buf) {
|
||||
n, err := s.rwc.Write(buf[bn:])
|
||||
bn += n
|
||||
for bn < totalLen {
|
||||
n, err := buf.WriteTo(s.rwc)
|
||||
bn += int(n)
|
||||
if err != nil {
|
||||
l := bn - padLen
|
||||
if l < 0 {
|
||||
@ -57,26 +62,11 @@ func (s *stream) writeMsg(bs []byte) (int, error) {
|
||||
// readMsg reads a message from the stream, accounting for stream padding, and is *not* thread safe.
|
||||
func (s *stream) readMsg() ([]byte, error) {
|
||||
for {
|
||||
buf := s.inputBuffer
|
||||
msg, ok, err := stream_chopMsg(&buf)
|
||||
switch {
|
||||
case err != nil:
|
||||
// Something in the stream format is corrupt
|
||||
bs, err := s.readMsgFromBuffer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("message error: %v", err)
|
||||
case ok:
|
||||
// Copy the packet into bs, shift the buffer, and return
|
||||
msg = append(util.GetBytes(), msg...)
|
||||
s.inputBuffer = append(s.inputBuffer[:0], buf...)
|
||||
return msg, nil
|
||||
default:
|
||||
// Wait for the underlying reader to return enough info for us to proceed
|
||||
n, err := s.rwc.Read(s.frag[:])
|
||||
if n > 0 {
|
||||
s.inputBuffer = append(s.inputBuffer, s.frag[:n]...)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bs, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -108,34 +98,30 @@ func (s *stream) _recvMetaBytes() ([]byte, error) {
|
||||
return metaBytes, nil
|
||||
}
|
||||
|
||||
// This takes a pointer to a slice as an argument. It checks if there's a
|
||||
// complete message and, if so, slices out those parts and returns the message,
|
||||
// true, and nil. If there's no error, but also no complete message, it returns
|
||||
// nil, false, and nil. If there's an error, it returns nil, false, and the
|
||||
// error, which the reader then handles (currently, by returning from the
|
||||
// reader, which causes the connection to close).
|
||||
func stream_chopMsg(bs *[]byte) ([]byte, bool, error) {
|
||||
// Returns msg, ok, err
|
||||
if len(*bs) < len(streamMsg) {
|
||||
return nil, false, nil
|
||||
// Reads bytes from the underlying rwc and returns 1 full message
|
||||
func (s *stream) readMsgFromBuffer() ([]byte, error) {
|
||||
pad := streamMsg // Copy
|
||||
_, err := io.ReadFull(s.inputBuffer, pad[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if pad != streamMsg {
|
||||
return nil, errors.New("bad message")
|
||||
}
|
||||
for idx := range streamMsg {
|
||||
if (*bs)[idx] != streamMsg[idx] {
|
||||
return nil, false, errors.New("bad message")
|
||||
lenSlice := make([]byte, 0, 10)
|
||||
// FIXME this nextByte stuff depends on wire.go format, kind of ugly to have it here
|
||||
nextByte := byte(0xff)
|
||||
for nextByte > 127 {
|
||||
nextByte, err = s.inputBuffer.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lenSlice = append(lenSlice, nextByte)
|
||||
}
|
||||
msgLen, msgLenLen := wire_decode_uint64((*bs)[len(streamMsg):])
|
||||
msgLen, _ := wire_decode_uint64(lenSlice)
|
||||
if msgLen > streamMsgSize {
|
||||
return nil, false, errors.New("oversized message")
|
||||
return nil, errors.New("oversized message")
|
||||
}
|
||||
msgBegin := len(streamMsg) + msgLenLen
|
||||
msgEnd := msgBegin + int(msgLen)
|
||||
if msgLenLen == 0 || len(*bs) < msgEnd {
|
||||
// We don't have the full message
|
||||
// Need to buffer this and wait for the rest to come in
|
||||
return nil, false, nil
|
||||
}
|
||||
msg := (*bs)[msgBegin:msgEnd]
|
||||
(*bs) = (*bs)[msgEnd:]
|
||||
return msg, true, nil
|
||||
msg := util.ResizeBytes(util.GetBytes(), int(msgLen))
|
||||
_, err = io.ReadFull(s.inputBuffer, msg)
|
||||
return msg, err
|
||||
}
|
||||
|
@ -814,17 +814,23 @@ func (t *switchTable) doWorker() {
|
||||
go func() {
|
||||
// Keep taking packets from the idle worker and sending them to the above whenever it's idle, keeping anything extra in a (fifo, head-drop) buffer
|
||||
var buf [][]byte
|
||||
var size int
|
||||
for {
|
||||
buf = append(buf, <-t.toRouter)
|
||||
bs := <-t.toRouter
|
||||
size += len(bs)
|
||||
buf = append(buf, bs)
|
||||
for len(buf) > 0 {
|
||||
select {
|
||||
case bs := <-t.toRouter:
|
||||
size += len(bs)
|
||||
buf = append(buf, bs)
|
||||
for len(buf) > 32 {
|
||||
for size > int(t.queueTotalMaxSize) {
|
||||
size -= len(buf[0])
|
||||
util.PutBytes(buf[0])
|
||||
buf = buf[1:]
|
||||
}
|
||||
case sendingToRouter <- buf[0]:
|
||||
size -= len(buf[0])
|
||||
buf = buf[1:]
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user