-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[add] improve examples and more possibilities
- Loading branch information
1 parent
ded6bbc
commit 3bf8801
Showing
25 changed files
with
7,040 additions
and
465 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from langchain.prompts import PromptTemplate | ||
from quest.model.vllm import VLLM | ||
from quest import Quest, RLHFSuffixProposal | ||
|
||
from quest.reward.base import ( | ||
Reward, | ||
ConstantReward, | ||
BackwardReward, | ||
) | ||
|
||
from quest.index import Uniform | ||
from quest.reward.model import ( | ||
ContextualRewardModel, | ||
) | ||
from datasets import load_dataset | ||
import os | ||
|
||
dataset_path = "Anthropic/hh-rlhf" | ||
|
||
|
||
def process_data(entry): | ||
breaks = entry["chosen"].split("\n\n") | ||
|
||
twopb = breaks[-1].split(":") | ||
descriptor = twopb[0] | ||
answer = ":".join(twopb[1:]) | ||
|
||
breaks[-1] = f"{descriptor}: " | ||
prompt = "\n\n".join(breaks) | ||
|
||
return { | ||
"prompt": prompt, | ||
"answer": answer, | ||
} | ||
|
||
|
||
def main( | ||
oposite: bool = False, | ||
beta: float = 0.5, | ||
steps: int = 25, | ||
temperature: float = 0.6, | ||
n: int = 1, | ||
model_path: str = "meta-llama/Meta-Llama-3-8B", | ||
reward_model_path: str = "OpenAssistant/reward-model-deberta-v3-large-v2", | ||
): | ||
|
||
ds = load_dataset( | ||
dataset_path, split="test" | ||
) | ||
|
||
psd = ds.map(process_data) | ||
|
||
data_iterable = list(psd)[:n] | ||
|
||
model = VLLM( | ||
model_path=model_path, | ||
# prompt_template=template, | ||
download_dir=os.environ.get( | ||
"HF_HOME", "/tmp/" | ||
), | ||
stop_tokens=["\n"], | ||
temperature=temperature, | ||
) | ||
|
||
reward = ContextualRewardModel( | ||
model_path=reward_model_path | ||
) # sentiment model. | ||
# ConstantReward(1.0)# | ||
|
||
context = [ | ||
model.get_prompt(**data) | ||
for data in data_iterable | ||
] | ||
|
||
reward.set_context(context) | ||
|
||
if oposite: | ||
reward = BackwardReward(reward) | ||
|
||
index = Uniform() | ||
|
||
chain = Quest( | ||
input_data=data_iterable, | ||
proposal=RLHFSuffixProposal( | ||
model=model, dist=index | ||
), | ||
reward=reward, | ||
beta=beta, | ||
) | ||
|
||
chain_outputs = chain.run( | ||
steps=steps, | ||
use_tqdm=True, | ||
) | ||
|
||
print(data_iterable) | ||
|
||
for s in chain_outputs.state_path: | ||
print((s["reward"], s["text"])) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
import fire | ||
|
||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from expkit import ExpSetup | ||
from expkit.ops import ( | ||
EvalMean, | ||
EvalLast, | ||
EvalMax, | ||
EvalTotalMean, | ||
EvalMeanLast, | ||
EvalMeanMax, | ||
Operation, | ||
) | ||
from functools import partial | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import seaborn as sns | ||
import re | ||
|
||
# beta in used by the rlhf model: 0.0325 | ||
N = 1024 | ||
K = 64 | ||
TEMP = 0.6 | ||
N_MEASURAMENTS = [ | ||
1, | ||
2, | ||
4, | ||
8, | ||
16, | ||
32, | ||
64, | ||
128, | ||
256, | ||
512, | ||
1024, | ||
] | ||
|
||
eval_key = ( | ||
"lastnumber" # "crm:hamishivi-tulu-v2" | ||
) | ||
rm_path = "hamishivi/tulu-v2.5-7b-uf-rm" | ||
|
||
|
||
setup = ExpSetup( | ||
"/gscratch/ark/graf/quest-rlhf/mt-outputs/", | ||
lazy=True, | ||
load_instances=False, | ||
).query({"steps": 128}) | ||
|
||
|
||
# print(setup.meta()) | ||
|
||
|
||
base = setup.query( | ||
{ | ||
"variant": "ancestral", | ||
} | ||
) | ||
|
||
quest = setup.query({"variant": "quest"}) | ||
|
||
|
||
for lp in ["en-zh", "en-cs"]: | ||
print("---" * 20) | ||
print(lp) | ||
print("base") | ||
print( | ||
len( | ||
base.query( | ||
{"language_pair": lp} | ||
).meta() | ||
) | ||
) | ||
print( | ||
( | ||
[ | ||
x["temperature"] | ||
for x in quest.query( | ||
{"language_pair": lp} | ||
).meta() | ||
] | ||
) | ||
) | ||
print("quest") | ||
print( | ||
len( | ||
quest.query( | ||
{"language_pair": lp} | ||
).meta() | ||
) | ||
) | ||
print( | ||
( | ||
[ | ||
x["beta"] | ||
for x in quest.query( | ||
{"language_pair": lp} | ||
).meta() | ||
] | ||
) | ||
) |
Oops, something went wrong.