5
0
mirror of https://github.com/cwinfo/yggdrasil-go.git synced 2024-11-09 15:10:27 +00:00

fix core_test.go and a race in setting/using mtu

This commit is contained in:
Arceliar 2021-06-13 13:40:20 -05:00
parent cb81be94ec
commit b34c3230f8
3 changed files with 33 additions and 13 deletions

View File

@ -243,16 +243,20 @@ func (c *Core) MaxMTU() uint64 {
return c.store.maxSessionMTU() return c.store.maxSessionMTU()
} }
// SetMTU can only safely be called after Init and before Start.
func (c *Core) SetMTU(mtu uint64) { func (c *Core) SetMTU(mtu uint64) {
if mtu < 1280 { if mtu < 1280 {
mtu = 1280 mtu = 1280
} }
c.store.mutex.Lock()
c.store.mtu = mtu c.store.mtu = mtu
c.store.mutex.Unlock()
} }
func (c *Core) MTU() uint64 { func (c *Core) MTU() uint64 {
return c.store.mtu c.store.mutex.Lock()
mtu := c.store.mtu
c.store.mutex.Unlock()
return mtu
} }
// Implement io.ReadWriteCloser // Implement io.ReadWriteCloser

View File

@ -43,11 +43,13 @@ func CreateAndConnectTwo(t testing.TB, verbose bool) (nodeA *Core, nodeB *Core)
if err := nodeA.Start(GenerateConfig(), GetLoggerWithPrefix("A: ", verbose)); err != nil { if err := nodeA.Start(GenerateConfig(), GetLoggerWithPrefix("A: ", verbose)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
nodeA.SetMTU(1500)
nodeB = new(Core) nodeB = new(Core)
if err := nodeB.Start(GenerateConfig(), GetLoggerWithPrefix("B: ", verbose)); err != nil { if err := nodeB.Start(GenerateConfig(), GetLoggerWithPrefix("B: ", verbose)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
nodeB.SetMTU(1500)
u, err := url.Parse("tcp://" + nodeA.links.tcp.getAddr().String()) u, err := url.Parse("tcp://" + nodeA.links.tcp.getAddr().String())
if err != nil { if err != nil {
@ -89,8 +91,9 @@ func CreateEchoListener(t testing.TB, nodeA *Core, bufLen int, repeats int) chan
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
buf := make([]byte, bufLen) buf := make([]byte, bufLen)
res := make([]byte, bufLen)
for i := 0; i < repeats; i++ { for i := 0; i < repeats; i++ {
n, from, err := nodeA.ReadFrom(buf) n, err := nodeA.Read(buf)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -99,7 +102,10 @@ func CreateEchoListener(t testing.TB, nodeA *Core, bufLen int, repeats int) chan
t.Error("missing data") t.Error("missing data")
return return
} }
_, err = nodeA.WriteTo(buf, from) copy(res, buf)
copy(res[8:24], buf[24:40])
copy(res[24:40], buf[8:24])
_, err = nodeA.Write(res)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -130,17 +136,20 @@ func TestCore_Start_Transfer(t *testing.T) {
// Send // Send
msg := make([]byte, msgLen) msg := make([]byte, msgLen)
rand.Read(msg) rand.Read(msg[40:])
_, err := nodeB.WriteTo(msg, nodeA.LocalAddr()) msg[0] = 0x60
copy(msg[8:24], nodeB.Address())
copy(msg[24:40], nodeA.Address())
_, err := nodeB.Write(msg)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
buf := make([]byte, msgLen) buf := make([]byte, msgLen)
_, _, err = nodeB.ReadFrom(buf) _, err = nodeB.Read(buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(msg, buf) { if !bytes.Equal(msg[40:], buf[40:]) {
t.Fatal("expected echo") t.Fatal("expected echo")
} }
<-done <-done
@ -159,18 +168,22 @@ func BenchmarkCore_Start_Transfer(b *testing.B) {
// Send // Send
msg := make([]byte, msgLen) msg := make([]byte, msgLen)
rand.Read(msg) rand.Read(msg[40:])
msg[0] = 0x60
copy(msg[8:24], nodeB.Address())
copy(msg[24:40], nodeA.Address())
buf := make([]byte, msgLen) buf := make([]byte, msgLen)
b.SetBytes(int64(msgLen)) b.SetBytes(int64(msgLen))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := nodeB.WriteTo(msg, nodeA.LocalAddr()) _, err := nodeB.Write(msg)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
_, _, err = nodeB.ReadFrom(buf) _, err = nodeB.Read(buf)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@ -243,12 +243,15 @@ func (k *keyStore) readPC(p []byte) (int, error) {
if len(bs) < 40 { if len(bs) < 40 {
continue continue
} }
if len(bs) > int(k.mtu) { k.mutex.Lock()
mtu := int(k.mtu)
k.mutex.Unlock()
if len(bs) > mtu {
// Using bs would make it leak off the stack, so copy to buf // Using bs would make it leak off the stack, so copy to buf
buf := make([]byte, 40) buf := make([]byte, 40)
copy(buf, bs) copy(buf, bs)
ptb := &icmp.PacketTooBig{ ptb := &icmp.PacketTooBig{
MTU: int(k.mtu), MTU: mtu,
Data: buf[:40], Data: buf[:40],
} }
if packet, err := CreateICMPv6(buf[8:24], buf[24:40], ipv6.ICMPTypePacketTooBig, 0, ptb); err == nil { if packet, err := CreateICMPv6(buf[8:24], buf[24:40], ipv6.ICMPTypePacketTooBig, 0, ptb); err == nil {