5
0
mirror of https://github.com/cwinfo/yggdrasil-go.git synced 2024-11-22 21:10:29 +00:00

Merge pull request #480 from Arceliar/speedup

Speedup
This commit is contained in:
Neil Alexander 2019-08-05 10:24:54 +01:00 committed by GitHub
commit bd3b42022b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 353 additions and 203 deletions

View File

@ -53,15 +53,14 @@ func (s *tunConn) reader() (err error) {
} }
s.tun.log.Debugln("Starting conn reader for", s.conn.String()) s.tun.log.Debugln("Starting conn reader for", s.conn.String())
defer s.tun.log.Debugln("Stopping conn reader for", s.conn.String()) defer s.tun.log.Debugln("Stopping conn reader for", s.conn.String())
var n int
b := make([]byte, 65535)
for { for {
select { select {
case <-s.stop: case <-s.stop:
return nil return nil
default: default:
} }
if n, err = s.conn.Read(b); err != nil { var bs []byte
if bs, err = s.conn.ReadNoCopy(); err != nil {
if e, eok := err.(yggdrasil.ConnError); eok && !e.Temporary() { if e, eok := err.(yggdrasil.ConnError); eok && !e.Temporary() {
if e.Closed() { if e.Closed() {
s.tun.log.Debugln(s.conn.String(), "TUN/TAP conn read debug:", err) s.tun.log.Debugln(s.conn.String(), "TUN/TAP conn read debug:", err)
@ -70,14 +69,11 @@ func (s *tunConn) reader() (err error) {
} }
return e return e
} }
} else if n > 0 { } else if len(bs) > 0 {
bs := append(util.GetBytes(), b[:n]...) s.tun.send <- bs
select {
case s.tun.send <- bs:
default:
util.PutBytes(bs)
}
s.stillAlive() s.stillAlive()
} else {
util.PutBytes(bs)
} }
} }
} }
@ -96,12 +92,12 @@ func (s *tunConn) writer() error {
select { select {
case <-s.stop: case <-s.stop:
return nil return nil
case b, ok := <-s.send: case bs, ok := <-s.send:
if !ok { if !ok {
return errors.New("send closed") return errors.New("send closed")
} }
// TODO write timeout and close // TODO write timeout and close
if _, err := s.conn.Write(b); err != nil { if err := s.conn.WriteNoCopy(bs); err != nil {
if e, eok := err.(yggdrasil.ConnError); !eok { if e, eok := err.(yggdrasil.ConnError); !eok {
if e.Closed() { if e.Closed() {
s.tun.log.Debugln(s.conn.String(), "TUN/TAP generic write debug:", err) s.tun.log.Debugln(s.conn.String(), "TUN/TAP generic write debug:", err)
@ -112,9 +108,9 @@ func (s *tunConn) writer() error {
// TODO: This currently isn't aware of IPv4 for CKR // TODO: This currently isn't aware of IPv4 for CKR
ptb := &icmp.PacketTooBig{ ptb := &icmp.PacketTooBig{
MTU: int(e.PacketMaximumSize()), MTU: int(e.PacketMaximumSize()),
Data: b[:900], Data: bs[:900],
} }
if packet, err := CreateICMPv6(b[8:24], b[24:40], ipv6.ICMPTypePacketTooBig, 0, ptb); err == nil { if packet, err := CreateICMPv6(bs[8:24], bs[24:40], ipv6.ICMPTypePacketTooBig, 0, ptb); err == nil {
s.tun.send <- packet s.tun.send <- packet
} }
} else { } else {
@ -127,7 +123,6 @@ func (s *tunConn) writer() error {
} else { } else {
s.stillAlive() s.stillAlive()
} }
util.PutBytes(b)
} }
} }
} }

View File

@ -139,8 +139,10 @@ func (tun *TunAdapter) readerPacketHandler(ch chan []byte) {
continue continue
} }
} }
// Shift forward to avoid leaking bytes off the front of the slide when we eventually store it if offset != 0 {
bs = append(recvd[:0], bs...) // Shift forward to avoid leaking bytes off the front of the slice when we eventually store it
bs = append(recvd[:0], bs...)
}
// From the IP header, work out what our source and destination addresses // From the IP header, work out what our source and destination addresses
// and node IDs are. We will need these in order to work out where to send // and node IDs are. We will need these in order to work out where to send
// the packet // the packet
@ -260,11 +262,8 @@ func (tun *TunAdapter) readerPacketHandler(ch chan []byte) {
tun.mutex.Unlock() tun.mutex.Unlock()
if tc != nil { if tc != nil {
for _, packet := range packets { for _, packet := range packets {
select { p := packet // Possibly required because of how range
case tc.send <- packet: tc.send <- p
default:
util.PutBytes(packet)
}
} }
} }
}() }()
@ -274,21 +273,18 @@ func (tun *TunAdapter) readerPacketHandler(ch chan []byte) {
} }
// If we have a connection now, try writing to it // If we have a connection now, try writing to it
if isIn && session != nil { if isIn && session != nil {
select { session.send <- bs
case session.send <- bs:
default:
util.PutBytes(bs)
}
} }
} }
} }
func (tun *TunAdapter) reader() error { func (tun *TunAdapter) reader() error {
recvd := make([]byte, 65535+tun_ETHER_HEADER_LENGTH)
toWorker := make(chan []byte, 32) toWorker := make(chan []byte, 32)
defer close(toWorker) defer close(toWorker)
go tun.readerPacketHandler(toWorker) go tun.readerPacketHandler(toWorker)
for { for {
// Get a slice to store the packet in
recvd := util.ResizeBytes(util.GetBytes(), 65535+tun_ETHER_HEADER_LENGTH)
// Wait for a packet to be delivered to us through the TUN/TAP adapter // Wait for a packet to be delivered to us through the TUN/TAP adapter
n, err := tun.iface.Read(recvd) n, err := tun.iface.Read(recvd)
if err != nil { if err != nil {
@ -298,9 +294,10 @@ func (tun *TunAdapter) reader() error {
panic(err) panic(err)
} }
if n == 0 { if n == 0 {
util.PutBytes(recvd)
continue continue
} }
bs := append(util.GetBytes(), recvd[:n]...) // Send the packet to the worker
toWorker <- bs toWorker <- recvd[:n]
} }
} }

View File

@ -26,27 +26,25 @@ func UnlockThread() {
} }
// This is used to buffer recently used slices of bytes, to prevent allocations in the hot loops. // This is used to buffer recently used slices of bytes, to prevent allocations in the hot loops.
var byteStoreMutex sync.Mutex var byteStore = sync.Pool{New: func() interface{} { return []byte(nil) }}
var byteStore [][]byte
// Gets an empty slice from the byte store. // Gets an empty slice from the byte store.
func GetBytes() []byte { func GetBytes() []byte {
byteStoreMutex.Lock() return byteStore.Get().([]byte)[:0]
defer byteStoreMutex.Unlock()
if len(byteStore) > 0 {
var bs []byte
bs, byteStore = byteStore[len(byteStore)-1][:0], byteStore[:len(byteStore)-1]
return bs
} else {
return nil
}
} }
// Puts a slice in the store. // Puts a slice in the store.
func PutBytes(bs []byte) { func PutBytes(bs []byte) {
byteStoreMutex.Lock() byteStore.Put(bs)
defer byteStoreMutex.Unlock() }
byteStore = append(byteStore, bs)
// Gets a slice of the appropriate length, reusing existing slice capacity when possible
func ResizeBytes(bs []byte, length int) []byte {
if cap(bs) >= length {
return bs[:length]
} else {
return make([]byte, length)
}
} }
// This is a workaround to go's broken timer implementation // This is a workaround to go's broken timer implementation

