Skip to content

Commit

Permalink
Implemented hf datasets for seq2seq and qa
Browse files Browse the repository at this point in the history
  • Loading branch information
Thilina Rajapakse committed Mar 19, 2021
1 parent ca6bae5 commit 1b6c133
Show file tree
Hide file tree
Showing 21 changed files with 761 additions and 226 deletions.
40 changes: 29 additions & 11 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,56 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.61.0] - 2021-03-19

### Added

- Added support for RAG models (in `Seq2Seq`) - docs will be updated soon
- Added support for Huggingface Datasets library for memory efficient training. Currently supports:
- Classification (all)
- NER
- Language Modeling
- Seq2Seq
- T5
- QA (Note that HF Datasets might not always work with QAModel)

### Changed
- Switched to using FastTokenizers where possible

## [0.60.8] - 2021-02-12

# Fixed
### Fixed

- Fixed bug in loading cached features with classification models

## [0.60.7] - 2021-02-11

# Changed
### Changed

- Multiprocessing during tokenization is now turned on by default. If you experience any instability, this can be turned off by setting `use_multiprocessing=False`

## [0.60.6] - 2021-02-05

# Changed
### Changed

- Multiprocessing during tokenization is now turned off by default. You can enable this by setting `use_multiprocessing=True`. However, the latest Pytorch versions seems to be unstable when using multiprocessing.


## [0.60.3] - 2021-02-02

# Changed
### Changed

- Multiprocessing is now turned off by default for evaluation. This is to avoid potential errors when doing evaluation during training. You can enable this by setting `use_multiprocessing_for_evaluation` to `True`.

## [0.60.2] - 2021-02-02

# Fixed
### Fixed

- Fixed bug in ClassificationDataset [mapmeld](https://github.com/mapmeld)

## [0.60.1] - 2021-02-02

# Added
### Added

- Added new NER models:
- ALBERT
Expand All @@ -52,27 +68,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.60.0] - 2021-02-02

# Added
### Added

- Added class weights support for Longformer classification
- Added new classification models:
- SqueezeBert
- DeBERTa
- MPNet

# Changed
### Changed

- Updated ClassificationModel logic to make it easier to add new models

## [0.51.16] - 2021-01-29

## Fixed
### Fixed

- Fixed bug in LayoutLM classification

## [0.51.15] - 2021-01-24

## Fixed
### Fixed

