diff --git a/MaxText/configs/models/mixtral-8x22b-mlperf.yml b/MaxText/configs/models/mixtral-8x22b-mlperf.yml index d4684496c..2bf6b351a 100644 --- a/MaxText/configs/models/mixtral-8x22b-mlperf.yml +++ b/MaxText/configs/models/mixtral-8x22b-mlperf.yml @@ -31,6 +31,6 @@ num_experts: 8 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mistral" -dataset_path: "gs://mlperf-llm-public2" -dataset_name: "c4/en:3.0.4" -eval_dataset_name: "c4/en:3.0.4" \ No newline at end of file +dataset_path: "gs://maxtext-dataset" +dataset_name: "c4/en:3.0.8" +eval_dataset_name: "c4/en:3.0.9" diff --git a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py index 4717227b9..2de324f66 100644 --- a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py @@ -231,14 +231,25 @@ def get_dataset( return ds +def convert_dtype(x, new_type, columns): + for col in columns: + x[col] = tf.cast(x[col], new_type) + return x + + def format_fn(x, eos_id: int = 1, pad_id: int = 0): """Format function for c4_mlperf.""" x["inputs"] = x["targets"] + if "targets_position" not in x: + x['targets_position'] = tf.convert_to_tensor(np.arange(len(x['targets']), dtype=np.int32)) x["inputs_position"] = x["targets_position"] x["targets"] = _shift_left_and_pad(x["targets"], eos_id) - x["inputs_segmentation"] = tf.where( - tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), x["targets_segmentation"], 0 - ) + if "targets_segmentation" not in x: + x["inputs_segmentation"] = tf.where(tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), 1, 0) + else: + x["inputs_segmentation"] = tf.where( + tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), x["targets_segmentation"], 0 + ) x["targets_segmentation"] = x["inputs_segmentation"] return x @@ -250,16 +261,21 @@ def preprocess_train_dataset( max_target_length: int, shuffle_buffer_size: int, data_shuffle_seed: int, + is_tokenized_dataset: bool = False, ) -> tf.data.Dataset: """Preprocess the training dataset.""" - train_ds = train_ds.map( - lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)), num_parallel_calls=AUTOTUNE - ) + if not is_tokenized_dataset: + train_ds = train_ds.map( + lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)), num_parallel_calls=AUTOTUNE + ) - train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096) - train_ds = split_tokens_to_targets_length(train_ds, max_target_length) - train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) - train_ds = sequence_packing.pack_dataset(train_ds, max_target_length) + train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096) + train_ds = split_tokens_to_targets_length(train_ds, max_target_length) + train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) + train_ds = sequence_packing.pack_dataset(train_ds, max_target_length) + else: + train_ds = train_ds.map(lambda x : convert_dtype(x, tf.int32, ['targets'])) + train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) train_ds = train_ds.map(format_fn, num_parallel_calls=AUTOTUNE) train_ds = train_ds.batch(train_global_batch_size_to_load // jax.process_count(), drop_remainder=True) @@ -285,8 +301,10 @@ def preprocess_eval_dataset( # to avoid padding tokens inserted in group text eval_ds = reduce_concat_tokens(eval_ds, feature_key="targets", batch_size=24567) eval_ds = split_tokens_to_targets_length(eval_ds, max_target_length) + eval_ds = sequence_packing.pack_dataset(eval_ds, max_target_length) - eval_ds = sequence_packing.pack_dataset(eval_ds, max_target_length) + else: + eval_ds = eval_ds.map(lambda x : convert_dtype(x, tf.int32, ['targets'])) eval_ds = eval_ds.map(format_fn, num_parallel_calls=AUTOTUNE) @@ -310,6 +328,11 @@ def make_c4_mlperf_train_iterator( process_indices, ): """Make train iterator of customized C4 dataset for mlperf gpt3 training.""" + if config.dataset_name == "c4/en:3.0.8": + is_tokenized_dataset = True + else: + is_tokenized_dataset = False + train_ds = get_dataset( dataset_name=config.dataset_name, split="train2", @@ -318,9 +341,13 @@ def make_c4_mlperf_train_iterator( enable_data_shuffling=config.enable_data_shuffling, data_shuffle_seed=config.data_shuffle_seed, ) - train_ds = rekey(train_ds, {"inputs": None, "targets": "text"}) - sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos) + if not is_tokenized_dataset: + train_ds = rekey(train_ds, {"inputs": None, "targets": "text"}) + sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos) + else: + sp_tokenizer = None + train_ds = preprocess_train_dataset( train_ds, sp_tokenizer=sp_tokenizer, @@ -328,7 +355,9 @@ def make_c4_mlperf_train_iterator( max_target_length=config.max_target_length, shuffle_buffer_size=128, data_shuffle_seed=config.data_shuffle_seed, + is_tokenized_dataset=is_tokenized_dataset ) + train_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(train_ds, global_mesh) return train_multihost_gen @@ -339,23 +368,23 @@ def make_c4_mlperf_eval_iterator( process_indices, ): """Make eval iterator of customized C4 dataset for mlperf gpt3 training.""" - if config.eval_dataset_name == "c4/en:3.0.5": + if config.eval_dataset_name == "c4/en:3.0.9": is_tokenized_dataset = True elif config.eval_dataset_name == "c4/en:3.0.4": is_tokenized_dataset = False else: - raise ValueError(f"{config.eval_dataset_name=} should be one of ('c4/en:3.0.4', 'c4/en:3.0.5')") + raise ValueError(f"{config.eval_dataset_name=} should be one of ('c4/en:3.0.4', 'c4/en:3.0.6')") if is_tokenized_dataset: eval_ds = get_dataset( dataset_name=config.eval_dataset_name, - split="validation_tokenized_5662seqs", + split="validation_tokenized_388seq", dataloading_host_index=process_indices.index(jax.process_index()), dataloading_host_count=len(process_indices), enable_data_shuffling=False, ) - # note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length - # mainly to avoid eval sequences change depending on the number of hosts - eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"}) + sp_tokenizer = None + # note validation_tokenized_388seq split is pre tokenized, reduce_concated and split to length of 32768 + # mainly to avoid eval sequences change depending on the number of hosts else: eval_ds = get_dataset( dataset_name=config.eval_dataset_name, @@ -367,8 +396,7 @@ def make_c4_mlperf_eval_iterator( eval_ds = rekey(eval_ds, {"inputs": None, "targets": "text"}) - sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos) - + sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos) eval_ds = preprocess_eval_dataset( eval_ds,