Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add megablocks support for MLP MoE #2

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 77 additions & 37 deletions smoe/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import inspect
import math
import warnings
from packaging import version
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import stk
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.checkpoint
from packaging import version
from megablocks.layers.arguments import Arguments as MBArgs
from megablocks.layers.dmoe import ParallelDroplessMLP
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import (
Expand Down Expand Up @@ -1725,6 +1726,19 @@ def forward(self, hidden_states):
return current_hidden_states


class ParallelDroplessMLPWithoutLBLSaving(ParallelDroplessMLP):
def forward(self, x, expert_weights, top_experts):
in_shape = x.size()
# Compute the experts.
x, _ = self.forward_fn(x, expert_weights, top_experts)
x = x.view(in_shape)
if self.bias is not None:
if self.args.return_bias:
return x, self.bias
return x + self.bias
return x


MISTRAL_ATTENTION_CLASSES = {
"eager": MixtralAttention,
"flash_attention_2": MixtralFlashAttention2,
Expand Down Expand Up @@ -1774,6 +1788,27 @@ def __init__(self, config):
for _ in range(self.num_experts)
] # 🔍
)
elif self.moe_type == "megablocks":
config: MixtralConfig
is_fp16 = self.gate.weight.data.dtype == torch.float16
is_bf16 = self.gate.weight.data.dtype == torch.bfloat16
mb_args = MBArgs(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_layers=config.num_hidden_layers,
bias=False,
return_bias=False,
activation_fn=nn.SiLU(),
moe_num_experts=config.num_local_experts,
moe_top_k=config.num_experts_per_tok,
memory_optimized_mlp=False,
mlp_type='glu',
mlp_impl='sparse',
fp16=is_fp16,
bf16=is_bf16,
device=torch.cuda.current_device(),
)
self.experts = ParallelDroplessMLPWithoutLBLSaving(mb_args)
else:
raise NotImplementedError(f"Unsupported moe_type: {self.moe_type}")

Expand All @@ -1790,45 +1825,50 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
if self.moe_type == "modulelist":
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

if (
top_x.shape[0] == 0 and not self.training
): # skip during training will lead to asynchrony among different GPUs and blocks the training!
continue
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)

# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * (
routing_weights[top_x_list, idx_list, None] * self.scale_factor
)
if (
top_x.shape[0] == 0 and not self.training
): # skip during training will lead to asynchrony among different GPUs and blocks the training!
continue

# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * (
routing_weights[top_x_list, idx_list, None] * self.scale_factor
)

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
elif self.moe_type == "megablocks":
final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts)
else:
raise NotImplementedError(f"Unsupported moe_type: {self.moe_type}")

