-
Notifications
You must be signed in to change notification settings - Fork 33
/
run.py
64 lines (51 loc) · 2.31 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import pickle
import torch
from torchmf import (BaseModule, BPRModule, BasePipeline,
bpr_loss, PairwiseInteractions)
import utils
def explicit():
train, test = utils.get_movielens_train_test_split()
pipeline = BasePipeline(train, test=test, model=BaseModule,
n_factors=10, batch_size=1024, dropout_p=0.02,
lr=0.02, weight_decay=0.1,
optimizer=torch.optim.Adam, n_epochs=40,
verbose=True, random_seed=2017)
pipeline.fit()
def implicit():
train, test = utils.get_movielens_train_test_split(implicit=True)
pipeline = BasePipeline(train, test=test, verbose=True,
batch_size=1024, num_workers=4,
n_factors=20, weight_decay=0,
dropout_p=0., lr=.2, sparse=True,
optimizer=torch.optim.SGD, n_epochs=40,
random_seed=2017, loss_function=bpr_loss,
model=BPRModule,
interaction_class=PairwiseInteractions,
eval_metrics=('auc', 'patk'))
pipeline.fit()
def hogwild():
train, test = utils.get_movielens_train_test_split(implicit=True)
pipeline = BasePipeline(train, test=test, verbose=True,
batch_size=1024, num_workers=4,
n_factors=20, weight_decay=0,
dropout_p=0., lr=.2, sparse=True,
optimizer=torch.optim.SGD, n_epochs=40,
random_seed=2017, loss_function=bpr_loss,
model=BPRModule, hogwild=True,
interaction_class=PairwiseInteractions,
eval_metrics=('auc', 'patk'))
pipeline.fit()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='torchmf')
parser.add_argument('--example',
help='explicit, implicit, or hogwild')
args = parser.parse_args()
if args.example == 'explicit':
explicit()
elif args.example == 'implicit':
implicit()
elif args.example == 'hogwild':
hogwild()
else:
print('example must be explicit, implicit, or hogwild')