Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Update config classes to match old configurations #26

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions vidur/config/device_sku_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ class BaseDeviceSKUConfig(BaseFixedConfig):


@dataclass
class A100DeviceSKUConfig(BaseDeviceSKUConfig):
fp16_tflops: int = 312
total_memory_gb: int = 80
class A40DeviceSKUConfig(BaseDeviceSKUConfig):
fp16_tflops: int = 150
total_memory_gb: int = 45

@staticmethod
def get_type():
return DeviceSKUType.A40


@dataclass
class A40DeviceSKUConfig(BaseDeviceSKUConfig):
fp16_tflops: int = 150
total_memory_gb: int = 45
class A100DeviceSKUConfig(BaseDeviceSKUConfig):
fp16_tflops: int = 312
total_memory_gb: int = 80

@staticmethod
def get_type():
Expand Down
35 changes: 29 additions & 6 deletions vidur/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BaseModelConfig(BaseFixedConfig):
post_attn_norm: bool
vocab_size: int
is_neox_style: Optional[bool] = True
rope_theta: Optional[int] = None
rope_theta: Optional[float] = None
rope_scaling: Optional[Dict[str, Any]] = None
partial_rotary_factor: float = 1.0
no_tensor_parallel: bool = False
Expand All @@ -41,7 +41,7 @@ class Llama2ModelConfig(BaseModelConfig):
post_attn_norm: bool = True
vocab_size: int = 32768
is_neox_style: Optional[bool] = True
rope_theta: Optional[int] = 10000.0
rope_theta: Optional[float] = 10000
rope_scaling: Optional[Dict[str, Any]] = None
partial_rotary_factor: float = 1.0
no_tensor_parallel: bool = False
Expand All @@ -58,6 +58,7 @@ class CodeLlama34BModelConfig(Llama2ModelConfig):
num_kv_heads: int = 8
embedding_dim: int = 8192
mlp_hidden_dim: int = 22016
rope_theta: Optional[float] = 1000000

@staticmethod
def get_name():
Expand All @@ -71,6 +72,7 @@ class Llama2_7BModelConfig(Llama2ModelConfig):
num_kv_heads: int = 32
embedding_dim: int = 4096
mlp_hidden_dim: int = 11008
max_position_embeddings: int = 4096

@staticmethod
def get_name():
Expand All @@ -84,6 +86,7 @@ class Llama2_70BModelConfig(Llama2ModelConfig):
num_kv_heads: int = 8
embedding_dim: int = 8192
mlp_hidden_dim: int = 28672
max_position_embeddings: int = 4096

@staticmethod
def get_name():
Expand All @@ -98,7 +101,7 @@ class Llama3_8BModelConfig(Llama2ModelConfig):
embedding_dim: int = 4096
mlp_hidden_dim: int = 14336
max_position_embeddings: int = 4096
rope_theta: Optional[int] = 500000.0
rope_theta: Optional[float] = 500000
vocab_size: int = 128256

@staticmethod
Expand All @@ -114,14 +117,33 @@ class Llama3_70BModelConfig(Llama2ModelConfig):
embedding_dim: int = 8192
mlp_hidden_dim: int = 28672
max_position_embeddings: int = 8192
rope_theta: Optional[int] = 500000.0
rope_theta: Optional[float] = 500000
vocab_size: int = 128256

@staticmethod
def get_name():
return "meta-llama/Meta-Llama-3-70B"


@dataclass
class InternLMModelConfig(Llama2ModelConfig):
max_position_embeddings: int = 4096
vocab_size: int = 103168


@dataclass
class InternLM_20BModelConfig(InternLMModelConfig):
num_layers: int = 60
num_q_heads: int = 40
num_kv_heads: int = 40
embedding_dim: int = 5120
mlp_hidden_dim: int = 13824

@staticmethod
def get_name():
return "internlm/internlm-20b"


@dataclass
class InternLM2ModelConfig(Llama2ModelConfig):
max_position_embeddings: int = 32768
Expand All @@ -135,6 +157,7 @@ class InternLM2_20BModelConfig(InternLM2ModelConfig):
num_kv_heads: int = 8
embedding_dim: int = 6144
mlp_hidden_dim: int = 16384
rope_theta: Optional[float] = 1000000

@staticmethod
def get_name():
Expand All @@ -157,10 +180,9 @@ class Phi2ModelConfig(Llama2ModelConfig):
post_attn_norm: bool = False
vocab_size: int = 51200
rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[int] = 10000.0
rope_theta: Optional[float] = 10000
partial_rotary_factor: float = 0.4
no_tensor_parallel: bool = True
is_neox_style: bool = True

@staticmethod
def get_name():
Expand All @@ -185,6 +207,7 @@ class Qwen72BModelConfig(QwenModelConfig):
num_kv_heads: int = 64
embedding_dim: int = 8192
mlp_hidden_dim: int = 24576
rope_theta: Optional[float] = 1000000

@staticmethod
def get_name():
Expand Down
4 changes: 2 additions & 2 deletions vidur/config/node_sku_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_type():
@dataclass
class A100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig):
device_sku_type: DeviceSKUType = DeviceSKUType.A100
num_devices_per_node: int = 8
num_devices_per_node: int = 4

@staticmethod
def get_type():
Expand All @@ -35,7 +35,7 @@ def get_type():
@dataclass
class H100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig):
device_sku_type: DeviceSKUType = DeviceSKUType.H100
num_devices_per_node: int = 8
num_devices_per_node: int = 4

@staticmethod
def get_type():
Expand Down
Loading