29
src/util/workerpool.go Normal file
View File

@ -0,0 +1,29 @@
package util
import "runtime"
var workerPool chan func()
func init() {
maxProcs := runtime.GOMAXPROCS(0)
if maxProcs < 1 {
maxProcs = 1
}
workerPool = make(chan func(), maxProcs)
for idx := 0; idx < maxProcs; idx++ {
go func() {
for f := range workerPool {
f()
}
}()
}
}
// WorkerGo submits a job to a pool of GOMAXPROCS worker goroutines.
// This is meant for short non-blocking functions f() where you could just go f(),
// but you want some kind of backpressure to prevent spawning endless goroutines.
// WorkerGo returns as soon as the function is queued to run, not when it finishes.
// In Yggdrasil, these workers are used for certain cryptographic operations.
func WorkerGo(f func()) {
workerPool <- f
}

View File

@ -82,7 +82,7 @@ func (c *Conn) String() string {
return fmt.Sprintf("conn=%p", c) return fmt.Sprintf("conn=%p", c)
} }
// This should never be called from the router goroutine // This should never be called from the router goroutine, used in the dial functions
func (c *Conn) search() error { func (c *Conn) search() error {
var sinfo *searchInfo var sinfo *searchInfo
var isIn bool var isIn bool
@ -122,6 +122,23 @@ func (c *Conn) search() error {
return nil return nil
} }
// Used in session keep-alive traffic in Conn.Write
func (c *Conn) doSearch() {
routerWork := func() {
// Check to see if there is a search already matching the destination
sinfo, isIn := c.core.searches.searches[*c.nodeID]
if !isIn {
// Nothing was found, so create a new search
searchCompleted := func(sinfo *sessionInfo, e error) {}
sinfo = c.core.searches.newIterSearch(c.nodeID, c.nodeMask, searchCompleted)
c.core.log.Debugf("%s DHT search started: %p", c.String(), sinfo)
}
// Continue the search
sinfo.continueSearch()
}
go func() { c.core.router.admin <- routerWork }()
}
func (c *Conn) getDeadlineCancellation(value *atomic.Value) util.Cancellation { func (c *Conn) getDeadlineCancellation(value *atomic.Value) util.Cancellation {
if deadline, ok := value.Load().(time.Time); ok { if deadline, ok := value.Load().(time.Time); ok {
// A deadline is set, so return a Cancellation that uses it // A deadline is set, so return a Cancellation that uses it
@ -132,123 +149,90 @@ func (c *Conn) getDeadlineCancellation(value *atomic.Value) util.Cancellation {
} }
} }
func (c *Conn) Read(b []byte) (int, error) { // Used internally by Read, the caller is responsible for util.PutBytes when they're done.
// Take a copy of the session object func (c *Conn) ReadNoCopy() ([]byte, error) {
sinfo := c.session
cancel := c.getDeadlineCancellation(&c.readDeadline) cancel := c.getDeadlineCancellation(&c.readDeadline)
defer cancel.Cancel(nil) defer cancel.Cancel(nil)
var bs []byte // Wait for some traffic to come through from the session
for { select {
// Wait for some traffic to come through from the session case <-cancel.Finished():
select { if cancel.Error() == util.CancellationTimeoutError {
case <-cancel.Finished(): return nil, ConnError{errors.New("read timeout"), true, false, false, 0}
if cancel.Error() == util.CancellationTimeoutError { } else {
return 0, ConnError{errors.New("read timeout"), true, false, false, 0} return nil, ConnError{errors.New("session closed"), false, false, true, 0}
} else {
return 0, ConnError{errors.New("session closed"), false, false, true, 0}
}
case p, ok := <-sinfo.recv:
// If the session is closed then do nothing
if !ok {
return 0, ConnError{errors.New("session closed"), false, false, true, 0}
}
var err error
sessionFunc := func() {
defer util.PutBytes(p.Payload)
// If the nonce is bad then drop the packet and return an error
if !sinfo.nonceIsOK(&p.Nonce) {
err = ConnError{errors.New("packet dropped due to invalid nonce"), false, true, false, 0}
return
}
// Decrypt the packet
var isOK bool
bs, isOK = crypto.BoxOpen(&sinfo.sharedSesKey, p.Payload, &p.Nonce)
// Check if we were unable to decrypt the packet for some reason and
// return an error if we couldn't
if !isOK {
err = ConnError{errors.New("packet dropped due to decryption failure"), false, true, false, 0}
return
}
// Update the session
sinfo.updateNonce(&p.Nonce)
sinfo.time = time.Now()
sinfo.bytesRecvd += uint64(len(bs))
}
sinfo.doFunc(sessionFunc)
// Something went wrong in the session worker so abort
if err != nil {
if ce, ok := err.(*ConnError); ok && ce.Temporary() {
continue
}
return 0, err
}
// Copy results to the output slice and clean up
copy(b, bs)
util.PutBytes(bs)
// If we've reached this point then everything went to plan, return the
// number of bytes we populated back into the given slice
return len(bs), nil
} }
case bs := <-c.session.recv:
return bs, nil
} }
} }
func (c *Conn) Write(b []byte) (bytesWritten int, err error) { // Implements net.Conn.Read
sinfo := c.session func (c *Conn) Read(b []byte) (int, error) {
var packet []byte bs, err := c.ReadNoCopy()
written := len(b) if err != nil {
return 0, err
}
n := len(bs)
if len(bs) > len(b) {
n = len(b)
err = ConnError{errors.New("read buffer too small for entire packet"), false, true, false, 0}
}
// Copy results to the output slice and clean up
copy(b, bs)
util.PutBytes(bs)
// Return the number of bytes copied to the slice, along with any error
return n, err
}
// Used internally by Write, the caller must not reuse the argument bytes when no error occurs
func (c *Conn) WriteNoCopy(bs []byte) error {
var err error
sessionFunc := func() { sessionFunc := func() {
// Does the packet exceed the permitted size for the session? // Does the packet exceed the permitted size for the session?
if uint16(len(b)) > sinfo.getMTU() { if uint16(len(bs)) > c.session.getMTU() {
written, err = 0, ConnError{errors.New("packet too big"), true, false, false, int(sinfo.getMTU())} err = ConnError{errors.New("packet too big"), true, false, false, int(c.session.getMTU())}
return return
} }
// Encrypt the packet
payload, nonce := crypto.BoxSeal(&sinfo.sharedSesKey, b, &sinfo.myNonce)
defer util.PutBytes(payload)
// Construct the wire packet to send to the router
p := wire_trafficPacket{
Coords: sinfo.coords,
Handle: sinfo.theirHandle,
Nonce: *nonce,
Payload: payload,
}
packet = p.encode()
sinfo.bytesSent += uint64(len(b))
// The rest of this work is session keep-alive traffic // The rest of this work is session keep-alive traffic
doSearch := func() {
routerWork := func() {
// Check to see if there is a search already matching the destination
sinfo, isIn := c.core.searches.searches[*c.nodeID]
if !isIn {
// Nothing was found, so create a new search
searchCompleted := func(sinfo *sessionInfo, e error) {}
sinfo = c.core.searches.newIterSearch(c.nodeID, c.nodeMask, searchCompleted)
c.core.log.Debugf("%s DHT search started: %p", c.String(), sinfo)
}
// Continue the search
sinfo.continueSearch()
}
go func() { c.core.router.admin <- routerWork }()
}
switch { switch {
case time.Since(sinfo.time) > 6*time.Second: case time.Since(c.session.time) > 6*time.Second:
if sinfo.time.Before(sinfo.pingTime) && time.Since(sinfo.pingTime) > 6*time.Second { if c.session.time.Before(c.session.pingTime) && time.Since(c.session.pingTime) > 6*time.Second {
// TODO double check that the above condition is correct // TODO double check that the above condition is correct
doSearch() c.doSearch()
} else { } else {
sinfo.core.sessions.ping(sinfo) c.core.sessions.ping(c.session)
} }
case sinfo.reset && sinfo.pingTime.Before(sinfo.time): case c.session.reset && c.session.pingTime.Before(c.session.time):
sinfo.core.sessions.ping(sinfo) c.core.sessions.ping(c.session)
default: // Don't do anything, to keep traffic throttled default: // Don't do anything, to keep traffic throttled
} }
} }
sinfo.doFunc(sessionFunc) c.session.doFunc(sessionFunc)
// Give the packet to the router if err == nil {
if written > 0 { cancel := c.getDeadlineCancellation(&c.writeDeadline)
sinfo.core.router.out(packet) defer cancel.Cancel(nil)
select {
case <-cancel.Finished():
if cancel.Error() == util.CancellationTimeoutError {
err = ConnError{errors.New("write timeout"), true, false, false, 0}
} else {
err = ConnError{errors.New("session closed"), false, false, true, 0}
}
case c.session.send <- bs:
}
}
return err
}
// Implements net.Conn.Write
func (c *Conn) Write(b []byte) (int, error) {
written := len(b)
bs := append(util.GetBytes(), b...)
err := c.WriteNoCopy(bs)
if err != nil {
util.PutBytes(bs)
written = 0
} }
// Finally return the number of bytes we wrote
return written, err return written, err
} }

