Skip to content

Commit

Permalink
Implemented simple-view for classification and QA. Added CLI option
Browse files Browse the repository at this point in the history
  • Loading branch information
Thilina Rajapakse committed Jul 28, 2020
1 parent 8cebdd7 commit ffb30b9
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 32 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## UNRELEASED

- Removed blank string answer in Question Answering predictions

## [0.45.2] - 2020-07-19

### Added
Expand Down
14 changes: 14 additions & 0 deletions bin/simple-viewer
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
cat >run_simple_transformers_streamlit_app.py <<'END_SCRIPT'
#!/usr/bin/env python
from simpletransformers.streamlit.simple_view import streamlit_runner
streamlit_runner()
END_SCRIPT

# Run
streamlit run run_simple_transformers_streamlit_app.py

rm run_simple_transformers_streamlit_app.py
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
long_description_content_type="text/markdown",
url="https://github.com/ThilinaRajapakse/simpletransformers/",
packages=find_packages(),
scripts=["bin/simple-viewer"],
classifiers=[
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
Expand Down
12 changes: 12 additions & 0 deletions simpletransformers/config/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class ModelArgs:
manual_seed: int = None
max_grad_norm: float = 1.0
max_seq_length: int = 128
model_type: str = None
model_name: str = None
multiprocessing_chunksize: int = 500
n_gpu: int = 1
no_cache: bool = False
Expand Down Expand Up @@ -104,6 +106,7 @@ class ClassificationArgs(ModelArgs):
Model args for a ClassificationModel
"""

model_class: str = "ClassificationModel"
labels_list: list = field(default_factory=list)
labels_map: dict = field(default_factory=dict)
lazy_delimiter: str = "\t"
Expand All @@ -125,6 +128,7 @@ class MultiLabelClassificationArgs(ModelArgs):
Model args for a MultiLabelClassificationModel
"""

model_class: str = "MultiLabelClassificationModel"
sliding_window: bool = False
stride: float = 0.8
threshold: float = 0.5
Expand All @@ -140,6 +144,7 @@ class NERArgs(ModelArgs):
Model args for a NERModel
"""

model_class: str = "NERModel"
classification_report: bool = False
labels_list: list = field(default_factory=list)
lazy_loading: bool = False
Expand All @@ -152,6 +157,7 @@ class QuestionAnsweringArgs(ModelArgs):
Model args for a QuestionAnsweringModel
"""

model_class: str = "QuestionAnsweringModel"
doc_stride: int = 384
early_stopping_metric: str = "correct"
early_stopping_metric_minimize: bool = False
Expand All @@ -168,6 +174,7 @@ class T5Args(ModelArgs):
Model args for a T5Model
"""

model_class: str = "T5Model"
dataset_class: Dataset = None
do_sample: bool = False
early_stopping: bool = True
Expand All @@ -190,6 +197,7 @@ class LanguageModelingArgs(ModelArgs):
Model args for a LanguageModelingModel
"""

model_class: str = "LanguageModelingModel"
block_size: int = -1
config_name: str = None
dataset_class: Dataset = None
Expand Down Expand Up @@ -219,6 +227,7 @@ class Seq2SeqArgs(ModelArgs):
Model args for a Seq2SeqModel
"""

model_class: str = "Seq2SeqModel"
base_marian_model_name: str = None
dataset_class: Dataset = None
do_sample: bool = False
Expand All @@ -241,6 +250,7 @@ class LanguageGenerationArgs(ModelArgs):
Model args for a LanguageGenerationModel
"""

model_class: str = "LanguageGenerationModel"
do_sample: bool = True
early_stopping: bool = True
evaluate_generated_text: bool = False
Expand All @@ -267,6 +277,7 @@ class ConvAIArgs(ModelArgs):
Model args for a ConvAIModel
"""

model_class: str = "ConvAIModel"
do_sample: bool = True
lm_coef: float = 2.0
max_history: int = 2
Expand All @@ -286,6 +297,7 @@ class MultiModalClassificationArgs(ModelArgs):
Model args for a MultiModalClassificationModel
"""

model_class: str = "MultiModalClassificationModel"
regression: bool = False
num_image_embeds: int = 1
text_label: str = "text"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,8 +944,8 @@ def predict(self, to_predict, n_best_size=None):
examples, features, all_results, n_best_size, args.max_answer_length, False, False, True, False,
)

answer_list = [{"id": answer["id"], "answer": answer["answer"]} for answer in answers]
probability_list = [{"id": answer["id"], "probability": answer["probability"]} for answer in answers]
answer_list = [{"id": answer["id"], "answer": answer["answer"][:-1]} for answer in answers]
probability_list = [{"id": answer["id"], "probability": answer["probability"][:-1]} for answer in answers]

return answer_list, probability_list

Expand Down
Empty file.
30 changes: 0 additions & 30 deletions simpletransformers/streamlit/classification.py

This file was deleted.

Loading

0 comments on commit ffb30b9

Please sign in to comment.