Skip to content

Commit

Permalink
[add] improve examples and more possibilities
Browse files Browse the repository at this point in the history
  • Loading branch information
goncalorafaria committed Aug 5, 2024
1 parent ded6bbc commit 3bf8801
Show file tree
Hide file tree
Showing 25 changed files with 7,040 additions and 465 deletions.
106 changes: 106 additions & 0 deletions examples/jailbreak/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from langchain.prompts import PromptTemplate
from quest.model.vllm import VLLM
from quest import Quest, RLHFSuffixProposal

from quest.reward.base import (
Reward,
ConstantReward,
BackwardReward,
)

from quest.index import Uniform
from quest.reward.model import (
ContextualRewardModel,
)
from datasets import load_dataset
import os

dataset_path = "Anthropic/hh-rlhf"


def process_data(entry):
breaks = entry["chosen"].split("\n\n")

twopb = breaks[-1].split(":")
descriptor = twopb[0]
answer = ":".join(twopb[1:])

breaks[-1] = f"{descriptor}: "
prompt = "\n\n".join(breaks)

return {
"prompt": prompt,
"answer": answer,
}


def main(
oposite: bool = False,
beta: float = 0.5,
steps: int = 25,
temperature: float = 0.6,
n: int = 1,
model_path: str = "meta-llama/Meta-Llama-3-8B",
reward_model_path: str = "OpenAssistant/reward-model-deberta-v3-large-v2",
):

ds = load_dataset(
dataset_path, split="test"
)

psd = ds.map(process_data)

data_iterable = list(psd)[:n]

model = VLLM(
model_path=model_path,
# prompt_template=template,
download_dir=os.environ.get(
"HF_HOME", "/tmp/"
),
stop_tokens=["\n"],
temperature=temperature,
)

reward = ContextualRewardModel(
model_path=reward_model_path
) # sentiment model.
# ConstantReward(1.0)#

context = [
model.get_prompt(**data)
for data in data_iterable
]

reward.set_context(context)

if oposite:
reward = BackwardReward(reward)

index = Uniform()

chain = Quest(
input_data=data_iterable,
proposal=RLHFSuffixProposal(
model=model, dist=index
),
reward=reward,
beta=beta,
)

chain_outputs = chain.run(
steps=steps,
use_tqdm=True,
)

print(data_iterable)

for s in chain_outputs.state_path:
print((s["reward"], s["text"]))


if __name__ == "__main__":

import fire

fire.Fire(main)
98 changes: 98 additions & 0 deletions examples/mt/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from expkit import ExpSetup
from expkit.ops import (
EvalMean,
EvalLast,
EvalMax,
EvalTotalMean,
EvalMeanLast,
EvalMeanMax,
Operation,
)
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import re

# beta in used by the rlhf model: 0.0325
N = 1024
K = 64
TEMP = 0.6
N_MEASURAMENTS = [
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
]

eval_key = (
"lastnumber" # "crm:hamishivi-tulu-v2"
)
rm_path = "hamishivi/tulu-v2.5-7b-uf-rm"


setup = ExpSetup(
"/gscratch/ark/graf/quest-rlhf/mt-outputs/",
lazy=True,
load_instances=False,
).query({"steps": 128})


# print(setup.meta())


base = setup.query(
{
"variant": "ancestral",
}
)

quest = setup.query({"variant": "quest"})


for lp in ["en-zh", "en-cs"]:
print("---" * 20)
print(lp)
print("base")
print(
len(
base.query(
{"language_pair": lp}
).meta()
)
)
print(
(
[
x["temperature"]
for x in quest.query(
{"language_pair": lp}
).meta()
]
)
)
print("quest")
print(
len(
quest.query(
{"language_pair": lp}
).meta()
)
)
print(
(
[
x["beta"]
for x in quest.query(
{"language_pair": lp}
).meta()
]
)
)
Loading

0 comments on commit 3bf8801

Please sign in to comment.