Medium severity4.9OSV Advisory· Published Oct 23, 2025· Updated Apr 15, 2026
CVE-2025-62820
CVE-2025-62820
Description
Slack Nebula before 1.9.7 mishandles CIDR in some configurations and thus accepts arbitrary source IP addresses within the Nebula network.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
github.com/slackhq/nebulaGo | >= 1.9.4, < 1.9.7 | 1.9.7 |
Affected products
1Patches
1e264a0ff888cSwitch most everything to netip in prep for ipv6 in the overlay (#1173)
79 files changed · +1896 −2678
allow_list.go+26 −65 modified@@ -2,25 +2,24 @@ package nebula import ( "fmt" - "net" + "net/netip" "regexp" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny - cidrTree *cidr.Tree6[bool] + cidrTree *bart.Table[bool] } type RemoteAllowList struct { AllowList *AllowList // Inside Range Specific, keys of this tree are inside CIDRs and values // are *AllowList - insideAllowLists *cidr.Tree6[*AllowList] + insideAllowLists *bart.Table[*AllowList] } type LocalAllowList struct { @@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } - tree := cidr.NewTree6[bool]() + tree := new(bart.Table[bool]) // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { @@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) } - _, ipNet, err := net.ParseCIDR(rawCIDR) + ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } + ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()) + // TODO: should we error on duplicate CIDRs in the config? - tree.AddCIDR(ipNet, value) + tree.Insert(ipNet, value) - maskBits, maskSize := ipNet.Mask.Size() + maskBits := ipNet.Bits() var rules *allowListRules - if maskSize == 32 { + if ipNet.Addr().Is4() { rules = &rules4 } else { rules = &rules6 @@ -156,17 +157,15 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in if !rules4.defaultSet { if rules4.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0") - tree.AddCIDR(zeroCIDR, !rules4.allValues) + tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k) } } if !rules6.defaultSet { if rules6.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("::/0") - tree.AddCIDR(zeroCIDR, !rules6.allValues) + tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k) } @@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error return nameRules, nil } -func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) { +func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) { value := c.Get(k) if value == nil { return nil, nil } - remoteAllowRanges := cidr.NewTree6[*AllowList]() + remoteAllowRanges := new(bart.Table[*AllowList]) rawMap, ok := value.(map[interface{}]interface{}) if !ok { @@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error return nil, err } - _, ipNet, err := net.ParseCIDR(rawCIDR) + ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } - remoteAllowRanges.AddCIDR(ipNet, allowList) + remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList) } return remoteAllowRanges, nil } -func (al *AllowList) Allow(ip net.IP) bool { - if al == nil { - return true - } - - _, result := al.cidrTree.MostSpecificContains(ip) - return result -} - -func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { - if al == nil { - return true - } - - _, result := al.cidrTree.MostSpecificContainsIpV4(ip) - return result -} - -func (al *AllowList) AllowIpV6(hi, lo uint64) bool { +func (al *AllowList) Allow(ip netip.Addr) bool { if al == nil { return true } - _, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) + result, _ := al.cidrTree.Lookup(ip) return result } -func (al *LocalAllowList) Allow(ip net.IP) bool { +func (al *LocalAllowList) Allow(ip netip.Addr) bool { if al == nil { return true } @@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool { return !al.nameRules[0].Allow } -func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool { +func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool { if al == nil { return true } return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool { +func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool { if !al.getInsideAllowList(vpnIp).Allow(ip) { return false } return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool { - if al == nil { - return true - } - if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) { - return false - } - return al.AllowList.AllowIpV4(ip) -} - -func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool { - if al == nil { - return true - } - if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) { - return false - } - return al.AllowList.AllowIpV6(hi, lo) -} - -func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList { +func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList { if al.insideAllowLists != nil { - ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) + inside, ok := al.insideAllowLists.Lookup(vpnIp) if ok { return inside }
allow_list_test.go+21 −21 modified@@ -1,11 +1,11 @@ package nebula import ( - "net" + "net/netip" "regexp" "testing" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" @@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) { "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") + assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) { } func TestAllowList_Allow(t *testing.T) { - assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) - - tree := cidr.NewTree6[bool]() - tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true) - tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false) - tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true) - tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true) - tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true) - tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false) - tree.AddCIDR(cidr.Parse("::1/128"), true) - tree.AddCIDR(cidr.Parse("::2/128"), false) + assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) + + tree := new(bart.Table[bool]) + tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true) + tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false) + tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true) + tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true) + tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true) + tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false) + tree.Insert(netip.MustParsePrefix("::1/128"), true) + tree.Insert(netip.MustParsePrefix("::2/128"), false) al := &AllowList{cidrTree: tree} - assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1"))) - assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4"))) - assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42"))) - assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41"))) - assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1"))) - assert.Equal(t, true, al.Allow(net.ParseIP("::1"))) - assert.Equal(t, false, al.Allow(net.ParseIP("::2"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2"))) } func TestLocalAllowList_AllowName(t *testing.T) {
calculated_remote.go+41 −25 modified@@ -1,63 +1,78 @@ package nebula import ( + "encoding/binary" "fmt" "math" "net" + "net/netip" "strconv" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) // This allows us to "guess" what the remote might be for a host while we wait // for the lighthouse response. See "lighthouse.calculated_remotes" in the // example config file. type calculatedRemote struct { - ipNet net.IPNet - maskIP iputil.VpnIp - mask iputil.VpnIp - port uint32 + ipNet netip.Prefix + mask netip.Prefix + port uint32 } -func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) { - // Ensure this is an IPv4 mask that we expect - ones, bits := ipNet.Mask.Size() - if ones == 0 || bits != 32 { - return nil, fmt.Errorf("invalid mask: %v", ipNet) - } +func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) { + masked := maskCidr.Masked() if port < 0 || port > math.MaxUint16 { return nil, fmt.Errorf("invalid port: %d", port) } return &calculatedRemote{ - ipNet: *ipNet, - maskIP: iputil.Ip2VpnIp(ipNet.IP), - mask: iputil.Ip2VpnIp(ipNet.Mask), - port: uint32(port), + ipNet: maskCidr, + mask: masked, + port: uint32(port), }, nil } func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } -func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { +func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { // Combine the masked bytes of the "mask" IP with the unmasked bytes // of the overlay IP - masked := (c.maskIP & c.mask) | (ip & ^c.mask) + if c.ipNet.Addr().Is4() { + return c.apply4(ip) + } + return c.apply6(ip) +} + +func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort { + //TODO: IPV6-WORK this can be less crappy + maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) + mask := binary.BigEndian.Uint32(maskb[:]) + + b := c.mask.Addr().As4() + maskIp := binary.BigEndian.Uint32(b[:]) + + b = ip.As4() + intIp := binary.BigEndian.Uint32(b[:]) + + return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} +} - return &Ip4AndPort{Ip: uint32(masked), Port: c.port} +func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { + //TODO: IPV6-WORK + panic("Can not calculate ipv6 remote addresses") } -func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) { +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { value := c.Get(k) if value == nil { return nil, nil } - calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]() + calculatedRemotes := new(bart.Table[[]*calculatedRemote]) rawMap, ok := value.(map[any]any) if !ok { @@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) } - _, ipNet, err := net.ParseCIDR(rawCIDR) + cidr, err := netip.ParsePrefix(rawCIDR) if err != nil { return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) } + //TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here entry, err := newCalculatedRemotesListFromConfig(rawValue) if err != nil { return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) } - calculatedRemotes.AddCIDR(ipNet, entry) + calculatedRemotes.Insert(cidr, entry) } return calculatedRemotes, nil @@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { if !ok { return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) } - _, ipNet, err := net.ParseCIDR(rawMask) + maskCidr, err := netip.ParsePrefix(rawMask) if err != nil { return nil, fmt.Errorf("invalid mask: %s", rawMask) } @@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) } - return newCalculatedRemote(ipNet, port) + return newCalculatedRemote(maskCidr, port) }
calculated_remote_test.go+7 −9 modified@@ -1,27 +1,25 @@ package nebula import ( - "net" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCalculatedRemoteApply(t *testing.T) { - _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + ipNet, err := netip.ParsePrefix("192.168.1.0/24") require.NoError(t, err) c, err := newCalculatedRemote(ipNet, 4242) require.NoError(t, err) - input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182}) + input, err := netip.ParseAddr("10.0.10.182") + assert.NoError(t, err) - expected := &Ip4AndPort{ - Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})), - Port: 4242, - } + expected, err := netip.ParseAddr("192.168.1.182") + assert.NoError(t, err) - assert.Equal(t, expected, c.Apply(input)) + assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) }
cidr/parse.go+0 −10 removed@@ -1,10 +0,0 @@ -package cidr - -import "net" - -// Parse is a convenience function that returns only the IPNet -// This function ignores errors since it is primarily a test helper, the result could be nil -func Parse(s string) *net.IPNet { - _, c, _ := net.ParseCIDR(s) - return c -}
cidr/tree4.go+0 −203 removed@@ -1,203 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -type Node[T any] struct { - left *Node[T] - right *Node[T] - parent *Node[T] - hasValue bool - value T -} - -type entry[T any] struct { - CIDR *net.IPNet - Value T -} - -type Tree4[T any] struct { - root *Node[T] - list []entry[T] -} - -const ( - startbit = iputil.VpnIp(0x80000000) -) - -func NewTree4[T any]() *Tree4[T] { - tree := new(Tree4[T]) - tree.root = &Node[T]{} - tree.list = []entry[T]{} - return tree -} - -func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) { - bit := startbit - node := tree.root - next := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // We already have this range so update the value - if next != nil { - addCIDR := cidr.String() - for i, v := range tree.list { - if addCIDR == v.CIDR.String() { - tree.list = append(tree.list[:i], tree.list[i+1:]...) - break - } - } - - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) - node.value = val - node.hasValue = true - return - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) -} - -// Contains finds the first match, which may be the least specific -func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - return true, node.value - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - - } - - return false, value -} - -// MostSpecificContains finds the most specific match -func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -type eachFunc[T any] func(T) bool - -// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete -// The final return value will be true if the provided function returned true -func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - // If the each func returns true then we can exit the loop - if each(node.value) { - return true - } - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return false -} - -// GetCIDR returns the entry added by the most recent matching AddCIDR call -func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) { - bit := startbit - node := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for node != nil && bit&mask != 0 { - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit = bit >> 1 - } - - if bit&mask == 0 && node != nil { - value = node.value - ok = node.hasValue - } - - return ok, value -} - -// List will return all CIDRs and their current values. Do not modify the contents! -func (tree *Tree4[T]) List() []entry[T] { - return tree.list -}
cidr/tree4_test.go+0 −170 removed@@ -1,170 +0,0 @@ -package cidr - -import ( - "net" - "testing" - - "github.com/slackhq/nebula/iputil" - "github.com/stretchr/testify/assert" -) - -func TestCIDRTree_List(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/16"), "1") - tree.AddCIDR(Parse("1.0.0.0/8"), "2") - tree.AddCIDR(Parse("1.0.0.0/16"), "3") - tree.AddCIDR(Parse("1.0.0.0/16"), "4") - list := tree.List() - assert.Len(t, list, 2) - assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) - assert.Equal(t, "2", list[0].Value) - assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) - assert.Equal(t, "4", list[1].Value) -} - -func TestCIDRTree_Contains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4a", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestCIDRTree_MostSpecificContains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.0/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestTree4_GetCIDR(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IPNet *net.IPNet - }{ - {true, "1", Parse("1.0.0.0/8")}, - {true, "2", Parse("2.1.0.0/16")}, - {true, "3", Parse("3.1.1.0/24")}, - {true, "4a", Parse("4.1.1.0/24")}, - {true, "4b", Parse("4.1.1.1/32")}, - {true, "4c", Parse("4.1.2.1/32")}, - {true, "5", Parse("254.0.0.0/4")}, - {false, "", Parse("2.0.0.0/8")}, - } - - for _, tt := range tests { - ok, r := tree.GetCIDR(tt.IPNet) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -} - -func BenchmarkCIDRTree_Contains(b *testing.B) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.1.0.0/16"), "1") - tree.AddCIDR(Parse("1.2.1.1/32"), "1") - tree.AddCIDR(Parse("192.2.1.1/32"), "1") - tree.AddCIDR(Parse("172.2.1.1/32"), "1") - - ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1")) - b.Run("found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) - - ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255")) - b.Run("not found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) -}
cidr/tree6.go+0 −189 removed@@ -1,189 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -const startbit6 = uint64(1 << 63) - -type Tree6[T any] struct { - root4 *Node[T] - root6 *Node[T] -} - -func NewTree6[T any]() *Tree6[T] { - tree := new(Tree6[T]) - tree.root4 = &Node[T]{} - tree.root6 = &Node[T]{} - return tree -} - -func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) { - var node, next *Node[T] - - cidrIP, ipv4 := isIPV4(cidr.IP) - if ipv4 { - node = tree.root4 - next = tree.root4 - - } else { - node = tree.root6 - next = tree.root6 - } - - for i := 0; i < len(cidrIP); i += 4 { - ip := iputil.Ip2VpnIp(cidrIP[i : i+4]) - mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4]) - bit := startbit - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true -} - -// Finds the most specific match -func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) { - var node *Node[T] - - wholeIP, ipv4 := isIPV4(ip) - if ipv4 { - node = tree.root4 - } else { - node = tree.root6 - } - - for i := 0; i < len(wholeIP); i += 4 { - ip := iputil.Ip2VpnIp(wholeIP[i : i+4]) - bit := startbit - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root4 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) { - ip := hi - node := tree.root6 - - for i := 0; i < 2; i++ { - bit := startbit6 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - ip = lo - } - - return ok, value -} - -func isIPV4(ip net.IP) (net.IP, bool) { - if len(ip) == net.IPv4len { - return ip, true - } - - if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { - return ip[12:16], true - } - - return ip, false -} - -func isZeros(p net.IP) bool { - for i := 0; i < len(p); i++ { - if p[i] != 0 { - return false - } - } - return true -}
cidr/tree6_test.go+0 −98 removed@@ -1,98 +0,0 @@ -package cidr - -import ( - "encoding/binary" - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCIDR6Tree_MostSpecificContains(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.1/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP)) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree6[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - tree.AddCIDR(Parse("::/0"), "cool6") - ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("::")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) -} - -func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - } - - for _, tt := range tests { - ip := net.ParseIP(tt.IP) - hi := binary.BigEndian.Uint64(ip[:8]) - lo := binary.BigEndian.Uint64(ip[8:]) - - ok, r := tree.MostSpecificContainsIpV6(hi, lo) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -}
connection_manager.go+17 −13 modified@@ -3,15 +3,15 @@ package nebula import ( "bytes" "context" + "encoding/binary" + "net/netip" "sync" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) type trafficDecision int @@ -224,8 +224,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) var index uint32 - var relayFrom iputil.VpnIp - var relayTo iputil.VpnIp + var relayFrom netip.Addr + var relayTo netip.Addr switch { case ok && existing.State == Established: // This relay already exists in newhostinfo, then do nothing. @@ -235,7 +235,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp + relayFrom = n.intf.myVpnNet.Addr() relayTo = existing.PeerIp case ForwardingType: relayFrom = existing.PeerIp @@ -260,7 +260,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp + relayFrom = n.intf.myVpnNet.Addr() relayTo = r.PeerIp case ForwardingType: relayFrom = r.PeerIp @@ -270,21 +270,25 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } + //TODO: IPV6-WORK + relayFromB := relayFrom.As4() + relayToB := relayTo.As4() + // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: uint32(relayFrom), - RelayToIp: uint32(relayTo), + RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]), + RelayToIp: binary.BigEndian.Uint32(relayToB[:]), } msg, err := req.Marshal() if err != nil { n.l.WithError(err).Error("failed to marshal Control message to migrate relay") } else { n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), + "relayFrom": req.RelayFromIp, + "relayTo": req.RelayToIp, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnIp": newhostinfo.vpnIp}). @@ -403,7 +407,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - if current.vpnIp < n.intf.myVpnIp { + if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. // The remotes vpn ip is lower than mine. I will not flip. @@ -457,12 +461,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } if n.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) { + hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, addr) }) - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) }
connection_manager_test.go+17 −17 modified@@ -5,28 +5,26 @@ import ( "crypto/ed25519" "crypto/rand" "net" + "net/netip" "testing" "time" "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) -var vpnIp iputil.VpnIp - func newTestLighthouse() *LightHouse { lh := &LightHouse{ l: test.NewLogger(), - addrMap: map[iputil.VpnIp]*RemoteList{}, - queryChan: make(chan iputil.VpnIp, 10), + addrMap: map[netip.Addr]*RemoteList{}, + queryChan: make(chan netip.Addr, 10), } - lighthouses := map[iputil.VpnIp]struct{}{} - staticList := map[iputil.VpnIp]struct{}{} + lighthouses := map[netip.Addr]struct{}{} + staticList := map[netip.Addr]struct{}{} lh.lighthouses.Store(&lighthouses) lh.staticList.Store(&staticList) @@ -37,10 +35,10 @@ func newTestLighthouse() *LightHouse { func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects hostMap := newHostMap(l, vpncidr) @@ -120,9 +118,10 @@ func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects hostMap := newHostMap(l, vpncidr) @@ -211,9 +210,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { IP: net.IPv4(172, 1, 1, 2), Mask: net.IPMask{255, 255, 255, 0}, } - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} hostMap := newHostMap(l, vpncidr) hostMap.preferredRanges.Store(&preferredRanges)
control.go+19 −21 modified@@ -2,17 +2,15 @@ package nebula import ( "context" - "net" + "net/netip" "os" "os/signal" "syscall" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" - "github.com/slackhq/nebula/udp" ) // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching @@ -21,10 +19,10 @@ import ( type controlEach func(h *HostInfo) type controlHostLister interface { - QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo + QueryVpnIp(vpnIp netip.Addr) *HostInfo ForEachIndex(each controlEach) ForEachVpnIp(each controlEach) - GetPreferredRanges() []*net.IPNet + GetPreferredRanges() []netip.Prefix } type Control struct { @@ -39,15 +37,15 @@ type Control struct { } type ControlHostInfo struct { - VpnIp net.IP `json:"vpnIp"` + VpnIp netip.Addr `json:"vpnIp"` LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` - RemoteAddrs []*udp.Addr `json:"remoteAddrs"` + RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` Cert *cert.NebulaCertificate `json:"cert"` MessageCounter uint64 `json:"messageCounter"` - CurrentRemote *udp.Addr `json:"currentRemote"` - CurrentRelaysToMe []iputil.VpnIp `json:"currentRelaysToMe"` - CurrentRelaysThroughMe []iputil.VpnIp `json:"currentRelaysThroughMe"` + CurrentRemote netip.AddrPort `json:"currentRemote"` + CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` + CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() @@ -132,7 +130,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found -func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager @@ -150,19 +149,21 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH } // SetRemoteForTunnel forces a tunnel to use a specific remote -func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) if hostInfo == nil { return nil } - hostInfo.SetRemote(addr.Copy()) + hostInfo.SetRemote(addr) ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges()) return &ch } // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. -func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) if hostInfo == nil { return false @@ -205,7 +206,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { } // Learn which hosts are being used as relays, so we can shut them down last. - relayingHosts := map[iputil.VpnIp]*HostInfo{} + relayingHosts := map[netip.Addr]*HostInfo{} // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { @@ -236,15 +237,16 @@ func (c *Control) Device() overlay.Device { return c.f.inside } -func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { +func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { chi := ControlHostInfo{ - VpnIp: h.vpnIp.ToIP(), + VpnIp: h.vpnIp, LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), + CurrentRemote: h.remote, } if h.ConnectionState != nil { @@ -255,10 +257,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { chi.Cert = c.Copy() } - if h.remote != nil { - chi.CurrentRemote = h.remote.Copy() - } - return chi }
control_tester.go+20 −27 modified@@ -4,14 +4,13 @@ package nebula import ( - "net" + "net/netip" "github.com/slackhq/nebula/cert" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp // This is necessary if you did not configure static hosts or are not running a lighthouse -func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { +func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - if v4 := toAddr.IP.To4(); v4 != nil { - remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port))) + if toAddr.Addr().Is4() { + remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) } else { - remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))) + remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) } } // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp // This is necessary to inform an initiator of possible relays for communicating with a responder -func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) { +func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - uVpnIp := []uint32{} - for _, rVPnIp := range relayVpnIps { - uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp))) - } - - remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp) + remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps) } // GetFromTun will pull a packet off the tun side of nebula @@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) { } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) { +func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { + //TODO: IPV6-WORK ip := layers.IPv4{ Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, - SrcIP: c.f.inside.Cidr().IP, - DstIP: toIp, + SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), + DstIP: toIp.Unmap().AsSlice(), } udp := layers.UDP{ @@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } -func (c *Control) GetVpnIp() iputil.VpnIp { - return c.f.myVpnIp +func (c *Control) GetVpnIp() netip.Addr { + return c.f.myVpnNet.Addr() } -func (c *Control) GetUDPAddr() string { - return c.f.outside.(*udp.TesterConn).Addr.String() +func (c *Control) GetUDPAddr() netip.AddrPort { + return c.f.outside.(*udp.TesterConn).Addr } -func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { - hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp)) +func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { + hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) if hostinfo == nil { return false } @@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate { return c.f.pki.GetCertState().Certificate } -func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { +func (c *Control) ReHandshake(vpnIp netip.Addr) { c.f.handshakeManager.StartHandshake(vpnIp, nil) }
control_test.go+33 −24 modified@@ -2,34 +2,34 @@ package nebula import ( "net" + "net/netip" "reflect" "testing" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := newHostMap(l, &net.IPNet{}) - hm.preferredRanges.Store(&[]*net.IPNet{}) + hm := newHostMap(l, netip.Prefix{}) + hm.preferredRanges.Store(&[]netip.Prefix{}) + + remote1 := netip.MustParseAddrPort("0.0.0.100:4444") + remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444") - remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) - remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), + IP: remote1.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } ipNet2 := net.IPNet{ - IP: net.ParseIP("1:2:3:4:5:6:7:8"), + IP: remote2.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } @@ -50,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } remotes := NewRemoteList(nil) - remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) - remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) + remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) + remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) + + vpnIp, ok := netip.AddrFromSlice(ipNet.IP) + assert.True(t, ok) + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, @@ -60,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: vpnIp, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) + vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP) + assert.True(t, ok) + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, @@ -76,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet2.IP), + vpnIp: vpnIp2, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -91,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false) + thi := c.GetHostInfoByVpnIp(vpnIp, false) expectedInfo := ControlHostInfo{ - VpnIp: net.IPv4(1, 2, 3, 4).To4(), + VpnIp: vpnIp, LocalIndex: 201, RemoteIndex: 200, - RemoteAddrs: []*udp.Addr{remote2, remote1}, + RemoteAddrs: []netip.AddrPort{remote2, remote1}, Cert: crt.Copy(), MessageCounter: 0, - CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), - CurrentRelaysToMe: []iputil.VpnIp{}, - CurrentRelaysThroughMe: []iputil.VpnIp{}, + CurrentRemote: remote1, + CurrentRelaysToMe: []netip.Addr{}, + CurrentRelaysThroughMe: []netip.Addr{}, } // Make sure we don't have any unexpected fields assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) - test.AssertDeepCopyEqual(t, &expectedInfo, thi) + assert.EqualValues(t, &expectedInfo, thi) + //TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here + //test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false) + thi = c.GetHostInfoByVpnIp(vpnIp2, false) }) }
dns_server.go+12 −6 modified@@ -3,14 +3,14 @@ package nebula import ( "fmt" "net" + "net/netip" "strconv" "strings" "sync" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) // This whole thing should be rewritten to use context @@ -42,19 +42,21 @@ func (d *dnsRecords) Query(data string) string { } func (d *dnsRecords) QueryCert(data string) string { - ip := net.ParseIP(data[:len(data)-1]) - if ip == nil { + ip, err := netip.ParseAddr(data[:len(data)-1]) + if err != nil { return "" } - iip := iputil.Ip2VpnIp(ip) - hostinfo := d.hostMap.QueryVpnIp(iip) + + hostinfo := d.hostMap.QueryVpnIp(ip) if hostinfo == nil { return "" } + q := hostinfo.GetCert() if q == nil { return "" } + cert := q.Details c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) return c @@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } case dns.TypeTXT: a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - b := net.ParseIP(a) + b, err := netip.ParseAddr(a) + if err != nil { + return + } + // We don't answer these queries from non nebula nodes or localhost //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
e2e/handshakes_test.go+169 −169 modified@@ -5,27 +5,26 @@ package e2e import ( "fmt" - "net" + "net/netip" "testing" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Start the servers myControl.Start() @@ -35,7 +34,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -44,19 +43,19 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -77,16 +76,16 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() @@ -95,20 +94,20 @@ func TestGoodHandshake(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -120,15 +119,15 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { h := &header.H{} err := h.Parse(p.Data) if err != nil { panic(err) } - if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 { + if p.To == theirUdpAddr && h.Type == 1 { return router.RouteAndExit } @@ -139,18 +138,18 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone @@ -164,13 +163,13 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -181,8 +180,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -194,14 +193,14 @@ func TestStage1Race(t *testing.T) { r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) @@ -219,7 +218,7 @@ func TestStage1Race(t *testing.T) { r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } @@ -241,13 +240,13 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -258,28 +257,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() - myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } @@ -290,13 +289,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -307,30 +306,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") theirHostmap := theirControl.GetHostmap() - theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } @@ -341,15 +340,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -361,31 +360,31 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it } func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -397,14 +396,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -421,21 +420,21 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -448,16 +447,16 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -470,7 +469,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") @@ -490,15 +489,15 @@ func TestStage1RaceRelays2(t *testing.T) { "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- } r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() @@ -507,16 +506,17 @@ func TestStage1RaceRelays2(t *testing.T) { // ////TODO: assert hostmaps } + func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -528,11 +528,11 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, @@ -556,8 +556,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -569,8 +569,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -581,29 +581,29 @@ func TestRehandshakingRelays(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("myControl hostinfos got cleaned up!") for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("theirControl hostinfos got cleaned up!") for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -612,15 +612,15 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -632,11 +632,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, @@ -660,8 +660,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -673,8 +673,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -685,43 +685,43 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("myControl hostinfos got cleaned up!") for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("theirControl hostinfos got cleaned up!") for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("relayControl hostinfos got cleaned up!") } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -732,7 +732,7 @@ func TestRehandshaking(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) @@ -754,8 +754,8 @@ func TestRehandshaking(t *testing.T) { myConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now break @@ -781,19 +781,19 @@ func TestRehandshaking(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) assert.Contains(t, c.Cert.Details.Groups, "new group") // We should only have a single tunnel now on both sides @@ -811,13 +811,13 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -828,10 +828,10 @@ func TestRehandshakingLoser(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) - tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) fmt.Println(tt1.LocalIndex, tt2.LocalIndex) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) @@ -854,8 +854,8 @@ func TestRehandshakingLoser(t *testing.T) { theirConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] if theirNewGroup { @@ -882,19 +882,19 @@ func TestRehandshakingLoser(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") // We should only have a single tunnel now on both sides @@ -912,13 +912,13 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -932,8 +932,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -963,7 +963,7 @@ func TestRaceRegression(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myControl.Stop() theirControl.Stop()
e2e/helpers.go+15 −8 modified@@ -4,6 +4,7 @@ import ( "crypto/rand" "io" "net" + "net/netip" "time" "github.com/slackhq/nebula/cert" @@ -12,7 +13,7 @@ import ( ) // NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) @@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] } if len(ips) > 0 { - nc.Details.Ips = ips + nc.Details.Ips = make([]*net.IPNet, len(ips)) + for i, ip := range ips { + nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} + } } if len(subnets) > 0 { - nc.Details.Subnets = subnets + nc.Details.Subnets = make([]*net.IPNet, len(subnets)) + for i, ip := range subnets { + nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} + } } if len(groups) > 0 { @@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { issuer, err := ca.Sha256Sum() if err != nil { panic(err) @@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af } pub, rawPriv := x25519Keypair() - + ipb := ip.Addr().AsSlice() nc := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{ip}, - Subnets: subnets, + Name: name, + Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}}, + //Subnets: subnets, Groups: groups, NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0),
e2e/helpers_test.go+28 −24 modified@@ -6,7 +6,7 @@ package e2e import ( "fmt" "io" - "net" + "net/netip" "os" "testing" "time" @@ -19,23 +19,30 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) { +func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - vpnIpNet.IP[1] += 128 - udpAddr := net.UDPAddr{ - IP: udpIp, - Port: 4242, + vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) + if err != nil { + panic(err) + } + + var udpAddr netip.AddrPort + if vpnIpNet.Addr().Is4() { + budpIp := vpnIpNet.Addr().As4() + budpIp[1] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) + } else { + budpIp := vpnIpNet.Addr().As16() + budpIp[13] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) @@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u // "try_interval": "1s", //}, "listen": m{ - "host": udpAddr.IP.String(), - "port": udpAddr.Port, + "host": udpAddr.Addr().String(), + "port": udpAddr.Port(), }, "logging": m{ "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), @@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u panic(err) } - return control, vpnIpNet, &udpAddr, c + return control, vpnIpNet, udpAddr, c } type doneCb func() @@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { } } -func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { +func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) @@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } -func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) { +func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false) + hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") - hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false) + hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") // Check that both vpn and real addr are correct assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") - assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A") - assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B") - - assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A") - assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B") + assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") + assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") // Check that our indexes match assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") @@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB //checkIndexes("hmB", hmB, hAinB) } -func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) { +func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") - assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect") - assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect") + assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect") udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) assert.NotNil(t, udp, "No udp data found")
e2e/router/hostmap.go+4 −4 modified@@ -5,11 +5,11 @@ package router import ( "fmt" + "net/netip" "sort" "strings" "github.com/slackhq/nebula" - "github.com/slackhq/nebula/iputil" ) type edge struct { @@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { return r, globalLines } -func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp { - keys := make([]iputil.VpnIp, 0, len(hosts)) +func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr { + keys := make([]netip.Addr, 0, len(hosts)) for key := range hosts { keys = append(keys, key) } sort.SliceStable(keys, func(i, j int) bool { - return keys[i] > keys[j] + return keys[i].Compare(keys[j]) > 0 }) return keys
e2e/router/router.go+36 −63 modified@@ -6,12 +6,11 @@ package router import ( "context" "fmt" - "net" + "net/netip" "os" "path/filepath" "reflect" "sort" - "strconv" "strings" "sync" "testing" @@ -21,26 +20,25 @@ import ( "github.com/google/gopacket/layers" "github.com/slackhq/nebula" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "golang.org/x/exp/maps" ) type R struct { // Simple map of the ip:port registered on a control to the control // Basically a router, right? - controls map[string]*nebula.Control + controls map[netip.AddrPort]*nebula.Control // A map for inbound packets for a control that doesn't know about this address - inNat map[string]*nebula.Control + inNat map[netip.AddrPort]*nebula.Control // A last used map, if an inbound packet hit the inNat map then // all return packets should use the same last used inbound address for the outbound sender // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver - outNat map[string]net.UDPAddr + outNat map[string]netip.AddrPort // A map of vpn ip to the nebula control it belongs to - vpnControls map[iputil.VpnIp]*nebula.Control + vpnControls map[netip.Addr]*nebula.Control ignoreFlows []ignoreFlow flow []flowEntry @@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { } r := &R{ - controls: make(map[string]*nebula.Control), - vpnControls: make(map[iputil.VpnIp]*nebula.Control), - inNat: make(map[string]*nebula.Control), - outNat: make(map[string]net.UDPAddr), + controls: make(map[netip.AddrPort]*nebula.Control), + vpnControls: make(map[netip.Addr]*nebula.Control), + inNat: make(map[netip.AddrPort]*nebula.Control), + outNat: make(map[string]netip.AddrPort), flow: []flowEntry{}, ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), @@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { for _, c := range controls { addr := c.GetUDPAddr() if _, ok := r.controls[addr]; ok { - panic("Duplicate listen address: " + addr) + panic("Duplicate listen address: " + addr.String()) } r.vpnControls[c.GetVpnIp()] = c @@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { // It does not look at the addr attached to the instance. // If a route is used, this will behave like a NAT for the return path. // Rewriting the source ip:port to what was last sent to from the origin -func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) { +func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) { r.Lock() defer r.Unlock() - inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)) + inAddr := netip.AddrPortFrom(ip, port) if _, ok := r.inNat[inAddr]; ok { - panic("Duplicate listen address inNat: " + inAddr) + panic("Duplicate listen address inNat: " + inAddr.String()) } r.inNat[inAddr] = c } @@ -198,7 +196,7 @@ func (r *R) renderFlow() { panic(err) } - var participants = map[string]struct{}{} + var participants = map[netip.AddrPort]struct{}{} var participantsVals []string fmt.Fprintln(f, "```mermaid") @@ -215,7 +213,7 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr, ":", "-", 1) + sanAddr := strings.Replace(addr.String(), ":", "-", 1) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s<br/>UDP: %s\n", @@ -252,9 +250,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), line, - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { - return c[i].GetVpnIp() > c[j].GetVpnIp() + return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 }) s := renderHostmaps(c...) @@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // Nope, lets push the sender along case p := <-udpTx: - outAddr := sender.GetUDPAddr() r.Lock() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + c := r.getControl(sender.GetUDPAddr(), p.To, p) if c == nil { r.Unlock() panic("No control for udp tx") @@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + c := r.getControl(cm[x].GetUDPAddr(), p.To, p) if c == nil { r.Unlock() panic("No control for udp tx") @@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { panic(err) } - outAddr := sender.GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(sender.GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` // If the router doesn't have the nebula controller for that address, we panic -func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) { +func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) { if finish == KeepRouting { finish = RouteAndExit } r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { - if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { + if p.To == toAddr { return finish } @@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { r.Lock() p := rx.Interface().(*udp.Packet) - - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteForAllExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -697,41 +685,25 @@ func (r *R) FlushAll() { p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't FlushAll for host: " + p.To.String()) } r.Unlock() } } // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock -func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control { - if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok { - p.FromIp = newAddr.IP - p.FromPort = uint16(newAddr.Port) +func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { + if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { + p.From = newAddr } c, ok := r.inNat[toAddr] if ok { - sHost, sPort, err := net.SplitHostPort(toAddr) - if err != nil { - panic(err) - } - - port, err := strconv.Atoi(sPort) - if err != nil { - panic(err) - } - - r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{ - IP: net.ParseIP(sHost), - Port: port, - } + r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr return c } @@ -746,8 +718,9 @@ func (r *R) formatUdpPacket(p *packet) string { } from := "unknown" - if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok { - from = c.GetUDPAddr() + srcAddr, _ := netip.AddrFromSlice(v4.SrcIP) + if c, ok := r.vpnControls[srcAddr]; ok { + from = c.GetUDPAddr().String() } udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) @@ -759,7 +732,7 @@ func (r *R) formatUdpPacket(p *packet) string { return fmt.Sprintf( " %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n", strings.Replace(from, ":", "-", 1), - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), udp.SrcPort, udp.DstPort, string(data.Payload()),
firewall.go+58 −42 modified@@ -6,23 +6,23 @@ import ( "errors" "fmt" "hash/fnv" - "net" + "net/netip" "reflect" "strconv" "strings" "sync" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error } type conn struct { @@ -52,8 +52,8 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *cidr.Tree4[struct{}] - assignedCIDR *net.IPNet + localIps *bart.Table[struct{}] + assignedCIDR netip.Prefix hasSubnets bool rules string @@ -108,7 +108,7 @@ type FirewallRule struct { Any *firewallLocalCIDR Hosts map[string]*firewallLocalCIDR Groups []*firewallGroups - CIDR *cidr.Tree4[*firewallLocalCIDR] + CIDR *bart.Table[*firewallLocalCIDR] } type firewallGroups struct { @@ -122,7 +122,7 @@ type firewallPort map[int32]*FirewallCA type firewallLocalCIDR struct { Any bool - LocalCIDR *cidr.Tree4[struct{}] + LocalCIDR *bart.Table[struct{}] } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. @@ -144,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D max = defaultTimeout } - localIps := cidr.NewTree4[struct{}]() - var assignedCIDR *net.IPNet + localIps := new(bart.Table[struct{}]) + var assignedCIDR netip.Prefix + var assignedSet bool for _, ip := range c.Details.Ips { - ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}} - localIps.AddCIDR(ipNet, struct{}{}) + //TODO: IPV6-WORK the unmap is a bit unfortunate + nip, _ := netip.AddrFromSlice(ip.IP) + nip = nip.Unmap() + nprefix := netip.PrefixFrom(nip, nip.BitLen()) + localIps.Insert(nprefix, struct{}{}) - if assignedCIDR == nil { + if !assignedSet { // Only grabbing the first one in the cert since any more than that currently has undefined behavior - assignedCIDR = ipNet + assignedCIDR = nprefix + assignedSet = true } } for _, n := range c.Details.Subnets { - localIps.AddCIDR(n, struct{}{}) + nip, _ := netip.AddrFromSlice(n.IP) + ones, _ := n.Mask.Size() + nip = nip.Unmap() + localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{}) } return &Firewall{ @@ -237,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf } // AddRule properly creates the in memory rule structure for a firewall table. -func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS // https://github.com/golang/go/issues/14131 sIp := "" - if ip != nil { + if ip.IsValid() { sIp = ip.String() } lIp := "" - if localIp != nil { + if localIp.IsValid() { lIp = localIp.String() } @@ -382,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } - var cidr *net.IPNet + var cidr netip.Prefix if r.Cidr != "" { - _, cidr, err = net.ParseCIDR(r.Cidr) + cidr, err = netip.ParsePrefix(r.Cidr) if err != nil { return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) } } - var localCidr *net.IPNet + var localCidr netip.Prefix if r.LocalCidr != "" { - _, localCidr, err = net.ParseCIDR(r.LocalCidr) + localCidr, err = netip.ParsePrefix(r.LocalCidr) if err != nil { return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) } @@ -421,7 +429,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * // Make sure remote address matches nebula certificate if remoteCidr := h.remoteCidr; remoteCidr != nil { - ok, _ := remoteCidr.Contains(fp.RemoteIP) + //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different + _, ok := remoteCidr.Lookup(fp.RemoteIP) if !ok { f.metrics(incoming).droppedRemoteIP.Inc(1) return ErrInvalidRemoteIP @@ -435,7 +444,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure we are supposed to be handling this local ip address - ok, _ := f.localIps.Contains(fp.LocalIP) + //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different + _, ok := f.localIps.Lookup(fp.LocalIP) if !ok { f.metrics(incoming).droppedLocalIP.Inc(1) return ErrInvalidLocalIP @@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! func (f *Firewall) evict(p firewall.Packet) { - //TODO: report a stat if the tcp rtt tracking was never resolved? // Are we still tracking this conn? conntrack := f.Conntrack t, ok := conntrack.Conns[p] @@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC return false } -func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } @@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer return fp[firewall.PortAny].match(p, c, caPool) } -func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { +func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error { fr := func() *FirewallRule { return &FirewallRule{ Hosts: make(map[string]*firewallLocalCIDR), Groups: make([]*firewallGroups, 0), - CIDR: cidr.NewTree4[*firewallLocalCIDR](), + CIDR: new(bart.Table[*firewallLocalCIDR]), } } @@ -740,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return fc.CANames[s.Details.Name].match(p, c) } -func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error { +func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { flc := func() *firewallLocalCIDR { return &firewallLocalCIDR{ - LocalCIDR: cidr.NewTree4[struct{}](), + LocalCIDR: new(bart.Table[struct{}]), } } @@ -780,23 +789,23 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n fr.Hosts[host] = nlc } - if ip != nil { - _, nlc := fr.CIDR.GetCIDR(ip) + if ip.IsValid() { + nlc, _ := fr.CIDR.Get(ip) if nlc == nil { nlc = flc() } err := nlc.addRule(f, localCIDR) if err != nil { return err } - fr.CIDR.AddCIDR(ip, nlc) + fr.CIDR.Insert(ip, nlc) } return nil } -func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { - if len(groups) == 0 && host == "" && ip == nil { +func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool { + if len(groups) == 0 && host == "" && !ip.IsValid() { return true } @@ -810,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool return true } - if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) { + if ip.IsValid() && ip.Bits() == 0 { return true } @@ -853,24 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } } - return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool { - return flc.match(p, c) + matched := false + prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) + fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { + if prefix.Contains(p.RemoteIP) && val.match(p, c) { + matched = true + return false + } + return true }) + return matched } -func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error { - if localIp == nil { +func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { + if !localIp.IsValid() { if !f.hasSubnets || f.defaultLocalCIDRAny { flc.Any = true return nil } localIp = f.assignedCIDR - } else if localIp.Contains(net.IPv4(0, 0, 0, 0)) { + } else if localIp.Bits() == 0 { flc.Any = true } - flc.LocalCIDR.AddCIDR(localIp, struct{}{}) + flc.LocalCIDR.Insert(localIp, struct{}{}) return nil } @@ -883,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate return true } - ok, _ := flc.LocalCIDR.Contains(p.LocalIP) + _, ok := flc.LocalCIDR.Lookup(p.LocalIP) return ok }
firewall/packet.go+3 −4 modified@@ -3,8 +3,7 @@ package firewall import ( "encoding/json" "fmt" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type m map[string]interface{} @@ -20,8 +19,8 @@ const ( ) type Packet struct { - LocalIP iputil.VpnIp - RemoteIP iputil.VpnIp + LocalIP netip.Addr + RemoteIP netip.Addr LocalPort uint16 RemotePort uint16 Protocol uint8
firewall_test.go+74 −73 modified@@ -5,13 +5,13 @@ import ( "errors" "math" "net" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) @@ -65,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) { assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) - _, ti, _ := net.ParseCIDR("1.2.3.4/32") + ti, err := netip.ParsePrefix("1.2.3.4/32") + assert.NoError(t, err) - assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti) + _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti) + _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) + anyIp, err := netip.ParsePrefix("0.0.0.0/0") + assert.NoError(t, err) + + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", "")) + assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -126,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -152,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr("1.2.3.4"), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) @@ -170,34 +173,34 @@ func TestFirewall_Drop(t *testing.T) { // test remote mismatch oldRemote := p.RemoteIP - p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10)) + p.RemoteIP = netip.MustParseAddr("1.2.3.10") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteIP = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) } @@ -207,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { TCP: firewallPort{}, } - _, n, _ := net.ParseCIDR("172.1.1.1/32") - goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP) - _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "") - _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "") + pfix := netip.MustParsePrefix("172.1.1.1/32") + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { @@ -231,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { c := &cert.NebulaCertificate{} - ip, _, _ := net.ParseCIDR("9.254.254.254/32") - lip := iputil.Ip2VpnIp(ip) + ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) } }) @@ -262,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) } }) @@ -286,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) } }) @@ -363,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -387,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h.CreateRemoteCIDR(&c) @@ -406,7 +407,7 @@ func TestFirewall_Drop2(t *testing.T) { h1.CreateRemoteCIDR(&c1) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups @@ -422,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, @@ -453,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h1.CreateRemoteCIDR(&c1) @@ -468,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h2.CreateRemoteCIDR(&c2) @@ -483,13 +484,13 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h3.CreateRemoteCIDR(&c3) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match @@ -508,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -534,12 +535,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -552,7 +553,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -561,7 +562,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -725,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) { conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") + assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh") + assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) @@ -747,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr - cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)} + cidr := netip.MustParsePrefix("10.0.0.0/8") conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error conf = config.NewC(l) @@ -871,8 +872,8 @@ type addRuleCall struct { endPort int32 groups []string host string - ip *net.IPNet - localIp *net.IPNet + ip netip.Prefix + localIp netip.Prefix caName string caSha string } @@ -882,7 +883,7 @@ type mockFirewall struct { nextCallReturn error } -func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto,
go.mod+2 −0 modified@@ -38,8 +38,10 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect + github.com/bits-and-blooms/bitset v1.13.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gaissmai/bart v0.11.1 // indirect github.com/google/btree v1.1.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.5.0 // indirect
go.sum+6 −0 modified@@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= +github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -24,6 +26,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/gaissmai/bart v0.10.0 h1:yCZCYF8xzcRnqDe4jMk14NlJjL1WmMsE7ilBzvuHtiI= +github.com/gaissmai/bart v0.10.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= +github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= +github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
handshake_ix.go+42 −16 modified@@ -1,13 +1,12 @@ package nebula import ( + "net/netip" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // NOISE IX Handshakes @@ -63,7 +62,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return true } -func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { +func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { certState := f.pki.GetCertState() ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed @@ -99,12 +98,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by e.Info("Invalid certificate from host") return } - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) + + vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) + if !ok { + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid vpn ip from host") + return + } + + vpnIp = vpnIp.Unmap() certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer - if vpnIp == f.myVpnIp { + if vpnIp == f.myVpnNet.Addr() { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -113,8 +126,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) { + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -138,8 +151,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } @@ -218,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err := f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). @@ -284,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err = f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). @@ -326,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -336,8 +349,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha defer hh.Unlock() hostinfo := hh.hostinfo - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } @@ -389,7 +402,20 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha return true } - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) + vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) + if !ok { + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid vpn ip from host") + return true + } + + vpnIp = vpnIp.Unmap() certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer @@ -453,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha ci.eKey = NewNebulaCipherState(eKey) // Make sure the current udpAddr being used is set for responding - if addr != nil { + if addr.IsValid() { hostinfo.SetRemote(addr) } else { hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
handshake_manager.go+50 −41 modified@@ -6,15 +6,15 @@ import ( "crypto/rand" "encoding/binary" "errors" - "net" + "net/netip" "sync" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" + "golang.org/x/exp/slices" ) const ( @@ -46,32 +46,32 @@ type HandshakeManager struct { // Mutex for interacting with the vpnIps and indexes maps sync.RWMutex - vpnIps map[iputil.VpnIp]*HandshakeHostInfo + vpnIps map[netip.Addr]*HandshakeHostInfo indexes map[uint32]*HandshakeHostInfo mainHostMap *HostMap lightHouse *LightHouse outside udp.Conn config HandshakeConfig - OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp] + OutboundHandshakeTimer *LockingTimerWheel[netip.Addr] messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter f *Interface l *logrus.Logger // can be used to trigger outbound handshake for the given vpnIp - trigger chan iputil.VpnIp + trigger chan netip.Addr } type HandshakeHostInfo struct { sync.Mutex - startTime time.Time // Time that we first started trying with this handshake - ready bool // Is the handshake ready - counter int // How many attempts have we made so far - lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt - packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + counter int // How many attempts have we made so far + lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo } @@ -103,14 +103,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{}, + vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, config: config, - trigger: make(chan iputil.VpnIp, config.triggerBuffer), - OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), + trigger: make(chan netip.Addr, config.triggerBuffer), + OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), @@ -134,10 +134,10 @@ func (c *HandshakeManager) Run(ctx context.Context) { } } -func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { +func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp - if addr != nil { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { + if addr.IsValid() { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) { hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -170,7 +170,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { } } -func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { +func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) { hh := hm.queryVpnIp(vpnIp) if hh == nil { return @@ -212,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) - remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) + remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. // This is a very specific optimization for a fast lighthouse reply. @@ -234,8 +234,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply - var sentTo []*udp.Addr - hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) { + var sentTo []netip.AddrPort + hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { @@ -268,13 +268,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay to myself, and don't relay through the host I'm trying to connect to - if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp { + if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { continue } - relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay) - if relayHostInfo == nil || relayHostInfo.remote == nil { + relayHostInfo := hm.mainHostMap.QueryVpnIp(relay) + if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - hm.f.Handshake(*relay) + hm.f.Handshake(relay) continue } // Check the relay HostInfo to see if we already established a relay through it @@ -285,12 +285,17 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Requested: hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + + //TODO: IPV6-WORK + myVpnIpB := hm.f.myVpnNet.Addr().As4() + theirVpnIpB := vpnIp.As4() + // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), + RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), + RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } msg, err := m.Marshal() if err != nil { @@ -301,10 +306,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, + "relayFrom": hm.f.myVpnNet.Addr(), "relayTo": vpnIp, "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": *relay}). + "relay": relay}). Info("send CreateRelayRequest") } default: @@ -316,17 +321,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } } else { // No relays exist or requested yet. - if relayHostInfo.remote != nil { + if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } + //TODO: IPV6-WORK + myVpnIpB := hm.f.myVpnNet.Addr().As4() + theirVpnIpB := vpnIp.As4() + m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), + RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), + RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } msg, err := m.Marshal() if err != nil { @@ -336,10 +345,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, + "relayFrom": hm.f.myVpnNet.Addr(), "relayTo": vpnIp, "initiatorRelayIndex": idx, - "relay": *relay}). + "relay": relay}). Info("send CreateRelayRequest") } } @@ -355,7 +364,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // The 2nd argument will be true if the hostinfo is ready to transmit traffic -func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { +func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { hm.mainHostMap.RLock() h, ok := hm.mainHostMap.Hosts[vpnIp] hm.mainHostMap.RUnlock() @@ -372,7 +381,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() if hh, ok := hm.vpnIps[vpnIp]; ok { @@ -388,8 +397,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } @@ -555,7 +564,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { delete(c.vpnIps, hostinfo.vpnIp) if len(c.vpnIps) == 0 { - c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{} + c.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } delete(c.indexes, hostinfo.localIndexId) @@ -570,7 +579,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { } } -func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { +func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo @@ -579,7 +588,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { } -func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo { +func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo { hm.RLock() defer hm.RUnlock() return hm.vpnIps[vpnIp] @@ -599,7 +608,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { return hm.indexes[index] } -func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { +func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { return c.mainHostMap.GetPreferredRanges() }
handshake_manager_test.go+9 −9 modified@@ -1,24 +1,24 @@ package nebula import ( - "net" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + ip := netip.MustParseAddr("172.1.1.2") + + preferredRanges := []netip.Prefix{localrange} mainHM := newHostMap(l, vpncidr) mainHM.preferredRanges.Store(&preferredRanges) @@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.vpnIps, ip) } -func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { +func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { for _, i := range tw.t.wheel { n := i.Head for n != nil { @@ -80,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { return } @@ -92,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M return } -func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {} +func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
hostmap.go+75 −71 modified@@ -3,18 +3,17 @@ package nebula import ( "errors" "net" + "net/netip" "sync" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // const ProbeLen = 100 @@ -49,17 +48,17 @@ type Relay struct { State int LocalIndex uint32 RemoteIndex uint32 - PeerIp iputil.VpnIp + PeerIp netip.Addr } type HostMap struct { sync.RWMutex //Because we concurrently read and write to our maps Indexes map[uint32]*HostInfo Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo - Hosts map[iputil.VpnIp]*HostInfo - preferredRanges atomic.Pointer[[]*net.IPNet] - vpnCIDR *net.IPNet + Hosts map[netip.Addr]*HostInfo + preferredRanges atomic.Pointer[[]netip.Prefix] + vpnCIDR netip.Prefix l *logrus.Logger } @@ -69,12 +68,12 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[iputil.VpnIp]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer + relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } -func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) { +func (rs *RelayState) DeleteRelay(ip netip.Addr) { rs.Lock() defer rs.Unlock() delete(rs.relays, ip) @@ -90,33 +89,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay { return ret } -func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByIp[ip] return r, ok } -func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) { +func (rs *RelayState) InsertRelayTo(ip netip.Addr) { rs.Lock() defer rs.Unlock() rs.relays[ip] = struct{}{} } -func (rs *RelayState) CopyRelayIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - ret := make([]iputil.VpnIp, 0, len(rs.relays)) + ret := make([]netip.Addr, 0, len(rs.relays)) for ip := range rs.relays { ret = append(ret, ip) } return ret } -func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp)) + currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) for relayIp := range rs.relayForByIp { currentRelays = append(currentRelays, relayIp) } @@ -133,19 +132,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { return ret } -func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) { - rs.Lock() - defer rs.Unlock() - r, ok := rs.relayForByIdx[localIdx] - if !ok { - return iputil.VpnIp(0), false - } - delete(rs.relayForByIdx, localIdx) - delete(rs.relayForByIp, r.PeerIp) - return r.PeerIp, true -} - -func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool { +func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() r, ok := rs.relayForByIp[vpnIp] @@ -175,7 +162,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re return &newRelay, true } -func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByIp[vpnIp] @@ -189,23 +176,23 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { return r, ok } -func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { +func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() rs.relayForByIp[ip] = r rs.relayForByIdx[idx] = r } type HostInfo struct { - remote *udp.Addr + remote netip.AddrPort remotes *RemoteList promoteCounter atomic.Uint32 ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnIp iputil.VpnIp + vpnIp netip.Addr recvError atomic.Uint32 - remoteCidr *cidr.Tree4[struct{}] + remoteCidr *bart.Table[struct{}] relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -227,7 +214,7 @@ type HostInfo struct { lastHandshakeTime uint64 lastRoam time.Time - lastRoamRemote *udp.Addr + lastRoamRemote netip.AddrPort // Used to track other hostinfos for this vpn ip since only 1 can be primary // Synchronised via hostmap lock and not the hostinfo lock. @@ -254,7 +241,7 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap { +func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { hm := newHostMap(l, vpnCIDR) hm.reload(c, true) @@ -269,24 +256,24 @@ func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *Ho return hm } -func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap { +func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{}, - Hosts: map[iputil.VpnIp]*HostInfo{}, + Hosts: map[netip.Addr]*HostInfo{}, vpnCIDR: vpnCIDR, l: l, } } func (hm *HostMap) reload(c *config.C, initial bool) { if initial || c.HasChanged("preferred_ranges") { - var preferredRanges []*net.IPNet + var preferredRanges []netip.Prefix rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) for _, rawPreferredRange := range rawPreferredRanges { - _, preferredRange, err := net.ParseCIDR(rawPreferredRange) + preferredRange, err := netip.ParsePrefix(rawPreferredRange) if err != nil { hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") @@ -378,7 +365,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it delete(hm.Hosts, hostinfo.vpnIp) if len(hm.Hosts) == 0 { - hm.Hosts = map[iputil.VpnIp]*HostInfo{} + hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { @@ -461,11 +448,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { } } -func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { +func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { return hm.queryVpnIp(vpnIp, nil) } -func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) { +func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() @@ -483,7 +470,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { +func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -535,7 +522,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } } -func (hm *HostMap) GetPreferredRanges() []*net.IPNet { +func (hm *HostMap) GetPreferredRanges() []netip.Prefix { //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer return *hm.preferredRanges.Load() } @@ -560,23 +547,23 @@ func (hm *HostMap) ForEachIndex(f controlEach) { // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! -func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { +func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) { c := i.promoteCounter.Add(1) if c%ifce.tryPromoteEvery.Load() == 0 { remote := i.remote // return early if we are already on a preferred remote - if remote != nil { - rIP := remote.IP + if remote.IsValid() { + rIP := remote.Addr() for _, l := range preferredRanges { if l.Contains(rIP) { return } } } - i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) { - if remote != nil && (addr == nil || !preferred) { + i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { + if remote.IsValid() && (!addr.IsValid() || !preferred) { return } @@ -605,23 +592,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate { return nil } -func (i *HostInfo) SetRemote(remote *udp.Addr) { +func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object - if !i.remote.Equals(remote) { - i.remote = remote.Copy() - i.remotes.LearnRemote(i.vpnIp, remote.Copy()) + if i.remote != remote { + i.remote = remote + i.remotes.LearnRemote(i.vpnIp, remote) } } // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // time on the HostInfo will also be updated. -func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { - if newRemote == nil { +func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { + if !newRemote.IsValid() { // relays have nil udp Addrs return false } currentRemote := i.remote - if currentRemote == nil { + if !currentRemote.IsValid() { i.SetRemote(newRemote) return true } @@ -631,19 +618,19 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { newIsPreferred := false for _, l := range hm.GetPreferredRanges() { // return early if we are already on a preferred remote - if l.Contains(currentRemote.IP) { + if l.Contains(currentRemote.Addr()) { return false } - if l.Contains(newRemote.IP) { + if l.Contains(newRemote.Addr()) { newIsPreferred = true } } if newIsPreferred { // Consider this a roaming event i.lastRoam = time.Now() - i.lastRoamRemote = currentRemote.Copy() + i.lastRoamRemote = currentRemote i.SetRemote(newRemote) @@ -666,13 +653,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { return } - remoteCidr := cidr.NewTree4[struct{}]() + remoteCidr := new(bart.Table[struct{}]) for _, ip := range c.Details.Ips { - remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) + //TODO: IPV6-WORK what to do when ip is invalid? + nip, _ := netip.AddrFromSlice(ip.IP) + nip = nip.Unmap() + bits, _ := ip.Mask.Size() + remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) } for _, n := range c.Details.Subnets { - remoteCidr.AddCIDR(n, struct{}{}) + //TODO: IPV6-WORK what to do when ip is invalid? + nip, _ := netip.AddrFromSlice(n.IP) + nip = nip.Unmap() + bits, _ := n.Mask.Size() + remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) } i.remoteCidr = remoteCidr } @@ -697,9 +692,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { +func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage - var ips []net.IP + var ips []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) @@ -721,20 +716,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { ip = v.IP } + nip, ok := netip.AddrFromSlice(ip) + if !ok { + if l.Level >= logrus.DebugLevel { + l.WithField("localIp", ip).Debug("ip was invalid for netip") + } + continue + } + nip = nip.Unmap() + //TODO: Filtering out link local for now, this is probably the most correct thing //TODO: Would be nice to filter out SLAAC MAC based ips as well - if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() { - allow := allowList.Allow(ip) + if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false { + allow := allowList.Allow(nip) if l.Level >= logrus.TraceLevel { - l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow") + l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow") } if !allow { continue } - ips = append(ips, ip) + ips = append(ips, nip) } } } - return &ips + return ips }
hostmap_tester.go+4 −2 modified@@ -5,9 +5,11 @@ package nebula // This file contains functions used to export information to the e2e testing framework -import "github.com/slackhq/nebula/iputil" +import ( + "net/netip" +) -func (i *HostInfo) GetVpnIp() iputil.VpnIp { +func (i *HostInfo) GetVpnIp() netip.Addr { return i.vpnIp }
hostmap_test.go+25 −34 modified@@ -1,7 +1,7 @@ package nebula import ( - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" @@ -13,26 +13,23 @@ func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() hm := newHostMap( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, + netip.MustParsePrefix("10.0.0.1/24"), ) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} + h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} + h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} + h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} + h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) hm.unlockedAddHostInfo(h2, f) hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -47,7 +44,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -62,7 +59,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -77,7 +74,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -93,20 +90,17 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() hm := newHostMap( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, + netip.MustParsePrefix("10.0.0.1/24"), ) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} - h5 := &HostInfo{vpnIp: 1, localIndexId: 5} - h6 := &HostInfo{vpnIp: 1, localIndexId: 6} + h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} + h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} + h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} + h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} + h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} + h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) @@ -122,7 +116,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -141,7 +135,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -159,7 +153,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -175,7 +169,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -189,7 +183,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -201,7 +195,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } @@ -211,14 +205,11 @@ func TestHostMap_reload(t *testing.T) { hm := NewHostMapFromConfig( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, + netip.MustParsePrefix("10.0.0.1/24"), c, ) - toS := func(ipn []*net.IPNet) []string { + toS := func(ipn []netip.Prefix) []string { var s []string for _, n := range ipn { s = append(s, n.String())
inside.go+20 −24 modified@@ -1,12 +1,13 @@ package nebula import ( + "net/netip" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" - "github.com/slackhq/nebula/udp" ) func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { @@ -19,11 +20,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore local broadcast packets - if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast { + if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { return } - if fwPacket.RemoteIP == f.myVpnIp { + if fwPacket.RemoteIP == f.myVpnNet.Addr() { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which // routes packets from the Nebula IP to the Nebula IP through the Nebula @@ -39,8 +40,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - // Ignore broadcast packets - if f.dropMulticast && isMulticast(fwPacket.RemoteIP) { + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { return } @@ -64,7 +65,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q) + f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) } else { f.rejectInside(packet, out, q) @@ -113,19 +114,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * return } - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } -func (f *Interface) Handshake(vpnIp iputil.VpnIp) { +func (f *Interface) Handshake(vpnIp netip.Addr) { f.getOrHandshake(vpnIp, nil) } // getOrHandshake returns nil if the vpnIp is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { +func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + if !f.myVpnNet.Contains(vpnIp) { vpnIp = f.inside.RouteFor(vpnIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return nil, false } } @@ -152,11 +153,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp return } - f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp -func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { +func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) @@ -182,10 +183,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) - f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) { +func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } @@ -255,12 +256,12 @@ func (f *Interface) SendVia(via *HostInfo, f.connectionManager.RelayUsed(relay.LocalIndex) } -func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) { +func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { if ci.eKey == nil { //TODO: log warning return } - useRelay := remote == nil && hostinfo.remote == nil + useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() fullOut := out if useRelay { @@ -308,13 +309,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType return } - if remote != nil { + if remote.IsValid() { err = f.writers[q].WriteTo(out, remote) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { hostinfo.logger(f.l).WithError(err). @@ -334,8 +335,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } } - -func isMulticast(ip iputil.VpnIp) bool { - // Class D multicast - return (((ip >> 24) & 0xff) & 0xf0) == 0xe0 -}
interface.go+36 −11 modified@@ -2,10 +2,11 @@ package nebula import ( "context" + "encoding/binary" "errors" "fmt" "io" - "net" + "net/netip" "os" "runtime" "sync/atomic" @@ -16,7 +17,6 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -63,8 +63,8 @@ type Interface struct { serveDns bool createTime time.Time lightHouse *LightHouse - localBroadcast iputil.VpnIp - myVpnIp iputil.VpnIp + myBroadcastAddr netip.Addr + myVpnNet netip.Prefix dropLocalBroadcast bool dropMulticast bool routines int @@ -102,9 +102,9 @@ type EncWriter interface { out []byte, nocopy bool, ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) - Handshake(vpnIp iputil.VpnIp) + Handshake(vpnIp netip.Addr) } type sendRecvErrorConfig uint8 @@ -115,10 +115,10 @@ const ( sendRecvErrorPrivate ) -func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool { +func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool { switch s { case sendRecvErrorPrivate: - return ip.IsPrivate() + return ip.Addr().IsPrivate() case sendRecvErrorAlways: return true case sendRecvErrorNever: @@ -156,7 +156,27 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { } certificate := c.pki.GetCertState().Certificate - myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP) + + myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) + if !ok { + return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP) + } + + myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask) + if !ok { + return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask) + } + + myVpnAddr = myVpnAddr.Unmap() + myVpnMask = myVpnMask.Unmap() + + if myVpnAddr.BitLen() != myVpnMask.BitLen() { + return nil, fmt.Errorf("ip address and mask are different lengths in certificate") + } + + ones, _ := certificate.Details.Ips[0].Mask.Size() + myVpnNet := netip.PrefixFrom(myVpnAddr, ones) + ifce := &Interface{ pki: c.pki, hostMap: c.HostMap, @@ -168,14 +188,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, - localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask), dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - myVpnIp: myVpnIp, + myVpnNet: myVpnNet, relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -190,6 +209,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } + if myVpnAddr.Is4() { + addr := myVpnNet.Masked().Addr().As4() + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) + ifce.myBroadcastAddr = netip.AddrFrom4(addr) + } + ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait))
iputil/packet.go+2 −0 modified@@ -6,6 +6,8 @@ import ( "golang.org/x/net/ipv4" ) +//TODO: IPV6-WORK can probably delete this + const ( // Need 96 bytes for the largest reject packet: // - 20 byte ipv4 header
iputil/util.go+0 −93 removed@@ -1,93 +0,0 @@ -package iputil - -import ( - "encoding/binary" - "fmt" - "net" - "net/netip" -) - -type VpnIp uint32 - -const maxIPv4StringLen = len("255.255.255.255") - -func (ip VpnIp) String() string { - b := make([]byte, maxIPv4StringLen) - - n := ubtoa(b, 0, byte(ip>>24)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>16&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>8&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip&255)) - return string(b[:n]) -} - -func (ip VpnIp) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil -} - -func (ip VpnIp) ToIP() net.IP { - nip := make(net.IP, 4) - binary.BigEndian.PutUint32(nip, uint32(ip)) - return nip -} - -func (ip VpnIp) ToNetIpAddr() netip.Addr { - var nip [4]byte - binary.BigEndian.PutUint32(nip[:], uint32(ip)) - return netip.AddrFrom4(nip) -} - -func Ip2VpnIp(ip []byte) VpnIp { - if len(ip) == 16 { - return VpnIp(binary.BigEndian.Uint32(ip[12:16])) - } - return VpnIp(binary.BigEndian.Uint32(ip)) -} - -func ToNetIpAddr(ip net.IP) (netip.Addr, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip) - } - return addr, nil -} - -func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) { - addr, err := ToNetIpAddr(ipNet.IP) - if err != nil { - return netip.Prefix{}, err - } - ones, bits := ipNet.Mask.Size() - if ones == 0 && bits == 0 { - return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet) - } - return netip.PrefixFrom(addr, ones), nil -} - -// ubtoa encodes the string form of the integer v to dst[start:] and -// returns the number of bytes written to dst. The caller must ensure -// that dst has sufficient length. -func ubtoa(dst []byte, start int, v byte) int { - if v < 10 { - dst[start] = v + '0' - return 1 - } else if v < 100 { - dst[start+1] = v%10 + '0' - dst[start] = v/10 + '0' - return 2 - } - - dst[start+2] = v%10 + '0' - dst[start+1] = (v/10)%10 + '0' - dst[start] = v/100 + '0' - return 3 -}
iputil/util_test.go+0 −17 removed@@ -1,17 +0,0 @@ -package iputil - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestVpnIp_String(t *testing.T) { - assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String()) - assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String()) - assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String()) - assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String()) - assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String()) - assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String()) -}
lighthouse.go+212 −188 modified@@ -7,16 +7,16 @@ import ( "fmt" "net" "net/netip" + "strconv" "sync" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) @@ -26,25 +26,18 @@ import ( var ErrHostNotKnown = errors.New("host not known") -type netIpAndPort struct { - ip net.IP - port uint16 -} - type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool - myVpnIp iputil.VpnIp - myVpnZeros iputil.VpnIp - myVpnNet *net.IPNet + myVpnNet netip.Prefix punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses // map of vpn Ip to answers - addrMap map[iputil.VpnIp]*RemoteList + addrMap map[netip.Addr]*RemoteList // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and @@ -57,26 +50,26 @@ type LightHouse struct { localAllowList atomic.Pointer[LocalAllowList] // used to trigger the HandshakeManager when we receive HostQueryReply - handshakeTrigger chan<- iputil.VpnIp + handshakeTrigger chan<- netip.Addr // staticList exists to avoid having a bool in each addrMap entry // since static should be rare - staticList atomic.Pointer[map[iputil.VpnIp]struct{}] - lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}] + staticList atomic.Pointer[map[netip.Addr]struct{}] + lighthouses atomic.Pointer[map[netip.Addr]struct{}] interval atomic.Int64 updateCancel context.CancelFunc ifce EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 - advertiseAddrs atomic.Pointer[[]netIpAndPort] + advertiseAddrs atomic.Pointer[[]netip.AddrPort] // IP's of relays that can be used by peers to access me - relaysForMe atomic.Pointer[[]iputil.VpnIp] + relaysForMe atomic.Pointer[[]netip.Addr] - queryChan chan iputil.VpnIp + queryChan chan netip.Addr - calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote + calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -85,7 +78,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - nebulaPort = uint32(uPort.Port) + nebulaPort = uint32(uPort.Port()) } - ones, _ := myVpnNet.Mask.Size() h := LightHouse{ ctx: ctx, amLighthouse: amLighthouse, - myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), - myVpnZeros: iputil.VpnIp(32 - ones), myVpnNet: myVpnNet, - addrMap: make(map[iputil.VpnIp]*RemoteList), + addrMap: make(map[netip.Addr]*RemoteList), nebulaPort: nebulaPort, punchConn: pc, punchy: p, - queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)), + queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } - lighthouses := make(map[iputil.VpnIp]struct{}) + lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) - staticList := make(map[iputil.VpnIp]struct{}) + staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) if c.GetBool("stats.lighthouse_metrics", false) { @@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, return &h, nil } -func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} { return *lh.staticList.Load() } -func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { return *lh.lighthouses.Load() } @@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList { return lh.localAllowList.Load() } -func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort { +func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort { return *lh.advertiseAddrs.Load() } -func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { +func (lh *LightHouse) GetRelaysForMe() []netip.Addr { return *lh.relaysForMe.Load() } -func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] { +func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] { return lh.calculatedRemotes.Load() } @@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 { func (lh *LightHouse) reload(c *config.C, initial bool) error { if initial || c.HasChanged("lighthouse.advertise_addrs") { rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{}) - advAddrs := make([]netIpAndPort, 0) + advAddrs := make([]netip.AddrPort, 0) for i, rawAddr := range rawAdvAddrs { - fIp, fPort, err := udp.ParseIPAndPort(rawAddr) + host, sport, err := net.SplitHostPort(rawAddr) if err != nil { return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - if fPort == 0 { - fPort = uint16(lh.nebulaPort) + ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) + if err != nil { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + if len(ips) == 0 { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil) + } + + port, err := strconv.Atoi(sport) + if err != nil { + return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + + if port == 0 { + port = int(lh.nebulaPort) } - if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) { + //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used + ip := ips[0].Unmap() + if lh.myVpnNet.Contains(ip) { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue } - advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort}) + advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port))) } lh.advertiseAddrs.Store(&advAddrs) @@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.RUnlock() } // Build a new list based on current config. - staticList := make(map[iputil.VpnIp]struct{}) - err := lh.loadStaticMap(c, lh.myVpnNet, staticList) + staticList := make(map[netip.Addr]struct{}) + err := lh.loadStaticMap(c, staticList) if err != nil { return err } @@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } if initial || c.HasChanged("lighthouse.hosts") { - lhMap := make(map[iputil.VpnIp]struct{}) - err := lh.parseLighthouses(c, lh.myVpnNet, lhMap) + lhMap := make(map[netip.Addr]struct{}) + err := lh.parseLighthouses(c, lhMap) if err != nil { return err } @@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { if len(c.GetStringSlice("relay.relays", nil)) > 0 { lh.l.Info("Ignoring relays from config because am_relay is true") } - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} lh.relaysForMe.Store(&relaysForMe) case false: - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} for _, v := range c.GetStringSlice("relay.relays", nil) { lh.l.WithField("relay", v).Info("Read relay from config") - configRIP := net.ParseIP(v) - if configRIP != nil { - relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP)) + configRIP, err := netip.ParseAddr(v) + //TODO: We could print the error here + if err == nil { + relaysForMe = append(relaysForMe, configRIP) } } lh.relaysForMe.Store(&relaysForMe) @@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return nil } -func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { lhs := c.GetStringSlice("lighthouse.hosts", []string{}) if lh.amLighthouse && len(lhs) != 0 { lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } for i, host := range lhs { - ip := net.ParseIP(host) - if ip == nil { - return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) + ip, err := netip.ParseAddr(host) + if err != nil { + return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - if !tunCidr.Contains(ip) { - return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) + if !lh.myVpnNet.Contains(ip) { + return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil) } - lhMap[iputil.Ip2VpnIp(ip)] = struct{}{} + lhMap[ip] = struct{}{} } if !lh.amLighthouse && len(lhMap) == 0 { @@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) { return network, nil } -func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error { d, err := getStaticMapCadence(c) if err != nil { return err @@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return err } - lookup_timeout, err := getStaticMapLookupTimeout(c) + lookupTimeout, err := getStaticMapLookupTimeout(c) if err != nil { return err } @@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList i := 0 for k, v := range shm { - rip := net.ParseIP(fmt.Sprintf("%v", k)) - if rip == nil { - return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil) + vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k)) + if err != nil { + return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - if !tunCidr.Contains(rip) { - return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil) + if !lh.myVpnNet.Contains(vpnIp) { + return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil) } - vpnIp := iputil.Ip2VpnIp(rip) vals, ok := v.([]interface{}) if !ok { vals = []interface{}{v} @@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) } - err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList) + err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList) if err != nil { return err } @@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { +func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { if !lh.IsLighthouseIP(ip) { lh.QueryServer(ip) } @@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { +func (lh *LightHouse) QueryServer(ip netip.Addr) { // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses if lh.amLighthouse || lh.IsLighthouseIP(ip) { return @@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { lh.queryChan <- ip } -func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { +func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { lh.RLock() if v, ok := lh.addrMap[ip]; ok { lh.RUnlock() @@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? if v, ok := lh.addrMap[vpnIp]; ok { @@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in return false, 0, nil } -func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { +func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) { // First we check the static mapping // and do nothing if it is there if _, ok := lh.GetStaticHostList()[vpnIp]; ok { @@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) am.Lock() @@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t am.unlockedSetHostnamesResults(hr) for _, addrPort := range hr.GetIPs() { - + if !lh.shouldAdd(vpnIp, addrPort.Addr()) { + continue + } switch { case addrPort.Addr().Is4(): - to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV4(vpnIp, to) { - continue - } - am.unlockedPrependV4(lh.myVpnIp, to) + am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): - to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV6(vpnIp, to) { - continue - } - am.unlockedPrependV6(lh.myVpnIp, to) + am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) } } @@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { tree := lh.getCalculatedRemotes() if tree == nil { return false } - ok, calculatedRemotes := tree.MostSpecificContains(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnIp) if !ok { return false } @@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4) + am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) return len(calculated) > 0 } // unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { +func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList { am, ok := lh.addrMap[vpnIp] if !ok { am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) @@ -617,59 +616,42 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { return am } -func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool { - switch { - case to.Is4(): - ipBytes := to.As4() - ip := iputil.Ip2VpnIp(ipBytes[:]) - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) { - return false - } - case to.Is6(): - ipBytes := to.As16() - - hi := binary.BigEndian.Uint64(ipBytes[:8]) - lo := binary.BigEndian.Uint64(ipBytes[8:]) - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - - // We don't check our vpn network here because nebula does not support ipv6 on the inside - if !allow { - return false - } +func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().Allow(vpnIp, to) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + } + if !allow || lh.myVpnNet.Contains(to) { + return false } + return true } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) +func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool { + ip := AddrPortFromIp4AndPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) { + if !allow || lh.myVpnNet.Contains(ip.Addr()) { return false } return true } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo) +func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool { + ip := AddrPortFromIp6AndPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") } - // We don't check our vpn network here because nebula does not support ipv6 on the inside - if !allow { + if !allow || lh.myVpnNet.Contains(ip.Addr()) { return false } @@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP { return ip } -func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool { +func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { if _, ok := lh.GetLighthouses()[vpnIp]; ok { return true } return false } -func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta { +func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta { + if vpnIp.Is6() { + //TODO: need to support ipv6 + panic("ipv6 is not yet supported") + } + + b := vpnIp.As4() return &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: uint32(VpnIp), + VpnIp: binary.BigEndian.Uint32(b[:]), }, } } -func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { - ipp := Ip4AndPort{Port: port} - ipp.Ip = uint32(iputil.Ip2VpnIp(ip)) - return &ipp +func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], ip.Ip) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port)) +} + +func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], ip.Hi) + binary.BigEndian.PutUint64(b[8:], ip.Lo) + return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port)) } func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { @@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { } } -func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip[:8]), - Lo: binary.BigEndian.Uint64(ip[8:]), - Port: port, - } -} - +// TODO: IPV6-WORK we can delete some more of these func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { ip6Addr := ip.As16() return &Ip6AndPort{ @@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { Port: uint32(port), } } -func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { - ip := ipp.Ip - return udp.NewAddr( - net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)), - uint16(ipp.Port), - ) -} - -func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { - return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) -} func (lh *LightHouse) startQueryWorker() { if lh.amLighthouse { @@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() { }() } -func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) { +func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) { if lh.IsLighthouseIP(ip) { return } @@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() { var v6 []*Ip6AndPort for _, e := range lh.GetAdvertiseAddrs() { - if ip := e.ip.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port))) + if e.Addr().Is4() { + v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port())) } else { - v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port))) + v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() - for _, e := range *localIps(lh.l, lal) { - if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) { + for _, e := range localIps(lh.l, lal) { + if lh.myVpnNet.Contains(e) { continue } // Only add IPs that aren't my VPN/tun IP - if ip := e.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort)) + if e.Is4() { + v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort))) } else { - v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort)) + v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort))) } } var relays []uint32 for _, r := range lh.GetRelaysForMe() { - relays = append(relays, (uint32)(r)) + //TODO: IPV6-WORK both relays and vpnip need ipv6 support + b := r.As4() + relays = append(relays, binary.BigEndian.Uint32(b[:])) } + //TODO: IPV6-WORK both relays and vpnip need ipv6 support + b := lh.myVpnNet.Addr().As4() + m := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: uint32(lh.myVpnIp), + VpnIp: binary.BigEndian.Uint32(b[:]), Ip4AndPorts: v4, Ip6AndPorts: v6, RelayVpnIp: relays, @@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { } func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { + return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) { lhh.HandleRequest(rAddr, vpnIp, p, f) } } -func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { @@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, //TODO: we can DRY this further reqVpnIp := n.Details.VpnIp + + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + queryVpnIp := netip.AddrFrom4(b) + //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) { + found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply n.Details.VpnIp = reqVpnIp @@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - n.Details.VpnIp = uint32(vpnIp) - + //TODO: IPV6-WORK + b = vpnIp.As4() + n.Details.VpnIp = binary.BigEndian.Uint32(b[:]) lhh.coalesceAnswers(c, n) return n.MarshalTo(lhh.pb) @@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0]) + + //TODO: IPV6-WORK + binary.BigEndian.PutUint32(b[:], reqVpnIp) + sendTo := netip.AddrFrom4(b) + w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { @@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { } if c.relay != nil { - n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...) + //TODO: IPV6-WORK + relays := make([]uint32, len(c.relay.relay)) + b := [4]byte{} + for i, _ := range relays { + b = c.relay.relay[i].As4() + relays[i] = binary.BigEndian.Uint32(b[:]) + } + n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...) } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp)) + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + certVpnIp := netip.AddrFrom4(b) + am := lhh.lh.unlockedGetRemoteList(certVpnIp) am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) + //TODO: IPV6-WORK am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + + //TODO: IPV6-WORK + relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) + for i, _ := range n.Details.RelayVpnIp { + binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) + relays[i] = netip.AddrFrom4(b) + } + am.unlockedSetRelay(vpnIp, certVpnIp, relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { - case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp): + case lhh.lh.handshakeTrigger <- certVpnIp: default: } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) @@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp } //Simple check that the host sent this not someone else - if n.Details.VpnIp != uint32(vpnIp) { + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + detailsVpnIp := netip.AddrFrom4(b) + if detailsVpnIp != vpnIp { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update") + lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") } return } @@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + + //TODO: IPV6-WORK + relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) + for i, _ := range n.Details.RelayVpnIp { + binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) + relays[i] = netip.AddrFrom4(b) + } + am.unlockedSetRelay(vpnIp, detailsVpnIp, relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - n.Details.VpnIp = uint32(vpnIp) + + //TODO: IPV6-WORK + vpnIpB := vpnIp.As4() + n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:]) ln, err := n.MarshalTo(lhh.pb) if err != nil { @@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } empty := []byte{0} - punch := func(vpnPeer *udp.Addr) { - if vpnPeer == nil { + punch := func(vpnPeer netip.AddrPort) { + if !vpnPeer.IsValid() { return } @@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i if lhh.l.Level >= logrus.DebugLevel { //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp)) + //TODO: IPV6-WORK, make this debug line not suck + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b)) } } for _, a := range n.Details.Ip4AndPorts { - punch(NewUDPAddrFromLH4(a)) + punch(AddrPortFromIp4AndPort(a)) } for _, a := range n.Details.Ip6AndPorts { - punch(NewUDPAddrFromLH6(a)) + punch(AddrPortFromIp6AndPort(a)) } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - queryVpnIp := iputil.VpnIp(n.Details.VpnIp) + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + queryVpnIp := netip.AddrFrom4(b) go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { @@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i }() } } - -// ipMaskContains checks if testIp is contained by ip after applying a cidr. -// zeros is 32 - bits from net.IPMask.Size() -func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool { - return (testIp^ip)>>zeros == 0 -}
lighthouse_test.go+85 −102 modified@@ -2,15 +2,14 @@ package nebula import ( "context" + "encoding/binary" "fmt" - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) @@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) { var m Ip4AndPort err := m.Unmarshal(b) assert.NoError(t, err) - assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String()) + ip := netip.MustParseAddr("10.1.1.1") + bp := ip.As4() + assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) } func TestNewLhQuery(t *testing.T) { - myIp := net.ParseIP("192.1.1.1") - myIpint := iputil.Ip2VpnIp(myIp) + myIp, err := netip.ParseAddr("192.1.1.1") + assert.NoError(t, err) // Generating a new lh query should work - a := NewLhQueryByInt(myIpint) + a := NewLhQueryByInt(myIp) // The result should be a nebulameta protobuf assert.IsType(t, &NebulaMeta{}, a) @@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) { func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") lh1 := "10.128.0.2" c := config.NewC(l) @@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") lh1 := "10.128.0.2" c := config.NewC(l) @@ -83,52 +84,55 @@ func TestReloadLighthouseInterval(t *testing.T) { lh.ifce = &mockEncWriter{} // The first one routine is kicked off by main.go currently, lets make sure that one dies - c.ReloadConfigString("lighthouse:\n interval: 5") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) assert.Equal(t, int64(5), lh.interval.Load()) // Subsequent calls are killed off by the LightHouse.Reload function - c.ReloadConfigString("lighthouse:\n interval: 10") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) assert.Equal(t, int64(10), lh.interval.Load()) // If this completes then nothing is stealing our reload routine - c.ReloadConfigString("lighthouse:\n interval: 11") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) assert.Equal(t, int64(11), lh.interval.Load()) } func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") + myVpnNet := netip.MustParsePrefix("10.128.0.1/0") c := config.NewC(l) lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) if !assert.NoError(b, err) { b.Fatal() } - hAddr := udp.NewAddrFromString("4.5.6.7:12345") - hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") - lh.addrMap[3] = NewRemoteList(nil) - lh.addrMap[3].unlockedSetV4( - 3, - 3, + hAddr := netip.MustParseAddrPort("4.5.6.7:12345") + hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") + + vpnIp3 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp3] = NewRemoteList(nil) + lh.addrMap[vpnIp3].unlockedSetV4( + vpnIp3, + vpnIp3, []*Ip4AndPort{ - NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), - NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), + NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), + NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) - rAddr := udp.NewAddrFromString("1.2.2.3:12345") - rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") - lh.addrMap[2] = NewRemoteList(nil) - lh.addrMap[2].unlockedSetV4( - 3, - 3, + rAddr := netip.MustParseAddrPort("1.2.2.3:12345") + rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") + vpnIp2 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp2] = NewRemoteList(nil) + lh.addrMap[vpnIp2].unlockedSetV4( + vpnIp3, + vpnIp3, []*Ip4AndPort{ - NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), - NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), + NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), + NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) mw := &mockEncWriter{} @@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, vpnIp2, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -161,59 +165,59 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, vpnIp2, p, mw) } }) } func TestLighthouse_Memory(t *testing.T) { l := test.NewLogger() - myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} - myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} - myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242} - myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242} - myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242} - myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243} - myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244} - myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245} - myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246} - myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247} - myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248} - myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249} - myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2")) - - theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242} - theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242} - theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242} - theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242} - theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242} - theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3")) + myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242") + myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242") + myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242") + myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242") + myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242") + myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243") + myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244") + myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245") + myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246") + myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247") + myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248") + myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249") + myVpnIp := netip.MustParseAddr("10.128.0.2") + + theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242") + theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242") + theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242") + theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242") + theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242") + theirVpnIp := netip.MustParseAddr("10.128.0.3") c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) assert.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) // Grow it back to 2 - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it - newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) + newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) @@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) { newLHHostUpdate( myUdpAddr0, myVpnIp, - []*udp.Addr{ + []netip.AddrPort{ myUdpAddr1, myUdpAddr2, myUdpAddr3, @@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) { ) // Make sure we won't add ips in our vpn network - bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242} - bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242} - good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242} - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh) + bad1 := netip.MustParseAddrPort("10.128.0.99:4242") + bad2 := netip.MustParseAddrPort("10.128.0.100:4242") + good := netip.MustParseAddrPort("1.128.0.99:4242") + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) } @@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) assert.NoError(t, err) nc := map[interface{}]interface{}{ @@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) { assert.NoError(t, err) } -func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { +func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { + //TODO: IPV6-WORK + bip := queryVpnIp.As4() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: uint32(queryVpnIp), + VpnIp: binary.BigEndian.Uint32(bip[:]), }, } @@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh return w.lastReply } -func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) { +func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) { + //TODO: IPV6-WORK + bip := vpnIp.As4() req := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: uint32(vpnIp), + VpnIp: binary.BigEndian.Uint32(bip[:]), Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), }, } for k, v := range addrs { - req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)} + req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) } b, err := req.Marshal() @@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, // ) //} -func Test_ipMaskContains(t *testing.T) { - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255")))) - assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) -} - type testLhReply struct { nebType header.MessageType nebSubType header.MessageSubType - vpnIp iputil.VpnIp + vpnIp netip.Addr msg *NebulaMeta } @@ -414,7 +416,7 @@ type testEncWriter struct { func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } -func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { +func (tw *testEncWriter) Handshake(vpnIp netip.Addr) { } func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { @@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } } -func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { @@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess } // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) { - if !assert.Len(t, have, len(want)) { - return - } - - for k, w := range want { - if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have))) - } - } -} - -// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match -func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) { +func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { - if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have)) + //TODO: IPV6-WORK + h := AddrPortFromIp4AndPort(have[k]) + if !(h == w) { + assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } } } - -func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr { - addrs := make([]*udp.Addr, len(ips)) - for k, v := range ips { - addrs[k] = NewUDPAddrFromLH4(v) - } - return addrs -}
main.go+22 −8 modified@@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "time" "github.com/sirupsen/logrus" @@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - // TODO: make sure mask is 4 bytes - tunCidr := certificate.Details.Ips[0] + ones, _ := certificate.Details.Ips[0].Mask.Size() + addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) + if !ok { + err = util.NewContextualError( + "Invalid ip address in certificate", + m{"vpnIp": certificate.Details.Ips[0].IP}, + nil, + ) + return nil, err + } + tunCidr := netip.PrefixFrom(addr, ones) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { @@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if !configTest { rawListenHost := c.GetString("listen.host", "0.0.0.0") - var listenHost *net.IPAddr + var listenHost netip.Addr if rawListenHost == "[::]" { // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. - listenHost = &net.IPAddr{IP: net.IPv6zero} + listenHost = netip.IPv6Unspecified() } else { - listenHost, err = net.ResolveIPAddr("ip", rawListenHost) + ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } + if len(ips) == 0 { + return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) + } + listenHost = ips[0].Unmap() } for i := 0; i < routines; i++ { - l.Infof("listening %q %d", listenHost.IP, port) - udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) + l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } @@ -178,7 +192,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - port = int(uPort.Port) + port = int(uPort.Port()) } } }
outside.go+48 −47 modified@@ -4,14 +4,14 @@ import ( "encoding/binary" "errors" "fmt" + "net/netip" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" "google.golang.org/protobuf/proto" @@ -21,9 +21,10 @@ const ( minFwPacketLen = 4 ) +// TODO: IPV6-WORK this can likely be removed now func readOutsidePackets(f *Interface) udp.EncReader { return func( - addr *udp.Addr, + addr netip.AddrPort, out []byte, packet []byte, header *header.H, @@ -37,27 +38,25 @@ func readOutsidePackets(f *Interface) udp.EncReader { } } -func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) + f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) } return } //l.Error("in packet ", header, packet[HeaderLen:]) - if addr != nil { - if ip4 := addr.IP.To4(); ip4 != nil { - if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet") - } - return + if ip.IsValid() { + if f.myVpnNet.Contains(ip.Addr()) { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } + return } } @@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt switch h.Type { case header.Message: // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } @@ -101,7 +100,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt // Successfully validated the thing. Get rid of the Relay header. signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.RelayUsed(h.RemoteIndex) @@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.LightHouse: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") @@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt return } - lhf(addr, hostinfo.vpnIp, d) + lhf(ip, hostinfo.vpnIp, d) // Fallthrough to the bottom to record incoming traffic case header.Test: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt test packet") @@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) } @@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(addr, via, packet, h) + f.handshakeManager.HandleIncoming(ip, via, packet, h) return case header.RecvError: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(addr, h) + f.handleRecvError(ip, h) return case header.CloseTunnel: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } - hostinfo.logger(f.l).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", ip). Info("Close tunnel received, tearing down.") f.closeTunnel(hostinfo) return case header.Control: - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt Control packet") return @@ -241,11 +240,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) return } - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.connectionManager.In(hostinfo.localIndexId) } @@ -264,34 +263,34 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) { - if addr != nil && !hostinfo.remote.Equals(addr) { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { - hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { + if ip.IsValid() && hostinfo.remote != ip { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(addr) + hostinfo.SetRemote(ip) } } -func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool { +func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { // If connectionstate exists and the replay protector allows, process packet // Else, send recv errors for 300 seconds after a restart to allow fast reconnection. if ci == nil || !ci.window.Check(f.l, h.MessageCounter) { - if addr != nil { + if addr.IsValid() { f.maybeSendRecvError(addr, h.RemoteIndex) return false } else { @@ -340,8 +339,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { // Firewall packets are locally oriented if incoming { - fp.RemoteIP = iputil.Ip2VpnIp(data[12:16]) - fp.LocalIP = iputil.Ip2VpnIp(data[16:20]) + //TODO: IPV6-WORK + fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) + fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -350,8 +350,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - fp.LocalIP = iputil.Ip2VpnIp(data[12:16]) - fp.RemoteIP = iputil.Ip2VpnIp(data[16:20]) + //TODO: IPV6-WORK + fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) + fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -425,13 +426,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return true } -func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) { - if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) { +func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { + if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) { f.sendRecvError(endpoint, index) } } -func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { +func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) //TODO: this should be a signed message so we can trust that we should drop the index @@ -444,7 +445,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { } } -func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { +func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr). @@ -461,7 +462,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { return } - if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) { + if hostinfo.remote.IsValid() && hostinfo.remote != addr { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return }
outside_test.go+5 −5 modified@@ -2,10 +2,10 @@ package nebula import ( "net" + "net/netip" "testing" "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" ) @@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.LocalPort, uint16(4)) @@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.LocalPort, uint16(5)) }
overlay/device.go+3 −5 modified@@ -2,16 +2,14 @@ package overlay import ( "io" - "net" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type Device interface { io.ReadWriteCloser Activate() error - Cidr() *net.IPNet + Cidr() netip.Prefix Name() string - RouteFor(iputil.VpnIp) iputil.VpnIp + RouteFor(netip.Addr) netip.Addr NewMultiQueueReader() (io.ReadWriteCloser, error) }
overlay/route.go+19 −25 modified@@ -1,34 +1,30 @@ package overlay import ( - "bytes" "fmt" "math" "net" + "net/netip" "runtime" "strconv" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type Route struct { MTU int Metric int - Cidr *net.IPNet - Via *iputil.VpnIp + Cidr netip.Prefix + Via netip.Addr Install bool } // Equal determines if a route that could be installed in the system route table is equal to another // Via is ignored since that is only consumed within nebula itself func (r Route) Equal(t Route) bool { - if !r.Cidr.IP.Equal(t.Cidr.IP) { - return false - } - if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) { + if r.Cidr != t.Cidr { return false } if r.Metric != t.Metric { @@ -51,21 +47,21 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { - routeTree := cidr.NewTree4[iputil.VpnIp]() +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) { + routeTree := new(bart.Table[netip.Addr]) for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) } - if r.Via != nil { - routeTree.AddCIDR(r.Cidr, *r.Via) + if r.Via.IsValid() { + routeTree.Insert(r.Cidr, r.Via) } } return routeTree, nil } -func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -116,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { MTU: mtu, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } - if !ipWithin(network, r.Cidr) { + if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { return nil, fmt.Errorf( "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", i+1, @@ -136,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return routes, nil } -func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") @@ -202,18 +198,16 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia) } - nVia := net.ParseIP(via) - if nVia == nil { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via) + viaVpnIp, err := netip.ParseAddr(via) + if err != nil { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) } rRoute, ok := m["route"] if !ok { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1) } - viaVpnIp := iputil.Ip2VpnIp(nVia) - install := true rInstall, ok := m["install"] if ok { @@ -224,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { } r := Route{ - Via: &viaVpnIp, + Via: viaVpnIp, MTU: mtu, Metric: metric, Install: install, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } - if ipWithin(network, r.Cidr) { + if network.Contains(r.Cidr.Addr()) { return nil, fmt.Errorf( "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", i+1,
overlay/route_test.go+27 −16 modified@@ -2,19 +2,19 @@ package overlay import ( "fmt" - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) // test no routes config routes, err := parseRoutes(c, n) @@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope") + assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} @@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) { func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) // test no routes config routes, err := parseUnsafeRoutes(c, n) @@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope") + assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} @@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope") + assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} @@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) { func Test_makeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, @@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) { routeTree, err := makeRouteTree(l, routes, true) assert.NoError(t, err) - ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) - ok, r := routeTree.MostSpecificContains(ip) + ip, err := netip.ParseAddr("1.0.0.2") + assert.NoError(t, err) + r, ok := routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err := netip.ParseAddr("192.168.0.1") + assert.NoError(t, err) + assert.Equal(t, nip, r) + + ip, err = netip.ParseAddr("1.0.0.1") + assert.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err = netip.ParseAddr("192.168.0.2") + assert.NoError(t, err) + assert.Equal(t, nip, r) + + ip, err = netip.ParseAddr("1.1.0.1") + assert.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.False(t, ok) }
overlay/tun_android.go+9 −10 modified@@ -6,27 +6,26 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser fd int - cidr *net.IPNet + cidr netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -53,12 +52,12 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -87,7 +86,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr }
overlay/tun_darwin.go+41 −18 modified@@ -8,15 +8,15 @@ import ( "fmt" "io" "net" + "net/netip" "os" "sync/atomic" "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" @@ -25,10 +25,10 @@ import ( type tun struct { io.ReadWriteCloser Device string - cidr *net.IPNet + cidr netip.Prefix DefaultMTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] linkAddr *netroute.LinkAddr l *logrus.Logger @@ -73,7 +73,7 @@ type ifreqMTU struct { pad [8]byte } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -172,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -188,8 +188,13 @@ func (t *tun) Activate() error { var addr, mask [4]byte - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) + if !t.cidr.Addr().Is4() { + //TODO: IPV6-WORK + panic("need ipv6") + } + + addr = t.cidr.Addr().As4() + copy(mask[:], prefixToMask(t.cidr)) s, err := unix.Socket( unix.AF_INET, @@ -329,13 +334,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - ok, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, ok := t.routeTree.Load().Lookup(ip) if ok { return r } - - return 0 + return netip.Addr{} } // Get the LinkAddr for the interface of the given name @@ -384,13 +388,19 @@ func (t *tun) addRoutes(logErrors bool) error { maskAddr := &netroute.Inet4Addr{} routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) + if !r.Cidr.Addr().Is4() { + //TODO: implement ipv6 + panic("Cant handle ipv6 routes yet") + } + + routeAddr.IP = r.Cidr.Addr().As4() + //TODO: we could avoid the copy + copy(maskAddr.IP[:], prefixToMask(r.Cidr)) err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) if err != nil { @@ -435,8 +445,13 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) + if r.Cidr.Addr().Is6() { + //TODO: implement ipv6 + panic("Cant handle ipv6 routes yet") + } + + routeAddr.IP = r.Cidr.Addr().As4() + copy(maskAddr.IP[:], prefixToMask(r.Cidr)) err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) if err != nil { @@ -536,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) { return n - 4, err } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -547,3 +562,11 @@ func (t *tun) Name() string { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } + +func prefixToMask(prefix netip.Prefix) []byte { + pLen := 128 + if prefix.Addr().Is4() { + pLen = 32 + } + return net.CIDRMask(prefix.Bits(), pLen) +}
overlay/tun_disabled.go+6 −6 modified@@ -3,7 +3,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "strings" "github.com/rcrowley/go-metrics" @@ -13,15 +13,15 @@ import ( type disabledTun struct { read chan []byte - cidr *net.IPNet + cidr netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter rx metrics.Counter l *logrus.Logger } -func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ cidr: cidr, read: make(chan []byte, queueLen), @@ -43,11 +43,11 @@ func (*disabledTun) Activate() error { return nil } -func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { + return netip.Addr{} } -func (t *disabledTun) Cidr() *net.IPNet { +func (t *disabledTun) Cidr() netip.Prefix { return t.cidr }
overlay/tun_freebsd.go+11 −12 modified@@ -9,18 +9,17 @@ import ( "fmt" "io" "io/fs" - "net" + "net/netip" "os" "os/exec" "strconv" "sync/atomic" "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) @@ -48,10 +47,10 @@ type ifreqDestroy struct { type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -79,11 +78,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error @@ -174,7 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -233,12 +232,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -253,7 +252,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue }
overlay/tun.go+5 −5 modified@@ -1,7 +1,7 @@ package overlay import ( - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -11,9 +11,9 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) @@ -25,12 +25,12 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { return newTunFromFd(c, l, *fd, tunCidr) } } -func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) { +func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { return false, nil, nil }
overlay/tun_ios.go+9 −10 modified@@ -7,32 +7,31 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "os" "sync" "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser - cidr *net.IPNet + cidr netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ cidr: cidr, @@ -80,8 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -143,7 +142,7 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr }
overlay/tun_linux.go+56 −35 modified@@ -4,19 +4,18 @@ package overlay import ( - "bytes" "fmt" "io" "net" + "net/netip" "os" "strings" "sync/atomic" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" @@ -26,15 +25,15 @@ type tun struct { io.ReadWriteCloser fd int Device string - cidr *net.IPNet + cidr netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int deviceIndex int ioctlFd uintptr Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] routeChan chan struct{} useSystemRoutes bool @@ -65,7 +64,7 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, cidr) @@ -78,7 +77,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) return t, nil } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -123,7 +122,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), @@ -231,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return file, nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -275,8 +274,10 @@ func (t *tun) Activate() error { var addr, mask [4]byte - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) + //TODO: IPV6-WORK + addr = t.cidr.Addr().As4() + tmask := net.CIDRMask(t.cidr.Bits(), 32) + copy(mask[:], tmask) s, err := unix.Socket( unix.AF_INET, @@ -364,14 +365,19 @@ func (t *tun) setMTU() { func (t *tun) setDefaultRoute() error { // Default route - dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask} + + dr := &net.IPNet{ + IP: t.cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, Dst: dr, MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: t.cidr.IP, + Src: net.IP(t.cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, @@ -392,9 +398,14 @@ func (t *tun) addRoutes(logErrors bool) error { continue } + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, - Dst: r.Cidr, + Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, @@ -426,9 +437,14 @@ func (t *tun) removeRoutes(routes []Route) { continue } + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, - Dst: r.Cidr, + Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, @@ -447,7 +463,7 @@ func (t *tun) removeRoutes(routes []Route) { } } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -499,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - if !t.cidr.Contains(r.Gw) { + //TODO: IPV6-WORK what if not ok? + gwAddr, ok := netip.AddrFromSlice(r.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") + return + } + + gwAddr = gwAddr.Unmap() + if !t.cidr.Contains(gwAddr) { // Gateway isn't in our overlay network, ignore t.l.WithField("route", r).Debug("Ignoring route update, not in our network") return @@ -511,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - newTree := cidr.NewTree4[iputil.VpnIp]() - if r.Type == unix.RTM_NEWROUTE { - for _, oldR := range t.routeTree.Load().List() { - newTree.AddCIDR(oldR.CIDR, oldR.Value) - } + dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") + return + } + + ones, _ := r.Dst.Mask.Size() + dst := netip.PrefixFrom(dstAddr, ones) + + newTree := t.routeTree.Load().Clone() + if r.Type == unix.RTM_NEWROUTE { t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") - newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw)) + newTree.Insert(dst, gwAddr) } else { - gw := iputil.Ip2VpnIp(r.Gw) - for _, oldR := range t.routeTree.Load().List() { - if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw { - // This is the record to delete - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") - continue - } - - newTree.AddCIDR(oldR.CIDR, oldR.Value) - } + newTree.Delete(dst) + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") } - t.routeTree.Store(newTree) }
overlay/tun_netbsd.go+14 −15 modified@@ -6,7 +6,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" @@ -15,10 +15,9 @@ import ( "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) @@ -29,10 +28,10 @@ type ifreqDestroy struct { type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -59,13 +58,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error @@ -109,13 +108,13 @@ func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -168,12 +167,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -188,12 +187,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -214,7 +213,7 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
overlay/tun_openbsd.go+14 −15 modified@@ -6,27 +6,26 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" "strconv" "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -43,13 +42,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") @@ -127,7 +126,7 @@ func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -139,7 +138,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -149,20 +148,20 @@ func (t *tun) Activate() error { return t.addRoutes(false) } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -183,7 +182,7 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") @@ -194,7 +193,7 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr }
overlay/tun_tester.go+9 −10 modified@@ -6,29 +6,28 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type TestTun struct { Device string - cidr *net.IPNet + cidr netip.Prefix Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] + routeTree *bart.Table[netip.Addr] l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, cidr, true) if err != nil { return nil, err @@ -49,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, e }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -87,16 +86,16 @@ func (t *TestTun) Get(block bool) []byte { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Lookup(ip) return r } func (t *TestTun) Activate() error { return nil } -func (t *TestTun) Cidr() *net.IPNet { +func (t *TestTun) Cidr() netip.Prefix { return t.cidr }
overlay/tun_water_windows.go+11 −11 modified@@ -4,30 +4,30 @@ import ( "fmt" "io" "net" + "net/netip" "os/exec" "strconv" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" "github.com/songgao/water" ) type waterTun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger f *net.Interface *water.Interface } -func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) { +func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) { // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() t := &waterTun{ cidr: cidr, @@ -70,8 +70,8 @@ func (t *waterTun) Activate() error { `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", fmt.Sprintf("name=%s", t.Device), "source=static", - fmt.Sprintf("addr=%s", t.cidr.IP), - fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)), + fmt.Sprintf("addr=%s", t.cidr.Addr()), + fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())), "gateway=none", ).Run() if err != nil { @@ -141,7 +141,7 @@ func (t *waterTun) addRoutes(logErrors bool) error { // Path routes routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } @@ -182,12 +182,12 @@ func (t *waterTun) removeRoutes(routes []Route) { } } -func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *waterTun) Cidr() *net.IPNet { +func (t *waterTun) Cidr() netip.Prefix { return t.cidr }
overlay/tun_windows.go+3 −3 modified@@ -5,7 +5,7 @@ package overlay import ( "fmt" - "net" + "net/netip" "os" "path/filepath" "runtime" @@ -15,11 +15,11 @@ import ( "github.com/slackhq/nebula/config" ) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { useWintun := true if err := checkWinTunExists(); err != nil { l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
overlay/tun_wintun_windows.go+12 −38 modified@@ -4,15 +4,13 @@ import ( "crypto" "fmt" "io" - "net" "net/netip" "sync/atomic" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" @@ -23,11 +21,10 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { Device string - cidr *net.IPNet - prefix netip.Prefix + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger tun *wintun.NativeTun @@ -52,22 +49,16 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) { return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } -func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) { +func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) { deviceName := c.GetString("tun.dev", "") guid, err := generateGUIDByDeviceName(deviceName) if err != nil { return nil, fmt.Errorf("generate GUID failed: %w", err) } - prefix, err := iputil.ToNetIpPrefix(*cidr) - if err != nil { - return nil, err - } - t := &winTun{ Device: deviceName, cidr: cidr, - prefix: prefix, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -140,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error { func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) - err := luid.SetIPAddresses([]netip.Prefix{t.prefix}) + err := luid.SetIPAddresses([]netip.Prefix{t.cidr}) if err != nil { return fmt.Errorf("failed to set address: %w", err) } @@ -159,24 +150,13 @@ func (t *winTun) addRoutes(logErrors bool) error { foundDefault4 := false for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) - if err != nil { - retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - continue - } else { - return retErr - } - } - // Add our unsafe route - err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric)) + err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) if logErrors { @@ -190,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error { } if !foundDefault4 { - if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { + if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { foundDefault4 = true } } @@ -221,13 +201,7 @@ func (t *winTun) removeRoutes(routes []Route) error { continue } - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) - if err != nil { - t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix") - continue - } - - err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr()) + err := luid.DeleteRoute(r.Cidr, r.Via) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -237,12 +211,12 @@ func (t *winTun) removeRoutes(routes []Route) error { return nil } -func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *winTun) Cidr() *net.IPNet { +func (t *winTun) Cidr() netip.Prefix { return t.cidr }
overlay/user.go+7 −8 modified@@ -2,18 +2,17 @@ package overlay import ( "io" - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { return NewUserDevice(tunCidr) } -func NewUserDevice(tunCidr *net.IPNet) (Device, error) { +func NewUserDevice(tunCidr netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() @@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) { } type UserDevice struct { - tunCidr *net.IPNet + tunCidr netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -39,9 +38,9 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr } -func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip } +func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil }
pki.go+2 −0 modified@@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { } if !initial { + //TODO: include check for mask equality as well + // did IP in cert change? if so, don't set currentCert := p.cs.Load().Certificate oldIPs := currentCert.Details.Ips
relay_manager.go+52 −31 modified@@ -2,14 +2,15 @@ package nebula import ( "context" + "encoding/binary" "errors" "fmt" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" ) type relayManager struct { @@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) { // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. -func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) { +func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() for i := 0; i < 32; i++ { @@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(m.RelayFromIp), - "relayTo": iputil.VpnIp(m.RelayToIp), + "relayFrom": m.RelayFromIp, + "relayTo": m.RelayToIp, "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, "vpnIp": h.vpnIp}). Info("handleCreateRelayResponse") - target := iputil.VpnIp(m.RelayToIp) + target := m.RelayToIp + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], m.RelayToIp) + targetAddr := netip.AddrFrom4(b) relay, err := rm.EstablishRelay(h, m) if err != nil { @@ -136,18 +141,20 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") return } - peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target) + peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") return } if peerRelay.State == PeerRequested { + //TODO: IPV6-WORK + b = peerHostInfo.vpnIp.As4() peerRelay.State = Established resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: uint32(peerHostInfo.vpnIp), + RelayFromIp: binary.BigEndian.Uint32(b[:]), RelayToIp: uint32(target), } msg, err := resp.Marshal() @@ -157,8 +164,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + "relayFrom": resp.RelayFromIp, + "relayTo": resp.RelayToIp, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": peerHostInfo.vpnIp}). @@ -168,9 +175,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * } func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], m.RelayFromIp) + from := netip.AddrFrom4(b) - from := iputil.VpnIp(m.RelayFromIp) - target := iputil.VpnIp(m.RelayToIp) + binary.BigEndian.PutUint32(b[:], m.RelayToIp) + target := netip.AddrFrom4(b) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, @@ -181,12 +192,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - if from == f.myVpnIp { - logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself") + if from == f.myVpnNet.Addr() { + logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } // Is the target of the relay me? - if target == f.myVpnIp { + if target == f.myVpnNet.Addr() { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { @@ -219,12 +230,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N return } + //TODO: IPV6-WORK + fromB := from.As4() + targetB := target.As4() + resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(from), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := resp.Marshal() if err != nil { @@ -233,8 +248,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + //TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now + "relayFrom": from, + "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": h.vpnIp}). @@ -253,7 +269,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N f.Handshake(target) return } - if peer.remote == nil { + if !peer.remote.IsValid() { // Only create relays to peers for whom I have a direct connection return } @@ -275,12 +291,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N sendCreateRequest = true } if sendCreateRequest { + //TODO: IPV6-WORK + fromB := h.vpnIp.As4() + targetB := target.As4() + // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := req.Marshal() if err != nil { @@ -289,8 +309,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), + //TODO: IPV6-WORK another lazy used to use the req object + "relayFrom": h.vpnIp, + "relayTo": target, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnIp": target}). @@ -321,12 +342,15 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } + //TODO: IPV6-WORK + fromB := h.vpnIp.As4() + targetB := target.As4() resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := resp.Marshal() if err != nil { @@ -335,8 +359,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + //TODO: IPV6-WORK more lazy, used to use resp object + "relayFrom": h.vpnIp, + "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": h.vpnIp}). @@ -349,7 +374,3 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } } } - -func (rm *relayManager) RemoveRelay(localIdx uint32) { - rm.hostmap.RemoveRelay(localIdx) -}
remote_list.go+73 −93 modified@@ -1,7 +1,6 @@ package nebula import ( - "bytes" "context" "net" "net/netip" @@ -12,16 +11,14 @@ import ( "time" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // forEachFunc is used to benefit folks that want to do work inside the lock -type forEachFunc func(addr *udp.Addr, preferred bool) +type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool +type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -30,9 +27,9 @@ type CacheMap map[string]*Cache // Cache is the other part of CacheMap to better represent the lighthouse cache for humans // We don't reason about ipv4 vs ipv6 here type Cache struct { - Learned []*udp.Addr `json:"learned,omitempty"` - Reported []*udp.Addr `json:"reported,omitempty"` - Relay []*net.IP `json:"relay"` + Learned []netip.AddrPort `json:"learned,omitempty"` + Reported []netip.AddrPort `json:"reported,omitempty"` + Relay []netip.Addr `json:"relay"` } //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion @@ -46,7 +43,7 @@ type cache struct { } type cacheRelay struct { - relay []uint32 + relay []netip.Addr } // cacheV4 stores learned and reported ipv4 records under cache @@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, continue } for _, a := range addrs { - netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} + netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{} } } origSet := r.ips.Load() @@ -193,22 +190,22 @@ type RemoteList struct { sync.RWMutex // A deduplicated set of addresses. Any accessor should lock beforehand. - addrs []*udp.Addr + addrs []netip.AddrPort // A set of relay addresses. VpnIp addresses that the remote identified as relays. - relays []*iputil.VpnIp + relays []netip.Addr // These are maps to store v4 and v6 addresses per lighthouse // Map key is the vpnIp of the person that told us about this the cached entries underneath. // For learned addresses, this is the vpnIp that sent the packet - cache map[iputil.VpnIp]*cache + cache map[netip.Addr]*cache hr *hostnamesResults shouldAdd func(netip.Addr) bool // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake - badRemotes []*udp.Addr + badRemotes []netip.AddrPort // A flag that the cache may have changed and addrs needs to be rebuilt shouldRebuild bool @@ -217,9 +214,9 @@ type RemoteList struct { // NewRemoteList creates a new empty RemoteList func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { return &RemoteList{ - addrs: make([]*udp.Addr, 0), - relays: make([]*iputil.VpnIp, 0), - cache: make(map[iputil.VpnIp]*cache), + addrs: make([]netip.AddrPort, 0), + relays: make([]netip.Addr, 0), + cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } } @@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { +func (r *RemoteList) Len(preferredRanges []netip.Prefix) int { r.Rebuild(preferredRanges) r.RLock() defer r.RUnlock() @@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { // ForEach locks and will call the forEachFunc for every deduplicated address in the list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) { +func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) { r.Rebuild(preferredRanges) r.RLock() for _, v := range r.addrs { - forEach(v, isPreferred(v.IP, preferredRanges)) + forEach(v, isPreferred(v.Addr(), preferredRanges)) } r.RUnlock() } // CopyAddrs locks and makes a deep copy of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { +func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort { if r == nil { return nil } @@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.addrs)) + c := make([]netip.AddrPort, len(r.addrs)) for i, v := range r.addrs { - c[i] = v.Copy() + c[i] = v } return c } @@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available // TODO: this needs to support the allow list list -func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) { +func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() - if v4 := addr.IP.To4(); v4 != nil { - r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port))) + if remote.Addr().Is4() { + r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) } else { - r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port))) + r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) } } @@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap { c := cm[vpnIp] if c == nil { c = &Cache{ - Learned: make([]*udp.Addr, 0), - Reported: make([]*udp.Addr, 0), - Relay: make([]*net.IP, 0), + Learned: make([]netip.AddrPort, 0), + Reported: make([]netip.AddrPort, 0), + Relay: make([]netip.Addr, 0), } cm[vpnIp] = c } @@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap { if mc.v4 != nil { if mc.v4.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned)) + c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) } for _, a := range mc.v4.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH4(a)) + c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned)) + c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) } for _, a := range mc.v6.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH6(a)) + c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) } } if mc.relay != nil { for _, a := range mc.relay.relay { - nip := iputil.VpnIp(a).ToIP() - c.Relay = append(c.Relay, &nip) + c.Relay = append(c.Relay, a) } } } @@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap { } // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list -func (r *RemoteList) BlockRemote(bad *udp.Addr) { - if bad == nil { +func (r *RemoteList) BlockRemote(bad netip.AddrPort) { + if !bad.IsValid() { // relays can have nil udp Addrs return } @@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) { } // We copy here because we are taking something else's memory and we can't trust everything - r.badRemotes = append(r.badRemotes, bad.Copy()) + r.badRemotes = append(r.badRemotes, bad) // Mark the next interaction must recollect/dedupe r.shouldRebuild = true } // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list -func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr { +func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.badRemotes)) + c := make([]netip.AddrPort, len(r.badRemotes)) for i, v := range r.badRemotes { - c[i] = v.Copy() + c[i] = v } return c } @@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() { // Rebuild locks and generates the deduplicated address list only if there is work to be done // There is generally no reason to call this directly but it is safe to do so -func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { +func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { r.Lock() defer r.Unlock() @@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { } // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list -func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { +func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { for _, v := range r.badRemotes { - if v.Equals(remote) { + if v == remote { return true } } @@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, } } -func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) { +func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) { r.shouldRebuild = true c := r.unlockedGetOrMakeRelay(ownerVpnIp) @@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) } } -func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay { +func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { +func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 { +func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() { for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { - u := NewUDPAddrFromLH4(c.v4.learned) + u := AddrPortFromIp4AndPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { - u := NewUDPAddrFromLH4(v) + u := AddrPortFromIp4AndPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() { if c.v6 != nil { if c.v6.learned != nil { - u := NewUDPAddrFromLH6(c.v6.learned) + u := AddrPortFromIp6AndPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { - u := NewUDPAddrFromLH6(v) + u := AddrPortFromIp6AndPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -572,20 +568,15 @@ func (r *RemoteList) unlockedCollect() { if c.relay != nil { for _, v := range c.relay.relay { - ip := iputil.VpnIp(v) - relays = append(relays, &ip) + relays = append(relays, v) } } } dnsAddrs := r.hr.GetIPs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { - v6 := addr.Addr().As16() - addrs = append(addrs, &udp.Addr{ - IP: v6[:], - Port: addr.Port(), - }) + addrs = append(addrs, addr) } } @@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() { } // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list -func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { +func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { n := len(r.addrs) if n < 2 { return @@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { b := r.addrs[j] // Preferred addresses first - aPref := isPreferred(a.IP, preferredRanges) - bPref := isPreferred(b.IP, preferredRanges) + aPref := isPreferred(a.Addr(), preferredRanges) + bPref := isPreferred(b.Addr(), preferredRanges) switch { case aPref && !bPref: // If i is preferred and j is not, i is less than j @@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // ipv6 addresses 2nd - a4 := a.IP.To4() - b4 := b.IP.To4() + a4 := a.Addr().Is4() + b4 := b.Addr().Is4() switch { - case a4 == nil && b4 != nil: + case a4 == false && b4 == true: // If i is v6 and j is v4, i is less than j return true - case a4 != nil && b4 == nil: + case a4 == true && b4 == false: // If j is v6 and i is v4, i is not less than j return false - case a4 != nil && b4 != nil: - // Special case for ipv4, a4 and b4 are not nil - aPrivate := isPrivateIP(a4) - bPrivate := isPrivateIP(b4) + case a4 == true && b4 == true: + // i and j are both ipv4 + aPrivate := a.Addr().IsPrivate() + bPrivate := b.Addr().IsPrivate() switch { case !aPrivate && bPrivate: // If i is a public ip (not private) and j is a private ip, i is less then j @@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // lexical order of ips 3rd - c := bytes.Compare(a.IP, b.IP) + c := a.Addr().Compare(b.Addr()) if c == 0 { // Ips are the same, Lexical order of ports 4th - return a.Port < b.Port + return a.Port() < b.Port() } // Ip wasn't the same @@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { // Deduplicate a, b := 0, 1 for b < n { - if !r.addrs[a].Equals(r.addrs[b]) { + if r.addrs[a] != r.addrs[b] { a++ if a != b { r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a] @@ -693,7 +684,7 @@ func minInt(a, b int) int { } // isPreferred returns true of the ip is contained in the preferredRanges list -func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { +func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool { //TODO: this would be better in a CIDR6Tree for _, p := range preferredRanges { if p.Contains(ip) { @@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { } return false } - -var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8") -var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12") -var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16") - -// isPrivateIP returns true if the ip is contained by a rfc 1918 private range -func isPrivateIP(ip net.IP) bool { - //TODO: another great cidrtree option - //TODO: Private for ipv6 or just let it ride? - return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip) -}
remote_list_test.go+98 −89 modified@@ -1,47 +1,47 @@ package nebula import ( - "net" + "encoding/binary" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" ) func TestRemoteList_Rebuild(t *testing.T) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe + newIp4AndPortFromString("70.199.182.92:1475"), // this is duped + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 1, - 1, + netip.MustParseAddr("0.0.0.1"), + netip.MustParseAddr("0.0.0.1"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped - NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe - NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), // this is duped + newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe + newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv6 first, sorted lexically within @@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) // Now ensure we can hoist ipv4 up - _, ipNet, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv4 first, public then private, lexically within them @@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) // Ensure we can hoist a specific ipv4 range over anything else - _, ipNet, err = net.ParseCIDR("172.17.0.0/16") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // Preferred ipv4 first @@ -104,132 +100,145 @@ func TestRemoteList_Rebuild(t *testing.T) { func BenchmarkFullRebuild(b *testing.B) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - assert.NoError(b, err) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - assert.NoError(b, err) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(b, err) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } func BenchmarkSortRebuild(b *testing.B) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - rl.Rebuild([]*net.IPNet{ipNet}) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") + rl.Rebuild([]netip.Prefix{ipNet1}) - assert.NoError(b, err) b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) - assert.NoError(b, err) b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) - assert.NoError(b, err) b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } + +func newIp4AndPortFromString(s string) *Ip4AndPort { + a := netip.MustParseAddrPort(s) + v4Addr := a.Addr().As4() + return &Ip4AndPort{ + Ip: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(a.Port()), + } +} + +func newIp6AndPortFromString(s string) *Ip6AndPort { + a := netip.MustParseAddrPort(s) + v6Addr := a.Addr().As16() + return &Ip6AndPort{ + Hi: binary.BigEndian.Uint64(v6Addr[:8]), + Lo: binary.BigEndian.Uint64(v6Addr[8:]), + Port: uint32(a.Port()), + } +}
service/service.go+1 −1 modified@@ -91,7 +91,7 @@ func New(config *config.C) (*Service, error) { ipNet := device.Cidr() pa := tcpip.ProtocolAddress{ - AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
service/service_test.go+6 −10 modified@@ -4,7 +4,7 @@ import ( "bytes" "context" "errors" - "net" + "net/netip" "testing" "time" @@ -18,12 +18,8 @@ import ( type m map[string]interface{} -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { - - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { panic(err) @@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, } func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ + ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, @@ -94,7 +90,7 @@ func TestService(t *testing.T) { "port": 4243, }, }) - b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ + b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{"localhost:4243"}, },
ssh.go+29 −36 modified@@ -7,6 +7,7 @@ import ( "flag" "fmt" "net" + "net/netip" "os" "reflect" "runtime" @@ -18,9 +19,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/sshd" - "github.com/slackhq/nebula/udp" ) type sshListHostMapFlags struct { @@ -431,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } sort.Slice(hm, func(i, j int) bool { - return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 + return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 }) if fs.Json || fs.Pretty { @@ -545,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -574,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -616,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -636,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } - var addr *udp.Addr + var addr netip.AddrPort if flags.Address != "" { - addr = udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err = netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } } hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) - if addr != nil { + if addr.IsValid() { hostInfo.SetRemote(addr) } @@ -667,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("No address was provided") } - addr := udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err := netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -792,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit cert := ifce.pki.GetCertState().Certificate if len(a) > 0 { - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -862,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr Error error Type string State string - PeerIp iputil.VpnIp + PeerIp netip.Addr LocalIndex uint32 RemoteIndex uint32 - RelayedThrough []iputil.VpnIp + RelayedThrough []netip.Addr } type RelayOutput struct { - NebulaIp iputil.VpnIp + NebulaIp netip.Addr RelayForIps []RelayFor } @@ -952,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) }
test/tun.go+5 −7 modified@@ -3,23 +3,21 @@ package test import ( "errors" "io" - "net" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type NoopTun struct{} -func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (NoopTun) RouteFor(addr netip.Addr) netip.Addr { + return netip.Addr{} } func (NoopTun) Activate() error { return nil } -func (NoopTun) Cidr() *net.IPNet { - return nil +func (NoopTun) Cidr() netip.Prefix { + return netip.Prefix{} } func (NoopTun) Name() string {
timeout_test.go+5 −4 modified@@ -1,6 +1,7 @@ package nebula import ( + "net/netip" "testing" "time" @@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) { assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ - {LocalIP: 1}, - {LocalIP: 2}, - {LocalIP: 3}, - {LocalIP: 4}, + {LocalIP: netip.MustParseAddr("0.0.0.1")}, + {LocalIP: netip.MustParseAddr("0.0.0.2")}, + {LocalIP: netip.MustParseAddr("0.0.0.3")}, + {LocalIP: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1)
udp/conn.go+8 −6 modified@@ -1,6 +1,8 @@ package udp import ( + "net/netip" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -9,7 +11,7 @@ import ( const MTU = 9001 type EncReader func( - addr *Addr, + addr netip.AddrPort, out []byte, packet []byte, header *header.H, @@ -22,9 +24,9 @@ type EncReader func( type Conn interface { Rebind() error - LocalAddr() (*Addr, error) + LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) - WriteTo(b []byte, addr *Addr) error + WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error } @@ -34,13 +36,13 @@ type NoopConn struct{} func (NoopConn) Rebind() error { return nil } -func (NoopConn) LocalAddr() (*Addr, error) { - return nil, nil +func (NoopConn) LocalAddr() (netip.AddrPort, error) { + return netip.AddrPort{}, nil } func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { return } -func (NoopConn) WriteTo(_ []byte, _ *Addr) error { +func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } func (NoopConn) ReloadConfig(_ *config.C) {
udp/temp.go+3 −2 modified@@ -1,9 +1,10 @@ package udp import ( - "github.com/slackhq/nebula/iputil" + "net/netip" ) //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare -type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) +// TODO: IPV6-WORK this can likely be removed now +type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)
udp/udp_all.go+0 −100 removed@@ -1,100 +0,0 @@ -package udp - -import ( - "encoding/json" - "fmt" - "net" - "strconv" -) - -type m map[string]interface{} - -type Addr struct { - IP net.IP - Port uint16 -} - -func NewAddr(ip net.IP, port uint16) *Addr { - addr := Addr{IP: make([]byte, net.IPv6len), Port: port} - copy(addr.IP, ip.To16()) - return &addr -} - -func NewAddrFromString(s string) *Addr { - ip, port, err := ParseIPAndPort(s) - //TODO: handle err - _ = err - return &Addr{IP: ip.To16(), Port: port} -} - -func (ua *Addr) Equals(t *Addr) bool { - if t == nil || ua == nil { - return t == nil && ua == nil - } - return ua.IP.Equal(t.IP) && ua.Port == t.Port -} - -func (ua *Addr) String() string { - if ua == nil { - return "<nil>" - } - - return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port)) -} - -func (ua *Addr) MarshalJSON() ([]byte, error) { - if ua == nil { - return nil, nil - } - - return json.Marshal(m{"ip": ua.IP, "port": ua.Port}) -} - -func (ua *Addr) Copy() *Addr { - if ua == nil { - return nil - } - - nu := Addr{ - Port: ua.Port, - IP: make(net.IP, len(ua.IP)), - } - - copy(nu.IP, ua.IP) - return &nu -} - -type AddrSlice []*Addr - -func (a AddrSlice) Equal(b AddrSlice) bool { - if len(a) != len(b) { - return false - } - - for i := range a { - if !a[i].Equals(b[i]) { - return false - } - } - - return true -} - -func ParseIPAndPort(s string) (net.IP, uint16, error) { - rIp, sPort, err := net.SplitHostPort(s) - if err != nil { - return nil, 0, err - } - - addr, err := net.ResolveIPAddr("ip", rIp) - if err != nil { - return nil, 0, err - } - - iPort, err := strconv.Atoi(sPort) - if err != nil { - return nil, 0, err - } - - return addr.IP, uint16(iPort), nil -}
udp/udp_android.go+2 −1 modified@@ -6,13 +6,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) }
udp/udp_bsd.go+2 −1 modified@@ -9,13 +9,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) }
udp/udp_darwin.go+2 −1 modified@@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) }
udp/udp_generic.go+23 −14 modified@@ -11,6 +11,7 @@ import ( "context" "fmt" "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -25,7 +26,7 @@ type GenericConn struct { var _ Conn = &GenericConn{} -func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -37,23 +38,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } -func (u *GenericConn) WriteTo(b []byte, addr *Addr) error { - _, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) +func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { + _, err := u.UDPConn.WriteToUDPAddrPort(b, addr) return err } -func (u *GenericConn) LocalAddr() (*Addr, error) { +func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: - addr := &Addr{IP: make([]byte, len(v.IP))} - copy(addr.IP, v.IP) - addr.Port = uint16(v.Port) - return addr, nil + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: - return nil, fmt.Errorf("LocalAddr returned: %#v", a) + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } @@ -75,19 +77,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f buffer := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { // Just read one packet at a time - n, rua, err := u.ReadFromUDP(buffer) + n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } - udpAddr.IP = rua.IP - udpAddr.Port = uint16(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), + plaintext[:0], + buffer[:n], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } }
udp/udp_linux.go+43 −32 modified@@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "syscall" "unsafe" @@ -35,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) { return ip, false } -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { - ipV4, isV4 := maybeIPV4(ip) +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { af := unix.AF_INET6 - if isV4 { + if ip.Is4() { af = unix.AF_INET } syscall.ForkLock.RLock() @@ -61,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //TODO: support multiple listening IPs (for limiting ipv6) var sa unix.Sockaddr - if isV4 { + if ip.Is4() { sa4 := &unix.SockaddrInet4{Port: port} - copy(sa4.Addr[:], ipV4) + sa4.Addr = ip.As4() sa = sa4 } else { sa6 := &unix.SockaddrInet6{Port: port} - copy(sa6.Addr[:], ip.To16()) + sa6.Addr = ip.As16() sa = sa6 } if err = unix.Bind(fd, sa); err != nil { @@ -79,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err + return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } func (u *StdConn) Rebind() error { @@ -102,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } -func (u *StdConn) LocalAddr() (*Addr, error) { +func (u *StdConn) LocalAddr() (netip.AddrPort, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { - return nil, err + return netip.AddrPort{}, err } - addr := &Addr{} switch sa := sa.(type) { case *unix.SockaddrInet4: - addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() - addr.Port = uint16(sa.Port) + return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil + case *unix.SockaddrInet6: - addr.IP = sa.Addr[0:] - addr.Port = uint16(sa.Port) - } + return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil - return addr, nil + default: + return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + } } func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{} + var ip netip.Addr nb := make([]byte, 12, 12) //TODO: should we track this? @@ -146,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew //metric.Update(int64(n)) for i := 0; i < n; i++ { if u.isV4 { - udpAddr.IP = names[i][4:8] + ip, _ = netip.AddrFromSlice(names[i][4:8]) + //TODO: IPV6-WORK what is not ok? } else { - udpAddr.IP = names[i][8:24] + ip, _ = netip.AddrFromSlice(names[i][8:24]) + //TODO: IPV6-WORK what is not ok? } - udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), + plaintext[:0], + buffers[i][:msgs[i].Len], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } } @@ -197,19 +207,20 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *StdConn) WriteTo(b []byte, addr *Addr) error { +func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { if u.isV4 { - return u.writeTo4(b, addr) + return u.writeTo4(b, ip) } - return u.writeTo6(b, addr) + return u.writeTo6(b, ip) } -func (u *StdConn) writeTo6(b []byte, addr *Addr) error { +func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 + rsa.Addr = ip.Addr().As16() + port := ip.Port() // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addr.IP.To16()) + rsa.Port = (port >> 8) | ((port & 0xff) << 8) for { _, _, err := unix.Syscall6( @@ -232,17 +243,17 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error { } } -func (u *StdConn) writeTo4(b []byte, addr *Addr) error { - addrV4, isAddrV4 := maybeIPV4(addr.IP) - if !isAddrV4 { +func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { + if !ip.Addr().Is4() { return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") } var rsa unix.RawSockaddrInet4 rsa.Family = unix.AF_INET + rsa.Addr = ip.Addr().As4() + port := ip.Port() // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addrV4) + rsa.Port = (port >> 8) | ((port & 0xff) << 8) for { _, _, err := unix.Syscall6(
udp/udp_netbsd.go+2 −1 modified@@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) }
udp/udp_rio_windows.go+22 −21 modified@@ -10,6 +10,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "sync/atomic" "syscall" @@ -61,16 +62,14 @@ type RIOConn struct { results [packetsPerRing]winrio.Result } -func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { +func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } u := &RIOConn{l: l} - addr := [16]byte{} - copy(addr[:], ip.To16()) - err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port}) + err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port}) if err != nil { return nil, fmt.Errorf("bind: %w", err) } @@ -124,7 +123,6 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew buffer := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { @@ -135,11 +133,17 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - udpAddr.IP = rua.Addr[:] - p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) - p[0] = byte(rua.Port >> 8) - p[1] = byte(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), + plaintext[:0], + buffer[:n], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } @@ -231,7 +235,7 @@ retry: return n, ep, nil } -func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { +func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { if !u.isOpen.Load() { return net.ErrClosed } @@ -274,10 +278,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { packet := u.tx.Push() packet.addr.Family = windows.AF_INET6 - p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) - p[0] = byte(addr.Port >> 8) - p[1] = byte(addr.Port) - copy(packet.addr.Addr[:], addr.IP.To16()) + packet.addr.Addr = ip.Addr().As16() + port := ip.Port() + packet.addr.Port = (port >> 8) | ((port & 0xff) << 8) copy(packet.data[:], buf) dataBuffer := &winrio.Buffer{ @@ -295,17 +298,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (u *RIOConn) LocalAddr() (*Addr, error) { +func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { - return nil, err + return netip.AddrPort{}, err } v6 := sa.(*windows.SockaddrInet6) - return &Addr{ - IP: v6.Addr[:], - Port: uint16(v6.Port), - }, nil + return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil + } func (u *RIOConn) Rebind() error {
udp/udp_tester.go+17 −32 modified@@ -4,9 +4,8 @@ package udp import ( - "fmt" "io" - "net" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" @@ -16,30 +15,24 @@ import ( ) type Packet struct { - ToIp net.IP - ToPort uint16 - FromIp net.IP - FromPort uint16 - Data []byte + To netip.AddrPort + From netip.AddrPort + Data []byte } func (u *Packet) Copy() *Packet { n := &Packet{ - ToIp: make(net.IP, len(u.ToIp)), - ToPort: u.ToPort, - FromIp: make(net.IP, len(u.FromIp)), - FromPort: u.FromPort, - Data: make([]byte, len(u.Data)), + To: u.To, + From: u.From, + Data: make([]byte, len(u.Data)), } - copy(n.ToIp, u.ToIp) - copy(n.FromIp, u.FromIp) copy(n.Data, u.Data) return n } type TesterConn struct { - Addr *Addr + Addr netip.AddrPort RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula @@ -48,9 +41,9 @@ type TesterConn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ - Addr: &Addr{ip, uint16(port)}, + Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), l: l, @@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) { } if u.l.Level >= logrus.DebugLevel { u.l.WithField("header", h). - WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). + WithField("udpAddr", packet.From). WithField("dataLen", len(packet.Data)). Debug("UDP receiving injected packet") } @@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (u *TesterConn) WriteTo(b []byte, addr *Addr) error { +func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { if u.closed.Load() { return io.ErrClosedPipe } p := &Packet{ - Data: make([]byte, len(b), len(b)), - FromIp: make([]byte, 16), - FromPort: u.Addr.Port, - ToIp: make([]byte, 16), - ToPort: addr.Port, + Data: make([]byte, len(b), len(b)), + From: u.Addr, + To: addr, } copy(p.Data, b) - copy(p.ToIp, addr.IP.To16()) - copy(p.FromIp, u.Addr.IP.To16()) - u.TxPackets <- p return nil } @@ -123,17 +111,14 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - ua := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { p, ok := <-u.RxPackets if !ok { return } - ua.Port = p.FromPort - copy(ua.IP, p.FromIp.To16()) - r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } @@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() { return func() {} } -func (u *TesterConn) LocalAddr() (*Addr, error) { +func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil }
udp/udp_windows.go+2 −1 modified@@ -6,12 +6,13 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
5News mentions
0No linked articles in our index yet.