Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mlperf to use mixtral-v1 tokenized dataset to avoid dropping in total_weights when splitting data across multiple host #1246

Open
wants to merge 2 commits into
base: new_mlperf_pipeline
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions MaxText/configs/models/mixtral-8x22b-mlperf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ num_experts: 8
num_experts_per_tok: 2
rope_max_timescale: 1_000_000
decoder_block: "mistral"
dataset_path: "gs://mlperf-llm-public2"
dataset_name: "c4/en:3.0.4"
eval_dataset_name: "c4/en:3.0.4"
dataset_path: "gs://maxtext-dataset"
dataset_name: "c4/en:3.0.8"
eval_dataset_name: "c4/en:3.0.9"
70 changes: 49 additions & 21 deletions MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,25 @@ def get_dataset(
return ds


def convert_dtype(x, new_type, columns):
for col in columns:
x[col] = tf.cast(x[col], new_type)
return x


def format_fn(x, eos_id: int = 1, pad_id: int = 0):
"""Format function for c4_mlperf."""
x["inputs"] = x["targets"]
if "targets_position" not in x:
x['targets_position'] = tf.convert_to_tensor(np.arange(len(x['targets']), dtype=np.int32))
x["inputs_position"] = x["targets_position"]
x["targets"] = _shift_left_and_pad(x["targets"], eos_id)
x["inputs_segmentation"] = tf.where(
tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), x["targets_segmentation"], 0
)
if "targets_segmentation" not in x:
x["inputs_segmentation"] = tf.where(tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), 1, 0)
else:
x["inputs_segmentation"] = tf.where(
tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), x["targets_segmentation"], 0
)
x["targets_segmentation"] = x["inputs_segmentation"]
return x

Expand All @@ -250,16 +261,21 @@ def preprocess_train_dataset(
max_target_length: int,
shuffle_buffer_size: int,
data_shuffle_seed: int,
is_tokenized_dataset: bool = False,
) -> tf.data.Dataset:
"""Preprocess the training dataset."""
train_ds = train_ds.map(
lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)), num_parallel_calls=AUTOTUNE
)
if not is_tokenized_dataset:
train_ds = train_ds.map(
lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)), num_parallel_calls=AUTOTUNE
)

train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096)
train_ds = split_tokens_to_targets_length(train_ds, max_target_length)
train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)
train_ds = sequence_packing.pack_dataset(train_ds, max_target_length)
train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096)
train_ds = split_tokens_to_targets_length(train_ds, max_target_length)
train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)
train_ds = sequence_packing.pack_dataset(train_ds, max_target_length)
else:
train_ds = train_ds.map(lambda x : convert_dtype(x, tf.int32, ['targets']))
train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)

train_ds = train_ds.map(format_fn, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.batch(train_global_batch_size_to_load // jax.process_count(), drop_remainder=True)
Expand All @@ -285,8 +301,10 @@ def preprocess_eval_dataset(
# to avoid padding tokens inserted in group text
eval_ds = reduce_concat_tokens(eval_ds, feature_key="targets", batch_size=24567)
eval_ds = split_tokens_to_targets_length(eval_ds, max_target_length)
eval_ds = sequence_packing.pack_dataset(eval_ds, max_target_length)

eval_ds = sequence_packing.pack_dataset(eval_ds, max_target_length)
else:
eval_ds = eval_ds.map(lambda x : convert_dtype(x, tf.int32, ['targets']))

eval_ds = eval_ds.map(format_fn, num_parallel_calls=AUTOTUNE)

Expand All @@ -310,6 +328,11 @@ def make_c4_mlperf_train_iterator(
process_indices,
):
"""Make train iterator of customized C4 dataset for mlperf gpt3 training."""
if config.dataset_name == "c4/en:3.0.8":
is_tokenized_dataset = True
else:
is_tokenized_dataset = False

train_ds = get_dataset(
dataset_name=config.dataset_name,
split="train2",
Expand All @@ -318,17 +341,23 @@ def make_c4_mlperf_train_iterator(
enable_data_shuffling=config.enable_data_shuffling,
data_shuffle_seed=config.data_shuffle_seed,
)
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})

sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos)
if not is_tokenized_dataset:
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})
sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos)
else:
sp_tokenizer = None

train_ds = preprocess_train_dataset(
train_ds,
sp_tokenizer=sp_tokenizer,
train_global_batch_size_to_load=config.global_batch_size_to_load,
max_target_length=config.max_target_length,
shuffle_buffer_size=128,
data_shuffle_seed=config.data_shuffle_seed,
is_tokenized_dataset=is_tokenized_dataset
)

train_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(train_ds, global_mesh)
return train_multihost_gen

Expand All @@ -339,23 +368,23 @@ def make_c4_mlperf_eval_iterator(
process_indices,
):
"""Make eval iterator of customized C4 dataset for mlperf gpt3 training."""
if config.eval_dataset_name == "c4/en:3.0.5":
if config.eval_dataset_name == "c4/en:3.0.9":
is_tokenized_dataset = True
elif config.eval_dataset_name == "c4/en:3.0.4":
is_tokenized_dataset = False
else:
raise ValueError(f"{config.eval_dataset_name=} should be one of ('c4/en:3.0.4', 'c4/en:3.0.5')")
raise ValueError(f"{config.eval_dataset_name=} should be one of ('c4/en:3.0.4', 'c4/en:3.0.6')")
if is_tokenized_dataset:
eval_ds = get_dataset(
dataset_name=config.eval_dataset_name,
split="validation_tokenized_5662seqs",
split="validation_tokenized_388seq",
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
enable_data_shuffling=False,
)
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"})
sp_tokenizer = None
# note validation_tokenized_388seq split is pre tokenized, reduce_concated and split to length of 32768
# mainly to avoid eval sequences change depending on the number of hosts
else:
eval_ds = get_dataset(
dataset_name=config.eval_dataset_name,
Expand All @@ -367,8 +396,7 @@ def make_c4_mlperf_eval_iterator(

eval_ds = rekey(eval_ds, {"inputs": None, "targets": "text"})

sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos)

sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos)

eval_ds = preprocess_eval_dataset(
eval_ds,
Expand Down
Loading