Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to conform to the general practice used #184

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions common/evaluators/qa_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ def get_scores(self):
for batch in self.data_loader:
qids.extend(batch.id.detach().cpu().numpy())
# Select embedding
sent1, sent2 = self.get_sentence_embeddings(batch)

output = self.model(sent1, sent2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw)
if hasattr(self.model, 'skip_embedding_lookup') and self.model.skip_embedding_lookup:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there are two conditions here? What if only one condition is True?

output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw)
else:
sent1, sent2 = self.get_sentence_embeddings(batch)
output = self.model(sent1, sent2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw)

test_cross_entropy_loss += F.cross_entropy(output, batch.label, size_average=False).item()

true_labels.extend(batch.label.detach().cpu().numpy())
Expand Down
9 changes: 6 additions & 3 deletions common/trainers/qa_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ def train_epoch(self, epoch):
for batch_idx, batch in enumerate(self.train_loader):
self.optimizer.zero_grad()

# Select embedding
sent1, sent2 = self.get_sentence_embeddings(batch)
if hasattr(self.model, 'skip_embedding_lookup') and self.model.skip_embedding_lookup:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question

output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw)
else:
# Select embedding
sent1, sent2 = self.get_sentence_embeddings(batch)
output = self.model(sent1, sent2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw)

output = self.model(sent1, sent2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw)
loss = F.nll_loss(output, batch.label, size_average=False)
total_loss += loss.item()
loss.backward()
Expand Down
5 changes: 0 additions & 5 deletions sm_cnn/.gitignore

This file was deleted.

8 changes: 1 addition & 7 deletions sm_cnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ Your repository root should be in your `PYTHONPATH` environment variable:
export PYTHONPATH=$(pwd)
```

To create the dataset:
```bash
cd Castor/sm_cnn/
./create_dataset.sh
```

We use `trec_eval` for evaluation:

```bash
Expand All @@ -39,7 +33,7 @@ You can train the SM model for the 4 following configurations:
To train on GPU 0 with static configuration:

```bash
python train.py --mode static --gpu 0
python train.py --mode static --device 0
```

NB: pass `--no_cuda` to use CPU
Expand Down
File renamed without changes.
127 changes: 127 additions & 0 deletions sm_cnn/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import argparse
import logging
from copy import deepcopy

import os
import pprint
import random

import numpy as np
import torch

from common.dataset import DatasetFactory
from common.evaluation import EvaluatorFactory
from common.train import TrainerFactory
from utils.serialization import load_checkpoint
from sm_cnn.model import SMCNN
from sm_cnn.args import get_args


def get_logger():
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger

def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device, keep_results=False):
saved_model_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, loader, batch_size, device,
keep_results=keep_results)
scores, metric_names = saved_model_evaluator.get_scores()
logger.info('Evaluation metrics for {}'.format(split_name))
logger.info('\t'.join([' '] + metric_names))
logger.info('\t'.join([split_name] + list(map(str, scores))))

if __name__ == '__main__':
# Getting args
args = get_args()
config = deepcopy(args)

# Getting logger
logger = get_logger()
logger.info(pprint.pformat(vars(args)))

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.device != -1:
torch.cuda.manual_seed(args.seed)

# Dealing with device
if not args.cuda:
args.gpu = -1
if torch.cuda.is_available() and args.cuda:
logger.info("Note: You are using GPU for training")
torch.cuda.set_device(args.gpu)
torch.cuda.manual_seed(args.seed)
if torch.cuda.is_available() and not args.cuda:
logger.info("Warning: You have Cuda but do not use it. You are using CPU for training")
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() and args.device >= 0 else 'cpu')

if args.dataset not in ('trecqa', 'wikiqa'):
raise ValueError('Unrecognized dataset')

dataset_cls, embedding, train_loader, test_loader, dev_loader \
= DatasetFactory.get_dataset(args.dataset, args.word_vectors_dir, args.word_vectors_file, args.batch_size,
args.device)

config.questions_num = dataset_cls.VOCAB_SIZE
config.answers_num = dataset_cls.VOCAB_SIZE
config.target_class = dataset_cls.NUM_CLASSES
model = SMCNN(config)


