diff --git a/augmentedtree/atree.go b/augmentedtree/atree.go index eaa66aa..d90e3ea 100644 --- a/augmentedtree/atree.go +++ b/augmentedtree/atree.go @@ -116,21 +116,21 @@ type tree struct { dummy node } -func (t *tree) Traverse(fn func(id Interval)) { - nodes := []*node{t.root} - - for len(nodes) != 0 { - c := nodes[len(nodes)-1] - nodes = nodes[:len(nodes)-1] +func (t *tree) Traverse(fn func(id Interval) IterationResult) { + c := t.root + nodes := []*node{} + for len(nodes) > 0 || c != nil { if c != nil { - fn(c.interval) - if c.children[0] != nil { - nodes = append(nodes, c.children[0]) - } - if c.children[1] != nil { - nodes = append(nodes, c.children[1]) - } + nodes = append(nodes, c) + c = c.children[0] + continue + } + c = nodes[len(nodes)-1] + nodes = nodes[:len(nodes)-1] + if !fn(c.interval) { + break } + c = c.children[1] } } diff --git a/augmentedtree/atree_test.go b/augmentedtree/atree_test.go index 7c96e67..b172a7c 100644 --- a/augmentedtree/atree_test.go +++ b/augmentedtree/atree_test.go @@ -626,8 +626,9 @@ func TestInsertDuplicateIntervalChildren(t *testing.T) { func TestTraverse(t *testing.T) { tree := newTree(1) - tree.Traverse(func(i Interval) { + tree.Traverse(func(i Interval) IterationResult { assert.Fail(t, `traverse should not be called for empty tree`) + return IterationBreak }) top := 30 @@ -635,8 +636,9 @@ func TestTraverse(t *testing.T) { tree.Add(constructSingleDimensionInterval(int64(i*10), int64((i+1)*10), uint64(i))) } found := map[uint64]bool{} - tree.Traverse(func(id Interval) { + tree.Traverse(func(id Interval) IterationResult { found[id.ID()] = true + return IterationContinue }) for i := 0; i <= top; i++ { if found, _ := found[uint64(i)]; !found { diff --git a/augmentedtree/interface.go b/augmentedtree/interface.go index 3cd4a37..8e381c3 100644 --- a/augmentedtree/interface.go +++ b/augmentedtree/interface.go @@ -54,6 +54,13 @@ type Interval interface { ID() uint64 } +type IterationResult bool + +const ( + IterationContinue IterationResult = true + IterationBreak IterationResult = false +) + // Tree defines the object that is returned from the // tree constructor. We use a Tree interface here because // the returned tree could be a single dimension or many @@ -71,5 +78,5 @@ type Tree interface { Query(interval Interval) Intervals // Traverse will traverse tree and give alls intervals // found in an undefined order - Traverse(func(Interval)) + Traverse(func(Interval) IterationResult) }