Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Further memory optimization of SPHINX series models #118

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions accessory/demos/multi_turn_mm_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from accessory.data.conversation.lib import conv_templates, SeparatorStyle
from PIL import Image, ImageDraw
from accessory.data.transform import get_transform
from segment_anything import sam_model_registry, SamPredictor

import regex as re

Expand Down Expand Up @@ -278,7 +277,7 @@ def draw_box_mask_on_image(img: Image, l_name_box_color, predictor):
if edge_s_name in key_point_cache and edge_t_name in key_point_cache:
draw.line([key_point_cache[edge_s_name], key_point_cache[edge_t_name]], fill="green", width=3)

if len(boxes) > 0:
if len(boxes) > 0 and predictor is not None:
img_mask = img.copy()
img_array = np.array(img)
predictor.set_image(img_array)
Expand Down Expand Up @@ -312,8 +311,12 @@ def gradio_worker(
of Web UI to be after the start of the model.
"""

sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").cuda()
sam_predictor = SamPredictor(sam)
if args.disable_sam:
sam_predictor = None
else:
from segment_anything import sam_model_registry, SamPredictor
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").cuda()
sam_predictor = SamPredictor(sam)

def show_user_input(msg, chatbot, chatbox_display):
return "", chatbot + [[msg, None]], chatbox_display + [[msg, None]]
Expand Down Expand Up @@ -474,6 +477,11 @@ def clear():
"--bind_all", action="store_true",
help="Listen to all addresses on the host."
)
parser.add_argument(
"--disable_sam", action="store_true",
help="Do not create SAM model. This saves some GPU memory but object "
"masks will no longer be predicted."
)
args = parser.parse_args()

# check and setup gpu_ids to use
Expand Down
76 changes: 38 additions & 38 deletions accessory/model/LLM/llama_ens5.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import accessory
from accessory.configs import global_configs
from accessory.util.tensor_type import default_tensor_type
if global_configs.USE_FLASH_ATTENTION:
from flash_attn import flash_attn_func

Expand Down Expand Up @@ -275,51 +276,50 @@ def __init__(self, args: ModelArgs, with_visual=False):
self.image_words = 0
self.cache_image_words = 0 # for inference
if with_visual:

default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)

print("build llama model with qformerv2")
if self.args.load_pretrained_visual_encoder:
self.qformer = Blip2Model.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=self.norm.weight.dtype
)
else:
self.qformer = Blip2Model(Blip2Config.from_pretrained(
str(impresources.files(accessory)/'resources/hf/Salesforce/blip2-opt-2.7b/config.json')))
self.qformer.language_projection = None
self.qformer.language_model = None
with default_tensor_type(dtype=torch.float32, device="cpu"):
print("build llama model with qformerv2")
if self.args.load_pretrained_visual_encoder:
self.qformer = Blip2Model.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=self.norm.weight.dtype
)
else:
self.qformer = Blip2Model(Blip2Config.from_pretrained(
str(impresources.files(accessory)/'resources/hf/Salesforce/blip2-opt-2.7b/config.json')))
self.qformer.language_projection = None
self.qformer.language_model = None
self.qformer.to(self.norm.weight)

print("build llama model with clip")
if self.args.load_pretrained_visual_encoder:
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
else:
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=None)
self.clip.transformer = None
with default_tensor_type(dtype=torch.float32, device="cpu"):
print("build llama model with clip")
if self.args.load_pretrained_visual_encoder:
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
else:
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=None)
self.clip.transformer = None
self.clip.to(self.norm.weight)

print("build llama model with openclip")
if self.args.load_pretrained_visual_encoder:
self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms(
"convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup"
)
else:
self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms(
"convnext_xxlarge", pretrained=None
)
self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk
self.openclip_convnext_xxl.head.global_pool = nn.Identity()
self.openclip_convnext_xxl.head.flatten = nn.Identity()
with default_tensor_type(dtype=torch.float32, device="cpu"):
print("build llama model with openclip")
if self.args.load_pretrained_visual_encoder:
self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms(
"convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup"
)
else:
self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms(
"convnext_xxlarge", pretrained=None
)
self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk
self.openclip_convnext_xxl.head.global_pool = nn.Identity()
self.openclip_convnext_xxl.head.flatten = nn.Identity()
self.openclip_convnext_xxl.to(self.norm.weight)

print("build llama model with dinov2")
if self.args.load_pretrained_visual_encoder:
self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=True)
else:
self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=False)
with default_tensor_type(dtype=torch.float32, device="cpu"):
print("build llama model with dinov2")
if self.args.load_pretrained_visual_encoder:
self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=True)
else:
self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=False)
self.dinov2_vitg14.to(self.norm.weight)
torch.set_default_dtype(default_dtype)

self.qformer_proj = nn.Sequential(
nn.Linear(768, args.dim),
Expand Down
160 changes: 133 additions & 27 deletions accessory/util/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.nn as nn
import torch.distributed as dist

import fairscale.nn.model_parallel.initialize as fs_init
from fairscale.nn.model_parallel.layers import (
Expand Down Expand Up @@ -107,18 +108,102 @@ def _load_checkpoint_and_merge_ranks(
return merged_ckpt


def _load_checkpoint_and_split_rank(
def _load_checkpoint_and_redistribute_general(
ckpt_path: str, ckpt_mp_world_size: int,
weight_parallel_dim: Dict[str, int], verbose: bool, format: str,
) -> OrderedDict[str, torch.Tensor]:
raise NotImplementedError()
mp_rank = fs_init.get_model_parallel_rank()
mp_world_size = fs_init.get_model_parallel_world_size()
mp_src = fs_init.get_model_parallel_src_rank()
mp_group = fs_init.get_model_parallel_group()

local_ckpt_shards = []
for shard_id in range(mp_rank, ckpt_mp_world_size, mp_world_size):
local_ckpt_shards.append(
load_tensor_parallel_shard_state_dict(
ckpt_path, format, shard_id, ckpt_mp_world_size
)
)

def _load_checkpoint_and_redistribute_general(
ckpt_path: str, ckpt_mp_world_size: int,
weight_parallel_dim: Dict[str, int], verbose: bool, format: str,
) -> OrderedDict[str, torch.Tensor]:
raise NotImplementedError()
resharded_ckpt = OrderedDict()

if mp_rank == 0:
tensor_keys = sorted(local_ckpt_shards[0].keys())
tensor_sizes_and_dtypes = []
for key in tensor_keys:
sharded_tensor = local_ckpt_shards[0][key]
if key not in weight_parallel_dim:
tensor_sizes_and_dtypes.append(
(list(sharded_tensor.shape), sharded_tensor.dtype)
)
else:
tensor_size = list(sharded_tensor.shape)
tensor_size[weight_parallel_dim[key]] *= ckpt_mp_world_size
tensor_sizes_and_dtypes.append(
(tensor_size, sharded_tensor.dtype)
)
dist.broadcast_object_list(
[tensor_keys, tensor_sizes_and_dtypes], mp_src, mp_group
)
else:
recv_objs = [None, None]
dist.broadcast_object_list(recv_objs, mp_src, mp_group)
tensor_keys, tensor_sizes_and_dtypes = recv_objs

for key, (size, dtype) in zip(tensor_keys, tensor_sizes_and_dtypes):
if key not in weight_parallel_dim:
bcast_tensor = (
local_ckpt_shards[0][key].cuda()
if mp_rank == 0 else
torch.empty(size, dtype=dtype, device="cuda")
)
dist.broadcast(bcast_tensor, mp_src, mp_group)
bcast_tensor = bcast_tensor.cpu()
resharded_ckpt[key] = bcast_tensor

else:
reshard_dim = weight_parallel_dim[key]
merged_tensor = torch.empty(size, dtype=dtype, device="cuda")
assert merged_tensor.size(reshard_dim) % mp_world_size == 0
loader = ShardedTensorLoader(target=merged_tensor,
num_shards=ckpt_mp_world_size,
shard_dim=reshard_dim)

sharded_size = size.copy()
sharded_size[reshard_dim] //= ckpt_mp_world_size

for round_idx in range(
(ckpt_mp_world_size - 1) // mp_world_size + 1
):
round_st = round_idx * mp_world_size
all_gather_list = [
torch.empty(
[int(round_st + i < ckpt_mp_world_size),
*sharded_size],
dtype=dtype, device="cuda"
)
for i in range(mp_world_size)
]
send_tensor = (
local_ckpt_shards[round_idx][key].cuda().unsqueeze(0)
if round_idx < len(local_ckpt_shards) else
torch.empty([0, *sharded_size], dtype=dtype, device="cuda")
)
assert send_tensor.size() == all_gather_list[mp_rank].size()
dist.all_gather(all_gather_list, send_tensor, mp_group)
for shard_id in range(
round_idx * mp_world_size,
min((round_idx + 1) * mp_world_size, ckpt_mp_world_size),
):
shard_value = all_gather_list[shard_id % mp_world_size][0]
loader.load_shard(shard_id, shard_value)

assert loader.is_complete()
resharded_ckpt[key] = merged_tensor.chunk(
mp_world_size, reshard_dim
)[mp_rank].cpu()

return resharded_ckpt


def get_tensor_parallel_shards_file_name(
Expand Down Expand Up @@ -200,6 +285,10 @@ def load_tensor_parallel_model_state_dict(
Returns:
OrderedDict[str, torch.Tensor]: The model state_dict local to the
model parallel rank of the current process.

..note::
This function is synchronous within the model parallel group and must
be called by all the group workers at the same time.
"""
def print_if_verbose(*args, **kwargs):
if verbose:
Expand Down Expand Up @@ -244,10 +333,6 @@ def print_if_verbose(*args, **kwargs):
local_state_dict = _load_checkpoint_and_merge_ranks(
path, ckpt_mp_world_size, weight_parallel_dim, verbose, format
)
elif mp_world_size % ckpt_mp_world_size == 0:
local_state_dict = _load_checkpoint_and_split_rank(
path, ckpt_mp_world_size, weight_parallel_dim, verbose, format
)
else:
local_state_dict = _load_checkpoint_and_redistribute_general(
path, ckpt_mp_world_size, weight_parallel_dim, verbose, format
Expand Down Expand Up @@ -280,6 +365,10 @@ def load_tensor_parallel_model(
being the missing keys and the second being the unexpected keys,
following the same convention as
``torch.nn.Module.load_state_dict``.

..note::
This function is synchronous within the model parallel group and must
be called by all the group workers at the same time.
"""
assert not format.endswith("_diff"), (
"A *_diff checkpoint must be used together with the corresponding "
Expand Down Expand Up @@ -327,12 +416,12 @@ def infer_checkpoint_format_and_mp_size(path: str) -> str:
raise NotImplementedError(f"Multiple matched format detected: "
f"{inferred_format} and {format}.")
if inferred_format is None:
folder_contents = ", ".join(
[x if os.path.isfile(os.path.join(path, x)) else x + " (not a file)"
for x in os.listdir(path)]
)
folder_contents = ", ".join([
x if os.path.isfile(os.path.join(path, x)) else x + " (not a file)"
for x in os.listdir(path)
])
raise NotImplementedError(
f"Files in the given folder do not match any format. "
f"Files in the given folder do not match any format. "
f"Contents in the folder: [{folder_contents}]."
)

Expand Down Expand Up @@ -408,19 +497,23 @@ def load_tensor_parallel_model_list(
for debugging purposes). The default is ``False``.

Returns:
Tuple[List[str], List[str]]: Returns two lists of strings, the first
being the missing keys and the second being the unexpected keys,
following the same convention as
``torch.nn.Module.load_state_dict``. A key is deemed missing if it
does not occur in any of the checkpoints in the list, and is deemed
unexpected if it is unexpected to the model and has appeared in any
one of the checkpoints in the list.
Dict: A dictionary with two keys ``missing_keys`` and
``unexpected_keys``. A key is deemed missing if it does not occur
in any of the checkpoints in the list, and is deemed unexpected if
it is unexpected to the model and has appeared in any one of the
checkpoints in the list.

..note::
This function is synchronous within the model parallel group and must
be called by all the group workers at the same time.
"""
existing_keys, missing_keys, unexpected_keys = set(), set(model.state_dict().keys()), set()
missing_keys = set(model.state_dict().keys())
existing_keys, unexpected_keys = set(), set()

for i, path in enumerate(path_list):
inferred_format, _ = infer_checkpoint_format_and_mp_size(path)
print(f"Loading from checkpoint at: {path} ({i + 1} of "
f"{len(path_list)}, format is \"{inferred_format})\"")
f"{len(path_list)}, format is \"{inferred_format}\")")
assert i != 0 or not inferred_format.endswith("_diff"), (
"The first checkpoint in the list cannot be a *_diff checkpoint."
)
Expand All @@ -443,7 +536,10 @@ def load_tensor_parallel_model_list(
missing_keys.intersection_update(step_missing_keys)
unexpected_keys.update(step_unexpected_keys)

return {"missing_keys": list(missing_keys), "unexpected_keys": list(unexpected_keys)}
return {
"missing_keys": list(missing_keys),
"unexpected_keys": list(unexpected_keys),
}


def tensor_load_shard(
Expand All @@ -456,7 +552,8 @@ def tensor_load_shard(

Args:
target (``torch.Tensor``): The target tensor to load the values into.
parallel_dim (int): Tensor parallel dimension of the tensor.
parallel_dim (int): Tensor parallel dimension of the tensor. ``-1`` for
replicated tensors.
num_shards (int): Number of tensor parallel shards of the value.
shard_id (int): The shard id of the current value.
value (``torch.Tensor``): The value to be loaded into the target
Expand All @@ -466,14 +563,22 @@ def tensor_load_shard(
If ``add``, the new value is added to the old value.
"""
assert parallel_dim < target.ndim or parallel_dim == -1
if parallel_dim >= 0:
assert target.size(parallel_dim) % num_shards == 0
target_slices = []
expected_shard_shape = []
for i in range(target.ndim):
if i == parallel_dim:
dim_st = target.size(i) // num_shards * shard_id
dim_ed = target.size(i) // num_shards * (shard_id + 1)
target_slices.append(slice(dim_st, dim_ed))
expected_shard_shape.append(dim_ed - dim_st)
else:
target_slices.append(slice(None))
expected_shard_shape.append(target.size(i))
assert value.shape == tuple(expected_shard_shape), (
f"{value.shape} vs. {expected_shard_shape}"
)
if parallel_dim == -1 and shard_id != 0 and mode in ["set", "add"]:
return
if mode == "set":
Expand Down Expand Up @@ -505,6 +610,7 @@ def __init__(
loaded into.
num_shards (int): Number of expected shards.
shard_dim (int): The dimension along which the tensor is sharded.
``-1`` for replicated tensors.
mode (str): Supported options are ``set`` and ``add``. If ``set``,
the old value in the target tensor is overrided with the new
value. If ``add``, the new value is added to the old value.
Expand Down