From 235b64345e16d87ee173352aa44705a697911b89 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 22 Nov 2019 18:34:43 +0000 Subject: [PATCH] Configure addresses and MTUs, fix bugs --- src/tuntap/iface.go | 4 +- src/tuntap/tun_windows.go | 237 +++++++++++++++++++++++++++++++------- 2 files changed, 197 insertions(+), 44 deletions(-) diff --git a/src/tuntap/iface.go b/src/tuntap/iface.go index 9cb5e37..fbb7c86 100644 --- a/src/tuntap/iface.go +++ b/src/tuntap/iface.go @@ -40,8 +40,8 @@ func (w *tunWriter) _write(b []byte) { } }) } - if written-TUN_OFFSET_BYTES != n { - w.tun.log.Errorln("TUN iface write mismatch:", written-TUN_OFFSET_BYTES, "bytes written vs", n, "bytes given") + if written != n { + w.tun.log.Errorln("TUN iface write mismatch:", written, "bytes written vs", n, "bytes given") } } diff --git a/src/tuntap/tun_windows.go b/src/tuntap/tun_windows.go index 6541754..8b4f92c 100644 --- a/src/tuntap/tun_windows.go +++ b/src/tuntap/tun_windows.go @@ -1,70 +1,223 @@ package tuntap import ( + "bytes" "errors" - "fmt" - "os/exec" + "log" + "net" + "runtime" "strings" - "time" + "unsafe" + + "golang.org/x/sys/windows" wgtun "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) // This is to catch Windows platforms // Configures the TUN adapter with the correct IPv6 address and MTU. func (tun *TunAdapter) setup(ifname string, addr string, mtu int) error { - iface, err := wgtun.CreateTUN(ifname, mtu) - if err != nil { - panic(err) - } - tun.iface = iface - if mtu, err := iface.MTU(); err == nil { - tun.mtu = getSupportedMTU(mtu) - } else { - tun.mtu = 0 - } - return tun.setupAddress(addr) + var err error + err = doAsSystem(func() { + iface, err := wgtun.CreateTUN(ifname, mtu) + if err != nil { + panic(err) + } + tun.iface = iface + + if err := tun.setupAddress(addr); err != nil { + tun.log.Errorln("Failed to set up TUN address:", err) + } + if err := tun.setupMTU(getSupportedMTU(mtu)); err != nil { + tun.log.Errorln("Failed to set up TUN MTU:", err) + } + + if mtu, err = iface.MTU(); err == nil { + tun.mtu = mtu + } + }) + return err } // Sets the MTU of the TAP adapter. func (tun *TunAdapter) setupMTU(mtu int) error { - if tun.iface == nil || tun.iface.Name() == "" { - return errors.New("Can't configure MTU as TAP adapter is not present") + if tun.iface == nil || tun.Name() == "" { + return errors.New("Can't configure MTU as TUN adapter is not present") } - // Set MTU - cmd := exec.Command("netsh", "interface", "ipv6", "set", "subinterface", - fmt.Sprintf("interface=%s", tun.iface.Name()), - fmt.Sprintf("mtu=%d", mtu), - "store=active") - tun.log.Debugln("netsh command:", strings.Join(cmd.Args, " ")) - output, err := cmd.CombinedOutput() - if err != nil { - tun.log.Errorln("Windows netsh failed:", err) - tun.log.Traceln(string(output)) - return err + if intf, ok := tun.iface.(*wgtun.NativeTun); ok { + luid := winipcfg.LUID(intf.LUID()) + ipfamily, err := luid.IPInterface(windows.AF_INET6) + if err != nil { + return err + } + + ipfamily.NLMTU = uint32(mtu) + intf.ForceMTU(int(ipfamily.NLMTU)) + ipfamily.UseAutomaticMetric = false + ipfamily.Metric = 0 + ipfamily.DadTransmits = 0 + ipfamily.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + + if err := ipfamily.Set(); err != nil { + return err + } } - time.Sleep(time.Second) // FIXME artifical delay to give netsh time to take effect + return nil } // Sets the IPv6 address of the TAP adapter. func (tun *TunAdapter) setupAddress(addr string) error { - if tun.iface == nil || tun.iface.Name() == "" { - return errors.New("Can't configure IPv6 address as TAP adapter is not present") + if tun.iface == nil || tun.Name() == "" { + return errors.New("Can't configure IPv6 address as TUN adapter is not present") } - // Set address - cmd := exec.Command("netsh", "interface", "ipv6", "add", "address", - fmt.Sprintf("interface=%s", tun.iface.Name()), - fmt.Sprintf("addr=%s", addr), - "store=active") - tun.log.Debugln("netsh command:", strings.Join(cmd.Args, " ")) - output, err := cmd.CombinedOutput() - if err != nil { - tun.log.Errorln("Windows netsh failed:", err) - tun.log.Traceln(string(output)) - return err + if intf, ok := tun.iface.(*wgtun.NativeTun); ok { + if ipaddr, ipnet, err := net.ParseCIDR(addr); err == nil { + luid := winipcfg.LUID(intf.LUID()) + addresses := append([]net.IPNet{}, net.IPNet{ + IP: ipaddr, + Mask: ipnet.Mask, + }) + + err := luid.SetIPAddressesForFamily(windows.AF_INET6, addresses) + if err == windows.ERROR_OBJECT_ALREADY_EXISTS { + cleanupAddressesOnDisconnectedInterfaces(windows.AF_INET6, addresses) + err = luid.SetIPAddressesForFamily(windows.AF_INET6, addresses) + } + if err != nil { + return err + } + } else { + return err + } + } else { + return errors.New("unable to get NativeTUN") } - time.Sleep(time.Second) // FIXME artifical delay to give netsh time to take effect return nil } + +/* + * doAsSystem + * SPDX-License-Identifier: LGPL-3.0 + * Copyright (C) 2017-2019 Jason A. Donenfeld . All Rights Reserved. + */ +func doAsSystem(f func()) error { + runtime.LockOSThread() + defer func() { + windows.RevertToSelf() + runtime.UnlockOSThread() + }() + privileges := windows.Tokenprivileges{ + PrivilegeCount: 1, + Privileges: [1]windows.LUIDAndAttributes{ + { + Attributes: windows.SE_PRIVILEGE_ENABLED, + }, + }, + } + err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr("SeDebugPrivilege"), &privileges.Privileges[0].Luid) + if err != nil { + return err + } + err = windows.ImpersonateSelf(windows.SecurityImpersonation) + if err != nil { + return err + } + var threadToken windows.Token + err = windows.OpenThreadToken(windows.CurrentThread(), windows.TOKEN_ADJUST_PRIVILEGES, false, &threadToken) + if err != nil { + return err + } + defer threadToken.Close() + err = windows.AdjustTokenPrivileges(threadToken, false, &privileges, uint32(unsafe.Sizeof(privileges)), nil, nil) + if err != nil { + return err + } + + processes, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if err != nil { + return err + } + defer windows.CloseHandle(processes) + + processEntry := windows.ProcessEntry32{Size: uint32(unsafe.Sizeof(windows.ProcessEntry32{}))} + pid := uint32(0) + for err = windows.Process32First(processes, &processEntry); err == nil; err = windows.Process32Next(processes, &processEntry) { + if strings.ToLower(windows.UTF16ToString(processEntry.ExeFile[:])) == "winlogon.exe" { + pid = processEntry.ProcessID + break + } + } + if pid == 0 { + return errors.New("unable to find winlogon.exe process") + } + + winlogonProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, pid) + if err != nil { + return err + } + defer windows.CloseHandle(winlogonProcess) + var winlogonToken windows.Token + err = windows.OpenProcessToken(winlogonProcess, windows.TOKEN_IMPERSONATE|windows.TOKEN_DUPLICATE, &winlogonToken) + if err != nil { + return err + } + defer winlogonToken.Close() + var duplicatedToken windows.Token + err = windows.DuplicateTokenEx(winlogonToken, 0, nil, windows.SecurityImpersonation, windows.TokenImpersonation, &duplicatedToken) + if err != nil { + return err + } + defer duplicatedToken.Close() + err = windows.SetThreadToken(nil, duplicatedToken) + if err != nil { + return err + } + f() + return nil +} + +/* + * cleanupAddressesOnDisconnectedInterfaces + * SPDX-License-Identifier: MIT + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ +func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) { + if len(addresses) == 0 { + return + } + includedInAddresses := func(a net.IPNet) bool { + // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer! + for _, addr := range addresses { + ip := addr.IP + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + mA, _ := addr.Mask.Size() + mB, _ := a.Mask.Size() + if bytes.Equal(ip, a.IP) && mA == mB { + return true + } + } + return false + } + interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault) + if err != nil { + return + } + for _, iface := range interfaces { + if iface.OperStatus == winipcfg.IfOperStatusUp { + continue + } + for address := iface.FirstUnicastAddress; address != nil; address = address.Next { + ip := address.Address.IP() + ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))} + if includedInAddresses(ipnet) { + log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", ipnet.String(), iface.FriendlyName()) + iface.LUID.DeleteIPAddress(ipnet) + } + } + } +}