diff --git a/cmd/yggdrasil/main.go b/cmd/yggdrasil/main.go index 5dba3a5..fb6cccb 100644 --- a/cmd/yggdrasil/main.go +++ b/cmd/yggdrasil/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "encoding/hex" "encoding/json" "flag" "fmt" @@ -20,22 +21,21 @@ import ( "github.com/yggdrasil-network/yggdrasil-go/src/admin" "github.com/yggdrasil-network/yggdrasil-go/src/config" + "github.com/yggdrasil-network/yggdrasil-go/src/crypto" "github.com/yggdrasil-network/yggdrasil-go/src/multicast" "github.com/yggdrasil-network/yggdrasil-go/src/tuntap" "github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil" ) -type nodeConfig = config.NodeConfig -type Core = yggdrasil.Core - type node struct { - core Core + core yggdrasil.Core + state *config.NodeState tuntap tuntap.TunAdapter multicast multicast.Multicast admin admin.AdminSocket } -func readConfig(useconf *bool, useconffile *string, normaliseconf *bool) *nodeConfig { +func readConfig(useconf *bool, useconffile *string, normaliseconf *bool) *config.NodeConfig { // Use a configuration file. If -useconf, the configuration will be read // from stdin. If -useconffile, the configuration will be read from the // filesystem. @@ -116,7 +116,7 @@ func main() { logging := flag.String("logging", "info,warn,error", "comma-separated list of logging levels to enable") flag.Parse() - var cfg *nodeConfig + var cfg *config.NodeConfig var err error switch { case *version: @@ -181,18 +181,20 @@ func main() { n := node{} // Now start Yggdrasil - this starts the DHT, router, switch and other core // components needed for Yggdrasil to operate - state, err := n.core.Start(cfg, logger) + n.state, err = n.core.Start(cfg, logger) if err != nil { logger.Errorln("An error occurred during startup") panic(err) } + // Register the session firewall gatekeeper function + n.core.SetSessionGatekeeper(n.sessionFirewall) // Start the admin socket - n.admin.Init(&n.core, state, logger, nil) + n.admin.Init(&n.core, n.state, logger, nil) if err := n.admin.Start(); err != nil { logger.Errorln("An error occurred starting admin socket:", err) } // Start the multicast interface - n.multicast.Init(&n.core, state, logger, nil) + n.multicast.Init(&n.core, n.state, logger, nil) if err := n.multicast.Start(); err != nil { logger.Errorln("An error occurred starting multicast:", err) } @@ -200,7 +202,7 @@ func main() { // Start the TUN/TAP interface if listener, err := n.core.ConnListen(); err == nil { if dialer, err := n.core.ConnDialer(); err == nil { - n.tuntap.Init(state, logger, listener, dialer) + n.tuntap.Init(n.state, logger, listener, dialer) if err := n.tuntap.Start(); err != nil { logger.Errorln("An error occurred starting TUN/TAP:", err) } @@ -251,3 +253,66 @@ func main() { } exit: } + +func (n *node) sessionFirewall(pubkey *crypto.BoxPubKey, initiator bool) bool { + n.state.Mutex.RLock() + defer n.state.Mutex.RUnlock() + + // Allow by default if the session firewall is disabled + if !n.state.Current.SessionFirewall.Enable { + return true + } + + // Prepare for checking whitelist/blacklist + var box crypto.BoxPubKey + // Reject blacklisted nodes + for _, b := range n.state.Current.SessionFirewall.BlacklistEncryptionPublicKeys { + key, err := hex.DecodeString(b) + if err == nil { + copy(box[:crypto.BoxPubKeyLen], key) + if box == *pubkey { + return false + } + } + } + + // Allow whitelisted nodes + for _, b := range n.state.Current.SessionFirewall.WhitelistEncryptionPublicKeys { + key, err := hex.DecodeString(b) + if err == nil { + copy(box[:crypto.BoxPubKeyLen], key) + if box == *pubkey { + return true + } + } + } + + // Allow outbound sessions if appropriate + if n.state.Current.SessionFirewall.AlwaysAllowOutbound { + if initiator { + return true + } + } + + // Look and see if the pubkey is that of a direct peer + var isDirectPeer bool + for _, peer := range n.core.GetPeers() { + if peer.PublicKey == *pubkey { + isDirectPeer = true + break + } + } + + // Allow direct peers if appropriate + if n.state.Current.SessionFirewall.AllowFromDirect && isDirectPeer { + return true + } + + // Allow remote nodes if appropriate + if n.state.Current.SessionFirewall.AllowFromRemote && !isDirectPeer { + return true + } + + // Finally, default-deny if not matching any of the above rules + return false +} diff --git a/src/yggdrasil/api.go b/src/yggdrasil/api.go index 5e58ffa..25f9869 100644 --- a/src/yggdrasil/api.go +++ b/src/yggdrasil/api.go @@ -395,6 +395,19 @@ func (c *Core) GetNodeInfo(keyString, coordString string, nocache bool) (NodeInf return NodeInfoPayload{}, errors.New(fmt.Sprintf("getNodeInfo timeout: %s", keyString)) } +// SetSessionGatekeeper allows you to configure a handler function for deciding +// whether a session should be allowed or not. The default session firewall is +// implemented in this way. The function receives the public key of the remote +// side and a boolean which is true if we initiated the session or false if we +// received an incoming session request. The function should return true to +// allow the session or false to reject it. +func (c *Core) SetSessionGatekeeper(f func(pubkey *crypto.BoxPubKey, initiator bool) bool) { + c.sessions.isAllowedMutex.Lock() + defer c.sessions.isAllowedMutex.Unlock() + + c.sessions.isAllowedHandler = f +} + // SetLogger sets the output logger of the Yggdrasil node after startup. This // may be useful if you want to redirect the output later. func (c *Core) SetLogger(log *log.Logger) { diff --git a/src/yggdrasil/session.go b/src/yggdrasil/session.go index 68b9095..2211847 100644 --- a/src/yggdrasil/session.go +++ b/src/yggdrasil/session.go @@ -6,7 +6,6 @@ package yggdrasil import ( "bytes" - "encoding/hex" "sync" "time" @@ -111,18 +110,20 @@ func (s *sessionInfo) update(p *sessionPing) bool { // Sessions are indexed by handle. // Additionally, stores maps of address/subnet onto keys, and keys onto handles. type sessions struct { - core *Core - listener *Listener - listenerMutex sync.Mutex - reconfigure chan chan error - lastCleanup time.Time - permShared map[crypto.BoxPubKey]*crypto.BoxSharedKey // Maps known permanent keys to their shared key, used by DHT a lot - sinfos map[crypto.Handle]*sessionInfo // Maps (secret) handle onto session info - conns map[crypto.Handle]*Conn // Maps (secret) handle onto connections - byMySes map[crypto.BoxPubKey]*crypto.Handle // Maps mySesPub onto handle - byTheirPerm map[crypto.BoxPubKey]*crypto.Handle // Maps theirPermPub onto handle - addrToPerm map[address.Address]*crypto.BoxPubKey - subnetToPerm map[address.Subnet]*crypto.BoxPubKey + core *Core + listener *Listener + listenerMutex sync.Mutex + reconfigure chan chan error + lastCleanup time.Time + isAllowedHandler func(pubkey *crypto.BoxPubKey, initiator bool) bool // Returns true or false if session setup is allowed + isAllowedMutex sync.RWMutex // Protects the above + permShared map[crypto.BoxPubKey]*crypto.BoxSharedKey // Maps known permanent keys to their shared key, used by DHT a lot + sinfos map[crypto.Handle]*sessionInfo // Maps (secret) handle onto session info + conns map[crypto.Handle]*Conn // Maps (secret) handle onto connections + byMySes map[crypto.BoxPubKey]*crypto.Handle // Maps mySesPub onto handle + byTheirPerm map[crypto.BoxPubKey]*crypto.Handle // Maps theirPermPub onto handle + addrToPerm map[address.Address]*crypto.BoxPubKey + subnetToPerm map[address.Subnet]*crypto.BoxPubKey } // Initializes the session struct. @@ -155,70 +156,17 @@ func (ss *sessions) init(core *Core) { ss.lastCleanup = time.Now() } -// Determines whether the session firewall is enabled. -func (ss *sessions) isSessionFirewallEnabled() bool { - ss.core.config.Mutex.RLock() - defer ss.core.config.Mutex.RUnlock() - - return ss.core.config.Current.SessionFirewall.Enable -} - // Determines whether the session with a given publickey is allowed based on // session firewall rules. func (ss *sessions) isSessionAllowed(pubkey *crypto.BoxPubKey, initiator bool) bool { - ss.core.config.Mutex.RLock() - defer ss.core.config.Mutex.RUnlock() + ss.isAllowedMutex.RLock() + defer ss.isAllowedMutex.RUnlock() - // Allow by default if the session firewall is disabled - if !ss.isSessionFirewallEnabled() { + if ss.isAllowedHandler == nil { return true } - // Prepare for checking whitelist/blacklist - var box crypto.BoxPubKey - // Reject blacklisted nodes - for _, b := range ss.core.config.Current.SessionFirewall.BlacklistEncryptionPublicKeys { - key, err := hex.DecodeString(b) - if err == nil { - copy(box[:crypto.BoxPubKeyLen], key) - if box == *pubkey { - return false - } - } - } - // Allow whitelisted nodes - for _, b := range ss.core.config.Current.SessionFirewall.WhitelistEncryptionPublicKeys { - key, err := hex.DecodeString(b) - if err == nil { - copy(box[:crypto.BoxPubKeyLen], key) - if box == *pubkey { - return true - } - } - } - // Allow outbound sessions if appropriate - if ss.core.config.Current.SessionFirewall.AlwaysAllowOutbound { - if initiator { - return true - } - } - // Look and see if the pubkey is that of a direct peer - var isDirectPeer bool - for _, peer := range ss.core.peers.ports.Load().(map[switchPort]*peer) { - if peer.box == *pubkey { - isDirectPeer = true - break - } - } - // Allow direct peers if appropriate - if ss.core.config.Current.SessionFirewall.AllowFromDirect && isDirectPeer { - return true - } - // Allow remote nodes if appropriate - if ss.core.config.Current.SessionFirewall.AllowFromRemote && !isDirectPeer { - return true - } - // Finally, default-deny if not matching any of the above rules - return false + + return ss.isAllowedHandler(pubkey, initiator) } // Gets the session corresponding to a given handle. @@ -271,6 +219,7 @@ func (ss *sessions) getByTheirSubnet(snet *address.Subnet) (*sessionInfo, bool) // includse initializing session info to sane defaults (e.g. lowest supported // MTU). func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo { + // TODO: this check definitely needs to be moved if !ss.isSessionAllowed(theirPermKey, true) { return nil } @@ -444,12 +393,12 @@ func (ss *sessions) sendPingPong(sinfo *sessionInfo, isPong bool) { func (ss *sessions) handlePing(ping *sessionPing) { // Get the corresponding session (or create a new session) sinfo, isIn := ss.getByTheirPerm(&ping.SendPermPub) - // Check the session firewall - if !isIn && ss.isSessionFirewallEnabled() { - if !ss.isSessionAllowed(&ping.SendPermPub, false) { - return - } + // Check if the session is allowed + // TODO: this check may need to be moved + if !isIn && !ss.isSessionAllowed(&ping.SendPermPub, false) { + return } + // Create the session if it doesn't already exist if !isIn { ss.createSession(&ping.SendPermPub) sinfo, isIn = ss.getByTheirPerm(&ping.SendPermPub)