From 2a212417389fcf8f41fd44c74864adfcc9b9e009 Mon Sep 17 00:00:00 2001
From: Neil Alexander <neilalexander@users.noreply.github.com>
Date: Wed, 11 Oct 2023 19:28:28 +0100
Subject: [PATCH] Multicast passwords

---
 cmd/yggdrasil/main.go               |  1 +
 contrib/mobile/mobile.go            |  1 +
 src/config/config.go                |  1 +
 src/multicast/advertisement.go      | 18 ++++-----
 src/multicast/advertisement_test.go | 10 ++---
 src/multicast/multicast.go          | 61 +++++++++++++++++++++--------
 src/multicast/options.go            |  9 +----
 7 files changed, 62 insertions(+), 39 deletions(-)

diff --git a/cmd/yggdrasil/main.go b/cmd/yggdrasil/main.go
index 09d2f31..2f29476 100644
--- a/cmd/yggdrasil/main.go
+++ b/cmd/yggdrasil/main.go
@@ -230,6 +230,7 @@ func main() {
 				Listen:   intf.Listen,
 				Port:     intf.Port,
 				Priority: uint8(intf.Priority),
+				Password: intf.Password,
 			})
 		}
 		if n.multicast, err = multicast.New(n.core, logger, options...); err != nil {
diff --git a/contrib/mobile/mobile.go b/contrib/mobile/mobile.go
index be1f5ff..eb79430 100644
--- a/contrib/mobile/mobile.go
+++ b/contrib/mobile/mobile.go
@@ -88,6 +88,7 @@ func (m *Yggdrasil) StartJSON(configjson []byte) error {
 				Listen:   intf.Listen,
 				Port:     intf.Port,
 				Priority: uint8(intf.Priority),
+				Password: intf.Password,
 			})
 		}
 		m.multicast, err = multicast.New(m.core, m.logger, options...)
diff --git a/src/config/config.go b/src/config/config.go
index e818f70..bb94b67 100644
--- a/src/config/config.go
+++ b/src/config/config.go
@@ -61,6 +61,7 @@ type MulticastInterfaceConfig struct {
 	Listen   bool
 	Port     uint16
 	Priority uint64 // really uint8, but gobind won't export it
+	Password string
 }
 
 // Generates default configuration and returns a pointer to the resulting
diff --git a/src/multicast/advertisement.go b/src/multicast/advertisement.go
index 69c29b6..d0db8b5 100644
--- a/src/multicast/advertisement.go
+++ b/src/multicast/advertisement.go
@@ -7,21 +7,21 @@ import (
 )
 
 type multicastAdvertisement struct {
-	MajorVersion  uint16
-	MinorVersion  uint16
-	PublicKey     ed25519.PublicKey
-	Port          uint16
-	Discriminator []byte
+	MajorVersion uint16
+	MinorVersion uint16
+	PublicKey    ed25519.PublicKey
+	Port         uint16
+	Hash         []byte
 }
 
 func (m *multicastAdvertisement) MarshalBinary() ([]byte, error) {
-	b := make([]byte, 0, ed25519.PublicKeySize+8+len(m.Discriminator))
+	b := make([]byte, 0, ed25519.PublicKeySize+8+len(m.Hash))
 	b = binary.BigEndian.AppendUint16(b, m.MajorVersion)
 	b = binary.BigEndian.AppendUint16(b, m.MinorVersion)
 	b = append(b, m.PublicKey...)
 	b = binary.BigEndian.AppendUint16(b, m.Port)
-	b = binary.BigEndian.AppendUint16(b, uint16(len(m.Discriminator)))
-	b = append(b, m.Discriminator...)
+	b = binary.BigEndian.AppendUint16(b, uint16(len(m.Hash)))
+	b = append(b, m.Hash...)
 	return b, nil
 }
 
