From 0cadcfb2e9b9a8ff99094404153f82eed72f9ba8 Mon Sep 17 00:00:00 2001 From: Ziyi Lin Date: Wed, 29 Nov 2023 16:24:44 +0800 Subject: [PATCH 1/3] memory opt experimental commit --- accessory/demos/multi_turn_mm_box.py | 16 ++- accessory/model/LLM/llama_ens5.py | 78 +++++++------- accessory/util/tensor_parallel.py | 148 ++++++++++++++++++++++----- 3 files changed, 171 insertions(+), 71 deletions(-) diff --git a/accessory/demos/multi_turn_mm_box.py b/accessory/demos/multi_turn_mm_box.py index d92f7b8c..7bde525e 100644 --- a/accessory/demos/multi_turn_mm_box.py +++ b/accessory/demos/multi_turn_mm_box.py @@ -22,7 +22,6 @@ from accessory.data.conversation.lib import conv_templates, SeparatorStyle from PIL import Image, ImageDraw from accessory.data.transform import get_transform -from segment_anything import sam_model_registry, SamPredictor import regex as re @@ -278,7 +277,7 @@ def draw_box_mask_on_image(img: Image, l_name_box_color, predictor): if edge_s_name in key_point_cache and edge_t_name in key_point_cache: draw.line([key_point_cache[edge_s_name], key_point_cache[edge_t_name]], fill="green", width=3) - if len(boxes) > 0: + if len(boxes) > 0 and predictor is not None: img_mask = img.copy() img_array = np.array(img) predictor.set_image(img_array) @@ -312,8 +311,12 @@ def gradio_worker( of Web UI to be after the start of the model. """ - sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").cuda() - sam_predictor = SamPredictor(sam) + if args.disable_sam: + sam_predictor = None + else: + from segment_anything import sam_model_registry, SamPredictor + sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").cuda() + sam_predictor = SamPredictor(sam) def show_user_input(msg, chatbot, chatbox_display): return "", chatbot + [[msg, None]], chatbox_display + [[msg, None]] @@ -474,6 +477,11 @@ def clear(): "--bind_all", action="store_true", help="Listen to all addresses on the host." ) + parser.add_argument( + "--disable_sam", action="store_true", + help="Do not create SAM model. This saves some GPU memory but object " + "masks will no longer be predicted." + ) args = parser.parse_args() # check and setup gpu_ids to use diff --git a/accessory/model/LLM/llama_ens5.py b/accessory/model/LLM/llama_ens5.py index 8491b178..3a9dde96 100644 --- a/accessory/model/LLM/llama_ens5.py +++ b/accessory/model/LLM/llama_ens5.py @@ -22,6 +22,7 @@ import accessory from accessory.configs import global_configs +from accessory.util.tensor_type import default_tensor_type if global_configs.USE_FLASH_ATTENTION: from flash_attn import flash_attn_func @@ -275,51 +276,48 @@ def __init__(self, args: ModelArgs, with_visual=False): self.image_words = 0 self.cache_image_words = 0 # for inference if with_visual: + with default_tensor_type(dtype=torch.float32, device="cpu"): + print("build llama model with qformerv2") + if self.args.load_pretrained_visual_encoder: + self.qformer = Blip2Model.from_pretrained( + "Salesforce/blip2-opt-2.7b", torch_dtype=self.norm.weight.dtype + ) + else: + self.qformer = Blip2Model(Blip2Config.from_pretrained( + str(impresources.files(accessory)/'resources/hf/Salesforce/blip2-opt-2.7b/config.json'))) + self.qformer.language_projection = None + self.qformer.language_model = None + + print("build llama model with clip") + if self.args.load_pretrained_visual_encoder: + self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') + else: + self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=None) + self.clip.transformer = None + + print("build llama model with openclip") + if self.args.load_pretrained_visual_encoder: + self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms( + "convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup" + ) + else: + self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms( + "convnext_xxlarge", pretrained=None + ) + self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk + self.openclip_convnext_xxl.head.global_pool = nn.Identity() + self.openclip_convnext_xxl.head.flatten = nn.Identity() + + print("build llama model with dinov2") + if self.args.load_pretrained_visual_encoder: + self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=True) + else: + self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=False) - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float32) - - print("build llama model with qformerv2") - if self.args.load_pretrained_visual_encoder: - self.qformer = Blip2Model.from_pretrained( - "Salesforce/blip2-opt-2.7b", torch_dtype=self.norm.weight.dtype - ) - else: - self.qformer = Blip2Model(Blip2Config.from_pretrained( - str(impresources.files(accessory)/'resources/hf/Salesforce/blip2-opt-2.7b/config.json'))) - self.qformer.language_projection = None - self.qformer.language_model = None self.qformer.to(self.norm.weight) - - print("build llama model with clip") - if self.args.load_pretrained_visual_encoder: - self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') - else: - self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=None) - self.clip.transformer = None self.clip.to(self.norm.weight) - - print("build llama model with openclip") - if self.args.load_pretrained_visual_encoder: - self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms( - "convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup" - ) - else: - self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms( - "convnext_xxlarge", pretrained=None - ) - self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk - self.openclip_convnext_xxl.head.global_pool = nn.Identity() - self.openclip_convnext_xxl.head.flatten = nn.Identity() self.openclip_convnext_xxl.to(self.norm.weight) - - print("build llama model with dinov2") - if self.args.load_pretrained_visual_encoder: - self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=True) - else: - self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=False) self.dinov2_vitg14.to(self.norm.weight) - torch.set_default_dtype(default_dtype) self.qformer_proj = nn.Sequential( nn.Linear(768, args.dim), diff --git a/accessory/util/tensor_parallel.py b/accessory/util/tensor_parallel.py index 7b247ce3..3e51b49d 100644 --- a/accessory/util/tensor_parallel.py +++ b/accessory/util/tensor_parallel.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +import torch.distributed as dist import fairscale.nn.model_parallel.initialize as fs_init from fairscale.nn.model_parallel.layers import ( @@ -107,18 +108,102 @@ def _load_checkpoint_and_merge_ranks( return merged_ckpt -def _load_checkpoint_and_split_rank( +def _load_checkpoint_and_redistribute_general( ckpt_path: str, ckpt_mp_world_size: int, weight_parallel_dim: Dict[str, int], verbose: bool, format: str, ) -> OrderedDict[str, torch.Tensor]: - raise NotImplementedError() + mp_rank = fs_init.get_model_parallel_rank() + mp_world_size = fs_init.get_model_parallel_world_size() + mp_src = fs_init.get_model_parallel_src_rank() + mp_group = fs_init.get_model_parallel_group() + local_ckpt_shards = [] + for shard_id in range(mp_rank, ckpt_mp_world_size, mp_world_size): + local_ckpt_shards.append( + load_tensor_parallel_shard_state_dict( + ckpt_path, format, shard_id, ckpt_mp_world_size + ) + ) -def _load_checkpoint_and_redistribute_general( - ckpt_path: str, ckpt_mp_world_size: int, - weight_parallel_dim: Dict[str, int], verbose: bool, format: str, -) -> OrderedDict[str, torch.Tensor]: - raise NotImplementedError() + resharded_ckpt = OrderedDict() + + if mp_rank == 0: + tensor_keys = sorted(local_ckpt_shards[0].keys()) + tensor_sizes_and_dtypes = [] + for key in tensor_keys: + sharded_tensor = local_ckpt_shards[0][key] + if key not in weight_parallel_dim: + tensor_sizes_and_dtypes.append( + (list(sharded_tensor.shape), sharded_tensor.dtype) + ) + else: + tensor_size = list(sharded_tensor.shape) + tensor_size[weight_parallel_dim[key]] *= ckpt_mp_world_size + tensor_sizes_and_dtypes.append( + (tensor_size, sharded_tensor.dtype) + ) + dist.broadcast_object_list( + [tensor_keys, tensor_sizes_and_dtypes], mp_src, mp_group + ) + else: + recv_objs = [None, None] + dist.broadcast_object_list(recv_objs, mp_src, mp_group) + tensor_keys, tensor_sizes_and_dtypes = recv_objs + + for key, (size, dtype) in zip(tensor_keys, tensor_sizes_and_dtypes): + if key not in weight_parallel_dim: + bcast_tensor = ( + local_ckpt_shards[0][key].cuda() + if mp_rank == 0 else + torch.empty(size, dtype=dtype, device="cuda") + ) + dist.broadcast(bcast_tensor, mp_src, mp_group) + bcast_tensor = bcast_tensor.cpu() + resharded_ckpt[key] = bcast_tensor + + else: + reshard_dim = weight_parallel_dim[key] + merged_tensor = torch.empty(size, dtype=dtype, device="cuda") + assert merged_tensor.size(reshard_dim) % mp_world_size == 0 + loader = ShardedTensorLoader(target=merged_tensor, + num_shards=ckpt_mp_world_size, + shard_dim=reshard_dim) + + sharded_size = size.copy() + sharded_size[reshard_dim] //= ckpt_mp_world_size + + for round_idx in range( + (ckpt_mp_world_size - 1) // mp_world_size + 1 + ): + round_st = round_idx * mp_world_size + all_gather_list = [ + torch.empty( + [int(round_st + i < ckpt_mp_world_size), + *sharded_size], + dtype=dtype, device="cuda" + ) + for i in range(mp_world_size) + ] + send_tensor = ( + local_ckpt_shards[round_idx][key].cuda().unsqueeze(0) + if round_idx < len(local_ckpt_shards) else + torch.empty([0, *sharded_size], dtype=dtype, device="cuda") + ) + assert send_tensor.size() == all_gather_list[mp_rank].size() + dist.all_gather(all_gather_list, send_tensor, mp_group) + for shard_id in range( + round_idx * mp_world_size, + min((round_idx + 1) * mp_world_size, ckpt_mp_world_size), + ): + shard_value = all_gather_list[shard_id % mp_world_size][0] + loader.load_shard(shard_id, shard_value) + + assert loader.is_complete() + resharded_ckpt[key] = merged_tensor.chunk( + mp_world_size, reshard_dim + )[mp_rank].cpu() + + return resharded_ckpt def get_tensor_parallel_shards_file_name( @@ -244,10 +329,6 @@ def print_if_verbose(*args, **kwargs): local_state_dict = _load_checkpoint_and_merge_ranks( path, ckpt_mp_world_size, weight_parallel_dim, verbose, format ) - elif mp_world_size % ckpt_mp_world_size == 0: - local_state_dict = _load_checkpoint_and_split_rank( - path, ckpt_mp_world_size, weight_parallel_dim, verbose, format - ) else: local_state_dict = _load_checkpoint_and_redistribute_general( path, ckpt_mp_world_size, weight_parallel_dim, verbose, format @@ -327,12 +408,12 @@ def infer_checkpoint_format_and_mp_size(path: str) -> str: raise NotImplementedError(f"Multiple matched format detected: " f"{inferred_format} and {format}.") if inferred_format is None: - folder_contents = ", ".join( - [x if os.path.isfile(os.path.join(path, x)) else x + " (not a file)" - for x in os.listdir(path)] - ) + folder_contents = ", ".join([ + x if os.path.isfile(os.path.join(path, x)) else x + " (not a file)" + for x in os.listdir(path) + ]) raise NotImplementedError( - f"Files in the given folder do not match any format. " + f"Files in the given folder do not match any format. " f"Contents in the folder: [{folder_contents}]." ) @@ -408,19 +489,19 @@ def load_tensor_parallel_model_list( for debugging purposes). The default is ``False``. Returns: - Tuple[List[str], List[str]]: Returns two lists of strings, the first - being the missing keys and the second being the unexpected keys, - following the same convention as - ``torch.nn.Module.load_state_dict``. A key is deemed missing if it - does not occur in any of the checkpoints in the list, and is deemed - unexpected if it is unexpected to the model and has appeared in any - one of the checkpoints in the list. + Dict: A dictionary with two keys ``missing_keys`` and + ``unexpected_keys``. A key is deemed missing if it does not occur + in any of the checkpoints in the list, and is deemed unexpected if + it is unexpected to the model and has appeared in any one of the + checkpoints in the list. """ - existing_keys, missing_keys, unexpected_keys = set(), set(model.state_dict().keys()), set() + missing_keys = set(model.state_dict().keys()) + existing_keys, unexpected_keys = set(), set() + for i, path in enumerate(path_list): inferred_format, _ = infer_checkpoint_format_and_mp_size(path) print(f"Loading from checkpoint at: {path} ({i + 1} of " - f"{len(path_list)}, format is \"{inferred_format})\"") + f"{len(path_list)}, format is \"{inferred_format}\")") assert i != 0 or not inferred_format.endswith("_diff"), ( "The first checkpoint in the list cannot be a *_diff checkpoint." ) @@ -443,7 +524,10 @@ def load_tensor_parallel_model_list( missing_keys.intersection_update(step_missing_keys) unexpected_keys.update(step_unexpected_keys) - return {"missing_keys": list(missing_keys), "unexpected_keys": list(unexpected_keys)} + return { + "missing_keys": list(missing_keys), + "unexpected_keys": list(unexpected_keys), + } def tensor_load_shard( @@ -456,7 +540,8 @@ def tensor_load_shard( Args: target (``torch.Tensor``): The target tensor to load the values into. - parallel_dim (int): Tensor parallel dimension of the tensor. + parallel_dim (int): Tensor parallel dimension of the tensor. ``-1`` for + replicated tensors. num_shards (int): Number of tensor parallel shards of the value. shard_id (int): The shard id of the current value. value (``torch.Tensor``): The value to be loaded into the target @@ -466,14 +551,22 @@ def tensor_load_shard( If ``add``, the new value is added to the old value. """ assert parallel_dim < target.ndim or parallel_dim == -1 + if parallel_dim >= 0: + assert target.size(parallel_dim) % num_shards == 0 target_slices = [] + expected_shard_shape = [] for i in range(target.ndim): if i == parallel_dim: dim_st = target.size(i) // num_shards * shard_id dim_ed = target.size(i) // num_shards * (shard_id + 1) target_slices.append(slice(dim_st, dim_ed)) + expected_shard_shape.append(dim_ed - dim_st) else: target_slices.append(slice(None)) + expected_shard_shape.append(target.size(i)) + assert value.shape == tuple(expected_shard_shape), ( + f"{value.shape} vs. {expected_shard_shape}" + ) if parallel_dim == -1 and shard_id != 0 and mode in ["set", "add"]: return if mode == "set": @@ -505,6 +598,7 @@ def __init__( loaded into. num_shards (int): Number of expected shards. shard_dim (int): The dimension along which the tensor is sharded. + ``-1`` for replicated tensors. mode (str): Supported options are ``set`` and ``add``. If ``set``, the old value in the target tensor is overrided with the new value. If ``add``, the new value is added to the old value. From 46ed145278758ddb74bb69c4e8b156e1719be409 Mon Sep 17 00:00:00 2001 From: Ziyi Lin Date: Wed, 29 Nov 2023 19:17:56 +0800 Subject: [PATCH 2/3] tensor parallel load_state_dict add synchronization warning --- accessory/util/tensor_parallel.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/accessory/util/tensor_parallel.py b/accessory/util/tensor_parallel.py index 3e51b49d..ac78c6d7 100644 --- a/accessory/util/tensor_parallel.py +++ b/accessory/util/tensor_parallel.py @@ -285,6 +285,10 @@ def load_tensor_parallel_model_state_dict( Returns: OrderedDict[str, torch.Tensor]: The model state_dict local to the model parallel rank of the current process. + + ..note:: + This function is synchronous within the model parallel group and must + be called by all the group workers at the same time. """ def print_if_verbose(*args, **kwargs): if verbose: @@ -361,6 +365,10 @@ def load_tensor_parallel_model( being the missing keys and the second being the unexpected keys, following the same convention as ``torch.nn.Module.load_state_dict``. + + ..note:: + This function is synchronous within the model parallel group and must + be called by all the group workers at the same time. """ assert not format.endswith("_diff"), ( "A *_diff checkpoint must be used together with the corresponding " @@ -494,6 +502,10 @@ def load_tensor_parallel_model_list( in any of the checkpoints in the list, and is deemed unexpected if it is unexpected to the model and has appeared in any one of the checkpoints in the list. + + ..note:: + This function is synchronous within the model parallel group and must + be called by all the group workers at the same time. """ missing_keys = set(model.state_dict().keys()) existing_keys, unexpected_keys = set(), set() From c2a4ccdf5cf6a46f8e53bdf2907988edcdf99a19 Mon Sep 17 00:00:00 2001 From: Ziyi Lin Date: Wed, 29 Nov 2023 19:23:40 +0800 Subject: [PATCH 3/3] move each vision encoder to cuda instantly after creation to reduce peak cpu mem usage --- accessory/model/LLM/llama_ens5.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/accessory/model/LLM/llama_ens5.py b/accessory/model/LLM/llama_ens5.py index 3a9dde96..5077f2e0 100644 --- a/accessory/model/LLM/llama_ens5.py +++ b/accessory/model/LLM/llama_ens5.py @@ -287,14 +287,18 @@ def __init__(self, args: ModelArgs, with_visual=False): str(impresources.files(accessory)/'resources/hf/Salesforce/blip2-opt-2.7b/config.json'))) self.qformer.language_projection = None self.qformer.language_model = None + self.qformer.to(self.norm.weight) + with default_tensor_type(dtype=torch.float32, device="cpu"): print("build llama model with clip") if self.args.load_pretrained_visual_encoder: self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') else: self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=None) self.clip.transformer = None + self.clip.to(self.norm.weight) + with default_tensor_type(dtype=torch.float32, device="cpu"): print("build llama model with openclip") if self.args.load_pretrained_visual_encoder: self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms( @@ -307,16 +311,14 @@ def __init__(self, args: ModelArgs, with_visual=False): self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk self.openclip_convnext_xxl.head.global_pool = nn.Identity() self.openclip_convnext_xxl.head.flatten = nn.Identity() + self.openclip_convnext_xxl.to(self.norm.weight) + with default_tensor_type(dtype=torch.float32, device="cpu"): print("build llama model with dinov2") if self.args.load_pretrained_visual_encoder: self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=True) else: self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", pretrained=False) - - self.qformer.to(self.norm.weight) - self.clip.to(self.norm.weight) - self.openclip_convnext_xxl.to(self.norm.weight) self.dinov2_vitg14.to(self.norm.weight) self.qformer_proj = nn.Sequential(