Skip to content

Commit

Permalink
Generic timerwheel (#804)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus authored Jan 18, 2023
1 parent c177126 commit 5278b6f
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 431 deletions.
24 changes: 10 additions & 14 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ type connectionManager struct {
inLock *sync.RWMutex
out map[iputil.VpnIp]struct{}
outLock *sync.RWMutex
TrafficTimer *SystemTimerWheel
TrafficTimer *LockingTimerWheel[iputil.VpnIp]
intf *Interface

pendingDeletion map[iputil.VpnIp]int
pendingDeletionLock *sync.RWMutex
pendingDeletionTimer *SystemTimerWheel
pendingDeletionTimer *LockingTimerWheel[iputil.VpnIp]

checkInterval int
pendingDeletionInterval int
Expand All @@ -40,11 +40,11 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
inLock: &sync.RWMutex{},
out: make(map[iputil.VpnIp]struct{}),
outLock: &sync.RWMutex{},
TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
TrafficTimer: NewLockingTimerWheel[iputil.VpnIp](time.Millisecond*500, time.Second*60),
intf: intf,
pendingDeletion: make(map[iputil.VpnIp]int),
pendingDeletionLock: &sync.RWMutex{},
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
pendingDeletionTimer: NewLockingTimerWheel[iputil.VpnIp](time.Millisecond*500, time.Second*60),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
l: l,
Expand Down Expand Up @@ -160,15 +160,13 @@ func (n *connectionManager) Run(ctx context.Context) {
}

func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
n.TrafficTimer.advance(now)
n.TrafficTimer.Advance(now)
for {
ep := n.TrafficTimer.Purge()
if ep == nil {
vpnIp, has := n.TrafficTimer.Purge()
if !has {
break
}

vpnIp := ep.(iputil.VpnIp)

// Check for traffic coming back in from this host.
traf := n.CheckIn(vpnIp)

Expand Down Expand Up @@ -214,15 +212,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
}

func (n *connectionManager) HandleDeletionTick(now time.Time) {
n.pendingDeletionTimer.advance(now)
n.pendingDeletionTimer.Advance(now)
for {
ep := n.pendingDeletionTimer.Purge()
if ep == nil {
vpnIp, has := n.pendingDeletionTimer.Purge()
if !has {
break
}

vpnIp := ep.(iputil.VpnIp)

hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
if err != nil {
n.l.Debugf("Not found in hostmap: %s", vpnIp)
Expand Down
6 changes: 4 additions & 2 deletions firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ type FirewallConntrack struct {
sync.Mutex

Conns map[firewall.Packet]*conn
TimerWheel *TimerWheel
TimerWheel *TimerWheel[firewall.Packet]
}

type FirewallTable struct {
Expand Down Expand Up @@ -145,7 +145,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel(min, max),
TimerWheel: NewTimerWheel[firewall.Packet](min, max),
},
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
Expand Down Expand Up @@ -510,6 +510,7 @@ func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
conntrack := f.Conntrack
conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Advance(time.Now())
conntrack.TimerWheel.Add(fp, timeout)
}

Expand Down Expand Up @@ -537,6 +538,7 @@ func (f *Firewall) evict(p firewall.Packet) {

// Timeout is in the future, re-add the timer
if newT > 0 {
conntrack.TimerWheel.Advance(time.Now())
conntrack.TimerWheel.Add(p, newT)
return
}
Expand Down
11 changes: 5 additions & 6 deletions handshake_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type HandshakeManager struct {
lightHouse *LightHouse
outside *udp.Conn
config HandshakeConfig
OutboundHandshakeTimer *SystemTimerWheel
OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
messageMetrics *MessageMetrics
metricInitiated metrics.Counter
metricTimedOut metrics.Counter
Expand All @@ -65,7 +65,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
outside: outside,
config: config,
trigger: make(chan iputil.VpnIp, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
Expand All @@ -90,13 +90,12 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
}

func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
c.OutboundHandshakeTimer.advance(now)
c.OutboundHandshakeTimer.Advance(now)
for {
ep := c.OutboundHandshakeTimer.Purge()
if ep == nil {
vpnIp, has := c.OutboundHandshakeTimer.Purge()
if !has {
break
}
vpnIp := ep.(iputil.VpnIp)
c.handleOutbound(vpnIp, f, false)
}
}
Expand Down
4 changes: 2 additions & 2 deletions handshake_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
}

func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
for _, i := range tw.wheel {
func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
for _, i := range tw.t.wheel {
n := i.Head
for n != nil {
c++
Expand Down
95 changes: 63 additions & 32 deletions timeout.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package nebula

import (
"sync"
"time"

"github.com/slackhq/nebula/firewall"
)

// How many timer objects should be cached
const timerCacheMax = 50000

var emptyFWPacket = firewall.Packet{}

type TimerWheel struct {
type TimerWheel[T any] struct {
// Current tick
current int

Expand All @@ -26,31 +23,38 @@ type TimerWheel struct {
wheelDuration time.Duration

// The actual wheel which is just a set of singly linked lists, head/tail pointers
wheel []*TimeoutList
wheel []*TimeoutList[T]

// Singly linked list of items that have timed out of the wheel
expired *TimeoutList
expired *TimeoutList[T]

// Item cache to avoid garbage collect
itemCache *TimeoutItem
itemCache *TimeoutItem[T]
itemsCached int
}

type LockingTimerWheel[T any] struct {
m sync.Mutex
t *TimerWheel[T]
}

// TimeoutList Represents a tick in the wheel
type TimeoutList struct {
Head *TimeoutItem
Tail *TimeoutItem
type TimeoutList[T any] struct {
Head *TimeoutItem[T]
Tail *TimeoutItem[T]
}

// TimeoutItem Represents an item within a tick
type TimeoutItem struct {
Packet firewall.Packet
Next *TimeoutItem
type TimeoutItem[T any] struct {
Item T
Next *TimeoutItem[T]
}

// NewTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
// Purge must be called once per entry to actually remove anything
func NewTimerWheel(min, max time.Duration) *TimerWheel {
// The TimerWheel does not handle concurrency on its own.
// Locks around access to it must be used if multiple routines are manipulating it.
func NewTimerWheel[T any](min, max time.Duration) *TimerWheel[T] {
//TODO provide an error
//if min >= max {
// return nil
Expand All @@ -61,26 +65,31 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel {
// timeout
wLen := int((max / min) + 2)

tw := TimerWheel{
tw := TimerWheel[T]{
wheelLen: wLen,
wheel: make([]*TimeoutList, wLen),
wheel: make([]*TimeoutList[T], wLen),
tickDuration: min,
wheelDuration: max,
expired: &TimeoutList{},
expired: &TimeoutList[T]{},
}

for i := range tw.wheel {
tw.wheel[i] = &TimeoutList{}
tw.wheel[i] = &TimeoutList[T]{}
}

return &tw
}

// Add will add a firewall.Packet to the wheel in it's proper timeout
func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem {
// Check and see if we should progress the tick
tw.advance(time.Now())
// NewLockingTimerWheel is version of TimerWheel that is safe for concurrent use with a small performance penalty
func NewLockingTimerWheel[T any](min, max time.Duration) *LockingTimerWheel[T] {
return &LockingTimerWheel[T]{
t: NewTimerWheel[T](min, max),
}
}

// Add will add an item to the wheel in its proper timeout.
// Caller should Advance the wheel prior to ensure the proper slot is used.
func (tw *TimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] {
i := tw.findWheel(timeout)

// Try to fetch off the cache
Expand All @@ -90,11 +99,11 @@ func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem
tw.itemsCached--
ti.Next = nil
} else {
ti = &TimeoutItem{}
ti = &TimeoutItem[T]{}
}

// Relink and return
ti.Packet = v
ti.Item = v
if tw.wheel[i].Tail == nil {
tw.wheel[i].Head = ti
tw.wheel[i].Tail = ti
Expand All @@ -106,9 +115,12 @@ func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem
return ti
}

func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
// Purge removes and returns the first available expired item from the wheel and the 2nd argument is true.
// If no item is available then an empty T is returned and the 2nd argument is false.
func (tw *TimerWheel[T]) Purge() (T, bool) {
if tw.expired.Head == nil {
return emptyFWPacket, false
var na T
return na, false
}

ti := tw.expired.Head
Expand All @@ -128,11 +140,11 @@ func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
tw.itemsCached++
}

return ti.Packet, true
return ti.Item, true
}

// advance will move the wheel forward by proper number of ticks. The caller _should_ lock the wheel before calling this
func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) {
// findWheel find the next position in the wheel for the provided timeout given the current tick
func (tw *TimerWheel[T]) findWheel(timeout time.Duration) (i int) {
if timeout < tw.tickDuration {
// Can't track anything below the set resolution
timeout = tw.tickDuration
Expand All @@ -154,8 +166,9 @@ func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) {
return tick
}

// advance will lock and move the wheel forward by proper number of ticks.
func (tw *TimerWheel) advance(now time.Time) {
// Advance will move the wheel forward by the appropriate number of ticks for the provided time and all items
// passed over will be moved to the expired list. Calling Purge is necessary to remove them entirely.
func (tw *TimerWheel[T]) Advance(now time.Time) {
if tw.lastTick == nil {
tw.lastTick = &now
}
Expand Down Expand Up @@ -192,3 +205,21 @@ func (tw *TimerWheel) advance(now time.Time) {
newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv))
tw.lastTick = &newTick
}

func (lw *LockingTimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] {
lw.m.Lock()
defer lw.m.Unlock()
return lw.t.Add(v, timeout)
}

func (lw *LockingTimerWheel[T]) Purge() (T, bool) {
lw.m.Lock()
defer lw.m.Unlock()
return lw.t.Purge()
}

func (lw *LockingTimerWheel[T]) Advance(now time.Time) {
lw.m.Lock()
defer lw.m.Unlock()
lw.t.Advance(now)
}
Loading

0 comments on commit 5278b6f

Please sign in to comment.