Skip to content

Commit

Permalink
Merge pull request #29 from alan-turing-institute/16-inference-pipeli…
Browse files Browse the repository at this point in the history
…ne-for-baskerville

16 inference pipeline for baskerville
  • Loading branch information
J-Dymond authored Nov 29, 2024
2 parents 7df8c51 + 98fbdcc commit 2f15ada
Show file tree
Hide file tree
Showing 17 changed files with 1,229 additions and 469 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ slurm_scripts/slurm_logs*
# other
temp
.vscode
outputs
local_notebooks

# test caches
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
ci:
autoupdate_commit_msg: "chore: update pre-commit hooks"
autofix_commit_msg: "style: pre-commit fixes"
skip: [pytest-check]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
11 changes: 11 additions & 0 deletions config/RTC_configs/roberta-mt5-zero-shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
OCR:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"

translator:
specific_task: "translation_fr_to_en"
model: "ybanas/autotrain-fr-en-translate-51410121895"

classifier:
specific_task: "zero-shot-classification"
model: "claritylab/zero-shot-explicit-binary-bert"
7 changes: 7 additions & 0 deletions config/data_configs/l1_fr_to_en.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
data_dir: "data"

level: 1

lang_pair:
source: "fr"
target: "en"
77 changes: 77 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,80 @@ python scripts/eval_topic_classifier.py \
--report_to tensorboard \
--dataset_name validation # change to "test" for test set evaluation
```

## pipeline_inference.py

Run inference on a complete pipeline over the MultiEURLEX dataset. Requires a data config path and a pipeline config path,
this should be `.yaml` files and should be structured as such:

### Dataset config:

```yaml
data_dir: "data"

level: 1

lang_pair:
source: "fr"
target: "en"
```
### Pipeline config:
```yaml
OCR:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"

translator:
specific_task: "translation_fr_to_en"
model: "ybanas/autotrain-fr-en-translate-51410121895"

classifier:
specific_task: "zero-shot-classification"
model: "claritylab/zero-shot-explicit-binary-bert"

```

It's called like so e.g. from project root:
```bash
python scripts/pipeline_inference.py [pipeline_config_path] [data_config_path]
```

## single_component_inference.py

Run inference on a single component of the pipeline over the MultiEURLEX dataset. Requires a data config path and a pipeline config path,
which should be `.yaml` files structured as below. It also takes an additional argument specifying the specific pipeline stage to be evaluated.
This should be one of `ocr`, `translator`, or `classifier`.

### Dataset config:

```yaml
data_dir: "data"

level: 1

lang_pair:
source: "fr"
target: "en"
```
### Pipeline config:
```yaml
OCR:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"

translator:
specific_task: "translation_fr_to_en"
model: "ybanas/autotrain-fr-en-translate-51410121895"

classifier:
specific_task: "zero-shot-classification"
model: "claritylab/zero-shot-explicit-binary-bert"

```

It's called like so e.g. from project root:
```bash
python scripts/pipeline_inference.py [pipeline_config_path] [data_config_path] translator
```
50 changes: 50 additions & 0 deletions scripts/pipeline_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import json
import os

from jsonargparse import CLI

from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
from arc_spice.utils import open_yaml_path
from arc_spice.variational_pipelines.RTC_variational_pipeline import (
RTCVariationalPipeline,
)

OUTPUT_DIR = "outputs"


def main(pipeline_config_pth: str, data_config_pth: str):
"""
Run inference on a given pipeline with provided data config
Args:
pipeline_config_pth: path to pipeline config yaml file
data_config_pth: path to data config yaml file
"""
# initialise pipeline
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
test_loader = data_sets["test"]
rtc_variational_pipeline = RTCVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
)
results_getter = ResultsGetter(meta_data["n_classes"])

test_results = run_inference(
dataloader=test_loader,
pipeline=rtc_variational_pipeline,
results_getter=results_getter,
)

data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}"
os.makedirs(save_loc, exist_ok=True)

with open(f"{save_loc}/full_pipeline.json", "w") as save_file:
json.dump(test_results, save_file)


if __name__ == "__main__":
CLI(main)
85 changes: 85 additions & 0 deletions scripts/single_component_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Steps:
- Load data
- Load pipeline/model
- Run inference on all test data
- Save outputs of specified model (on clean data)
- Calculate error of specified model (on clean data)
- Save results
- File structure:
- output/check_callibration/pipeline_name/run_[X]/[OUTPUT FILES HERE]
"""

import json
import os

from jsonargparse import CLI

from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
from arc_spice.utils import open_yaml_path
from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
ClassificationVariationalPipeline,
RecognitionVariationalPipeline,
TranslationVariationalPipeline,
)

OUTPUT_DIR = "outputs"


def main(pipeline_config_pth: str, data_config_pth: str, model_key: str):
"""
Run inference on a given pipeline component with provided data config and model key.
Args:
pipeline_config_pth: path to pipeline config yaml file
data_config_pth: path to data config yaml file
model_key: name of model on which to run inference
"""
# initialise pipeline
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
test_loader = data_sets["test"]
if model_key == "ocr":
rtc_single_component_pipeline = RecognitionVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
)
elif model_key == "translator":
rtc_single_component_pipeline = TranslationVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
)
elif model_key == "classifier":
rtc_single_component_pipeline = ClassificationVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
)
else:
error_msg = (
"model_key should be: 'ocr', 'translator', or 'classifier'."
f" Given: {model_key}"
)
raise ValueError(error_msg)

results_getter = ResultsGetter(meta_data["n_classes"])

test_results = run_inference(
dataloader=test_loader,
pipeline=rtc_single_component_pipeline,
results_getter=results_getter,
)

data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = (
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
f"single_component"
)
os.makedirs(save_loc, exist_ok=True)

with open(f"{save_loc}/{model_key}.json", "w") as save_file:
json.dump(test_results, save_file)


if __name__ == "__main__":
CLI(main)
5 changes: 2 additions & 3 deletions scripts/variational_RTC_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ def main(rtc_pars):
rtc_variational_pipeline.check_dropout()

# perform variational inference
clean_output, var_output = rtc_variational_pipeline.variational_inference(
test_row["source_text"]
)
clean_output, var_output = rtc_variational_pipeline.variational_inference(test_row)

comet_model = get_comet_model()

Expand All @@ -115,4 +113,5 @@ def main(rtc_pars):
"model": "claritylab/zero-shot-explicit-binary-bert",
},
}

main(rtc_pars=rtc_pars)
5 changes: 0 additions & 5 deletions src/arc_spice/eval/classification_error.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import torch


def hamming_accuracy(preds: torch.Tensor, class_labels: torch.Tensor) -> torch.Tensor:
# Inverse of the hamming loss (the fraction of labels incorrectly predicted)
return torch.mean((preds.float() == class_labels.float()).float())


def aggregate_score(probs: torch.Tensor) -> torch.Tensor:
# average 'distance' from the predicted class
preds = torch.round(probs).float()
Expand Down
Loading

0 comments on commit 2f15ada

Please sign in to comment.