-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sample notebook: model alignment with DPO (#81)
Co-authored-by: Julien Simon <[email protected]> Co-authored-by: Tyler-Odenthal <[email protected]>
- Loading branch information
1 parent
8220dd5
commit e5634ff
Showing
1 changed file
with
253 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Aligning a model on Arcee Cloud with Direct Preference Optimization (DPO)\n", | ||
"\n", | ||
"In this notebook, you will learn how to align a model with DPO on Arcee Cloud.\n", | ||
"\n", | ||
"In order to run this demo, you need a Starter account on Arcee Cloud. Please see our [pricing](https://www.arcee.ai/pricing) page for details.\n", | ||
"\n", | ||
"The Arcee documentation is available at [docs.arcee.ai](https://docs.arcee.ai/deployment/start-deployment)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Prerequisites\n", | ||
"\n", | ||
"Please [sign up](https://app.arcee.ai/account/signup) to Arcee Cloud and create an [API key](https://docs.arcee.ai/getting-arcee-api-key/getting-arcee-api-key).\n", | ||
"\n", | ||
"Then, please update the cell below with your API key. Remember to keep this key safe, and **DON'T COMMIT IT to one of your repositories**." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%env ARCEE_API_KEY=YOUR_API_KEY" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Create a new Python environment (optional but recommended) and install [arcee-python](https://github.com/arcee-ai/arcee-python)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Uncomment the next three lines to create a virtual environment\n", | ||
"#!pip install -q virtualenv\n", | ||
"#!virtualenv -q arcee-cloud\n", | ||
"#!source arcee-cloud/bin/activate\n", | ||
"\n", | ||
"%pip install -qU arcee-py" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import arcee\n", | ||
"import pprint" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Aligning the model\n", | ||
"\n", | ||
"At the moment, the DPO dataset is not configurable. We use the [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. It consists of 64k prompts, 256k responses from differents LLMs and 380k high-quality feedback provided by GPT-4. \n", | ||
"\n", | ||
"Here, we will run DPO on the [Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model we tuned for instruction following in the Supervised Fine-Tuning (SFT) notebook. You may remember that we used the [reasoning-share-gpt](https://huggingface.co/datasets/arcee-ai/reasoning-sharegpt) dataset.\n", | ||
"\n", | ||
"We could pick any model available on the Hugging Face hub, or a model we've already worked with on Arcee Cloud.\n", | ||
"\n", | ||
"Let's launch the alignment job with the `start_alignment()` API. It should last between 2 and 2.5 hours." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"help(arcee.start_alignment)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"alignment_name = \"llama-3-8B-reasoning-share-gpt-dpo\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"response=arcee.start_alignment(alignment_name=alignment_name,\n", | ||
" #hf_model=\"meta-llama/Meta-Llama-3-8B\",\n", | ||
" alignment_model=\"llama-3-8B-reasoning-share-gpt\",\n", | ||
" alignment_type=\"dpo\",\n", | ||
" full_or_peft=\"peft\"\n", | ||
")\n", | ||
"print(response)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from time import sleep\n", | ||
"\n", | ||
"while True:\n", | ||
" response = arcee.alignment_status(alignment_name)\n", | ||
" if response[\"processing_state\"] == \"processing\":\n", | ||
" print(\"Alignment is in progress. Waiting 15 minutes before checking again.\")\n", | ||
" sleep(900)\n", | ||
" else:\n", | ||
" print(response)\n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Deploying our aligned model\n", | ||
"\n", | ||
"Once alignment is complete, we can deploy and test the aligned model. As part of the Arcee Cloud free tier, this is free of charge and the endpoint will be automatically shut down after 2 hours.\n", | ||
"\n", | ||
"Deployment should take 5-7 minutes. Please see the model deployment sample notebook for details." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"deployment_name = alignment_name" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"response = arcee.start_deployment(deployment_name=deployment_name, alignment=alignment_name)\n", | ||
"\n", | ||
"while True:\n", | ||
" response = arcee.deployment_status(deployment_name)\n", | ||
" if response[\"deployment_processing_state\"] == \"pending\":\n", | ||
" print(\"Deployment is in progress. Waiting 60 seconds before checking again.\")\n", | ||
" sleep(60)\n", | ||
" else:\n", | ||
" print(response)\n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Once the model endpoint is up and running, we can prompt the model with a domain-specific question." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#query = \"Is Pluto a planet? Use markdown.\"\n", | ||
"query = \"I was supposed to fly to NYC but my connecting flight was cancelled. I'm now stuck in Omaha, Nebraska and it's 8PM. I have a meeting in Manhattan tomorrow at 10AM. What is my best option? Use markdown.\"\n", | ||
"\n", | ||
"response = arcee.generate(deployment_name=deployment_name, query=query)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from IPython.display import display, Markdown\n", | ||
"\n", | ||
"display(Markdown(response[\"text\"]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Stopping our deployment\n", | ||
"\n", | ||
"When we're done working with our model, we should stop the deployment to save resources and avoid unwanted charges.\n", | ||
"\n", | ||
"The `stop_deployment()` API only requires the deployment name." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"arcee.stop_deployment(deployment_name=deployment_name)\n", | ||
"arcee.deployment_status(deployment_name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This concludes the model alignment demonstration. Thank you for your time!\n", | ||
"\n", | ||
"If you'd like to know more about using Arcee Cloud in your organization, please visit the [Arcee website](https://www.arcee.ai), or contact [[email protected]](mailto:[email protected]).\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |