Skip to content

Commit

Permalink
Fix possible panic in the timerwheels (#802)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus authored Jan 12, 2023
1 parent c44da3a commit c177126
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 34 deletions.
12 changes: 6 additions & 6 deletions firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,27 @@ func TestNewFirewall(t *testing.T) {

assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)

fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)

fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)

fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)

fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)

fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
}

func TestFirewall_AddRule(t *testing.T) {
Expand Down
13 changes: 7 additions & 6 deletions timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,30 @@ type TimerWheel struct {
itemsCached int
}

// Represents a tick in the wheel
// TimeoutList Represents a tick in the wheel
type TimeoutList struct {
Head *TimeoutItem
Tail *TimeoutItem
}

// Represents an item within a tick
// TimeoutItem Represents an item within a tick
type TimeoutItem struct {
Packet firewall.Packet
Next *TimeoutItem
}

// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
// NewTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
// Purge must be called once per entry to actually remove anything
func NewTimerWheel(min, max time.Duration) *TimerWheel {
//TODO provide an error
//if min >= max {
// return nil
//}

// Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full
// max duration
wLen := int((max / min) + 1)
// Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full
// max duration, even if our current tick is at the maximum position and the next item to be added is at maximum
// timeout
wLen := int((max / min) + 2)

tw := TimerWheel{
wheelLen: wLen,
Expand Down
13 changes: 7 additions & 6 deletions timeout_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,30 @@ type SystemTimerWheel struct {
lock sync.Mutex
}

// Represents a tick in the wheel
// SystemTimeoutList Represents a tick in the wheel
type SystemTimeoutList struct {
Head *SystemTimeoutItem
Tail *SystemTimeoutItem
}

// Represents an item within a tick
// SystemTimeoutItem Represents an item within a tick
type SystemTimeoutItem struct {
Item iputil.VpnIp
Next *SystemTimeoutItem
}

// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
// NewSystemTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
// Purge must be called once per entry to actually remove anything
func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
//TODO provide an error
//if min >= max {
// return nil
//}

// Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full
// max duration
wLen := int((max / min) + 1)
// Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full
// max duration, even if our current tick is at the maximum position and the next item to be added is at maximum
// timeout
wLen := int((max / min) + 2)

tw := SystemTimerWheel{
wheelLen: wLen,
Expand Down
37 changes: 29 additions & 8 deletions timeout_system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@ import (
func TestNewSystemTimerWheel(t *testing.T) {
// Make sure we get an object we expect
tw := NewSystemTimerWheel(time.Second, time.Second*10)
assert.Equal(t, 11, tw.wheelLen)
assert.Equal(t, 12, tw.wheelLen)
assert.Equal(t, 0, tw.current)
assert.Nil(t, tw.lastTick)
assert.Equal(t, time.Second*1, tw.tickDuration)
assert.Equal(t, time.Second*10, tw.wheelDuration)
assert.Len(t, tw.wheel, 11)
assert.Len(t, tw.wheel, 12)

// Assert the math is correct
tw = NewSystemTimerWheel(time.Second*3, time.Second*10)
assert.Equal(t, 4, tw.wheelLen)
assert.Equal(t, 5, tw.wheelLen)

tw = NewSystemTimerWheel(time.Second*120, time.Minute*10)
assert.Equal(t, 6, tw.wheelLen)
assert.Equal(t, 7, tw.wheelLen)
}

func TestSystemTimerWheel_findWheel(t *testing.T) {
tw := NewSystemTimerWheel(time.Second, time.Second*10)
assert.Len(t, tw.wheel, 11)
assert.Len(t, tw.wheel, 12)

// Current + tick + 1 since we don't know how far into current we are
assert.Equal(t, 2, tw.findWheel(time.Second*1))
Expand All @@ -38,15 +38,32 @@ func TestSystemTimerWheel_findWheel(t *testing.T) {
assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))

// Make sure we hit that last index
assert.Equal(t, 0, tw.findWheel(time.Second*10))
assert.Equal(t, 11, tw.findWheel(time.Second*10))

// Scale down to max duration
assert.Equal(t, 0, tw.findWheel(time.Second*11))
assert.Equal(t, 11, tw.findWheel(time.Second*11))

tw.current = 1
// Make sure we account for the current position properly
assert.Equal(t, 3, tw.findWheel(time.Second*1))
assert.Equal(t, 1, tw.findWheel(time.Second*10))
assert.Equal(t, 0, tw.findWheel(time.Second*10))

// Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
for min := time.Duration(1); min < 100; min++ {
for max := min; max < 100; max++ {
tw = NewSystemTimerWheel(min, max)

for current := 0; current < tw.wheelLen; current++ {
tw.current = current
for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
tick := tw.findWheel(timeout)
if tick >= tw.wheelLen {
t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
}
}
}
}
}
}

func TestSystemTimerWheel_Add(t *testing.T) {
Expand Down Expand Up @@ -129,6 +146,10 @@ func TestSystemTimerWheel_Purge(t *testing.T) {
tw.advance(ta)
assert.Equal(t, 10, tw.current)

ta = ta.Add(time.Second * 1)
tw.advance(ta)
assert.Equal(t, 11, tw.current)

ta = ta.Add(time.Second * 1)
tw.advance(ta)
assert.Equal(t, 0, tw.current)
Expand Down
37 changes: 29 additions & 8 deletions timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@ import (
func TestNewTimerWheel(t *testing.T) {
// Make sure we get an object we expect
tw := NewTimerWheel(time.Second, time.Second*10)
assert.Equal(t, 11, tw.wheelLen)
assert.Equal(t, 12, tw.wheelLen)
assert.Equal(t, 0, tw.current)
assert.Nil(t, tw.lastTick)
assert.Equal(t, time.Second*1, tw.tickDuration)
assert.Equal(t, time.Second*10, tw.wheelDuration)
assert.Len(t, tw.wheel, 11)
assert.Len(t, tw.wheel, 12)

// Assert the math is correct
tw = NewTimerWheel(time.Second*3, time.Second*10)
assert.Equal(t, 4, tw.wheelLen)
assert.Equal(t, 5, tw.wheelLen)

tw = NewTimerWheel(time.Second*120, time.Minute*10)
assert.Equal(t, 6, tw.wheelLen)
assert.Equal(t, 7, tw.wheelLen)
}

func TestTimerWheel_findWheel(t *testing.T) {
tw := NewTimerWheel(time.Second, time.Second*10)
assert.Len(t, tw.wheel, 11)
assert.Len(t, tw.wheel, 12)

// Current + tick + 1 since we don't know how far into current we are
assert.Equal(t, 2, tw.findWheel(time.Second*1))
Expand All @@ -37,15 +37,15 @@ func TestTimerWheel_findWheel(t *testing.T) {
assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))

// Make sure we hit that last index
assert.Equal(t, 0, tw.findWheel(time.Second*10))
assert.Equal(t, 11, tw.findWheel(time.Second*10))

// Scale down to max duration
assert.Equal(t, 0, tw.findWheel(time.Second*11))
assert.Equal(t, 11, tw.findWheel(time.Second*11))

tw.current = 1
// Make sure we account for the current position properly
assert.Equal(t, 3, tw.findWheel(time.Second*1))
assert.Equal(t, 1, tw.findWheel(time.Second*10))
assert.Equal(t, 0, tw.findWheel(time.Second*10))
}

func TestTimerWheel_Add(t *testing.T) {
Expand Down Expand Up @@ -75,6 +75,23 @@ func TestTimerWheel_Add(t *testing.T) {
tw.Add(fp2, time.Second*1)
assert.Nil(t, tw.itemCache)
assert.Equal(t, 0, tw.itemsCached)

// Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
for min := time.Duration(1); min < 100; min++ {
for max := min; max < 100; max++ {
tw = NewTimerWheel(min, max)

for current := 0; current < tw.wheelLen; current++ {
tw.current = current
for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
tick := tw.findWheel(timeout)
if tick >= tw.wheelLen {
t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
}
}
}
}
}
}

func TestTimerWheel_Purge(t *testing.T) {
Expand Down Expand Up @@ -134,6 +151,10 @@ func TestTimerWheel_Purge(t *testing.T) {
tw.advance(ta)
assert.Equal(t, 10, tw.current)

ta = ta.Add(time.Second * 1)
tw.advance(ta)
assert.Equal(t, 11, tw.current)

ta = ta.Add(time.Second * 1)
tw.advance(ta)
assert.Equal(t, 0, tw.current)
Expand Down

0 comments on commit c177126

Please sign in to comment.