Skip to content

Commit

Permalink
update EncReader and EncWriter interface function args to have concre…
Browse files Browse the repository at this point in the history
…te types (#844)

* Update LightHouseHandlerFunc to remove EncWriter param.
* Move EncWriter to interface
* EncReader, too
  • Loading branch information
brad-defined authored Apr 7, 2023
1 parent 3cb4e0e commit 9b03053
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 54 deletions.
2 changes: 1 addition & 1 deletion handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"github.com/slackhq/nebula/udp"
)

func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) {
func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) {
// First remote allow list check before we know the vpnIp
if addr != nil {
if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
Expand Down
25 changes: 11 additions & 14 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hostinfo.handshakeStart = time.Now()
}

func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) {
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
Expand Down Expand Up @@ -240,14 +240,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
}
return
} else {
via2 := via.(*ViaSender)
if via2 == nil {
if via == nil {
f.l.Error("Handshake send failed: both addr and via are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp).
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
return
Expand Down Expand Up @@ -315,14 +314,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
Info("Handshake message sent")
}
} else {
via2 := via.(*ViaSender)
if via2 == nil {
if via == nil {
f.l.Error("Handshake send failed: both addr and via are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp).
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
Expand All @@ -338,7 +336,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
return
}

func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool {
func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
if hostinfo == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
Expand Down Expand Up @@ -482,8 +480,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
if addr != nil {
hostinfo.SetRemote(addr)
} else {
via2 := via.(*ViaSender)
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
}

// Build up the radix for the firewall if we have subnets in the cert
Expand Down
6 changes: 3 additions & 3 deletions handshake_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
}
}

func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
clockSource := time.NewTicker(c.config.tryInterval)
defer clockSource.Stop()

Expand All @@ -89,7 +89,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
}
}

func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
c.OutboundHandshakeTimer.Advance(now)
for {
vpnIp, has := c.OutboundHandshakeTimer.Purge()
Expand All @@ -100,7 +100,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E
}
}

func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
return
Expand Down
2 changes: 1 addition & 1 deletion handshake_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
return
}

func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
return
}

Expand Down
7 changes: 2 additions & 5 deletions inside.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
// nb is a buffer used to store the nonce value, re-used for performance reasons.
// out is a buffer used to store the result of the Encrypt operation
// q indicates which writer to use to send the packet.
func (f *Interface) SendVia(viaIfc interface{},
relayIfc interface{},
func (f *Interface) SendVia(via *HostInfo,
relay *Relay,
ad,
nb,
out []byte,
nocopy bool,
) {
via := viaIfc.(*HostInfo)
relay := relayIfc.(*Relay)

if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
via.ConnectionState.writeLock.Lock()
Expand Down
15 changes: 14 additions & 1 deletion interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
Expand Down Expand Up @@ -89,6 +90,18 @@ type Interface struct {
l *logrus.Logger
}

type EncWriter interface {
SendVia(via *HostInfo,
relay *Relay,
ad,
nb,
out []byte,
nocopy bool,
)
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
Handshake(vpnIp iputil.VpnIp)
}

type sendRecvErrorConfig uint8

const (
Expand Down Expand Up @@ -238,7 +251,7 @@ func (f *Interface) listenOut(i int) {

lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
}

func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
Expand Down
22 changes: 14 additions & 8 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type LightHouse struct {
interval atomic.Int64
updateCancel context.CancelFunc
updateParentCtx context.Context
updateUdp udp.EncWriter
updateUdp EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16

advertiseAddrs atomic.Pointer[[]netIpAndPort]
Expand Down Expand Up @@ -382,7 +382,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return nil
}

func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f)
}
Expand All @@ -396,7 +396,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
}

// This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) {
if lh.amLighthouse {
return
}
Expand Down Expand Up @@ -629,7 +629,7 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
}

func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
lh.updateParentCtx = ctx
lh.updateUdp = f

Expand All @@ -655,7 +655,7 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
}
}

func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
func (lh *LightHouse) SendUpdate(f EncWriter) {
var v4 []*Ip4AndPort
var v6 []*Ip6AndPort

Expand Down Expand Up @@ -760,7 +760,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta
}

func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
lhh.HandleRequest(rAddr, vpnIp, p, f)
}
}

func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
Expand Down Expand Up @@ -795,7 +801,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
}
}

func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
Expand Down Expand Up @@ -928,7 +934,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
am.Unlock()
}

func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}
Expand Down
2 changes: 1 addition & 1 deletion lighthouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ type testEncWriter struct {
metaFilter *NebulaMeta_MessageType
}

func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
}
func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
}
Expand Down
20 changes: 18 additions & 2 deletions outside.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,23 @@ const (
minFwPacketLen = 4
)

func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
func readOutsidePackets(f *Interface) udp.EncReader {
return func(
addr *udp.Addr,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh udp.LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) {
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
}
}

func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
Expand Down Expand Up @@ -149,7 +165,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by
return
}

lhf(addr, hostinfo.vpnIp, d, f)
lhf(addr, hostinfo.vpnIp, d)

// Fallthrough to the bottom to record incoming traffic

Expand Down
1 change: 0 additions & 1 deletion udp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ const MTU = 9001

type EncReader func(
addr *Addr,
via interface{},
out []byte,
packet []byte,
header *header.H,
Expand Down
15 changes: 1 addition & 14 deletions udp/temp.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
package udp

import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
)

type EncWriter interface {
SendVia(via interface{},
relay interface{},
ad,
nb,
out []byte,
nocopy bool,
)
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
Handshake(vpnIp iputil.VpnIp)
}

//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare

type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)
2 changes: 1 addition & 1 deletion udp/udp_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall

udpAddr.IP = rua.IP
udpAddr.Port = uint16(rua.Port)
r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
2 changes: 1 addition & 1 deletion udp/udp_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
for i := 0; i < n; i++ {
udpAddr.IP = names[i][8:24]
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion udp/udp_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
}
ua.Port = p.FromPort
copy(ua.IP, p.FromIp.To16())
r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}

Expand Down

0 comments on commit 9b03053

Please sign in to comment.