@@ -34,6 +34,6 @@ func (m *multicastAdvertisement) UnmarshalBinary(b []byte) error {
 	m.PublicKey = append(m.PublicKey[:0], b[4:4+ed25519.PublicKeySize]...)
 	m.Port = binary.BigEndian.Uint16(b[4+ed25519.PublicKeySize : 6+ed25519.PublicKeySize])
 	dl := binary.BigEndian.Uint16(b[6+ed25519.PublicKeySize : 8+ed25519.PublicKeySize])
-	m.Discriminator = append(m.Discriminator[:0], b[8+ed25519.PublicKeySize:8+ed25519.PublicKeySize+dl]...)
+	m.Hash = append(m.Hash[:0], b[8+ed25519.PublicKeySize:8+ed25519.PublicKeySize+dl]...)
 	return nil
 }
diff --git a/src/multicast/advertisement_test.go b/src/multicast/advertisement_test.go
index 7132322..9541da6 100644
--- a/src/multicast/advertisement_test.go
+++ b/src/multicast/advertisement_test.go
@@ -13,11 +13,11 @@ func TestMulticastAdvertisementRoundTrip(t *testing.T) {
 	}
 
 	orig := multicastAdvertisement{
-		MajorVersion:  1,
-		MinorVersion:  2,
-		PublicKey:     pk,
-		Port:          3,
-		Discriminator: sk, // any bytes will do
+		MajorVersion: 1,
+		MinorVersion: 2,
+		PublicKey:    pk,
+		Port:         3,
+		Hash:         sk, // any bytes will do
 	}
 
 	ob, err := orig.MarshalBinary()
