Skip to content

Commit

Permalink
convert logics.ItemToItem to interface
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Jan 18, 2025
1 parent 9a367e0 commit c46137c
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 33 deletions.
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ type NeighborsConfig struct {

type ItemToItemConfig struct {
Name string `mapstructure:"name" json:"name"`
Type string `mapstructure:"type" json:"type" validate:"oneof=embedding"`
Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags"`
Column string `mapstructure:"column" json:"column" validate:"item_expr"`
}

Expand Down
176 changes: 150 additions & 26 deletions logics/item_to_item.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,53 +15,98 @@
package logics

import (
"errors"
"github.com/chewxy/math32"
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"
"github.com/samber/lo"
"github.com/zhenghaoz/gorse/base/floats"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/common/ann"
"github.com/zhenghaoz/gorse/config"
"github.com/zhenghaoz/gorse/dataset"
"github.com/zhenghaoz/gorse/storage/cache"
"github.com/zhenghaoz/gorse/storage/data"
"go.uber.org/zap"
"time"
)

type ItemToItem struct {
type ItemToItem interface {
Items() []string
Push(item data.Item)
PopAll(callback func(itemId string, score []cache.Score))
}

func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time) (ItemToItem, error) {
switch cfg.Type {
case "embedding":
return newEmbeddingItemToItem(cfg, n, timestamp)
case "tags":
return newTagsItemToItem(cfg, n, timestamp)
default:
return nil, errors.New("invalid item-to-item type")
}
}

type baseItemToItem[T any] struct {
name string
n int
timestamp time.Time
columnFunc *vm.Program
index *ann.HNSW[float32]
index *ann.HNSW[T]
items []string
dimension int
}

func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time) (*ItemToItem, error) {
func (b *baseItemToItem[T]) Items() []string {
return b.items
}

func (b *baseItemToItem[T]) PopAll(callback func(itemId string, score []cache.Score)) {
for index, item := range b.items {
scores, err := b.index.SearchIndex(index, b.n+1, true)
if err != nil {
log.Logger().Error("failed to search index", zap.Error(err))
return
}
callback(item, lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score {
return cache.Score{
Id: b.items[v.A],
Score: float64(v.B),
Timestamp: b.timestamp,
}
}))
}
}

type embeddingItemToItem struct {
baseItemToItem[float32]
dimension int
}

func newEmbeddingItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time) (ItemToItem, error) {
// Compile column expression
columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{
"item": data.Item{},
}))
if err != nil {
return nil, err
}
return &ItemToItem{
return &embeddingItemToItem{baseItemToItem: baseItemToItem[float32]{
name: cfg.Name,
n: n,
timestamp: timestamp,
columnFunc: columnFunc,
index: ann.NewHNSW[float32](floats.Euclidean),
}, nil
}}, nil
}

