Skip to content

Commit

Permalink
Fix relay (#827)
Browse files Browse the repository at this point in the history
Co-authored-by: Nate Brown <[email protected]>
  • Loading branch information
brad-defined and nbrownus authored Mar 30, 2023
1 parent e28336c commit 2801fb2
Show file tree
Hide file tree
Showing 12 changed files with 276 additions and 103 deletions.
8 changes: 7 additions & 1 deletion connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
break
}

hostinfo, err := n.hostMap.QueryIndex(localIndex)
hostinfo, mainHostInfo, err := n.hostMap.QueryIndexIsPrimary(localIndex)
if err != nil {
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
n.ClearLocalIndex(localIndex)
Expand All @@ -269,6 +269,12 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {

n.ClearLocalIndex(localIndex)
n.ClearPendingDeletion(localIndex)

if !mainHostInfo {
// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
// Keep tracking so that we can tear it down when it goes away
n.Out(localIndex)
}
continue
}

Expand Down
2 changes: 1 addition & 1 deletion control.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
hostInfos := []*HostInfo{}
// Grab the hostMap lock to access the Hosts map
c.f.hostMap.Lock()
for _, relayHost := range c.f.hostMap.Hosts {
for _, relayHost := range c.f.hostMap.Indexes {
if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
hostInfos = append(hostInfos, relayHost)
}
Expand Down
118 changes: 92 additions & 26 deletions e2e/handshakes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
Expand Down Expand Up @@ -393,47 +394,112 @@ func TestStage1RaceRelays(t *testing.T) {
relayControl.Start()
theirControl.Start()

r.Log("Trigger a handshake to start on both me and relay")
myControl.InjectTunUDPPacket(relayVpnIpNet.IP, 80, 80, []byte("Hi from me"))
relayControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from relay"))
r.Log("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)

r.Log("Get both stage 1 handshake packets")
//TODO: this is where it breaks, we need to get the hs packets for the relay not for the destination
myHsForThem := myControl.GetFromUDP(true)
relayHsForMe := relayControl.GetFromUDP(true)
r.Log("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)

r.Log("Now inject both stage 1 handshake packets")
r.InjectUDPPacket(relayControl, myControl, relayHsForMe)
r.InjectUDPPacket(myControl, relayControl, myHsForThem)
r.Log("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))

r.Log("Route for me until I send a message packet to relay")
r.RouteForAllUntilAfterMsgTypeTo(relayControl, header.Message, header.MessageNone)
r.Log("Wait for a packet from them to me")
p := r.RouteForAllUntilTxTun(myControl)
_ = p

r.Log("My cached packet should be received by relay")
myCachedPacket := relayControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, relayVpnIpNet.IP, 80, 80)
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
//
////TODO: assert hostmaps
}

r.Log("Relays cached packet should be received by me")
relayCachedPacket := r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from relay"), relayCachedPacket, relayVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
func TestStage1RaceRelays2(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
l := NewTestLogger()

// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)

myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})

r.Log("Do a bidirectional tunnel test; me and relay")
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)

// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()

// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()

r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)

r.Log("Create a tunnel between relay and them")
r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)

r.RenderHostmaps("Starting hostmaps", myControl, relayControl, theirControl)
r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))

//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)

r.Log("Trigger a handshake to start from me to them via the relay")
//TODO: if we initiate a handshake from me and then assert the tunnel it will cause a relay control race that can blow up
// this is a problem that exists on master today
//myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
r.Log("Wait for a packet from them to me")
l.Info("Wait for a packet from them to me; myControl")
r.RouteForAllUntilTxTun(myControl)
l.Info("Wait for a packet from them to me; theirControl")
r.RouteForAllUntilTxTun(theirControl)

r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)

t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels")
l.WithFields(
logrus.Fields{
"myControl": len(myControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
retries := 60
for hostInfos > 6 && retries > 0 {
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
l.WithFields(
logrus.Fields{
"myControl": len(myControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
retries--
}

r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)

myControl.Stop()
theirControl.Stop()
relayControl.Stop()

//
////TODO: assert hostmaps
}
Expand Down
4 changes: 4 additions & 0 deletions e2e/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
"level": l.Level.String(),
},
"timers": m{
"pending_deletion_interval": 4,
"connection_alive_interval": 4,
},
}

if overrides != nil {
Expand Down
20 changes: 13 additions & 7 deletions e2e/router/hostmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)

hm := c.GetHostmap()
hm.RLock()
defer hm.RUnlock()

// Draw the vpn to index nodes
r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName)
for _, vpnIp := range sortedHosts(hm.Hosts) {
hosts := sortedHosts(hm.Hosts)
for _, vpnIp := range hosts {
hi := hm.Hosts[vpnIp]
r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp)
lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex()))
Expand Down Expand Up @@ -94,12 +97,15 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {

// Draw the local index to relay or remote index nodes
r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName)
for _, idx := range sortedIndexes(hm.Indexes) {
hi := hm.Indexes[idx]
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
_ = hi
indexes := sortedIndexes(hm.Indexes)
for _, idx := range indexes {
hi, ok := hm.Indexes[idx]
if ok {
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
_ = hi
}
}
r += "\t\tend\n"

Expand Down
14 changes: 13 additions & 1 deletion handshake_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
Error("Failed to marshal Control message to create relay")
} else {
f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu))
c.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(c.lightHouse.myVpnIp),
"relayTarget": iputil.VpnIp(vpnIp),
"initiatorIdx": existingRelay.LocalIndex,
"vpnIp": *relay}).
Info("send CreateRelayRequest")
}
default:
hostinfo.logger(c.l).
Expand Down Expand Up @@ -267,6 +273,12 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
Error("Failed to marshal Control message to create relay")
} else {
f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu))
c.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(c.lightHouse.myVpnIp),
"relayTarget": iputil.VpnIp(vpnIp),
"initiatorIdx": idx,
"vpnIp": *relay}).
Info("send CreateRelayRequest")
}
}
}
Expand Down Expand Up @@ -333,7 +345,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
}