diff --git a/src/multicast/multicast.go b/src/multicast/multicast.go
index f58af93..741c431 100644
--- a/src/multicast/multicast.go
+++ b/src/multicast/multicast.go
@@ -3,6 +3,7 @@ package multicast
 import (
 	"bytes"
 	"context"
+	"crypto/ed25519"
 	"encoding/hex"
 	"fmt"
 	"math/rand"
@@ -14,6 +15,7 @@ import (
 	"github.com/gologme/log"
 
 	"github.com/yggdrasil-network/yggdrasil-go/src/core"
+	"golang.org/x/crypto/blake2b"
 	"golang.org/x/net/ipv6"
 )
 
@@ -31,10 +33,8 @@ type Multicast struct {
 	_interfaces map[string]*interfaceInfo
 	_timer      *time.Timer
 	config      struct {
-		_discriminator      []byte
-		_discriminatorMatch func([]byte) bool
-		_groupAddr          GroupAddress
-		_interfaces         map[MulticastInterface]struct{}
+		_groupAddr  GroupAddress
+		_interfaces map[MulticastInterface]struct{}
 	}
 }
 
@@ -45,6 +45,8 @@ type interfaceInfo struct {
 	listen   bool
 	port     uint16
 	priority uint8
+	password []byte
+	hash     []byte
 }
 
 type listenerInfo struct {
@@ -178,6 +180,7 @@ func (m *Multicast) _getAllowedInterfaces() map[string]*interfaceInfo {
 		return nil
 	}
 	// Work out which interfaces to announce on
+	pk := m.core.PublicKey()
 	for _, iface := range allifaces {
 		switch {
 		case iface.Flags&net.FlagUp == 0:
@@ -196,12 +199,23 @@ func (m *Multicast) _getAllowedInterfaces() map[string]*interfaceInfo {
 			if !ifcfg.Regex.MatchString(iface.Name) {
 				continue
 			}
+			hasher, err := blake2b.New512([]byte(ifcfg.Password))
+			if err != nil {
+				continue
+			}
+			if n, err := hasher.Write(pk); err != nil {
+				continue
+			} else if n != ed25519.PublicKeySize {
+				continue
+			}
 			interfaces[iface.Name] = &interfaceInfo{
 				iface:    iface,
 				beacon:   ifcfg.Beacon,
 				listen:   ifcfg.Listen,
 				port:     ifcfg.Port,
 				priority: ifcfg.Priority,
+				password: []byte(ifcfg.Password),
+				hash:     hasher.Sum(nil),
 			}
 			break
 		}
@@ -298,10 +312,13 @@ func (m *Multicast) _announce() {
 			var linfo *listenerInfo
 			if _, ok := m._listeners[iface.Name]; !ok {
 				// No listener was found - let's create one
-				urlString := fmt.Sprintf("tls://[%s]:%d", addrIP, info.port)
-				u, err := url.Parse(urlString)
-				if err != nil {
-					panic(err)
+				v := &url.Values{}
+				v.Add("priority", fmt.Sprintf("%d", info.priority))
+				v.Add("password", string(info.password))
+				u := &url.URL{
+					Scheme:   "tls",
+					Host:     net.JoinHostPort(addrIP.String(), fmt.Sprintf("%d", info.port)),
+					RawQuery: v.Encode(),
 				}
 				if li, err := m.core.Listen(u, iface.Name); err == nil {
 					m.log.Debugln("Started multicasting on", iface.Name)
@@ -324,11 +341,11 @@ func (m *Multicast) _announce() {
 			}
 			addr := linfo.listener.Addr().(*net.TCPAddr)
 			adv := multicastAdvertisement{
-				MajorVersion:  core.ProtocolVersionMajor,
-				MinorVersion:  core.ProtocolVersionMinor,
-				PublicKey:     m.core.PublicKey(),
-				Port:          uint16(addr.Port),
-				Discriminator: m.config._discriminator,
+				MajorVersion: core.ProtocolVersionMajor,
+				MinorVersion: core.ProtocolVersionMinor,
+				PublicKey:    m.core.PublicKey(),
+				Port:         uint16(addr.Port),
+				Hash:         info.hash,
 			}
 			msg, err := adv.MarshalBinary()
 			if err != nil {
@@ -356,6 +373,7 @@ func (m *Multicast) listen() {
 		panic(err)
 	}
 	bs := make([]byte, 2048)
+	hb := make([]byte, 0, blake2b.Size) // Reused to reduce hash allocations
 	for {
 		n, rcm, fromAddr, err := m.sock.ReadFrom(bs)
 		if err != nil {
@@ -386,10 +404,6 @@ func (m *Multicast) listen() {
 			continue
 		case adv.PublicKey.Equal(m.core.PublicKey()):
 			continue
-		case m.config._discriminatorMatch == nil && !bytes.Equal(adv.Discriminator, m.config._discriminator):
-			continue
-		case m.config._discriminatorMatch != nil && !m.config._discriminatorMatch(adv.Discriminator):
-			continue
 		}
 		from := fromAddr.(*net.UDPAddr)
 		from.Port = int(adv.Port)
@@ -398,9 +412,22 @@ func (m *Multicast) listen() {
 			interfaces = m._interfaces
 		})
 		if info, ok := interfaces[from.Zone]; ok && info.listen {
+			hasher, err := blake2b.New512(info.password)
+			if err != nil {
+				continue
+			}
+			if n, err := hasher.Write(adv.PublicKey); err != nil {
+				continue
+			} else if n != ed25519.PublicKeySize {
+				continue
+			}
+			if !bytes.Equal(hasher.Sum(hb[:0]), adv.Hash) {
+				continue
+			}
 			v := &url.Values{}
 			v.Add("key", hex.EncodeToString(adv.PublicKey))
 			v.Add("priority", fmt.Sprintf("%d", info.priority))
+			v.Add("password", string(info.password))
 			u := &url.URL{
 				Scheme:   "tls",
 				Host:     from.String(),
diff --git a/src/multicast/options.go b/src/multicast/options.go
index aa74060..bd9fea5 100644
--- a/src/multicast/options.go
+++ b/src/multicast/options.go
@@ -8,10 +8,6 @@ func (m *Multicast) _applyOption(opt SetupOption) {
 		m.config._interfaces[v] = struct{}{}
 	case GroupAddress:
 		m.config._groupAddr = v
-	case Discriminator:
-		m.config._discriminator = append(m.config._discriminator[:0], v...)
-	case DiscriminatorMatch:
-		m.config._discriminatorMatch = v
 	}
 }
 
@@ -25,13 +21,10 @@ type MulticastInterface struct {
 	Listen   bool
 	Port     uint16
 	Priority uint8
+	Password string
 }
 
 type GroupAddress string
-type Discriminator []byte
-type DiscriminatorMatch func([]byte) bool
 
 func (a MulticastInterface) isSetupOption() {}
 func (a GroupAddress) isSetupOption()       {}
-func (a Discriminator) isSetupOption()      {}
-func (a DiscriminatorMatch) isSetupOption() {}