func (i *ItemToItem) Push(item data.Item) {
func (e *embeddingItemToItem) Push(item data.Item) {
// Check if hidden
if item.IsHidden {
return
}
// Evaluate filter function
result, err := expr.Run(i.columnFunc, map[string]any{
result, err := expr.Run(e.columnFunc, map[string]any{
"item": item,
})
if err != nil {
Expand All @@ -76,34 +121,113 @@ func (i *ItemToItem) Push(item data.Item) {
return
}
// Check dimension
if i.dimension == 0 && len(v) > 0 {
i.dimension = len(v)
} else if i.dimension != len(v) {
if e.dimension == 0 && len(v) > 0 {
e.dimension = len(v)
} else if e.dimension != len(v) {
log.Logger().Error("invalid column dimension", zap.Int("dimension", len(v)))
return
}
// Push item
i.items = append(i.items, item.ItemId)
_, err = i.index.Add(v)
e.items = append(e.items, item.ItemId)
_, err = e.index.Add(v)
if err != nil {
log.Logger().Error("failed to add item to index", zap.Error(err))
return
}
}

func (i *ItemToItem) PopAll(callback func(itemId string, score []cache.Score)) {
for index, item := range i.items {
scores, err := i.index.SearchIndex(index, i.n+1, true)
if err != nil {
log.Logger().Error("failed to search index", zap.Error(err))
return
}
callback(item, lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score {
return cache.Score{
Id: i.items[v.A],
Score: float64(v.B),
Timestamp: i.timestamp,
type tagsItemToItem struct {
baseItemToItem[dataset.ID]
idf []float32
}

func newTagsItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time) (ItemToItem, error) {
// Compile column expression
columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{
"item": data.Item{},
}))
if err != nil {
return nil, err
}
t := &tagsItemToItem{}
b := baseItemToItem[dataset.ID]{
name: cfg.Name,
n: n,
timestamp: timestamp,
columnFunc: columnFunc,
index: ann.NewHNSW[dataset.ID](t.distance),
}
t.baseItemToItem = b
return t, nil
}

func (t *tagsItemToItem) Push(item data.Item) {
// Check if hidden
if item.IsHidden {
return
}
// Evaluate filter function
result, err := expr.Run(t.columnFunc, map[string]any{
"item": item,
})
if err != nil {
log.Logger().Error("failed to evaluate column expression",
zap.Any("item", item), zap.Error(err))
return
}
// Check column type
v, ok := result.([]dataset.ID)
if !ok {
log.Logger().Error("invalid column type", zap.Any("column", result))
return
}
// Push item
t.items = append(t.items, item.ItemId)
_, err = t.index.Add(v)
if err != nil {
log.Logger().Error("failed to add item to index", zap.Error(err))
return
}
}

func (t *tagsItemToItem) distance(a, b []dataset.ID) float32 {
commonSum, commonCount := t.weightedSumCommonElements(a, b)
if commonCount > 0 {
// Add shrinkage to avoid division by zero
return commonSum * commonCount /
math32.Sqrt(t.weightedSum(a)) /
math32.Sqrt(t.weightedSum(b)) /
(commonCount + 100)
} else {
return 0
}
}

func (t *tagsItemToItem) weightedSumCommonElements(a, b []dataset.ID) (float32, float32) {
i, j, sum, count := 0, 0, float32(0), float32(0)
for i < len(a) && j < len(b) {
if a[i] == b[j] {
if a[i] >= 0 && int(a[i]) < len(t.idf) {
sum += t.idf[a[i]]
}
}))
count++
i++
j++
} else if a[i] < b[j] {
i++
} else if a[i] > b[j] {
j++
}
}
return sum, count
}

func (t *tagsItemToItem) weightedSum(a []dataset.ID) float32 {
var sum float32
for _, i := range a {
if i >= 0 && int(i) < len(t.idf) {
sum += t.idf[i]
}
}
return sum
}
12 changes: 7 additions & 5 deletions logics/item_to_item_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

func TestColumnFunc(t *testing.T) {
item2item, err := NewItemToItem(config.ItemToItemConfig{
Type: "embedding",
Column: "item.Labels.description",
}, 10, time.Now())
assert.NoError(t, err)
Expand All @@ -37,7 +38,7 @@ func TestColumnFunc(t *testing.T) {
"description": []float32{0.1, 0.2, 0.3},
},
})
assert.Len(t, item2item.items, 1)
assert.Len(t, item2item.Items(), 1)

// Hidden
item2item.Push(data.Item{
Expand All @@ -47,7 +48,7 @@ func TestColumnFunc(t *testing.T) {
"description": []float32{0.1, 0.2, 0.3},
},
})
assert.Len(t, item2item.items, 1)
assert.Len(t, item2item.Items(), 1)

// Dimension does not match
item2item.Push(data.Item{
Expand All @@ -56,7 +57,7 @@ func TestColumnFunc(t *testing.T) {
"description": []float32{0.1, 0.2},
},
})
assert.Len(t, item2item.items, 1)
assert.Len(t, item2item.Items(), 1)

// Type does not match
item2item.Push(data.Item{
Expand All @@ -65,19 +66,20 @@ func TestColumnFunc(t *testing.T) {
"description": "hello",
},
})
assert.Len(t, item2item.items, 1)
assert.Len(t, item2item.Items(), 1)

// Column does not exist
item2item.Push(data.Item{
ItemId: "2",
Labels: []float32{0.1, 0.2, 0.3},
})
assert.Len(t, item2item.items, 1)
assert.Len(t, item2item.Items(), 1)
}

func TestEmbedding(t *testing.T) {
timestamp := time.Now()
item2item, err := NewItemToItem(config.ItemToItemConfig{
Type: "embedding",
Column: "item.Labels.description",
}, 10, timestamp)
assert.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion master/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -1768,7 +1768,7 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error {
defer span.End()

// Build item-to-item recommenders
itemToItemRecommenders := make([]*logics.ItemToItem, 0, len(m.Config.Recommend.ItemToItem))
itemToItemRecommenders := make([]logics.ItemToItem, 0, len(m.Config.Recommend.ItemToItem))
for _, cfg := range m.Config.Recommend.ItemToItem {
recommender, err := logics.NewItemToItem(cfg, m.Config.Recommend.CacheSize, dataset.GetTimestamp())
if err != nil {
Expand Down

0 comments on commit c46137c

Please sign in to comment.