- Fixed bug in Language Generation models [mapmeld](https://github.com/mapmeld)
- Fixed bug in MBart models [nilboy](https://github.com/nilboy)
Expand Down Expand Up @@ -1446,7 +1462,9 @@ Model checkpoint is now saved for all epochs again.

- This CHANGELOG file to hopefully serve as an evolving example of a standardized open source project CHANGELOG.

[0.60.2]: https://github.com/ThilinaRajapakse/simpletransformers/compare/de989b5...HEAD
[0.61.0]: https://github.com/ThilinaRajapakse/simpletransformers/compare/76f1df5...HEAD

[0.60.2]: https://github.com/ThilinaRajapakse/simpletransformers/compare/de989b5...76f1df5

[0.60.1]: https://github.com/ThilinaRajapakse/simpletransformers/compare/6f189e0...de989b5

Expand Down
18 changes: 0 additions & 18 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,6 @@ types:
test: clean
pytest tests --cov simpletransformers/classification simpletransformers/ner simpletransformers/question_answering simpletransformers/language_modeling simpletransformers/t5 simpletransformers/seq2seq

test-classification:
pytest tests --cov simpletransformers/classification

test-ner:
pytest tests --cov simpletransformers/ner

test-question_answering:
pytest tests --cov simpletransformers/question_answering

test-language_modeling:
pytest tests --cov simpletransformers/language_modeling

test-t5:
pytest tests --cov simpletransformers/t5

test-seq2seq:
pytest tests --cov simpletransformers/seq2seq

# if this runs through we can be sure the readme is properly shown on pypi
check-readme:
python setup.py check --restructuredtext
18 changes: 12 additions & 6 deletions simpletransformers/classification/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,13 @@ def load_and_cache_examples(
return dataset
else:
dataset = ClassificationDataset(
examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode
examples,
self.tokenizer,
self.args,
mode=mode,
multi_label=multi_label,
output_mode=output_mode,
no_cache=no_cache,
)
return dataset

Expand Down Expand Up @@ -1454,9 +1460,9 @@ def predict(self, to_predict, multi_label=False):
preds = None
out_label_ids = None
for i, batch in enumerate(tqdm(eval_dataloader, disable=args.silent, desc="Running Prediction")):
# batch = tuple(t.to(device) for t in batch)
# batch = tuple(t.to(self.device) for t in batch)
with torch.no_grad():
inputs = self._get_inputs_dict(batch)
inputs = self._get_inputs_dict(batch, no_hf=True)

if self.args.fp16:
with amp.autocast():
Expand Down Expand Up @@ -1501,7 +1507,7 @@ def predict(self, to_predict, multi_label=False):
# batch = tuple(t.to(device) for t in batch)

with torch.no_grad():
inputs = self._get_inputs_dict(batch)
inputs = self._get_inputs_dict(batch, no_hf=True)

if self.args.fp16:
with amp.autocast():
Expand Down Expand Up @@ -1624,8 +1630,8 @@ def _threshold(self, x, threshold):
def _move_model_to_device(self):
self.model.to(self.device)

def _get_inputs_dict(self, batch):
if self.args.use_hf_datasets:
def _get_inputs_dict(self, batch, no_hf=False):
if self.args.use_hf_datasets and not no_hf:
return {key: value.to(self.device) for key, value in batch.items()}
if isinstance(batch[0], dict):
inputs = {key: value.squeeze(1).to(self.device) for key, value in batch[0].items()}
Expand Down
30 changes: 18 additions & 12 deletions simpletransformers/classification/classification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,16 @@ def preprocess_data_multiprocessing(data):


def preprocess_batch_for_hf_dataset(dataset, tokenizer, max_seq_length):
return tokenizer(
text=dataset["text_a"],
text_pair=dataset["text_b"],
truncation=True,
padding="max_length",
max_length=max_seq_length,
)
if "text_b" in dataset:
return tokenizer(
text=dataset["text_a"],
text_pair=dataset["text_b"],
truncation=True,
padding="max_length",
max_length=max_seq_length,
)
else:
return tokenizer(text=dataset["text"], truncation=True, padding="max_length", max_length=max_seq_length,)


def preprocess_data(text_a, text_b, labels, tokenizer, max_seq_length):
Expand Down Expand Up @@ -147,7 +150,7 @@ def preprocess_data(text_a, text_b, labels, tokenizer, max_seq_length):
# return [tokenized_example, [example.label for example in data]]


def build_classification_dataset(data, tokenizer, args, mode, multi_label, output_mode):
def build_classification_dataset(data, tokenizer, args, mode, multi_label, output_mode, no_cache):
cached_features_file = os.path.join(
args.cache_dir,
"cached_{}_{}_{}_{}_{}".format(mode, args.model_type, args.max_seq_length, len(args.labels_list), len(data),),
Expand Down Expand Up @@ -212,17 +215,17 @@ def build_classification_dataset(data, tokenizer, args, mode, multi_label, outpu

data = (examples, labels)

if not args.no_cache:
if not args.no_cache and not no_cache:
logger.info(" Saving features into cached file %s", cached_features_file)
torch.save(data, cached_features_file)

return (examples, labels)


class ClassificationDataset(Dataset):
def __init__(self, data, tokenizer, args, mode, multi_label, output_mode):
def __init__(self, data, tokenizer, args, mode, multi_label, output_mode, no_cache):
self.examples, self.labels = build_classification_dataset(
data, tokenizer, args, mode, multi_label, output_mode
data, tokenizer, args, mode, multi_label, output_mode, no_cache
)

def __len__(self):
Expand Down Expand Up @@ -255,7 +258,10 @@ def load_hf_dataset(data, tokenizer, args, multi_label):
batched=True,
)

dataset.set_format(type="pt", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])
if args.model_type in ["bert", "xlnet", "albert", "layoutlm"]:
dataset.set_format(type="pt", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])
else:
dataset.set_format(type="pt", columns=["input_ids", "attention_mask", "labels"])

if isinstance(data, str):
# This is not necessarily a train dataset. The datasets library insists on calling it train.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def load_and_cache_examples(
examples, evaluate=evaluate, no_cache=no_cache, multi_label=multi_label, verbose=verbose, silent=silent
)

def compute_metrics(self, preds, labels, eval_examples, multi_label=True, **kwargs):
return super().compute_metrics(preds, labels, eval_examples, multi_label=multi_label, **kwargs)
def compute_metrics(self, preds, model_outputs, labels, eval_examples, multi_label=True, **kwargs):
return super().compute_metrics(preds, model_outputs, labels, eval_examples, multi_label=multi_label, **kwargs)

def predict(self, to_predict, multi_label=True):
return super().predict(to_predict, multi_label=multi_label)
28 changes: 15 additions & 13 deletions simpletransformers/config/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def get_special_tokens():

@dataclass
class ModelArgs:
adafactor_beta1: float = None
adafactor_clip_threshold: float = 1.0
adafactor_decay_rate: float = -0.8
adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3))
adafactor_relative_step: bool = True
adafactor_scale_parameter: bool = True
adafactor_warmup_init: bool = True
adam_epsilon: float = 1e-8
best_model_dir: str = "outputs/best_model"
cache_dir: str = "cache_dir/"
Expand All @@ -38,13 +45,6 @@ class ModelArgs:
early_stopping_metric_minimize: bool = True
early_stopping_patience: int = 3
encoding: str = None
adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3))
adafactor_clip_threshold: float = 1.0
adafactor_decay_rate: float = -0.8
adafactor_beta1: float = None
adafactor_scale_parameter: bool = True
adafactor_relative_step: bool = True
adafactor_warmup_init: bool = True
eval_batch_size: int = 8
evaluate_during_training: bool = False
evaluate_during_training_silent: bool = True
Expand All @@ -70,9 +70,9 @@ class ModelArgs:
optimizer: str = "AdamW"
output_dir: str = "outputs/"
overwrite_output_dir: bool = False
process_count: int = field(default_factory=get_default_process_count)
polynomial_decay_schedule_lr_end: float = 1e-7
polynomial_decay_schedule_power: float = 1.0
process_count: int = field(default_factory=get_default_process_count)
quantized_model: bool = False
reprocess_input_data: bool = True
save_best_model: bool = True
Expand All @@ -85,12 +85,13 @@ class ModelArgs:
skip_special_tokens: bool = True
tensorboard_dir: str = None
thread_count: int = None
tokenizer_type: str = None
tokenizer_name: str = None
tokenizer_type: str = None
train_batch_size: int = 8
train_custom_parameters_only: bool = False
use_cached_eval_features: bool = False
use_early_stopping: bool = False
use_hf_datasets: bool = False
use_multiprocessing: bool = True
use_multiprocessing_for_evaluation: bool = True
wandb_kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -150,7 +151,6 @@ class ClassificationArgs(ModelArgs):
special_tokens_list: list = field(default_factory=list)
stride: float = 0.8
tie_value: int = 1
use_hf_datasets: bool = False