model = model.to(device)
embedding = embedding.to(device)


optimizer = torch.optim.Adadelta(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

train_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, train_loader, args.batch_size,
args.device)
test_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, test_loader, args.batch_size,
args.device)
dev_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, dev_loader, args.batch_size,
args.device)


trainer_config = {
'optimizer': optimizer,
'batch_size': args.batch_size,
'log_interval': args.log_interval,
'model_outfile': args.model_outfile,
'patience': args.patience,
'tensorboard': args.tensorboard,
'run_label': args.run_label,
'logger': logger
}

trainer = TrainerFactory.get_trainer(args.dataset, model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator)

if not args.skip_training:
total_params = 0
for param in model.parameters():
size = [s for s in param.size()]
total_params += np.prod(size)
logger.info('Total number of parameters: %s', total_params)
model.static_question_embed.weight.data.copy_(embedding.weight)
model.nonstatic_question_embed.weight.data.copy_(embedding.weight)
model.static_answer_embed.weight.data.copy_(embedding.weight)
model.nonstatic_answer_embed.weight.data.copy_(embedding.weight)

trainer.train(args.epochs)

_, _, state_dict, _, _ = load_checkpoint(args.model_outfile)

for k, tensor in state_dict.items():
state_dict[k] = tensor.to(device)

model.load_state_dict(state_dict)
if dev_loader:
evaluate_dataset('dev', dataset_cls, model, embedding, dev_loader, args.batch_size, args.device)
evaluate_dataset('test', dataset_cls, model, embedding, test_loader, args.batch_size, args.device, args.keep_results)
31 changes: 23 additions & 8 deletions sm_cnn/args.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
from argparse import ArgumentParser
import os

def get_args():
parser = ArgumentParser(description="SM CNN")

parser.add_argument('model_outfile', help='file to save final model')
parser.add_argument('--dataset', type=str, help='trecqa|wikiqa', default='trecqa')
parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda')
parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU
parser.add_argument('--word-vectors-dir', help='word vectors directory',
default=os.path.join(os.pardir, 'Castor-data', 'embeddings', 'word2vec'))
parser.add_argument('--word-vectors-file', help='word vectors filename', default='aquaint+wiki.txt.gz.ndim=50.txt')
parser.add_argument('--word-vectors-dim', type=int, default=50,
help='number of dimensions of word vectors (default: 50)')
parser.add_argument('--skip-training', help='will load pre-trained model', action='store_true')
parser.add_argument('--device', type=int, default=0, help='GPU device, -1 for CPU (default: 0)')
parser.add_argument('--batch-size', type=int, default=64, help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--mode', type=str, default='static')
parser.add_argument('--lr', type=float, default=1.0)
parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)')
parser.add_argument('--seed', type=int, default=3435)
parser.add_argument('--dataset', type=str, help='TREC|wiki', default='TREC')
parser.add_argument('--resume_snapshot', type=str, default=None)
parser.add_argument('--dev_every', type=int, default=30)
parser.add_argument('--log_every', type=int, default=10)
parser.add_argument('--patience', type=int, default=50)
parser.add_argument('--save_path', type=str, default='saves')
parser.add_argument('--output_channel', type=int, default=100)
parser.add_argument('--filter_width', type=int, default=5)
parser.add_argument('--words_dim', type=int, default=50)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--epoch_decay', type=int, default=15)
parser.add_argument('--vector_cache', type=str, default='data/word2vec.trecqa.pt')
parser.add_argument('--trained_model', type=str, default="")
parser.add_argument('--weight_decay',type=float, default=1e-5)
parser.add_argument('--weight_decay', type=float, default=1e-5)
parser.add_argument('--onnx', action='store_true', help='export model to onnx')
parser.add_argument('--mode', type=str, default='rand')
parser.add_argument('--keep-results', action='store_true',
help='store the output score and qrel files into disk for the test set')
parser.add_argument('--log-interval', type=int, default=10,
help='how many batches to wait before logging training status (default: 10)')
parser.add_argument('--tensorboard', action='store_true', default=False,
help='use TensorBoard to visualize training (default: false)')
parser.add_argument('--run-label', type=str, help='label to describe run')


args = parser.parse_args()
return args
Loading