final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
Expand Down
240 changes: 240 additions & 0 deletions smoe/utils/expert_construction/convert_llama_to_mixtral_mb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""
Convert the original llama weights into mixtral weights with megablocks support.
"""

import math
import os.path
import re
import shutil
from collections import defaultdict
from pathlib import Path

import torch
from safetensors import safe_open
from safetensors.torch import save_file
from torch.nn import init
from transformers.modeling_utils import dtype_byte_size

from smoe.models.mixtral.configuration_mixtral import MixtralConfig
from smoe.models.mixtral.modeling_mixtral import MixtralForCausalLM
from smoe.utils.io import dump_json, load_json


def is_safetensors_file(filepath):
if isinstance(filepath, str):
filepath = Path(filepath)
string = filepath.name
return re.match(r"model-\d{5}-of-\d{5}.safetensors", string) is not None


FFN_TYPE_MAP = {
"modulelist": {
"gate": "w1",
"down": "w2",
"up": "w3",
},
"megablocks": {
# mlp.(w1|v1|w2)
"gate": "w1", # (ffn_hidden_size x num_experts) x hsz
"up": "v1", # (ffn_hidden_size x num_experts) x hsz
"down": "w2", # (ffn_hidden_size x num_experts) x hsz
},
}


def convert_safetensors(
model_dir,
dump_dir,
num_experts: int,
top_k: int,
scale_factor: float = 1.0,
num_moe_contract_layers: int = 0,
moe_type: str = "modulelist",
neuron_indices: dict = None,
gate_weights: dict = None,
):
# fmt: off
model_folder = Path(model_dir)
dump_folder = Path(dump_dir)
dump_folder.mkdir(parents=True, exist_ok=True)
ffn_type_map = FFN_TYPE_MAP[moe_type]

raw_total_size = -1
tensor_filepaths = []
for filepath in model_folder.glob("*"):
if not os.path.isdir(filepath):
if is_safetensors_file(filepath):
tensor_filepaths.append(filepath)
if filepath.name == "config.json":
config = MixtralConfig.from_pretrained(filepath)
config.architectures = ["MixtralForCausalLM"]
config.num_experts_per_tok = top_k
config.num_local_experts = num_experts
config.router_aux_loss_coef = 1e-2
config.scale_factor = scale_factor
config.moe_type = moe_type
config.num_moe_contract_layers=num_moe_contract_layers
config.intermediate_size = config.intermediate_size // num_experts
config.auto_map = {
"AutoConfig": "configuration_mixtral.MixtralConfig",
"AutoModel": "modeling_mixtral.MixtralModel",
"AutoModelForCausalLM": "modeling_mixtral.MixtralForCausalLM",
}
config.save_pretrained(dump_folder)
for filename in [
"configuration_mixtral.py",
"modeling_mixtral.py",
]:
shutil.copy2(f"smoe/models/mixtral/{filename}", dump_folder / filename)
(dump_folder / "__init__.py").touch()
elif filepath.name == "model.safetensors.index.json":
raw_total_size = load_json(filepath)["metadata"]["total_size"]
else:
# cp to dump_dir
shutil.copy2(filepath, dump_folder / filepath.name)

router_records = set()
weight_map = {}
total_size = 0
total_gate_size = 0
visited_layers = set()
for fi, filepath in enumerate(tensor_filepaths):
with safe_open(filepath, framework="pt", device="cpu") as f:
tensors = {}
contained_layers = set()
for key in f.keys():
tensor = f.get_tensor(key)
if ".mlp." in key:
# preparation
layer_idx, ffn_type = re.search(
r"model.layers.(\d+).mlp.(gate|up|down)_proj.weight", key
).groups()
layer_idx = int(layer_idx)

is_moe = (layer_idx >= num_moe_contract_layers) and (layer_idx < config.num_hidden_layers - num_moe_contract_layers)

if is_moe:
contained_layers.add(layer_idx)

if ffn_type == "down":
hsz, mid = tensor.shape
mid_idx = 1
else:
mid, hsz = tensor.shape
mid_idx = 0

# initialize gate weights
if layer_idx not in router_records:
if gate_weights is None: # use newly initialized gate weights
gate_weight = torch.zeros(num_experts, hsz)
init.kaiming_uniform_(gate_weight, a=math.sqrt(5))
tensors[
f"model.layers.{layer_idx}.block_sparse_moe.gate.weight"
] = gate_weight
else: # use provided gate weights
print(f"Initializing layer {layer_idx} gate weights using {gate_weights[layer_idx]}...")
tensors[
f"model.layers.{layer_idx}.block_sparse_moe.gate.weight"
] = gate_weights[layer_idx].clone()
router_records.add(layer_idx)
new_ffn_type = ffn_type_map[ffn_type]

# initialize expert weights
if moe_type == "modulelist":
expert_size = mid // num_experts
for expert_idx in range(num_experts):
if mid_idx == 0:
if neuron_indices is None: # sequential split
expert_tensor = tensor[expert_idx * expert_size: (expert_idx + 1) * expert_size].clone()
else: # split according to the given indices
this_layer_indices: list = neuron_indices[layer_idx]
print(f"Initializing layer {layer_idx} expert {expert_idx} {ffn_type} using neurons with indices {this_layer_indices[expert_idx]}...")
expert_tensor = tensor[this_layer_indices[expert_idx]].clone()
else:
if neuron_indices is None: # sequential split
expert_tensor = tensor[:, expert_idx * expert_size: (expert_idx + 1) * expert_size].clone()
else: # split according to the given indices
this_layer_indices: list = neuron_indices[layer_idx]
print(f"Initializing layer {layer_idx} expert {expert_idx} {ffn_type} using neurons with indices {this_layer_indices[expert_idx]}...")
expert_tensor = tensor[:, this_layer_indices[expert_idx]].clone()
tensors[
f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.{new_ffn_type}.weight"
] = expert_tensor

elif moe_type == "megablocks":
expert_size = mid // num_experts
tname = f"model.layers.{layer_idx}.block_sparse_moe.experts.mlp.{new_ffn_type}"
if mid_idx == 0:
# up & gate
tensors[tname] = tensor
else:
# down
tensors[tname] = tensor.t()

else:
raise NotImplementedError

else:
tensors[key] = tensor

else:
tensors[key] = tensor

for key in tensors:
tensors[key] = tensors[key].contiguous()
save_file(tensors, dump_folder / filepath.name, metadata={"format": "pt"})
for key, tensor in tensors.items():
weight_size = tensor.numel() * dtype_byte_size(tensor.dtype)
total_size += weight_size
weight_map[key] = filepath.name
if ".block_sparse_moe.gate." in key:
total_gate_size += weight_size
print(key, tensor.shape)

metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
dump_json(index, dump_folder / "model.safetensors.index.json", indent=2)
assert total_size - total_gate_size == raw_total_size


if __name__ == "__main__":
num_experts = 8
top_k = 2

# src_model_dir = "/mnt/petrelfs/share_data/quxiaoye/models/Meta-Llama-3-8B"
src_model_dir = "/mnt/petrelfs/share_data/quxiaoye/models/Meta-Llama-3-8B-Instruct"
# tgt_model_dir_prefix = f"/mnt/petrelfs/share_data/quxiaoye/llama_moe_v2/converted_models/split-sequential-Top{top_k}"


# moe_type = "modulelist"
# tgt_model_dir_prefix = "/mnt/petrelfs/share_data/quxiaoye/llama_moe_v2/converted_models/tzhu_mixtral_mb/ml_8top2"

moe_type = "megablocks"
tgt_model_dir_prefix = "/mnt/petrelfs/share_data/quxiaoye/llama_moe_v2/converted_models/tzhu_mixtral_mb/mb_8top2"

neuron_indices_file = ""
gate_weights_file = ""

print(f"converting {moe_type}")
convert_safetensors(
src_model_dir,
f"{tgt_model_dir_prefix}",
num_experts=num_experts,
top_k=top_k,
moe_type=moe_type,
neuron_indices=None
if neuron_indices_file == ""
else torch.load(neuron_indices_file),
gate_weights=None
if gate_weights_file == ""
else torch.load(gate_weights_file),
)

print(f"testing {moe_type}")
m = MixtralForCausalLM.from_pretrained(f"{tgt_model_dir_prefix}", torch_dtype=torch.bfloat16)

print(f"Re-saving {moe_type}")
m.save_pretrained(f"{tgt_model_dir_prefix}")

print("Done")
# fmt: on