@dataclass
Expand All @@ -168,7 +168,6 @@ class MultiLabelClassificationArgs(ModelArgs):
labels_map: dict = field(default_factory=dict)
lazy_loading: bool = False
special_tokens_list: list = field(default_factory=list)
use_hf_datasets: bool = False


@dataclass
Expand All @@ -184,7 +183,6 @@ class NERArgs(ModelArgs):
lazy_loading_start_line: int = 0
onnx: bool = False
special_tokens_list: list = field(default_factory=list)
use_hf_datasets: bool = False


@dataclass
Expand Down Expand Up @@ -264,7 +262,6 @@ class LanguageModelingArgs(ModelArgs):
special_tokens_list: list = field(default_factory=list)
strip_accents: bool = True
local_rank: int = -1
use_hf_datasets: bool = False

def save(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -301,15 +298,20 @@ class Seq2SeqArgs(ModelArgs):
do_sample: bool = False
early_stopping: bool = True
evaluate_generated_text: bool = False
faiss_d: int = 768
faiss_m: int = 128
length_penalty: float = 2.0
max_length: int = 20
max_steps: int = -1
num_beams: int = 1
num_return_sequences: int = 1
rag_embed_batch_size: int = 16
repetition_penalty: float = 1.0
top_k: float = None
top_p: float = None
use_multiprocessed_decoding: bool = False
save_knowledge_dataset: bool = True
save_knowledge_dataset_with_checkpoints: bool = False
src_lang: str = "en_XX"
tgt_lang: str = "ro_RO"

Expand Down
Loading

0 comments on commit 1b6c133

Please sign in to comment.