Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing batch dimension from default layout maps for Gemma and Llama #2035

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions keras_hub/src/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,7 @@ def get_config(self):
return config

@staticmethod
def get_layout_map(
device_mesh,
model_parallel_dim_name="model",
data_parallel_dim_name="batch",
):
def get_layout_map(device_mesh, model_parallel_dim_name="model"):
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.

The returned `LayoutMap` contains the sharding spec for the gemma
Expand Down Expand Up @@ -257,8 +253,6 @@ def get_layout_map(
distribution.
model_parallel_dim_name: The axis name of the device mesh, where
the weights should be partition on.
data_parallel_dim_name: The axis name of the device mesh, where
the data should be partition on.
Return:
`keras.distribution.LayoutMap` that contains the sharding spec
for all the model weights.
Expand Down Expand Up @@ -286,31 +280,25 @@ def get_layout_map(
f"{model_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
if data_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{data_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
# Note that it is possible to further config the mesh to be 3D, eg
# (data, seq, model). We leave it as 2D for now for simplicity.
data_dim = data_parallel_dim_name
model_dim = model_parallel_dim_name
# The sharding config is based on the Gemma team training config.
# See https://arxiv.org/abs/2403.08295
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map["decoder_block.*attention.*(query|key|value).kernel"] = (
model_dim,
data_dim,
None,
None,
)
layout_map["decoder_block.*attention_output.kernel"] = (
model_dim,
None,
data_dim,
None,
)
layout_map["decoder_block.*ffw_gating.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_gating_2.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, data_dim)
layout_map["decoder_block.*ffw_gating.kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_gating_2.kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, None)

return layout_map
24 changes: 7 additions & 17 deletions keras_hub/src/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,7 @@ def get_config(self):
return config

@staticmethod
def get_layout_map(
device_mesh,
model_parallel_dim_name="model",
data_parallel_dim_name="batch",
):
def get_layout_map(device_mesh, model_parallel_dim_name="model"):
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.

The returned `LayoutMap` contains the sharding spec for the Llama
Expand Down Expand Up @@ -260,44 +256,38 @@ def get_layout_map(
f"{model_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
if data_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{data_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
# Note that it is possible to further config the mesh to be 3D, eg
# (data, seq, model). We leave it as 2D for now for simplicity.
data_dim = data_parallel_dim_name
model_dim = model_parallel_dim_name
# The sharding config is based on the Gemma team training config.
# See https://arxiv.org/abs/2403.08295
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map[
"transformer_layer.*self_attention.*(query|key|value).kernel"
] = (
model_dim,
data_dim,
None,
None,
)
layout_map["transformer_layer.*attention_output.kernel"] = (
model_dim,
None,
data_dim,
None,
)
layout_map[
"transformer_layer.*feedforward_intermediate_dense.kernel"
] = (
data_dim,
None,
model_dim,
)
layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = (
data_dim,
None,
model_dim,
)
layout_map["transformer_layer.*feedforward_output_dense.kernel"] = (
model_dim,
data_dim,
None,
)

return layout_map
Loading