Skip to content

Commit

Permalink
[add] add comet
Browse files Browse the repository at this point in the history
  • Loading branch information
goncalorafaria committed May 30, 2024
1 parent 9a18fdd commit b46821c
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 41 deletions.
5 changes: 0 additions & 5 deletions examples/em/README.md

This file was deleted.

2 changes: 1 addition & 1 deletion examples/mt/run_all.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@



export STEPS=128
export STEPS=8
export TEMPERATURE=0.8

for BETA in 0.05 0.1 0.5 1
Expand Down
52 changes: 30 additions & 22 deletions examples/mt/wmt23_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@
import numpy as np
import transformers
from datasets import load_dataset

from quest.decoding import Quest

# from models.tm import ABR_LANGUAGE_MAP
from quest.model.vllm import VLLM
from quest.reward.qe import QEModel

# from wmt22_constants import *
from quest.reward.mt import QEModel

ABR_LANGUAGE_MAP = {
"pt": "Portuguese",
Expand All @@ -27,13 +23,11 @@
"zh": "Chinese",
}


transformers.logging.set_verbosity_error()

warnings.filterwarnings("ignore") # Ignore warnings

logging.getLogger().setLevel(logging.ERROR) # Show only errors in logging

logging.basicConfig(level=logging.ERROR)

llms = {
Expand All @@ -52,18 +46,9 @@
}


def generate(
gpu_memory_utilization=0.6,
llm: str = "alma",
beta: float = 0.1,
temperature: float = 0.8,
reward_model_checkpoint="Unbabel/wmt23-cometkiwi-da-xl",
steps: int = 50,
reward_batch_size: int = 8,
device_count: int = 1,
language_pair="en-de",
def load_wmt23_data(
language_pair: str = "en-de",
):

src_lang, tgt_lang = language_pair.split("-")
data = load_dataset("haoranxu/WMT23-Test", language_pair, split="test")
input_data = [
Expand All @@ -76,6 +61,23 @@ def generate(
for sample in data
]

return input_data


def generate(
llm: str = "alma",
beta: float = 0.1,
temperature: float = 0.8,
reward_model_checkpoint="Unbabel/wmt23-cometkiwi-da-xl",
steps: int = 50,
language_pair="en-de",
reward_batch_size: int = 8,
device_count: int = 1,
gpu_memory_utilization=0.6,
):

input_data = load_wmt23_data(language_pair)

model = VLLM(
model_path=llms[llm]["path"],
prompt_template=llms[llm]["prompt"],
Expand All @@ -86,10 +88,13 @@ def generate(
)

reward = QEModel(
reward_model_checkpoint, batch_size=reward_batch_size, device_count=device_count
) # sentiment model.
model_path=reward_model_checkpoint,
batch_size=reward_batch_size,
device_count=device_count,
)

reward.set_sources([sample["source_sentence"] for sample in input_data])
sources = [sample["source_sentence"] for sample in input_data]
reward.set_sources(sources)

output = Quest(
input_data=input_data,
Expand All @@ -98,7 +103,10 @@ def generate(
beta=beta,
).run(steps=steps, use_tqdm=True)

return output.samples
return [
{"outputs": outputs, "source": src}
for outputs, src in zip(output.samples, sources)
]


def main(
Expand Down
2 changes: 1 addition & 1 deletion quest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from quest.decoding import Quest, QuestRLHF, QuestMetropolis
from quest.reward.base import Reward
from quest.reward.model import RewardModel
from quest.reward.qe import QEModel
from quest.reward.mt import QEModel
from quest.index import Uniform
35 changes: 26 additions & 9 deletions quest/reward/qe.py → quest/reward/mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
from typing import List
import os


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class QEModel(Reward):
# translation quality estimation
class CometModel(Reward):

def __init__(
self,
model_path="Unbabel/wmt23-cometkiwi-da-xl",
model_path="Unbabel/XCOMET-XL",
batch_size: int = 32,
device_count=1,
clamp: float = 1e-3,
Expand All @@ -26,10 +24,23 @@ def __init__(
self.device_count = device_count
self.clamp = clamp
self.sources = None
self.references = None

def set_sources(self, sources: List[str]):
self.sources = sources

def set_references(self, references: List[str]):
self.references = references

def make_input(self, candidates: List[str], accepted_indices: List[int]):

data = [
{"src": self.sources[i], "mt": candidates[i], "ref": self.references[i]}
for i in accepted_indices
]

return data

def evaluate(
self, candidates: List[str], accepted_indices=None, **kwargs
) -> List[float]:
Expand All @@ -43,18 +54,24 @@ def evaluate(
List[float]: The list of reward values for each candidate sequence.
"""

if accepted_indices is None:
accepted_indices = list(range(len(candidates)))

assert (
self.sources is not None
), "Please set sources before evaluating candidates."

data = [{"src": self.sources[i], "mt": candidates[i]} for i in accepted_indices]
data = self.make_input(candidates, accepted_indices)

return [
clamp_logit(score, self.clamp)
for score in self.model.predict(
data, batch_size=self.batch_size, gpus=self.device_count
)["scores"]
]


class QEModel(CometModel):

def __init__(self, model_path="Unbabel/wmt23-cometkiwi-da-xl", **kwargs):
super().__init__(model_path=model_path, **kwargs)

def make_input(self, candidates: List[str], accepted_indices: List[int]):
return [{"src": self.sources[i], "mt": candidates[i]} for i in accepted_indices]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
description="A package for sampling from intractable distributions with LLMs.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/goncalo-faria/quest-decoding",
url="https://github.com/deep-spin/quest-decoding",
packages=setuptools.find_packages(),
install_requires=installation_requirements,
python_requires=">=3.10.0",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_qe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from quest.reward.qe import QEModel
from quest.reward.mt import QEModel


def integrated_test():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_questqe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from quest.model.vllm import VLLM
from quest.decoding import Quest
from quest.index import Uniform
from quest.reward.qe import QEModel
from quest.reward.mt import QEModel
import os


Expand Down

0 comments on commit b46821c

Please sign in to comment.