Skip to content

Commit

Permalink
chore: add items filter (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
yxtay authored Oct 16, 2024
1 parent 5ce6ef2 commit 516155b
Show file tree
Hide file tree
Showing 8 changed files with 465 additions and 326 deletions.
75 changes: 70 additions & 5 deletions mf_torch/bentoml/prepare.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar

import torch
from pydantic import BaseModel

from mf_torch.bentoml.schemas import ItemCandidate, Query
from mf_torch.bentoml.schemas import ItemCandidate, Query, UserQuery
from mf_torch.data.load import select_fields
from mf_torch.params import (
CHECKPOINT_PATH,
EXPORTED_PROGRAM_PATH,
ITEMS_TABLE_NAME,
LANCE_DB_PATH,
MODEL_NAME,
USERS_TABLE_NAME,
)

if TYPE_CHECKING:
Expand All @@ -24,6 +27,9 @@
from mf_torch.lightning import MatrixFactorizationLitModule


T = TypeVar("T", bound=BaseModel)


def load_args(ckpt_path: str | None) -> dict:
if not ckpt_path:
return {"model": {}, "data": {}}
Expand Down Expand Up @@ -123,6 +129,55 @@ class ItemSchema(ItemCandidate, LanceModel):
return table


def prepare_users(trainer: Trainer) -> list[dict]:
datamodule: MatrixFactorizationDataModule = trainer.datamodule

interactions: Iterable[dict] = datamodule.get_raw_data(subset="train").map(
partial(
select_fields,
fields=["user_id", "gender", "age", "occupation", "zipcode", "movie_id"],
)
)
interacted: dict[int, dict] = {}
for row in interactions:
user_id = row["user_id"]
movie_id = row.pop("movie_id")

if user_id in interacted:
curr = interacted[user_id]
curr.update(row)
curr["movie_ids"].add(movie_id)
else:
curr = row
curr["movie_ids"] = {movie_id}

interacted[user_id] = curr
return list(interacted.values())


def index_users(
users: Iterable[dict], lance_db_path: str = LANCE_DB_PATH
) -> lancedb.table.LanceTable:
import datetime

import lancedb
from lancedb.pydantic import LanceModel

class UserSchema(UserQuery, LanceModel):
pass

db = lancedb.connect(lance_db_path)
table = db.create_table(
USERS_TABLE_NAME,
data=users,
schema=UserSchema,
mode="overwrite",
)
table.compact_files()
table.cleanup_old_versions(datetime.timedelta(days=1))
return table


def save_model(trainer: Trainer) -> None:
import shutil

Expand All @@ -134,20 +189,30 @@ def save_model(trainer: Trainer) -> None:
trainer.model.export_dynamo(model_ref.path_of(EXPORTED_PROGRAM_PATH))


def load_indexed_items() -> list[ItemCandidate]:
def load_lancedb_indexed(table_name: str, schema: T) -> T:
import bentoml
import lancedb
from pydantic import TypeAdapter

lancedb_path = bentoml.models.get(MODEL_NAME).path_of(LANCE_DB_PATH)
tbl = lancedb.connect(lancedb_path).open_table(ITEMS_TABLE_NAME)
return TypeAdapter(list[ItemCandidate]).validate_python(tbl.to_arrow().to_pylist())
tbl = lancedb.connect(lancedb_path).open_table(table_name)
return TypeAdapter(list[schema]).validate_python(tbl.to_arrow().to_pylist())


def load_indexed_items() -> list[ItemCandidate]:
return load_lancedb_indexed(table_name=ITEMS_TABLE_NAME, schema=ItemCandidate)


def load_indexed_users() -> list[UserQuery]:
return load_lancedb_indexed(table_name=USERS_TABLE_NAME, schema=UserQuery)


def main(ckpt_path: str | None) -> None:
trainer = prepare_trainer(ckpt_path)
items = prepare_items(trainer)
index_items(items=items)
users = prepare_users(trainer)
index_users(users=users)
save_model(trainer=trainer)


Expand Down
1 change: 1 addition & 0 deletions mf_torch/bentoml/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class UserQuery(BaseModel):
age: int | None = None
occupation: int | None = None
zipcode: str | None = None
movie_ids: list[int] | None = None

def to_query(self: Self, **kwargs: dict[str, int]) -> Query:
from mf_torch.data.load import process_features
Expand Down
91 changes: 73 additions & 18 deletions mf_torch/bentoml/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
ITEMS_TABLE_NAME,
LANCE_DB_PATH,
MODEL_NAME,
USERS_TABLE_NAME,
)