// Is this a newer handshake?
if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime {
if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime && !existingHostInfo.ConnectionState.initiator {
return existingHostInfo, ErrExistingHostInfo
}

Expand Down
82 changes: 66 additions & 16 deletions hostmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ type HostMap struct {
l *logrus.Logger
}

// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
// struct, make a copy of an existing value, edit the fileds in the copy, and
// then store a pointer to the new copy in both realyForBy* maps.
type RelayState struct {
sync.RWMutex

Expand Down Expand Up @@ -123,13 +126,43 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) {
rs.Lock()
defer rs.Unlock()
relay, ok := rs.relayForByIdx[localIdx]
r, ok := rs.relayForByIdx[localIdx]
if !ok {
return iputil.VpnIp(0), false
}
delete(rs.relayForByIdx, localIdx)
delete(rs.relayForByIp, relay.PeerIp)
return relay.PeerIp, true
delete(rs.relayForByIp, r.PeerIp)
return r.PeerIp, true
}

func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool {
rs.Lock()
defer rs.Unlock()
r, ok := rs.relayForByIp[vpnIp]
if !ok {
return false
}
newRelay := *r
newRelay.State = Established
newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay
return true
}

func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Relay, bool) {
rs.Lock()
defer rs.Unlock()
r, ok := rs.relayForByIdx[localIdx]
if !ok {
return nil, false
}
newRelay := *r
newRelay.State = Established
newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay
return &newRelay, true
}

func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) {
Expand All @@ -145,6 +178,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
r, ok := rs.relayForByIdx[idx]
return r, ok
}

func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
rs.Lock()
defer rs.Unlock()
Expand Down Expand Up @@ -362,25 +396,27 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
hm.unlockedDeleteHostInfo(hostinfo)
hm.Unlock()

// And tear down all the relays going through this host
// And tear down all the relays going through this host, if final
for _, localIdx := range hostinfo.relayState.CopyRelayForIdxs() {
hm.RemoveRelay(localIdx)
}

// And tear down the relays this deleted hostInfo was using to be reached
teardownRelayIdx := []uint32{}
for _, relayIp := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, err := hm.QueryVpnIp(relayIp)
if err != nil {
hm.l.WithError(err).WithField("relay", relayIp).Info("Missing relay host in hostmap")
} else {
if r, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp); ok {
teardownRelayIdx = append(teardownRelayIdx, r.LocalIndex)
if final {
// And tear down the relays this deleted hostInfo was using to be reached
teardownRelayIdx := []uint32{}
for _, relayIp := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, err := hm.QueryVpnIp(relayIp)
if err != nil {
hm.l.WithError(err).WithField("relay", relayIp).Info("Missing relay host in hostmap")
} else {
if r, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp); ok {
teardownRelayIdx = append(teardownRelayIdx, r.LocalIndex)
}
}
}
}
for _, localIdx := range teardownRelayIdx {
hm.RemoveRelay(localIdx)
for _, localIdx := range teardownRelayIdx {
hm.RemoveRelay(localIdx)
}
}

return final
Expand Down Expand Up @@ -486,6 +522,20 @@ func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
return nil, errors.New("unable to find index")
}
}

// Retrieves a HostInfo by Index. Returns whether the HostInfo is primary at time of query.
// This helper exists so that the hostinfo.prev pointer can be read while the hostmap lock is held.
func (hm *HostMap) QueryIndexIsPrimary(index uint32) (*HostInfo, bool, error) {
//TODO: we probably just want to return bool instead of error, or at least a static error
hm.RLock()
if h, ok := hm.Indexes[index]; ok {
hm.RUnlock()
return h, h.prev == nil, nil
} else {
hm.RUnlock()
return nil, false, errors.New("unable to find index")
}
}
func (hm *HostMap) QueryRelayIndex(index uint32) (*HostInfo, error) {
//TODO: we probably just want to return bool instead of error, or at least a static error
hm.RLock()
Expand Down
4 changes: 2 additions & 2 deletions hostmap_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ func (i *HostInfo) GetRemoteIndex() uint32 {
return i.remoteIndexId
}

func (i *HostInfo) GetRelayState() RelayState {
return i.relayState
func (i *HostInfo) GetRelayState() *RelayState {
return &i.relayState
}
Loading

0 comments on commit 2801fb2

Please sign in to comment.