diff --git a/src/crypto/crypto.go b/src/crypto/crypto.go index 14b1218..44d2d8e 100644 --- a/src/crypto/crypto.go +++ b/src/crypto/crypto.go @@ -172,7 +172,7 @@ func BoxOpen(shared *BoxSharedKey, boxed []byte, nonce *BoxNonce) ([]byte, bool) { out := util.GetBytes() - return append(out, boxed...), true + return append(out, boxed...), true //FIXME disabled crypto for benchmarking s := (*[BoxSharedKeyLen]byte)(shared) n := (*[BoxNonceLen]byte)(nonce) unboxed, success := box.OpenAfterPrecomputation(out, boxed, n, s) @@ -185,7 +185,7 @@ func BoxSeal(shared *BoxSharedKey, unboxed []byte, nonce *BoxNonce) ([]byte, *Bo } nonce.Increment() out := util.GetBytes() - return append(out, unboxed...), nonce + return append(out, unboxed...), nonce // FIXME disabled crypto for benchmarking s := (*[BoxSharedKeyLen]byte)(shared) n := (*[BoxNonceLen]byte)(nonce) boxed := box.SealAfterPrecomputation(out, unboxed, n, s) diff --git a/src/yggdrasil/stream.go b/src/yggdrasil/stream.go index 30dd924..011943f 100644 --- a/src/yggdrasil/stream.go +++ b/src/yggdrasil/stream.go @@ -1,6 +1,7 @@ package yggdrasil import ( + "bufio" "errors" "fmt" "io" @@ -13,10 +14,8 @@ import ( var _ = linkInterfaceMsgIO(&stream{}) type stream struct { - rwc io.ReadWriteCloser - inputBuffer []byte // Incoming packet stream - frag [2 * streamMsgSize]byte // Temporary data read off the underlying rwc, on its way to the inputBuffer - //outputBuffer [2 * streamMsgSize]byte // Temporary data about to be written to the rwc + rwc io.ReadWriteCloser + inputBuffer *bufio.Reader outputBuffer net.Buffers } @@ -32,6 +31,7 @@ func (s *stream) init(rwc io.ReadWriteCloser) { // TODO have this also do the metadata handshake and create the peer struct s.rwc = rwc // TODO call something to do the metadata exchange + s.inputBuffer = bufio.NewReaderSize(s.rwc, 2*streamMsgSize) } // writeMsg writes a message with stream padding, and is *not* thread safe. @@ -62,26 +62,11 @@ func (s *stream) writeMsg(bs []byte) (int, error) { // readMsg reads a message from the stream, accounting for stream padding, and is *not* thread safe. func (s *stream) readMsg() ([]byte, error) { for { - buf := s.inputBuffer - msg, ok, err := stream_chopMsg(&buf) - switch { - case err != nil: - // Something in the stream format is corrupt + bs, err := s.readMsgFromBuffer() + if err != nil { return nil, fmt.Errorf("message error: %v", err) - case ok: - // Copy the packet into bs, shift the buffer, and return - msg = append(util.GetBytes(), msg...) - s.inputBuffer = append(s.inputBuffer[:0], buf...) - return msg, nil - default: - // Wait for the underlying reader to return enough info for us to proceed - n, err := s.rwc.Read(s.frag[:]) - if n > 0 { - s.inputBuffer = append(s.inputBuffer, s.frag[:n]...) - } else if err != nil { - return nil, err - } } + return bs, err } } @@ -113,34 +98,30 @@ func (s *stream) _recvMetaBytes() ([]byte, error) { return metaBytes, nil } -// This takes a pointer to a slice as an argument. It checks if there's a -// complete message and, if so, slices out those parts and returns the message, -// true, and nil. If there's no error, but also no complete message, it returns -// nil, false, and nil. If there's an error, it returns nil, false, and the -// error, which the reader then handles (currently, by returning from the -// reader, which causes the connection to close). -func stream_chopMsg(bs *[]byte) ([]byte, bool, error) { - // Returns msg, ok, err - if len(*bs) < len(streamMsg) { - return nil, false, nil +// Reads bytes from the underlying rwc and returns 1 full message +func (s *stream) readMsgFromBuffer() ([]byte, error) { + pad := streamMsg // Copy + _, err := io.ReadFull(s.inputBuffer, pad[:]) + if err != nil { + return nil, err + } else if pad != streamMsg { + return nil, errors.New("bad message") } - for idx := range streamMsg { - if (*bs)[idx] != streamMsg[idx] { - return nil, false, errors.New("bad message") + lenSlice := make([]byte, 0, 10) + // FIXME this nextByte stuff depends on wire.go format, kind of ugly to have it here + nextByte := byte(0xff) + for nextByte > 127 { + nextByte, err = s.inputBuffer.ReadByte() + if err != nil { + return nil, err } + lenSlice = append(lenSlice, nextByte) } - msgLen, msgLenLen := wire_decode_uint64((*bs)[len(streamMsg):]) + msgLen, _ := wire_decode_uint64(lenSlice) if msgLen > streamMsgSize { - return nil, false, errors.New("oversized message") + return nil, errors.New("oversized message") } - msgBegin := len(streamMsg) + msgLenLen - msgEnd := msgBegin + int(msgLen) - if msgLenLen == 0 || len(*bs) < msgEnd { - // We don't have the full message - // Need to buffer this and wait for the rest to come in - return nil, false, nil - } - msg := (*bs)[msgBegin:msgEnd] - (*bs) = (*bs)[msgEnd:] - return msg, true, nil + msg := util.ResizeBytes(util.GetBytes(), int(msgLen)) + _, err = io.ReadFull(s.inputBuffer, msg) + return msg, err }