diff --git a/src/yggdrasil/link.go b/src/yggdrasil/link.go index 5c4f0e6..03e6007 100644 --- a/src/yggdrasil/link.go +++ b/src/yggdrasil/link.go @@ -1,7 +1,6 @@ package yggdrasil import ( - "bytes" "encoding/hex" "errors" "fmt" @@ -71,8 +70,8 @@ type linkInterface struct { } type linkOptions struct { - pinnedCurve25519Keys []crypto.BoxPubKey - pinnedEd25519Keys []crypto.SigPubKey + pinnedCurve25519Keys map[crypto.BoxPubKey]struct{} + pinnedEd25519Keys map[crypto.SigPubKey]struct{} } func (l *link) init(c *Core) error { @@ -102,24 +101,22 @@ func (l *link) call(uri string, sintf string) error { pathtokens := strings.Split(strings.Trim(u.Path, "/"), "/") tcpOpts := tcpOptions{} if pubkeys, ok := u.Query()["curve25519"]; ok && len(pubkeys) > 0 { + tcpOpts.pinnedCurve25519Keys = make(map[crypto.BoxPubKey]struct{}) for _, pubkey := range pubkeys { if boxPub, err := hex.DecodeString(pubkey); err != nil { var boxPubKey crypto.BoxPubKey copy(boxPubKey[:], boxPub) - tcpOpts.pinnedCurve25519Keys = append( - tcpOpts.pinnedCurve25519Keys, boxPubKey, - ) + tcpOpts.pinnedCurve25519Keys[boxPubKey] = struct{}{} } } } if pubkeys, ok := u.Query()["ed25519"]; ok && len(pubkeys) > 0 { + tcpOpts.pinnedEd25519Keys = make(map[crypto.SigPubKey]struct{}) for _, pubkey := range pubkeys { if sigPub, err := hex.DecodeString(pubkey); err != nil { var sigPubKey crypto.SigPubKey copy(sigPubKey[:], sigPub) - tcpOpts.pinnedEd25519Keys = append( - tcpOpts.pinnedEd25519Keys, sigPubKey, - ) + tcpOpts.pinnedEd25519Keys[sigPubKey] = struct{}{} } } } @@ -222,22 +219,14 @@ func (intf *linkInterface) handler() error { } // Check if the remote side matches the keys we expected. This is a bit of a weak // check - in future versions we really should check a signature or something like that. - if pinned := intf.options.pinnedCurve25519Keys; len(pinned) > 0 { - allowed := false - for _, key := range pinned { - allowed = allowed || (bytes.Compare(key[:], meta.box[:]) == 0) - } - if !allowed { + if pinned := intf.options.pinnedCurve25519Keys; pinned != nil { + if _, allowed := pinned[meta.box]; !allowed { intf.link.core.log.Errorf("Failed to connect to node: %q sent curve25519 key that does not match pinned keys", intf.name) return fmt.Errorf("failed to connect: host sent curve25519 key that does not match pinned keys") } } - if pinned := intf.options.pinnedEd25519Keys; len(pinned) > 0 { - allowed := false - for _, key := range pinned { - allowed = allowed || (bytes.Compare(key[:], meta.sig[:]) == 0) - } - if !allowed { + if pinned := intf.options.pinnedEd25519Keys; pinned != nil { + if _, allowed := pinned[meta.sig]; !allowed { intf.link.core.log.Errorf("Failed to connect to node: %q sent ed25519 key that does not match pinned keys", intf.name) return fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys") }