diff --git a/error.go b/error.go index 75b4771..6341501 100644 --- a/error.go +++ b/error.go @@ -18,6 +18,7 @@ var ( ErrDiscardedResponseWriter = errors.New("discarded response writer") ErrInvalidRedirectCode = errors.New("invalid redirect code") ErrNoClientIPStrategy = errors.New("no client ip strategy") + ErrConcurrentAccess = errors.New("concurrent access violation: multiple writes detected on tree") ) // RouteConflictError is a custom error type used to represent conflicts when diff --git a/fox_test.go b/fox_test.go index 7ab2bb4..a758fe0 100644 --- a/fox_test.go +++ b/fox_test.go @@ -3196,6 +3196,58 @@ func TestDataRace(t *testing.T) { wg.Wait() } +func TestTree_RaceDetector(t *testing.T) { + var wg sync.WaitGroup + start, wait := atomicSync() + var raceCount atomic.Uint32 + + tree := New().Tree() + + wg.Add(len(staticRoutes) * 3) + for _, rte := range staticRoutes { + go func() { + wait() + defer func() { + if v := recover(); v != nil { + raceCount.Add(1) + assert.ErrorIs(t, v.(error), ErrConcurrentAccess) + } + wg.Done() + }() + tree.insert(rte.method, rte.path, "", 0, &Route{path: rte.path}) + }() + + go func() { + wait() + defer func() { + if v := recover(); v != nil { + raceCount.Add(1) + assert.ErrorIs(t, v.(error), ErrConcurrentAccess) + } + wg.Done() + }() + tree.update(rte.method, rte.path, "", &Route{path: rte.path}) + }() + + go func() { + wait() + defer func() { + if v := recover(); v != nil { + raceCount.Add(1) + assert.ErrorIs(t, v.(error), ErrConcurrentAccess) + } + wg.Done() + }() + tree.remove(rte.method, rte.path, "") + }() + } + + time.Sleep(500 * time.Millisecond) + start() + wg.Wait() + assert.GreaterOrEqual(t, raceCount.Load(), uint32(1)) +} + func TestConcurrentRequestHandling(t *testing.T) { r := New() diff --git a/tree.go b/tree.go index f185b85..b1eee2c 100644 --- a/tree.go +++ b/tree.go @@ -38,6 +38,7 @@ type Tree struct { sync.Mutex maxParams atomic.Uint32 maxDepth atomic.Uint32 + race atomic.Uint32 } // Handle registers a new handler for the given method and path. On success, it returns the newly registered [Route]. @@ -220,6 +221,11 @@ func (t *Tree) Iter() Iter { // parseRoute before. func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, route *Route) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. + if !t.race.CompareAndSwap(0, 1) { + panic(ErrConcurrentAccess) + } + defer t.race.Store(0) + var rootNode *node nds := *t.nodes.Load() index := findRootNode(method, nds) @@ -394,6 +400,11 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, route *R // update is not safe for concurrent use. func (t *Tree) update(method string, path, catchAllKey string, route *Route) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. + if !t.race.CompareAndSwap(0, 1) { + panic(ErrConcurrentAccess) + } + defer t.race.Store(0) + nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { @@ -427,6 +438,12 @@ func (t *Tree) update(method string, path, catchAllKey string, route *Route) err // remove is not safe for concurrent use. func (t *Tree) remove(method, path, catchAllKey string) bool { + // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. + if !t.race.CompareAndSwap(0, 1) { + panic(ErrConcurrentAccess) + } + defer t.race.Store(0) + nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 {