View File

@ -69,6 +69,7 @@ func (d *Dialer) DialByNodeIDandMask(nodeID, nodeMask *crypto.NodeID) (*Conn, er
defer t.Stop() defer t.Stop()
select { select {
case <-conn.session.init: case <-conn.session.init:
conn.session.startWorkers(conn.cancel)
return conn, nil return conn, nil
case <-t.C: case <-t.C:
conn.Close() conn.Close()

View File

@ -127,7 +127,6 @@ func (r *router) mainLoop() {
r.core.switchTable.doMaintenance() r.core.switchTable.doMaintenance()
r.core.dht.doMaintenance() r.core.dht.doMaintenance()
r.core.sessions.cleanup() r.core.sessions.cleanup()
util.GetBytes() // To slowly drain things
} }
case f := <-r.admin: case f := <-r.admin:
f() f()
@ -166,8 +165,8 @@ func (r *router) handleTraffic(packet []byte) {
return return
} }
select { select {
case sinfo.recv <- &p: // FIXME ideally this should be front drop case sinfo.fromRouter <- &p:
default: case <-sinfo.cancel.Finished():
util.PutBytes(p.Payload) util.PutBytes(p.Payload)
} }
} }

View File

@ -6,11 +6,13 @@ package yggdrasil
import ( import (
"bytes" "bytes"
"errors"
"sync" "sync"
"time" "time"
"github.com/yggdrasil-network/yggdrasil-go/src/address" "github.com/yggdrasil-network/yggdrasil-go/src/address"
"github.com/yggdrasil-network/yggdrasil-go/src/crypto" "github.com/yggdrasil-network/yggdrasil-go/src/crypto"
"github.com/yggdrasil-network/yggdrasil-go/src/util"
) )
// All the information we know about an active session. // All the information we know about an active session.
@ -44,8 +46,11 @@ type sessionInfo struct {
tstamp int64 // ATOMIC - tstamp from their last session ping, replay attack mitigation tstamp int64 // ATOMIC - tstamp from their last session ping, replay attack mitigation
bytesSent uint64 // Bytes of real traffic sent in this session bytesSent uint64 // Bytes of real traffic sent in this session
bytesRecvd uint64 // Bytes of real traffic received in this session bytesRecvd uint64 // Bytes of real traffic received in this session
recv chan *wire_trafficPacket // Received packets go here, picked up by the associated Conn fromRouter chan *wire_trafficPacket // Received packets go here, picked up by the associated Conn
init chan struct{} // Closed when the first session pong arrives, used to signal that the session is ready for initial use init chan struct{} // Closed when the first session pong arrives, used to signal that the session is ready for initial use
cancel util.Cancellation // Used to terminate workers
recv chan []byte
send chan []byte
} }
func (sinfo *sessionInfo) doFunc(f func()) { func (sinfo *sessionInfo) doFunc(f func()) {
@ -222,7 +227,9 @@ func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo {
sinfo.myHandle = *crypto.NewHandle() sinfo.myHandle = *crypto.NewHandle()
sinfo.theirAddr = *address.AddrForNodeID(crypto.GetNodeID(&sinfo.theirPermPub)) sinfo.theirAddr = *address.AddrForNodeID(crypto.GetNodeID(&sinfo.theirPermPub))
sinfo.theirSubnet = *address.SubnetForNodeID(crypto.GetNodeID(&sinfo.theirPermPub)) sinfo.theirSubnet = *address.SubnetForNodeID(crypto.GetNodeID(&sinfo.theirPermPub))
sinfo.recv = make(chan *wire_trafficPacket, 32) sinfo.fromRouter = make(chan *wire_trafficPacket, 1)
sinfo.recv = make(chan []byte, 32)
sinfo.send = make(chan []byte, 32)
ss.sinfos[sinfo.myHandle] = &sinfo ss.sinfos[sinfo.myHandle] = &sinfo
ss.byTheirPerm[sinfo.theirPermPub] = &sinfo.myHandle ss.byTheirPerm[sinfo.theirPermPub] = &sinfo.myHandle
return &sinfo return &sinfo
@ -355,6 +362,7 @@ func (ss *sessions) handlePing(ping *sessionPing) {
for i := range conn.nodeMask { for i := range conn.nodeMask {
conn.nodeMask[i] = 0xFF conn.nodeMask[i] = 0xFF
} }
conn.session.startWorkers(conn.cancel)
ss.listener.conn <- conn ss.listener.conn <- conn
} }
ss.listenerMutex.Unlock() ss.listenerMutex.Unlock()
@ -418,3 +426,150 @@ func (ss *sessions) reset() {
}) })
} }
} }
////////////////////////////////////////////////////////////////////////////////
//////////////////////////// Worker Functions Below ////////////////////////////
////////////////////////////////////////////////////////////////////////////////
func (sinfo *sessionInfo) startWorkers(cancel util.Cancellation) {
sinfo.cancel = cancel
go sinfo.recvWorker()
go sinfo.sendWorker()
}
func (sinfo *sessionInfo) recvWorker() {
// TODO move theirNonce etc into a struct that gets stored here, passed in over a channel
// Since there's no reason for anywhere else in the session code to need to *read* it...
// Only needs to be updated from the outside if a ping resets it...
// That would get rid of the need to take a mutex for the sessionFunc
var callbacks []chan func()
doRecv := func(p *wire_trafficPacket) {
var bs []byte
var err error
var k crypto.BoxSharedKey
sessionFunc := func() {
if !sinfo.nonceIsOK(&p.Nonce) {
err = ConnError{errors.New("packet dropped due to invalid nonce"), false, true, false, 0}
return
}
k = sinfo.sharedSesKey
}
sinfo.doFunc(sessionFunc)
if err != nil {
util.PutBytes(p.Payload)
return
}
var isOK bool
ch := make(chan func(), 1)
poolFunc := func() {
bs, isOK = crypto.BoxOpen(&k, p.Payload, &p.Nonce)
callback := func() {
util.PutBytes(p.Payload)
if !isOK {
util.PutBytes(bs)
return
}
sessionFunc = func() {
if k != sinfo.sharedSesKey || !sinfo.nonceIsOK(&p.Nonce) {
// The session updated in the mean time, so return an error
err = ConnError{errors.New("session updated during crypto operation"), false, true, false, 0}
return
}
sinfo.updateNonce(&p.Nonce)
sinfo.time = time.Now()
sinfo.bytesRecvd += uint64(len(bs))
}
sinfo.doFunc(sessionFunc)
if err != nil {
// Not sure what else to do with this packet, I guess just drop it
util.PutBytes(bs)
} else {
// Pass the packet to the buffer for Conn.Read
sinfo.recv <- bs
}
}
ch <- callback
}
// Send to the worker and wait for it to finish
util.WorkerGo(poolFunc)
callbacks = append(callbacks, ch)
}
for {
for len(callbacks) > 0 {
select {
case f := <-callbacks[0]:
callbacks = callbacks[1:]
f()
case <-sinfo.cancel.Finished():
return
case p := <-sinfo.fromRouter:
doRecv(p)
}
}
select {
case <-sinfo.cancel.Finished():
return
case p := <-sinfo.fromRouter:
doRecv(p)
}
}
}
func (sinfo *sessionInfo) sendWorker() {
// TODO move info that this worker needs here, send updates via a channel
// Otherwise we need to take a mutex to avoid races with update()
var callbacks []chan func()
doSend := func(bs []byte) {
var p wire_trafficPacket
var k crypto.BoxSharedKey
sessionFunc := func() {
sinfo.bytesSent += uint64(len(bs))
p = wire_trafficPacket{
Coords: append([]byte(nil), sinfo.coords...),
Handle: sinfo.theirHandle,
Nonce: sinfo.myNonce,
}
sinfo.myNonce.Increment()
k = sinfo.sharedSesKey
}
// Get the mutex-protected info needed to encrypt the packet
sinfo.doFunc(sessionFunc)
ch := make(chan func(), 1)
poolFunc := func() {
// Encrypt the packet
p.Payload, _ = crypto.BoxSeal(&k, bs, &p.Nonce)
packet := p.encode()
// The callback will send the packet
callback := func() {
// Cleanup
util.PutBytes(bs)
util.PutBytes(p.Payload)
// Send the packet
sinfo.core.router.out(packet)
}
ch <- callback
}
// Send to the worker and wait for it to finish
util.WorkerGo(poolFunc)
callbacks = append(callbacks, ch)
}
for {
for len(callbacks) > 0 {
select {
case f := <-callbacks[0]:
callbacks = callbacks[1:]
f()
case <-sinfo.cancel.Finished():
return
case bs := <-sinfo.send:
doSend(bs)
}
}
select {
case <-sinfo.cancel.Finished():
return
case bs := <-sinfo.send:
doSend(bs)
}
}
}

