diff --git a/src/yggdrasil/api.go b/src/yggdrasil/api.go index 4245f43..80f669b 100644 --- a/src/yggdrasil/api.go +++ b/src/yggdrasil/api.go @@ -280,7 +280,14 @@ func (c *Core) ConnDialer() (*Dialer, error) { // "Listen" configuration item, e.g. // tcp://a.b.c.d:e func (c *Core) ListenTCP(uri string) (*TcpListener, error) { - return c.link.tcp.listen(uri) + return c.link.tcp.listen(uri, nil) +} + +// ListenTLS starts a new TLS listener. The input URI should match that of the +// "Listen" configuration item, e.g. +// tls://a.b.c.d:e +func (c *Core) ListenTLS(uri string) (*TcpListener, error) { + return c.link.tcp.listen(uri, c.link.tcp.tls.forListener) } // NodeID gets the node ID. This is derived from your router encryption keys. diff --git a/src/yggdrasil/link.go b/src/yggdrasil/link.go index 98c080c..1710e20 100644 --- a/src/yggdrasil/link.go +++ b/src/yggdrasil/link.go @@ -93,9 +93,11 @@ func (l *link) call(uri string, sintf string) error { pathtokens := strings.Split(strings.Trim(u.Path, "/"), "/") switch u.Scheme { case "tcp": - l.tcp.call(u.Host, nil, sintf) + l.tcp.call(u.Host, nil, sintf, nil) case "socks": - l.tcp.call(pathtokens[0], u.Host, sintf) + l.tcp.call(pathtokens[0], u.Host, sintf, nil) + case "tls": + l.tcp.call(u.Host, nil, sintf, l.tcp.tls.forDialer) default: return errors.New("unknown call scheme: " + u.Scheme) } @@ -109,7 +111,10 @@ func (l *link) listen(uri string) error { } switch u.Scheme { case "tcp": - _, err := l.tcp.listen(u.Host) + _, err := l.tcp.listen(u.Host, nil) + return err + case "tls": + _, err := l.tcp.listen(u.Host, l.tcp.tls.forListener) return err default: return errors.New("unknown listen scheme: " + u.Scheme) diff --git a/src/yggdrasil/tcp.go b/src/yggdrasil/tcp.go index 66f708c..c456959 100644 --- a/src/yggdrasil/tcp.go +++ b/src/yggdrasil/tcp.go @@ -39,6 +39,7 @@ type tcp struct { listeners map[string]*TcpListener calls map[string]struct{} conns map[linkInfo](chan struct{}) + tls tcptls } // TcpListener is a stoppable TCP listener interface. These are typically @@ -47,9 +48,15 @@ type tcp struct { // multicast interfaces. type TcpListener struct { Listener net.Listener + upgrade *TcpUpgrade stop chan struct{} } +type TcpUpgrade struct { + upgrade func(c net.Conn) (net.Conn, error) + name string +} + func (l *TcpListener) Stop() { defer func() { recover() }() close(l.stop) @@ -81,6 +88,7 @@ func (t *tcp) getAddr() *net.TCPAddr { // Initializes the struct. func (t *tcp) init(l *link) error { t.link = l + t.tls.init(t) t.mutex.Lock() t.calls = make(map[string]struct{}) t.conns = make(map[linkInfo](chan struct{})) @@ -90,12 +98,17 @@ func (t *tcp) init(l *link) error { t.link.core.config.Mutex.RLock() defer t.link.core.config.Mutex.RUnlock() for _, listenaddr := range t.link.core.config.Current.Listen { - if listenaddr[:6] != "tcp://" { + switch listenaddr[:6] { + case "tcp://": + if _, err := t.listen(listenaddr[6:], nil); err != nil { + return err + } + case "tls://": + if _, err := t.listen(listenaddr[6:], t.tls.forListener); err != nil { + return err + } + default: t.link.core.log.Errorln("Failed to add listener: listener", listenaddr, "is not correctly formatted, ignoring") - continue - } - if _, err := t.listen(listenaddr[6:]); err != nil { - return err } } @@ -119,18 +132,21 @@ func (t *tcp) reconfigure() { t.link.core.config.Mutex.RUnlock() if len(added) > 0 || len(deleted) > 0 { for _, a := range added { - if a[:6] != "tcp://" { + switch a[:6] { + case "tcp://": + if _, err := t.listen(a[6:], nil); err != nil { + t.link.core.log.Errorln("Error adding TCP", a[6:], "listener:", err) + } + case "tls://": + if _, err := t.listen(a[6:], t.tls.forListener); err != nil { + t.link.core.log.Errorln("Error adding TLS", a[6:], "listener:", err) + } + default: t.link.core.log.Errorln("Failed to add listener: listener", a, "is not correctly formatted, ignoring") - continue - } - if _, err := t.listen(a[6:]); err != nil { - t.link.core.log.Errorln("Error adding TCP", a[6:], "listener:", err) - } else { - t.link.core.log.Infoln("Started TCP listener:", a[6:]) } } for _, d := range deleted { - if d[:6] != "tcp://" { + if d[:6] != "tcp://" && d[:6] != "tls://" { t.link.core.log.Errorln("Failed to delete listener: listener", d, "is not correctly formatted, ignoring") continue } @@ -146,7 +162,7 @@ func (t *tcp) reconfigure() { } } -func (t *tcp) listen(listenaddr string) (*TcpListener, error) { +func (t *tcp) listen(listenaddr string, upgrade *TcpUpgrade) (*TcpListener, error) { var err error ctx := context.Background() @@ -157,6 +173,7 @@ func (t *tcp) listen(listenaddr string) (*TcpListener, error) { if err == nil { l := TcpListener{ Listener: listener, + upgrade: upgrade, stop: make(chan struct{}), } t.waitgroup.Add(1) @@ -204,7 +221,7 @@ func (t *tcp) listener(l *TcpListener, listenaddr string) { return } t.waitgroup.Add(1) - go t.handler(sock, true, nil) + go t.handler(sock, true, nil, l.upgrade) } } @@ -222,11 +239,15 @@ func (t *tcp) startCalling(saddr string) bool { // If the dial is successful, it launches the handler. // When finished, it removes the outgoing call, so reconnection attempts can be made later. // This all happens in a separate goroutine that it spawns. -func (t *tcp) call(saddr string, options interface{}, sintf string) { +func (t *tcp) call(saddr string, options interface{}, sintf string, upgrade *TcpUpgrade) { go func() { callname := saddr + callproto := "TCP" + if upgrade != nil { + callproto = strings.ToUpper(upgrade.name) + } if sintf != "" { - callname = fmt.Sprintf("%s/%s", saddr, sintf) + callname = fmt.Sprintf("%s/%s/%s", callproto, saddr, sintf) } if !t.startCalling(callname) { return @@ -261,7 +282,7 @@ func (t *tcp) call(saddr string, options interface{}, sintf string) { return } t.waitgroup.Add(1) - t.handler(conn, false, saddr) + t.handler(conn, false, saddr, nil) } else { dst, err := net.ResolveTCPAddr("tcp", saddr) if err != nil { @@ -322,18 +343,28 @@ func (t *tcp) call(saddr string, options interface{}, sintf string) { } conn, err = dialer.Dial("tcp", dst.String()) if err != nil { - t.link.core.log.Debugln("Failed to dial TCP:", err) + t.link.core.log.Debugf("Failed to dial %s: %s", callproto, err) return } t.waitgroup.Add(1) - t.handler(conn, false, nil) + t.handler(conn, false, nil, upgrade) } }() } -func (t *tcp) handler(sock net.Conn, incoming bool, options interface{}) { +func (t *tcp) handler(sock net.Conn, incoming bool, options interface{}, upgrade *TcpUpgrade) { defer t.waitgroup.Done() // Happens after sock.close defer sock.Close() + var upgraded bool + if upgrade != nil { + var err error + if sock, err = upgrade.upgrade(sock); err != nil { + t.link.core.log.Errorln("TCP handler upgrade failed:", err) + return + } else { + upgraded = true + } + } t.setExtraOptions(sock) stream := stream{} stream.init(sock) @@ -344,8 +375,13 @@ func (t *tcp) handler(sock net.Conn, incoming bool, options interface{}) { local, _, _ = net.SplitHostPort(sock.LocalAddr().String()) remote, _, _ = net.SplitHostPort(socksaddr) } else { - name = "tcp://" + sock.RemoteAddr().String() - proto = "tcp" + if upgraded { + proto = upgrade.name + name = proto + "://" + sock.RemoteAddr().String() + } else { + proto = "tcp" + name = proto + "://" + sock.RemoteAddr().String() + } local, _, _ = net.SplitHostPort(sock.LocalAddr().String()) remote, _, _ = net.SplitHostPort(sock.RemoteAddr().String()) } diff --git a/src/yggdrasil/tls.go b/src/yggdrasil/tls.go new file mode 100644 index 0000000..78fe3a9 --- /dev/null +++ b/src/yggdrasil/tls.go @@ -0,0 +1,92 @@ +package yggdrasil + +import ( + "bytes" + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/hex" + "encoding/pem" + "log" + "math/big" + "net" + "time" +) + +type tcptls struct { + tcp *tcp + config *tls.Config + forDialer *TcpUpgrade + forListener *TcpUpgrade +} + +func (t *tcptls) init(tcp *tcp) { + t.tcp = tcp + t.forDialer = &TcpUpgrade{ + upgrade: t.upgradeDialer, + name: "tls", + } + t.forListener = &TcpUpgrade{ + upgrade: t.upgradeListener, + name: "tls", + } + + edpriv := make(ed25519.PrivateKey, ed25519.PrivateKeySize) + copy(edpriv[:], tcp.link.core.sigPriv[:]) + + certBuf := &bytes.Buffer{} + + pubtemp := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: hex.EncodeToString(tcp.link.core.sigPub[:]), + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 365), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derbytes, err := x509.CreateCertificate(rand.Reader, &pubtemp, &pubtemp, edpriv.Public(), edpriv) + if err != nil { + log.Fatalf("Failed to create certificate: %s", err) + } + + if err := pem.Encode(certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: derbytes}); err != nil { + panic("failed to encode certificate into PEM") + } + + cpool := x509.NewCertPool() + cpool.AppendCertsFromPEM(derbytes) + + t.config = &tls.Config{ + RootCAs: cpool, + Certificates: []tls.Certificate{ + tls.Certificate{ + Certificate: [][]byte{derbytes}, + PrivateKey: edpriv, + }, + }, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + } +} + +func (t *tcptls) upgradeListener(c net.Conn) (net.Conn, error) { + conn := tls.Server(c, t.config) + if err := conn.Handshake(); err != nil { + return c, err + } + return conn, nil +} + +func (t *tcptls) upgradeDialer(c net.Conn) (net.Conn, error) { + conn := tls.Client(c, t.config) + if err := conn.Handshake(); err != nil { + return c, err + } + return conn, nil +}