-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sample notebook: model pretraining (#77)
Co-authored-by: Julien Simon <[email protected]> Co-authored-by: Tyler-Odenthal <[email protected]>
- Loading branch information
1 parent
a1e1048
commit 5df2d26
Showing
2 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,382 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Pretraining a model on Arcee Cloud\n", | ||
"\n", | ||
"In this notebook, you will learn how to run continuous pretraining a model on Arcee Cloud. In this example, we'll train a Llama3-8B model on the Energy domain.\n", | ||
"\n", | ||
"In order to run this demo, you need a Starter account on Arcee Cloud. Please see our [pricing](https://www.arcee.ai/pricing) page for details.\n", | ||
"\n", | ||
"The Arcee documentation is available at [docs.arcee.ai](https://docs.arcee.ai/deployment/start-deployment)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Prerequisites\n", | ||
"\n", | ||
"Please [sign up](https://app.arcee.ai/account/signup) to Arcee Cloud and create an [API key](https://docs.arcee.ai/getting-arcee-api-key/getting-arcee-api-key).\n", | ||
"\n", | ||
"Then, please update the cell below with your API key. Remember to keep this key safe, and **DON'T COMMIT IT to one of your repositories**." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%env ARCEE_API_KEY=YOUR_API_KEY" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Create a new Python environment (optional but recommended) and install [arcee-python](https://github.com/arcee-ai/arcee-python)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Uncomment the next three lines to create a virtual environment\n", | ||
"#!pip install -q virtualenv\n", | ||
"#!virtualenv -q arcee-cloud\n", | ||
"#!source arcee-cloud/bin/activate\n", | ||
"\n", | ||
"%pip install -q arcee-py" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import arcee\n", | ||
"from IPython.display import Image" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Preparing our dataset\n", | ||
"\n", | ||
"We need a dataset that holds the appropriate domain knowledge on the Energy domain. Arcee Cloud can ingest data in a variety of formats, like PDF, JSON, XML, TXT, HTML, and CSV. Please check the [documentation](https://docs.arcee.ai/continuous-pretraining/upload-pretraining-data) for an up-to-date list of supported formats.\n", | ||
"\n", | ||
"\n", | ||
"We assembled a collection of about 300 PDF reports from the [International Energy Agency]((https://www.iea.org/analysis?type=report)) and the [Energy Reports](https://www.sciencedirect.com/journal/energy-reports) journal. The total size of the dataset is 1.5GB and 16 million tokens. Please note that this is probably too small for efficient pretraining. For real-life applications, we recommend using at least 100 million tokens.\n", | ||
"\n", | ||
"For convenience, we have stored the dataset in this Google drive [folder](https://drive.google.com/drive/folders/1DX5hIuVfykHqz2gwLTu4MR9R6TTAxiEO?usp=sharing). However, please note that Arcee Cloud requires training datasets to be stored in Amazon S3, so we also uploaded the dataset to a \"customer\" bucket defined below. You will be able to use this bucket to run the rest of this notebook, but you won't be able to list its content. In real-life, you would of course use your own S3 bucket." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataset_bucket_name = \"juliensimon-datasets\"\n", | ||
"dataset_name = \"energy-pdf\"\n", | ||
"dataset_s3_uri=f\"s3://{dataset_bucket_name}/{dataset_name}\"\n", | ||
"print(f\"Dataset S3 URI: {dataset_s3_uri}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The training code in Arcee Cloud runs in one of Arcee's AWS accounts. \n", | ||
"\n", | ||
"We need to allow this account to access the data stored in the bucket above (which is attached to a different AWS account). \n", | ||
"\n", | ||
"This setup is called \"cross-account access\" and it requires adding a policy to the bucket, allowing the Arcee account to read the data it stores. \n", | ||
"\n", | ||
"You'll find more information about cross-account access and bucket policies in the [AWS documentation](https://docs.aws.amazon.com/AmazonS3/latest/userguide/example-walkthroughs-managing-access-example2.html). \n", | ||
"\n", | ||
"If you're unfamiliar with the process, or don't have the AWS permissions required, please contact your AWS administrator." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Here is the bucket policy applied to the \"customer\" bucket. \n", | ||
"\n", | ||
"It gives Arcee's AWS account `812782781539` read and list permission on the \"customer\" bucket. Working with your bucket, you would need to update the `Resource` section with your bucket and prefixes. Then, you would either apply this bucket policy to your bucket, using either the AWS console or one of the AWS SDKs.\n", | ||
" \n", | ||
" \n", | ||
" import boto3\n", | ||
" import json\n", | ||
"\n", | ||
" bucket_policy = {\n", | ||
" \"Version\": \"2012-10-17\",\n", | ||
" \"Statement\": [\n", | ||
" {\n", | ||
" \"Effect\": \"Allow\",\n", | ||
" \"Principal\": {\n", | ||
" \"AWS\": \"arn:aws:iam::812782781539:root\"\n", | ||
" },\n", | ||
" \"Action\": [\n", | ||
" \"s3:GetBucketLocation\",\n", | ||
" \"s3:ListBucket\",\n", | ||
" \"s3:GetObject\",\n", | ||
" \"s3:GetObjectAttributes\",\n", | ||
" \"s3:GetObjectTagging\"\n", | ||
" ],\n", | ||
" \"Resource\": [\n", | ||
" \"arn:aws:s3:::juliensimon-datasets\",\n", | ||
" \"arn:aws:s3:::juliensimon-datasets/*\"\n", | ||
" ]\n", | ||
" },\n", | ||
" ]\n", | ||
" }\n", | ||
"\n", | ||
" policy_string = json.dumps(bucket_policy)\n", | ||
"\n", | ||
" boto3.client('s3').put_bucket_policy(Bucket=\"juliensimon-datasets\", Policy=policy_string)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Uploading our dataset\n", | ||
"\n", | ||
"Now that Arcee Cloud can read the training dataset, let's upload it with the `upload_corpus_folder()` API." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"help(arcee.upload_corpus_folder)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_name = \"meta-llama/Meta-Llama-3-8B\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"response = arcee.upload_corpus_folder(\n", | ||
" corpus=dataset_name,\n", | ||
" s3_folder_url=dataset_s3_uri,\n", | ||
" tokenizer_name=model_name,\n", | ||
" block_size=8192 # see max_position_embeddings in https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from time import sleep\n", | ||
"\n", | ||
"while True:\n", | ||
" response = arcee.corpus_status(dataset_name)\n", | ||
" if response[\"processing_state\"] == \"processing\":\n", | ||
" print(\"Upload is in progress. Waiting 30 seconds before checking again.\")\n", | ||
" sleep(30)\n", | ||
" else:\n", | ||
" print(response)\n", | ||
" break\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Pretraining our model\n", | ||
"\n", | ||
"Once the dataset has been uploaded, we can launch training with the `start_pretraining()` API." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"help(arcee.start_pretraining)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pretraining_name=f\"{model_name}-{dataset_name}\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"response = arcee.start_pretraining(\n", | ||
" pretraining_name=pretraining_name,\n", | ||
" corpus=dataset_name,\n", | ||
" base_model=model_name\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In the Arcee Cloud console, we can see the training job has started. After a few minutes, you should see the training loss decreasing, indicating that the model is learning how to correctly predict the tokens present in your dataset." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"Image(\"model_pretraining_01.png\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Deploying our trained model\n", | ||
"\n", | ||
"Once training is complete, we can deploy and test the pretrained model. The model hasn't been aligned, so chances are it's not going to generate anything really useful. However, we should still check that the model is able to generate properly.\n", | ||
"\n", | ||
"As part of the Arcee Cloud free tier, model deployment is free of charge and the endpoint will be automatically shut down after 2 hours.\n", | ||
"\n", | ||
"Deployment should take 5-7 minutes. Please see the model deployment sample notebook for details." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"deployment_name = f\"{model_name}-{dataset_name}\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"response = arcee.start_deployment(deployment_name=deployment_name, pretraining=pretraining_name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"while True:\n", | ||
" response = arcee.deployment_status(deployment_name)\n", | ||
" if response[\"deployment_processing_state\"] == \"pending\":\n", | ||
" print(\"Deployment is in progress. Waiting 60 seconds before checking again.\")\n", | ||
" sleep(60)\n", | ||
" else:\n", | ||
" print(response)\n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Once the model endpoint is up and running, we can prompt the model with a domain-specific question." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"query = \"Is solar a good way to achieve net zero?\"\n", | ||
"\n", | ||
"response = arcee.generate(deployment_name=deployment_name, query=query)\n", | ||
"print(response[\"text\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Stopping our deployment\n", | ||
"\n", | ||
"When we're done working with our model, we should stop the deployment to save resources and avoid unwanted charges.\n", | ||
"\n", | ||
"The `stop_deployment()` API only requires the deployment name." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"arcee.stop_deployment(deployment_name=deployment_name)\n", | ||
"arcee.deployment_status(deployment_name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This concludes the model pretraining demonstration. Thank you for your time!\n", | ||
"\n", | ||
"If you'd like to know more about using Arcee Cloud in your organization, please visit the [Arcee website](https://www.arcee.ai), or contact [[email protected]](mailto:[email protected]).\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.