View File

@ -1,9 +1,11 @@
package yggdrasil package yggdrasil
import ( import (
"bufio"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"github.com/yggdrasil-network/yggdrasil-go/src/util" "github.com/yggdrasil-network/yggdrasil-go/src/util"
) )
@ -13,9 +15,8 @@ var _ = linkInterfaceMsgIO(&stream{})
type stream struct { type stream struct {
rwc io.ReadWriteCloser rwc io.ReadWriteCloser
inputBuffer []byte // Incoming packet stream inputBuffer *bufio.Reader
frag [2 * streamMsgSize]byte // Temporary data read off the underlying rwc, on its way to the inputBuffer outputBuffer net.Buffers
outputBuffer [2 * streamMsgSize]byte // Temporary data about to be written to the rwc
} }
func (s *stream) close() error { func (s *stream) close() error {
@ -30,19 +31,23 @@ func (s *stream) init(rwc io.ReadWriteCloser) {
// TODO have this also do the metadata handshake and create the peer struct // TODO have this also do the metadata handshake and create the peer struct
s.rwc = rwc s.rwc = rwc
// TODO call something to do the metadata exchange // TODO call something to do the metadata exchange
s.inputBuffer = bufio.NewReaderSize(s.rwc, 2*streamMsgSize)
} }
// writeMsg writes a message with stream padding, and is *not* thread safe. // writeMsg writes a message with stream padding, and is *not* thread safe.
func (s *stream) writeMsg(bs []byte) (int, error) { func (s *stream) writeMsg(bs []byte) (int, error) {
buf := s.outputBuffer[:0] buf := s.outputBuffer[:0]
buf = append(buf, streamMsg[:]...) buf = append(buf, streamMsg[:])
buf = wire_put_uint64(uint64(len(bs)), buf) l := wire_put_uint64(uint64(len(bs)), util.GetBytes())
padLen := len(buf) defer util.PutBytes(l)
buf = append(buf, bs...) buf = append(buf, l)
padLen := len(buf[0]) + len(buf[1])
buf = append(buf, bs)
totalLen := padLen + len(bs)
var bn int var bn int
for bn < len(buf) { for bn < totalLen {
n, err := s.rwc.Write(buf[bn:]) n, err := buf.WriteTo(s.rwc)
bn += n bn += int(n)
if err != nil { if err != nil {
l := bn - padLen l := bn - padLen
if l < 0 { if l < 0 {
@ -57,26 +62,11 @@ func (s *stream) writeMsg(bs []byte) (int, error) {
// readMsg reads a message from the stream, accounting for stream padding, and is *not* thread safe. // readMsg reads a message from the stream, accounting for stream padding, and is *not* thread safe.
func (s *stream) readMsg() ([]byte, error) { func (s *stream) readMsg() ([]byte, error) {
for { for {
buf := s.inputBuffer bs, err := s.readMsgFromBuffer()
msg, ok, err := stream_chopMsg(&buf) if err != nil {
switch {
case err != nil:
// Something in the stream format is corrupt
return nil, fmt.Errorf("message error: %v", err) return nil, fmt.Errorf("message error: %v", err)
case ok:
// Copy the packet into bs, shift the buffer, and return
msg = append(util.GetBytes(), msg...)
s.inputBuffer = append(s.inputBuffer[:0], buf...)
return msg, nil
default:
// Wait for the underlying reader to return enough info for us to proceed
n, err := s.rwc.Read(s.frag[:])
if n > 0 {
s.inputBuffer = append(s.inputBuffer, s.frag[:n]...)
} else if err != nil {
return nil, err
}
} }
return bs, err
} }
} }
@ -108,34 +98,30 @@ func (s *stream) _recvMetaBytes() ([]byte, error) {
return metaBytes, nil return metaBytes, nil
} }
// This takes a pointer to a slice as an argument. It checks if there's a // Reads bytes from the underlying rwc and returns 1 full message
// complete message and, if so, slices out those parts and returns the message, func (s *stream) readMsgFromBuffer() ([]byte, error) {
// true, and nil. If there's no error, but also no complete message, it returns pad := streamMsg // Copy
// nil, false, and nil. If there's an error, it returns nil, false, and the _, err := io.ReadFull(s.inputBuffer, pad[:])
// error, which the reader then handles (currently, by returning from the if err != nil {
// reader, which causes the connection to close). return nil, err
func stream_chopMsg(bs *[]byte) ([]byte, bool, error) { } else if pad != streamMsg {
// Returns msg, ok, err return nil, errors.New("bad message")
if len(*bs) < len(streamMsg) {
return nil, false, nil
} }
for idx := range streamMsg { lenSlice := make([]byte, 0, 10)
if (*bs)[idx] != streamMsg[idx] { // FIXME this nextByte stuff depends on wire.go format, kind of ugly to have it here
return nil, false, errors.New("bad message") nextByte := byte(0xff)
for nextByte > 127 {
nextByte, err = s.inputBuffer.ReadByte()
if err != nil {
return nil, err
} }
lenSlice = append(lenSlice, nextByte)
} }
msgLen, msgLenLen := wire_decode_uint64((*bs)[len(streamMsg):]) msgLen, _ := wire_decode_uint64(lenSlice)
if msgLen > streamMsgSize { if msgLen > streamMsgSize {
return nil, false, errors.New("oversized message") return nil, errors.New("oversized message")
} }
msgBegin := len(streamMsg) + msgLenLen msg := util.ResizeBytes(util.GetBytes(), int(msgLen))
msgEnd := msgBegin + int(msgLen) _, err = io.ReadFull(s.inputBuffer, msg)
if msgLenLen == 0 || len(*bs) < msgEnd { return msg, err
// We don't have the full message
// Need to buffer this and wait for the rest to come in
return nil, false, nil
}
msg := (*bs)[msgBegin:msgEnd]
(*bs) = (*bs)[msgEnd:]
return msg, true, nil
} }

View File

@ -814,17 +814,23 @@ func (t *switchTable) doWorker() {
go func() { go func() {
// Keep taking packets from the idle worker and sending them to the above whenever it's idle, keeping anything extra in a (fifo, head-drop) buffer // Keep taking packets from the idle worker and sending them to the above whenever it's idle, keeping anything extra in a (fifo, head-drop) buffer
var buf [][]byte var buf [][]byte
var size int
for { for {
buf = append(buf, <-t.toRouter) bs := <-t.toRouter
size += len(bs)
buf = append(buf, bs)
for len(buf) > 0 { for len(buf) > 0 {
select { select {
case bs := <-t.toRouter: case bs := <-t.toRouter:
size += len(bs)
buf = append(buf, bs) buf = append(buf, bs)
for len(buf) > 32 { for size > int(t.queueTotalMaxSize) {
size -= len(buf[0])
util.PutBytes(buf[0]) util.PutBytes(buf[0])
buf = buf[1:] buf = buf[1:]
} }
case sendingToRouter <- buf[0]: case sendingToRouter <- buf[0]:
size -= len(buf[0])
buf = buf[1:] buf = buf[1:]
} }
} }