@bentoml.service()
class Embedder:
model_ref = bentoml.models.get(MODEL_NAME)
model_ref = bentoml.models.BentoModel(MODEL_NAME)

@logger.catch(reraise=True)
def __init__(self: Self) -> None:
Expand All @@ -35,23 +36,30 @@ def embed(self: Self, query: Query) -> Query:

@bentoml.service()
class ItemIndex:
model_ref = bentoml.models.get(MODEL_NAME)
model_ref = bentoml.models.BentoModel(MODEL_NAME)

@logger.catch(reraise=True)
def __init__(self: Self) -> None:
import lancedb

src_path = self.model_ref.path_of(LANCE_DB_PATH)
self.tbl = lancedb.connect(src_path).open_table(ITEMS_TABLE_NAME)
db = lancedb.connect(src_path)
self.items = db.open_table(ITEMS_TABLE_NAME)
self.users = db.open_table(USERS_TABLE_NAME)
logger.info("items index loaded: {}", src_path)

@bentoml.api()
@logger.catch(reraise=True)
def search(self: Self, query: Query) -> list[ItemCandidate]:
def search(
self: Self, query: Query, exclude_items: list[int]
) -> list[ItemCandidate]:
from pydantic import TypeAdapter

exclude_filter = ", ".join(f"{item}" for item in exclude_items)
exclude_filter = f"movie_id NOT IN ({exclude_filter})"
results_df = (
self.tbl.search(query.embedding)
self.items.search(query.embedding)
.where(exclude_filter, prefilter=True)
.nprobes(20)
.refine_factor(5)
.limit(10)
Expand All @@ -68,16 +76,27 @@ def search(self: Self, query: Query) -> list[ItemCandidate]:
def item_id(self: Self, item_id: int) -> ItemCandidate:
from bentoml.exceptions import NotFound

result = self.tbl.search().where(f"movie_id = {item_id}").to_list()
result = self.items.search().where(f"movie_id = {item_id}").to_list()
if len(result) == 0:
msg = f"item not found: {item_id = }"
raise NotFound(msg)
return ItemCandidate.model_validate(result[0])

@bentoml.api()
@logger.catch(reraise=True)
def user_id(self: Self, user_id: int) -> UserQuery:
from bentoml.exceptions import NotFound

result = self.users.search().where(f"user_id = {user_id}").to_list()
if len(result) == 0:
msg = f"user not found: {user_id = }"
raise NotFound(msg)
return UserQuery.model_validate(result[0])


@bentoml.service()
class Service:
model_ref = bentoml.models.get(MODEL_NAME)
model_ref = bentoml.models.BentoModel(MODEL_NAME)
embedder = bentoml.depends(Embedder)
item_index = bentoml.depends(ItemIndex)

Expand All @@ -94,36 +113,72 @@ async def embed_query(self: Self, query: Query) -> Query:

@bentoml.api()
@logger.catch(reraise=True)
async def search_items(self: Self, query: Query) -> list[ItemCandidate]:
return await self.item_index.to_async.search(query)
async def search_items(
self: Self, query: Query, exclude_items: list[int] | None = None
) -> list[ItemCandidate]:
exclude_items = exclude_items or []
return await self.item_index.to_async.search(query, exclude_items=exclude_items)

@bentoml.api()
@logger.catch(reraise=True)
async def recommend_with_query(self: Self, query: Query) -> list[ItemCandidate]:
async def recommend_with_query(
self: Self, query: Query, exclude_items: list[int] | None = None
) -> list[ItemCandidate]:
query = await self.embed_query(query)
return await self.search_items(query)
return await self.search_items(query, exclude_items=exclude_items)

@bentoml.api()
@logger.catch(reraise=True)
async def recommend_with_item(self: Self, item: ItemQuery) -> list[ItemCandidate]:
async def recommend_with_item(
self: Self, item: ItemQuery, exclude_items: list[int] | None = None
) -> list[ItemCandidate]:
if item.movie_id:
exclude_items = [*(exclude_items or []), item.movie_id]

query = item.to_query(
num_hashes=self.num_hashes, num_embeddings=self.num_embeddings
)
return await self.recommend_with_query(query)
return await self.recommend_with_query(query, exclude_items=exclude_items)

@bentoml.api()
@logger.catch(reraise=True)
async def recommend_with_item_id(
self: Self, item_id: int, exclude_items: list[int] | None = None
) -> list[ItemCandidate]:
item = await self.item_id(item_id)
return await self.recommend_with_item(item, exclude_items=exclude_items)

@bentoml.api()
@logger.catch(reraise=True)
async def recommend_with_item_id(self: Self, item_id: int) -> list[ItemCandidate]:
item = self.item_index.item_id(item_id)
return await self.recommend_with_item(item)
async def item_id(self: Self, item_id: int) -> ItemCandidate:
return await self.item_index.to_async.item_id(item_id)

@bentoml.api()
@logger.catch(reraise=True)
async def recommend_with_user(self: Self, user: UserQuery) -> list[ItemCandidate]:
async def recommend_with_user(
self: Self, user: UserQuery, exclude_items: list[int] | None = None
) -> list[ItemCandidate]:
if user.movie_ids:
exclude_items = exclude_items or []
exclude_items = [*exclude_items, *user.movie_ids]

query = user.to_query(
num_hashes=self.num_hashes, num_embeddings=self.num_embeddings
)
return await self.recommend_with_query(query)
return await self.recommend_with_query(query, exclude_items=exclude_items)

@bentoml.api()
@logger.catch(reraise=True)
async def recommend_with_user_id(
self: Self, user_id: int, exclude_items: list[int] | None = None
) -> list[ItemCandidate]:
user = await self.user_id(user_id)
return await self.recommend_with_user(user, exclude_items=exclude_items)

@bentoml.api()
@logger.catch(reraise=True)
async def user_id(self: Self, user_id: int) -> UserQuery:
return await self.item_index.to_async.user_id(user_id)

@bentoml.api()
@logger.catch(reraise=True)
Expand Down
6 changes: 6 additions & 0 deletions mf_torch/flaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def get_lightning_args(
# "num_heads": num_heads,
# "dropout": config["dropout"],
"hard_negatives_ratio": hard_negatives_ratio,
"sigma": config["sigma"],
"margin": config["margin"],
"learning_rate": config["learning_rate"],
}
data_args = {
Expand Down Expand Up @@ -79,6 +81,8 @@ def flaml_tune() -> flaml.tune.tune.ExperimentAnalysis:
"train_loss": flaml.tune.choice(train_losses),
"use_hard_negatives": flaml.tune.choice([True, False]),
"hard_negatives_ratio": flaml.tune.quniform(0.5, 2.0, 0.01),
"sigma": flaml.tune.lograndint(1, 1000),
"margin": flaml.tune.quniform(-1.0, 1.0, 0.01),
"learning_rate": flaml.tune.qloguniform(0.001, 0.1, 0.001),
}
low_cost_partial_config = {}
Expand All @@ -92,6 +96,8 @@ def flaml_tune() -> flaml.tune.tune.ExperimentAnalysis:
"train_loss": "PairwiseHingeLoss",
"use_hard_negatives": False,
"hard_negatives_ratio": 1.0,
"sigma": 1.0,
"margin": 1.0,
"learning_rate": 0.1,
}
return flaml.tune.run(
Expand Down
13 changes: 11 additions & 2 deletions mf_torch/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(
dropout: float = 0.0,
normalize: bool = True,
hard_negatives_ratio: float | None = None,
sigma: float = 1.0,
margin: float = 1.0,
learning_rate: float = 0.01,
) -> None:
super().__init__()
Expand Down Expand Up @@ -277,7 +279,11 @@ def loss_fns(self: Self) -> torch.nn.ModuleList:
mf_losses.PairwiseLogisticLoss,
]
loss_fns = [
loss_class(hard_negatives_ratio=self.hparams.get("hard_negatives_ratio"))
loss_class(
hard_negatives_ratio=self.hparams.get("hard_negatives_ratio"),
sigma=self.hparams.get("sigma"),
margin=self.hparams.get("margin"),
)
for loss_class in loss_classes
]
return torch.nn.ModuleList(loss_fns)
Expand Down Expand Up @@ -440,4 +446,7 @@ def cli_main(

if __name__ == "__main__":
cli_main()
# cli = cli_main(args={"fit": {"trainer": {"overfit_batches": 1}}})
# cli = cli_main(
# args={"fit": {"trainer": {"overfit_batches": 1, "num_sanity_val_steps": 0}}}
# )
# cli.model.export_dynamo_onnx()
Loading

0 comments on commit 516155b

Please sign in to comment.