Skip to content

Commit

Permalink
Merge pull request #25 from CyberAgentAI/feature/load-data-in-c
Browse files Browse the repository at this point in the history
Feature: Add data path to train method
  • Loading branch information
nsakki55 authored Jun 30, 2023
2 parents 719fa62 + 4cf62b5 commit 88ee3a3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
14 changes: 10 additions & 4 deletions ffm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ def read_ffm_model(cls, model_path: str) -> "Model":


def train(
train_data: Dataset,
train_data: Optional[Dataset] = None,
train_path: Optional[str] = None,
valid_data: Optional[Dataset] = None,
valid_path: Optional[str] = None,
eta: float = 0.2,
lam: float = 0.00002,
nr_iters: int = 15,
Expand All @@ -107,17 +109,21 @@ def train(
random: bool = True,
nds_rate: float = 1.0,
) -> Model:
tr = (train_data.data, train_data.labels)
iw = train_data.importance_weights
tr, iw = None, None
if train_data is not None:
tr = (train_data.data, train_data.labels)
iw = train_data.importance_weights

va, iwv = None, None
if valid_data is not None:
va = (valid_data.data, valid_data.labels)
iwv = valid_data.importance_weights

weights, best_iteration, normalization, best_va_loss = libffm_train(
tr,
tr=tr,
tr_path=train_path,
va=va,
va_path=valid_path,
iw=iw,
iwv=iwv,
eta=eta,
Expand Down
16 changes: 13 additions & 3 deletions ffm/libffm.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ cdef extern from "ffm.h" namespace "ffm" nogil:
ffm_float best_va_loss

ffm_model *ffm_train_with_validation(ffm_problem *Tr, ffm_problem *Va, ffm_importance_weights *iws, ffm_importance_weights *iwvs, ffm_parameter param);
ffm_problem *ffm_read_problem(char *path);


cdef ffm_problem* make_ffm_prob(X, y):
Expand Down Expand Up @@ -174,8 +175,10 @@ cdef object _train(


def train(
tr,
tr=None,
tr_path=None,
va=None,
va_path=None,
iw=None,
iwv=None,
eta=0.2,
Expand Down Expand Up @@ -205,11 +208,18 @@ def train(
param.nds_rate = nds_rate

cdef:
ffm_problem* tr_ptr = make_ffm_prob(tr[0], tr[1])
ffm_problem* tr_ptr
ffm_problem* va_ptr
ffm_importance_weights *iw_ptr, *iwv_ptr

if va is not None:
if tr_path is not None:
tr_ptr = ffm_read_problem(tr_path.encode("utf-8"))
else:
tr_ptr = make_ffm_prob(tr[0], tr[1])

if va_path is not None:
va_ptr = ffm_read_problem(va_path.encode("utf-8"))
elif va is not None:
va_ptr = make_ffm_prob(va[0], va[1])
else:
va_ptr = NULL
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

setup(
name="ffm",
version="0.3.1",
version="0.4.0",
description="LibFFM Python Package",
long_description="LibFFM Python Package",
install_requires=["numpy"],
Expand Down

0 comments on commit 88ee3a3

Please sign in to comment.