Skip to content

Commit

Permalink
simplify shard option
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslu committed Mar 22, 2024
1 parent 6400f3e commit 2d766b8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
3 changes: 0 additions & 3 deletions lru_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package lru

import (
"errors"
"sync/atomic"
"unsafe"
)
Expand Down Expand Up @@ -74,8 +73,6 @@ func (c *LRUCache[K, V]) Get(key K) (value V, ok bool) {
return (*lrushard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Get(hash, key)
}

var ErrLoaderIsNil = errors.New("loader is nil")

// GetOrLoad returns value for key, call loader function by singleflight if value was not in cache.
func (c *LRUCache[K, V]) GetOrLoad(key K) (value V, err error, ok bool) {
hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
Expand Down
24 changes: 10 additions & 14 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lru

import (
"errors"
"runtime"
"time"
"unsafe"
Expand All @@ -21,32 +22,25 @@ type shardsOption[K comparable, V any] struct {
count uint32
}

func (o *shardsOption[K, V]) ApplyToLRUCache(c *LRUCache[K, V]) {
func (o *shardsOption[K, V]) getcount(maxcount uint32) uint32 {
var shardcount uint32
if o.count == 0 {
shardcount = nextPowOf2(uint32(runtime.GOMAXPROCS(0) * 16))
} else {
shardcount = nextPowOf2(o.count)
}
if maxcount := uint32(len(c.shards)); shardcount > maxcount {
if shardcount > maxcount {
shardcount = maxcount
}
return shardcount
}

c.mask = uint32(shardcount) - 1
func (o *shardsOption[K, V]) ApplyToLRUCache(c *LRUCache[K, V]) {
c.mask = o.getcount(uint32(len(c.shards))) - 1
}

func (o *shardsOption[K, V]) ApplyToTTLCache(c *TTLCache[K, V]) {
var shardcount uint32
if o.count == 0 {
shardcount = nextPowOf2(uint32(runtime.GOMAXPROCS(0) * 16))
} else {
shardcount = nextPowOf2(o.count)
}
if maxcount := uint32(len(c.shards)); shardcount > maxcount {
shardcount = maxcount
}

c.mask = uint32(shardcount) - 1
c.mask = o.getcount(uint32(len(c.shards))) - 1
}

// WithHasher specifies the hasher function of cache.
Expand Down Expand Up @@ -85,6 +79,8 @@ func (o *slidingOption[K, V]) ApplyToTTLCache(c *TTLCache[K, V]) {
}
}

var ErrLoaderIsNil = errors.New("loader is nil")

// WithLoader specifies that loader function of LoadingCache.
func WithLoader[K comparable, V any, Loader func(key K) (value V, err error) | func(key K) (value V, ttl time.Duration, err error)](loader Loader) Option[K, V] {
return &loaderOption[K, V]{loader: loader}
Expand Down

0 comments on commit 2d766b8

Please sign in to comment.