Skip to content

Commit

Permalink
add context support to loader function
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslu committed Mar 31, 2024
1 parent 7c362e5 commit 5290c3f
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 31 deletions.
7 changes: 4 additions & 3 deletions example_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lru_test

import (
"context"
"time"
"unsafe"

Expand All @@ -23,16 +24,16 @@ func ExampleWithHasher() {
}

func ExampleWithLoader() {
loader := func(key string) (int, time.Duration, error) {
loader := func(ctx context.Context, key string) (int, time.Duration, error) {
return 42, time.Hour, nil
}

cache := lru.NewTTLCache[string, int](4096, lru.WithLoader[string, int](loader))

println(cache.Get("a"))
println(cache.Get("b"))
println(cache.GetOrLoad("a", nil))
println(cache.GetOrLoad("b", func(key string) (int, time.Duration, error) { return 100, 0, nil }))
println(cache.GetOrLoad(context.Background(), "a", nil))
println(cache.GetOrLoad(context.Background(), "b", func(context.Context, string) (int, time.Duration, error) { return 100, 0, nil }))
println(cache.Get("a"))
println(cache.Get("b"))
}
Expand Down
7 changes: 4 additions & 3 deletions lru_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package lru

import (
"context"
"unsafe"
)

Expand All @@ -13,7 +14,7 @@ type LRUCache[K comparable, V any] struct {
mask uint32
hasher func(key unsafe.Pointer, seed uintptr) uintptr
seed uintptr
loader func(key K) (value V, err error)
loader func(ctx context.Context, key K) (value V, err error)
group singleflight_Group[K, V]
}

Expand Down Expand Up @@ -73,7 +74,7 @@ func (c *LRUCache[K, V]) Get(key K) (value V, ok bool) {
}

// GetOrLoad returns value for key, call loader function by singleflight if value was not in cache.
func (c *LRUCache[K, V]) GetOrLoad(key K, loader func(key K) (value V, err error)) (value V, err error, ok bool) {
func (c *LRUCache[K, V]) GetOrLoad(ctx context.Context, key K, loader func(context.Context, K) (V, error)) (value V, err error, ok bool) {
hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
value, ok = c.shards[hash&c.mask].Get(hash, key)
if !ok {
Expand All @@ -85,7 +86,7 @@ func (c *LRUCache[K, V]) GetOrLoad(key K, loader func(key K) (value V, err error
return
}
value, err, ok = c.group.Do(key, func() (V, error) {
v, err := loader(key)
v, err := loader(ctx, key)
if err != nil {
return v, err
}
Expand Down
19 changes: 10 additions & 9 deletions lru_cache_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lru

import (
"context"
"fmt"
"math/rand"
"runtime"
Expand Down Expand Up @@ -233,31 +234,31 @@ func TestLRUCacheSliding(t *testing.T) {

func TestLRUCacheLoader(t *testing.T) {
cache := NewLRUCache[string, int](1024)
if v, err, ok := cache.GetOrLoad("a", nil); ok || err == nil || v != 0 {
if v, err, ok := cache.GetOrLoad(context.Background(), "a", nil); ok || err == nil || v != 0 {
t.Errorf("cache.GetOrLoad(\"a\", nil) again should be return error: %v, %v, %v", v, err, ok)
}

cache = NewLRUCache[string, int](1024, WithLoader[string, int](func(key string) (int, error) {
cache = NewLRUCache[string, int](1024, WithLoader[string, int](func(ctx context.Context, key string) (int, error) {
if key == "" {
return 0, fmt.Errorf("invalid key: %v", key)
}
i := int(key[0] - 'a' + 1)
return i, nil
}))

if v, err, ok := cache.GetOrLoad("", nil); ok || err == nil || v != 0 {
if v, err, ok := cache.GetOrLoad(context.Background(), "", nil); ok || err == nil || v != 0 {
t.Errorf("cache.GetOrLoad(\"a\", nil) again should be return error: %v, %v, %v", v, err, ok)
}

if v, err, ok := cache.GetOrLoad("b", nil); ok || err != nil || v != 2 {
if v, err, ok := cache.GetOrLoad(context.Background(), "b", nil); ok || err != nil || v != 2 {
t.Errorf("cache.GetOrLoad(\"b\", nil) again should be return 2: %v, %v, %v", v, err, ok)
}

if v, err, ok := cache.GetOrLoad("a", nil); ok || err != nil || v != 1 {
if v, err, ok := cache.GetOrLoad(context.Background(), "a", nil); ok || err != nil || v != 1 {
t.Errorf("cache.GetOrLoad(\"a\", nil) should be return 1: %v, %v, %v", v, err, ok)
}

if v, err, ok := cache.GetOrLoad("a", nil); !ok || err != nil || v != 1 {
if v, err, ok := cache.GetOrLoad(context.Background(), "a", nil); !ok || err != nil || v != 1 {
t.Errorf("cache.GetOrLoad(\"a\", nil) again should be return 1: %v, %v, %v", v, err, ok)
}
}
Expand All @@ -270,7 +271,7 @@ func TestLRUCacheLoaderPanic(t *testing.T) {
}
}
}()
_ = NewLRUCache[string, int](1024, WithLoader[string, int](func(key string) (int, time.Duration, error) {
_ = NewLRUCache[string, int](1024, WithLoader[string, int](func(ctx context.Context, key string) (int, time.Duration, error) {
return 1, time.Hour, nil
}))
t.Errorf("should be panic above")
Expand All @@ -279,7 +280,7 @@ func TestLRUCacheLoaderPanic(t *testing.T) {
func TestLRUCacheLoaderSingleflight(t *testing.T) {
var loads uint32

cache := NewLRUCache[string, int](1024, WithLoader[string, int](func(key string) (int, error) {
cache := NewLRUCache[string, int](1024, WithLoader[string, int](func(ctx context.Context, key string) (int, error) {
atomic.AddUint32(&loads, 1)
time.Sleep(100 * time.Millisecond)
return int(key[0] - 'a' + 1), nil
Expand All @@ -290,7 +291,7 @@ func TestLRUCacheLoaderSingleflight(t *testing.T) {
for i := 0; i < 10; i++ {
go func(i int) {
defer wg.Done()
v, err, ok := cache.GetOrLoad("a", nil)
v, err, ok := cache.GetOrLoad(context.Background(), "a", nil)
if v != 1 || err != nil || !ok {
t.Errorf("a should be set to 1: %v,%v,%v", v, err, ok)
}
Expand Down
7 changes: 4 additions & 3 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lru

import (
"context"
"errors"
"runtime"
"time"
Expand Down Expand Up @@ -82,7 +83,7 @@ func (o *slidingOption[K, V]) ApplyToTTLCache(c *TTLCache[K, V]) {
var ErrLoaderIsNil = errors.New("loader is nil")

// WithLoader specifies that loader function of LoadingCache.
func WithLoader[K comparable, V any, Loader ~func(key K) (value V, err error) | ~func(key K) (value V, ttl time.Duration, err error)](loader Loader) Option[K, V] {
func WithLoader[K comparable, V any, Loader ~func(ctx context.Context, key K) (value V, err error) | ~func(ctx context.Context, key K) (value V, ttl time.Duration, err error)](loader Loader) Option[K, V] {
return &loaderOption[K, V]{loader: loader}
}

Expand All @@ -91,7 +92,7 @@ type loaderOption[K comparable, V any] struct {
}

func (o *loaderOption[K, V]) ApplyToLRUCache(c *LRUCache[K, V]) {
loader, ok := o.loader.(func(key K) (value V, err error))
loader, ok := o.loader.(func(ctx context.Context, key K) (value V, err error))
if !ok {
panic("not_supported")
}
Expand All @@ -100,7 +101,7 @@ func (o *loaderOption[K, V]) ApplyToLRUCache(c *LRUCache[K, V]) {
}

func (o *loaderOption[K, V]) ApplyToTTLCache(c *TTLCache[K, V]) {
loader, ok := o.loader.(func(key K) (value V, ttl time.Duration, err error))
loader, ok := o.loader.(func(ctx context.Context, key K) (value V, ttl time.Duration, err error))
if !ok {
panic("not_supported")
}
Expand Down
7 changes: 4 additions & 3 deletions ttl_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package lru

import (
"context"
"sync/atomic"
"time"
"unsafe"
Expand All @@ -14,7 +15,7 @@ type TTLCache[K comparable, V any] struct {
mask uint32
hasher func(key unsafe.Pointer, seed uintptr) uintptr
seed uintptr
loader func(key K) (value V, ttl time.Duration, err error)
loader func(ctx context.Context, key K) (value V, ttl time.Duration, err error)
group singleflight_Group[K, V]
}

Expand Down Expand Up @@ -76,7 +77,7 @@ func (c *TTLCache[K, V]) Get(key K) (value V, ok bool) {
}

// GetOrLoad returns value for key, call loader function by singleflight if value was not in cache.
func (c *TTLCache[K, V]) GetOrLoad(key K, loader func(key K) (value V, ttl time.Duration, err error)) (value V, err error, ok bool) {
func (c *TTLCache[K, V]) GetOrLoad(ctx context.Context, key K, loader func(context.Context, K) (V, time.Duration, error)) (value V, err error, ok bool) {
hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
// value, ok = c.shards[hash&c.mask].Get(hash, key)
value, ok = (*ttlshard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Get(hash, key)
Expand All @@ -89,7 +90,7 @@ func (c *TTLCache[K, V]) GetOrLoad(key K, loader func(key K) (value V, ttl time.
return
}
value, err, ok = c.group.Do(key, func() (V, error) {
v, ttl, err := loader(key)
v, ttl, err := loader(ctx, key)
if err != nil {
return v, err
}
Expand Down
21 changes: 11 additions & 10 deletions ttl_cache_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lru

import (
"context"
"fmt"
"math/rand"
"runtime"
Expand Down Expand Up @@ -239,37 +240,37 @@ func TestTTLCacheHasher(t *testing.T) {

func TestTTLCacheLoader(t *testing.T) {
cache := NewTTLCache[string, int](1024)
if v, err, ok := cache.GetOrLoad("a", nil); ok || err == nil || v != 0 {
if v, err, ok := cache.GetOrLoad(context.Background(), "a", nil); ok || err == nil || v != 0 {
t.Errorf("cache.GetOrLoad(\"a\", nil) again should be return error: %v, %v, %v", v, err, ok)
}

cache = NewTTLCache[string, int](1024, WithLoader[string, int](func(key string) (int, time.Duration, error) {
cache = NewTTLCache[string, int](1024, WithLoader[string, int](func(ctx context.Context, key string) (int, time.Duration, error) {
if key == "" {
return 0, 0, fmt.Errorf("invalid key: %v", key)
}
i := int(key[0] - 'a' + 1)
return i, time.Duration(i) * time.Second, nil
}))

if v, err, ok := cache.GetOrLoad("", nil); ok || err == nil || v != 0 {
if v, err, ok := cache.GetOrLoad(context.Background(), "", nil); ok || err == nil || v != 0 {
t.Errorf("cache.GetOrLoad(\"a\", nil) again should be return error: %v, %v, %v", v, err, ok)
}

if v, err, ok := cache.GetOrLoad("b", nil); ok || err != nil || v != 2 {
if v, err, ok := cache.GetOrLoad(context.Background(), "b", nil); ok || err != nil || v != 2 {
t.Errorf("cache.GetOrLoad(\"b\", nil) again should be return 2: %v, %v, %v", v, err, ok)
}

if v, err, ok := cache.GetOrLoad("a", nil); ok || err != nil || v != 1 {
if v, err, ok := cache.GetOrLoad(context.Background(), "a", nil); ok || err != nil || v != 1 {
t.Errorf("cache.GetOrLoad(\"a\", nil) should be return 1: %v, %v, %v", v, err, ok)
}

if v, err, ok := cache.GetOrLoad("a", nil); !ok || err != nil || v != 1 {
if v, err, ok := cache.GetOrLoad(context.Background(), "a", nil); !ok || err != nil || v != 1 {
t.Errorf("cache.GetOrLoad(\"a\") again should be return 1: %v, %v, %v", v, err, ok)
}

time.Sleep(2 * time.Second)

if v, err, ok := cache.GetOrLoad("a", nil); ok || err != nil || v != 1 {
if v, err, ok := cache.GetOrLoad(context.Background(), "a", nil); ok || err != nil || v != 1 {
t.Errorf("cache.GetOrLoad(\"a\") again should be return 1: %v, %v, %v", v, err, ok)
}
}
Expand All @@ -282,7 +283,7 @@ func TestTTLCacheLoaderPanic(t *testing.T) {
}
}
}()
_ = NewTTLCache[string, int](1024, WithLoader[string, int](func(key string) (int, error) {
_ = NewTTLCache[string, int](1024, WithLoader[string, int](func(ctx context.Context, key string) (int, error) {
return 1, nil
}))
t.Errorf("should be panic above")
Expand All @@ -291,7 +292,7 @@ func TestTTLCacheLoaderPanic(t *testing.T) {
func TestTTLCacheLoaderSingleflight(t *testing.T) {
var loads uint32

cache := NewTTLCache[string, int](1024, WithLoader[string, int](func(key string) (int, time.Duration, error) {
cache := NewTTLCache[string, int](1024, WithLoader[string, int](func(ctx context.Context, key string) (int, time.Duration, error) {
atomic.AddUint32(&loads, 1)
time.Sleep(100 * time.Millisecond)
return int(key[0] - 'a' + 1), time.Hour, nil
Expand All @@ -302,7 +303,7 @@ func TestTTLCacheLoaderSingleflight(t *testing.T) {
for i := 0; i < 10; i++ {
go func(i int) {
defer wg.Done()
v, err, ok := cache.GetOrLoad("a", nil)
v, err, ok := cache.GetOrLoad(context.Background(), "a", nil)
if v != 1 || err != nil || !ok {
t.Errorf("a should be set to 1: %v,%v,%v", v, err, ok)
}
Expand Down

0 comments on commit 5290c3f

Please sign in to comment.