Skip to content

Commit 025cb00

Browse files
committed
feat: acclerate for big model
1 parent 7868ea7 commit 025cb00

File tree

7 files changed

+233
-6
lines changed

7 files changed

+233
-6
lines changed

opengpt/factory.py

+11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Union
33

44
import torch
5+
from loguru import logger
56

67

78
def list_models():
@@ -35,6 +36,8 @@ def create_model_and_transforms(
3536
# TODO: Add support for loading config based on model name
3637
model_config = {}
3738

39+
logger.debug(f'Loading model: {model_name}')
40+
3841
if model_name == 'OpenFlamingo-9B':
3942
from .models.flamingo.loading import load_model_and_transforms
4043

@@ -44,5 +47,13 @@ def create_model_and_transforms(
4447
'tokenizer_name_or_path': 'llama_7B',
4548
}
4649
return load_model_and_transforms(**model_config)
50+
elif model_name.startswith('facebook/llama'):
51+
from .models.llama.loading import load_model_and_tokenizer
52+
53+
model_config = {
54+
'model_name_or_path': 'llama_7B',
55+
'tokenizer_name_or_path': 'llama_7B',
56+
}
57+
return load_model_and_tokenizer(**model_config)
4758
else:
4859
raise ValueError(f'Unknown model name: {model_name}')

opengpt/helper.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
1+
import sys
2+
3+
from loguru import logger
4+
5+
6+
def setup_logging(debug: bool):
7+
"""
8+
Setup the log formatter for AnnLite.
9+
"""
10+
11+
log_level = 'INFO'
12+
if debug:
13+
log_level = 'DEBUG'
14+
15+
logger.remove()
16+
logger.add(
17+
sys.stdout,
18+
colorize=True,
19+
level=log_level,
20+
)
21+
22+
123
def get_envs():
224
from torch.utils import collect_env
325

opengpt/models/flamingo/loading.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from open_flamingo.src.flamingo_lm import FlamingoLMMixin
33
from open_flamingo.src.utils import extend_instance
44

5-
from ..hf_model import load_model_and_tokenizer
5+
from ..llama.loading import load_model_and_tokenizer
66
from .modeling import FlamingoModel
77

88

opengpt/models/llama/loading.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,37 @@
1-
from transformers import AutoModelForCausalLM, AutoTokenizer
1+
from typing import TYPE_CHECKING, Union
22

3+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
4+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
35

4-
def hf_load_model_and_tokenizer(model_name_or_path: str, tokenizer_name_or_path: str):
6+
if TYPE_CHECKING:
7+
import torch
8+
9+
from loguru import logger
10+
11+
12+
def load_model_and_tokenizer(
13+
model_name_or_path: str,
14+
tokenizer_name_or_path: str,
15+
dtype: Union[str, 'torch.dtype'] = 'torch.float16',
16+
**kwargs
17+
):
518
"""Load a model and tokenizer from HuggingFace."""
619
tokenizer = AutoTokenizer.from_pretrained(
7-
tokenizer_name_or_path, local_files_only=False
20+
tokenizer_name_or_path, local_files_only=True
821
)
9-
model = AutoModelForCausalLM.from_pretrained(
10-
model_name_or_path, local_files_only=False
22+
23+
# Create a model and initialize it with empty weights
24+
config = AutoConfig.from_pretrained(model_name_or_path, local_files_only=True)
25+
26+
with init_empty_weights():
27+
model = AutoModelForCausalLM.from_config(config)
28+
29+
# Load the checkpoint and dispatch it to the right devices
30+
model = load_checkpoint_and_dispatch(
31+
model, model_name_or_path, device_map="auto", dtype=dtype, **kwargs
1132
)
33+
34+
# model = AutoModelForCausalLM.from_pretrained(
35+
# model_name_or_path, local_files_only=False
36+
# )
1237
return model, tokenizer

opengpt/profile.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""This file contains a few functions to profile the memory usage of the model.
2+
3+
It is not meant to be used in production, but rather to help us debug the memory usage of the model.
4+
5+
The codes are borrowed from https://github.com/huggingface/accelerate/blob/main/benchmarks/measures_util.py
6+
"""
7+
8+
import gc
9+
import threading
10+
import time
11+
12+
import psutil
13+
import torch
14+
from accelerate.utils import compute_module_sizes
15+
16+
17+
class PeakCPUMemory:
18+
def __init__(self):
19+
self.process = psutil.Process()
20+
self.peak_monitoring = False
21+
22+
def peak_monitor(self):
23+
self.cpu_memory_peak = -1
24+
25+
while True:
26+
self.cpu_memory_peak = max(
27+
self.process.memory_info().rss, self.cpu_memory_peak
28+
)
29+
30+
# can't sleep or will not catch the peak right (this comment is here on purpose)
31+
if not self.peak_monitoring:
32+
break
33+
34+
def start(self):
35+
self.peak_monitoring = True
36+
self.thread = threading.Thread(target=self.peak_monitor)
37+
self.thread.daemon = True
38+
self.thread.start()
39+
40+
def stop(self):
41+
self.peak_monitoring = False
42+
self.thread.join()
43+
return self.cpu_memory_peak
44+
45+
46+
cpu_peak_tracker = PeakCPUMemory()
47+
48+
49+
def start_measure():
50+
# Time
51+
measures = {"time": time.time()}
52+
53+
gc.collect()
54+
torch.cuda.empty_cache()
55+
56+
# CPU mem
57+
measures["cpu"] = psutil.Process().memory_info().rss
58+
cpu_peak_tracker.start()
59+
60+
# GPU mem
61+
for i in range(torch.cuda.device_count()):
62+
measures[str(i)] = torch.cuda.memory_allocated(i)
63+
torch.cuda.reset_peak_memory_stats()
64+
65+
return measures
66+
67+
68+
def end_measure(start_measures):
69+
# Time
70+
measures = {"time": time.time() - start_measures["time"]}
71+
72+
gc.collect()
73+
torch.cuda.empty_cache()
74+
75+
# CPU mem
76+
measures["cpu"] = (
77+
psutil.Process().memory_info().rss - start_measures["cpu"]
78+
) / 2**20
79+
measures["cpu-peak"] = (cpu_peak_tracker.stop() - start_measures["cpu"]) / 2**20
80+
81+
# GPU mem
82+
for i in range(torch.cuda.device_count()):
83+
measures[str(i)] = (
84+
torch.cuda.memory_allocated(i) - start_measures[str(i)]
85+
) / 2**20
86+
measures[f"{i}-peak"] = (
87+
torch.cuda.max_memory_allocated(i) - start_measures[str(i)]
88+
) / 2**20
89+
90+
return measures
91+
92+
93+
def log_measures(measures, description):
94+
print(f"{description}:")
95+
print(f"- Time: {measures['time']:.2f}s")
96+
for i in range(torch.cuda.device_count()):
97+
print(f"- GPU {i} allocated: {measures[str(i)]:.2f}MiB")
98+
peak = measures[f"{i}-peak"]
99+
print(f"- GPU {i} peak: {peak:.2f}MiB")
100+
print(f"- CPU RAM allocated: {measures['cpu']:.2f}MiB")
101+
print(f"- CPU RAM peak: {measures['cpu-peak']:.2f}MiB")

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ build-backend = "poetry.core.masonry.api"
3939
# Compatible Python versions
4040
python = ">=3.8"
4141
torch = ">=1.9.0,<2.0.0" # a meta device requires torch >= 1.9.0
42+
loguru = "^0.5"
4243
click = "^8.1.3"
4344
numpy = "^1.21.2"
4445
einops = "^0.6.0"

scripts/upload_to_s3.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# pip install boto3 hf_transfer
2+
3+
import os
4+
from math import ceil
5+
from time import time
6+
7+
import boto3
8+
from hf_transfer import multipart_upload
9+
10+
# 10 MiB
11+
CHUNK_SIZE = 10_485_760
12+
13+
s3 = boto3.client("s3")
14+
15+
bucket = "test-hf-transfer-multi-part-upload"
16+
bucket_key = "some_file"
17+
18+
upload = s3.create_multipart_upload(
19+
ACL="bucket-owner-full-control",
20+
Bucket=bucket,
21+
Key=bucket_key,
22+
)
23+
upload_id = upload["UploadId"]
24+
print("created multipart upload")
25+
26+
file_name = "some_file"
27+
file_size = os.stat(file_name).st_size
28+
29+
urls = []
30+
nb_parts = ceil(file_size / CHUNK_SIZE)
31+
for part_number in range(1, nb_parts + 1):
32+
params = {
33+
"Bucket": bucket,
34+
"Key": bucket_key,
35+
"PartNumber": part_number,
36+
"UploadId": upload_id,
37+
}
38+
urls.append(
39+
s3.generate_presigned_url(
40+
ClientMethod="upload_part", Params=params, ExpiresIn=86400
41+
)
42+
)
43+
print("prepared parts urls")
44+
45+
print("uploading parts...")
46+
start = time()
47+
responses = multipart_upload(
48+
file_path=file_name,
49+
parts_urls=urls,
50+
chunk_size=CHUNK_SIZE,
51+
max_files=64,
52+
parallel_failures=63,
53+
max_retries=5,
54+
)
55+
print(f"uploaded parts in {time() - start}")
56+
57+
etag_with_parts = []
58+
for part_number, header in enumerate(responses):
59+
etag = header.get("etag")
60+
etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1})
61+
62+
parts = {"Parts": etag_with_parts}
63+
64+
s3.complete_multipart_upload(
65+
Bucket=bucket, Key=bucket_key, MultipartUpload=parts, UploadId=upload_id
66+
)
67+
print("upload complete")

0 commit comments

Comments
 (0)