Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Move metrics logic to MetricLogger class #1315

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 0 additions & 43 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager


import json
import yaml
import flax
from flax.training import train_state
Expand Down Expand Up @@ -120,28 +119,6 @@ def close_summary_writer(summary_writer):
summary_writer.close()


def _prepare_metrics_for_json(metrics, step, run_name):
"""Converts metric dictionary into json supported types (e.g. float)"""
metrics_dict = {}
for val in metrics["scalar"]:
metrics_dict[val] = float(metrics["scalar"][val])
metrics_dict["step"] = float(step)
metrics_dict["run_name"] = run_name
return metrics_dict


def write_metrics_locally(metrics, step, config, file, is_training=True):
"""Writes metrics locally for testing"""
if step == 0:
file.truncate(0)

metrics_dict = _prepare_metrics_for_json(metrics, step, config.run_name)
file.write(str(json.dumps(metrics_dict)) + "\n")

if is_training and step == config.steps - 1:
file.close()


def add_config_to_summary_writer(config, summary_writer):
"""Writes config params to tensorboard"""
if jax.process_index() == 0:
Expand All @@ -155,26 +132,6 @@ def add_text_to_summary_writer(key, value, summary_writer):
summary_writer.add_text(key, value)


def write_metrics_for_gcs(metrics, step, config, running_metrics, is_training=True):
"""Writes metrics to gcs"""
metrics_dict_step = _prepare_metrics_for_json(metrics, step, config.run_name)
running_metrics.append(metrics_dict_step)
if is_training and (step + 1) % config.log_period == 0 or step == config.steps - 1:
start_step = (step // config.log_period) * config.log_period
metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt"
with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs:
for metrics_step in running_metrics:
metrics_for_gcs.write(str(json.dumps(metrics_step)) + "\n")

metrics_for_gcs.close()
gcs_filename = os.path.join(config.metrics_dir, metrics_filename)
max_logging.log(f"Moving file {metrics_filename} to GCS...")
upload_blob(gcs_filename, metrics_filename)
max_logging.log(f"File {metrics_filename} moved successfully!")
running_metrics = [] # reset running_metrics to empty list
return running_metrics


def write_config_raw_keys_for_gcs(raw_keys):
"""Writes config raw keys to GCS"""
if not raw_keys["save_config_to_gcs"] or jax.process_index() != 0:
Expand Down
128 changes: 128 additions & 0 deletions MaxText/metric_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# pylint: disable=bare-except, consider-using-generator
"""Logger that saves metrics to a local file, GCS and TensorBoard. """

import jax
import json
import os
import numpy as np

import max_logging
import max_utils


def _prepare_metrics_for_json(metrics, step, run_name):
"""Converts metric dictionary into json supported types (e.g. float)"""
metrics_dict = {val: float(metrics["scalar"][val]) for val in metrics["scalar"]}
metrics_dict["step"] = float(step)
metrics_dict["run_name"] = run_name
return metrics_dict


class MetricLogger:
"""
Logger for saving metrics to a local file, GCS and TensorBoard.
"""

def __init__(self, writer, config):
self.buffered_step = None
self.buffered_metrics = None
self.writer = writer
self.config = config

def write_metrics(self, running_gcs_metrics, metrics, step, is_training=True):
"""Entry point for all metrics writing in Train's Main.

To avoid introducing an unnecessary dependency, we "double buffer" -- we hold
onto the last metrics and step and only publish when we receive a new metrics and step.
The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
metrics_to_write, steps_to_write = None, None
if is_training:
if self.buffered_metrics is not None:
if self.buffered_step is None:
raise ValueError(f"When writing metrics, {self.buffered_step=} was none")
metrics_to_write = self.buffered_metrics
steps_to_write = self.buffered_step
self.buffered_metrics = metrics
self.buffered_step = step
else:
metrics_to_write = metrics
steps_to_write = step

if metrics_to_write:
if self.config.enable_tensorboard:
self.write_metrics_to_tensorboard(metrics_to_write, steps_to_write, is_training)

if self.config.metrics_file:
self.write_metrics_locally(metrics_to_write, steps_to_write)

if self.config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = self.write_metrics_for_gcs(metrics_to_write, steps_to_write, running_gcs_metrics, is_training)

def write_metrics_locally(self, metrics, step):
"""Writes metrics locally for testing"""
with open(self.config.metrics_file, "a", encoding="utf8") as local_metrics_file:
if step == 0:
local_metrics_file.truncate(0)

metrics_dict = _prepare_metrics_for_json(metrics, step, self.config.run_name)
local_metrics_file.write(str(json.dumps(metrics_dict)) + "\n")

def write_metrics_for_gcs(self, metrics, step, running_metrics, is_training):
"""Writes metrics to gcs"""
metrics_dict_step = _prepare_metrics_for_json(metrics, step, self.config.run_name)
running_metrics.append(metrics_dict_step)
if is_training and (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1:
start_step = (step // self.config.log_period) * self.config.log_period
metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt"
with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs:
for metrics_step in running_metrics:
metrics_for_gcs.write(str(json.dumps(metrics_step)) + "\n")

gcs_filename = os.path.join(self.config.metrics_dir, metrics_filename)
max_logging.log(f"Moving file {metrics_filename} to GCS...")
max_utils.upload_blob(gcs_filename, metrics_filename)
max_logging.log(f"File {metrics_filename} moved successfully!")
running_metrics = [] # reset running_metrics to empty list
return running_metrics

def write_metrics_to_tensorboard(self, metrics, step, is_training):
"""Writes metrics to TensorBoard"""
with jax.spmd_mode("allow_all"):
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
self.writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
self.writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

if is_training:
full_log = step % self.config.log_period == 0

max_logging.log(
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {metrics['scalar']['learning/loss']:.3f}"
)

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={self.config.tensorboard_dir}'")
self.writer.flush()
87 changes: 6 additions & 81 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import pathwaysutils # pylint: disable=unused-import
import tensorflow as tf

from metric_logger import MetricLogger

from vertex_tensorboard import VertexTensorboardManager
# Placeholder: internal

Expand Down Expand Up @@ -119,80 +121,6 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr, per_d
metrics["scalar"].update({"learning/current_learning_rate": lr})


_buffered_step = None
_buffered_metrics = None


def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config, is_training=True):
"""Entry point for all metrics writing in Train's Main.
TODO: would be better as a Class in the future (that initialized all state!)

To avoid introducing an unnecessary dependency, we "double buffer" -- we hold
onto the last metrics and step and only publish when we receive a new metrics and step.
The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
metrics_to_write, steps_to_write = None, None
if is_training:
global _buffered_step, _buffered_metrics
if _buffered_metrics is not None:
if _buffered_step is None:
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
metrics_to_write = _buffered_metrics
steps_to_write = _buffered_step
else:
metrics_to_write = metrics
steps_to_write = step

if metrics_to_write:
if config.enable_tensorboard:
write_metrics_to_tensorboard(writer, metrics_to_write, steps_to_write, config, is_training)

if config.metrics_file:
max_utils.write_metrics_locally(metrics_to_write, steps_to_write, config, local_metrics_file, is_training)

if config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = max_utils.write_metrics_for_gcs(
metrics_to_write, steps_to_write, config, running_gcs_metrics, is_training
)

if is_training:
_buffered_step = step
_buffered_metrics = metrics


def write_metrics_to_tensorboard(writer, metrics, step, config, is_training=True):
"""Writes metrics to tensorboard"""
with jax.spmd_mode("allow_all"):
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

if is_training:
full_log = step % config.log_period == 0

max_logging.log(
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {metrics['scalar']['learning/loss']:.3f}"
)

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
writer.flush()


def clear_buffered_metrics():
global _buffered_step
global _buffered_metrics
_buffered_step = None
_buffered_metrics = None


def save_checkpoint(
checkpoint_manager,
step,
Expand Down Expand Up @@ -855,7 +783,6 @@ def train_loop(config, state=None):
else:
p_eval_step = None

local_metrics_file = open(config.metrics_file, "a", encoding="utf8") if config.metrics_file else None
running_gcs_metrics = [] if config.gcs_metrics else None

start_step = get_first_step(state) # this is the start_step for training
Expand All @@ -877,6 +804,7 @@ def train_loop(config, state=None):
performance_metric_queue = queue.Queue()
gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue)

metric_logger = MetricLogger(writer, config)
for step in np.arange(start_step, config.steps):
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
Expand Down Expand Up @@ -909,7 +837,7 @@ def train_loop(config, state=None):
checkpoint_manager.wait_until_finished()
sys.exit()

write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config)
metric_logger.write_metrics(running_gcs_metrics, metrics, step)

if config.dump_hlo and step == start_step:
jax.block_until_ready(state) # Ensure compilation has finished.
Expand Down Expand Up @@ -954,9 +882,7 @@ def train_loop(config, state=None):
)
if config.use_dpo:
cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = eval_dpo_reward_accuracy / eval_step_count
write_metrics(
writer, local_metrics_file, running_gcs_metrics, cumulative_eval_metrics, step, config, is_training=False
)
metric_logger.write_metrics(running_gcs_metrics, cumulative_eval_metrics, step, is_training=False)
max_logging.log(
f"average loss after {step=}: {eval_step_count=}, {eval_loss=},"
f" total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}"
Expand All @@ -974,10 +900,9 @@ def train_loop(config, state=None):

if checkpoint_manager is not None:
checkpoint_manager.wait_until_finished()
write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
max_utils.close_summary_writer(writer)
record_goodput(recorder, config, recorder.record_job_end_time if recorder else None)
clear_buffered_metrics()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this clear no longer needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is basically setting the global variables to None. I have moved those global variables to class variables instead. AFAIK, python's garbage collector will automatically clear the memory occupied by that class object when program exits. So, this method shouldn't be required.
@gobbleturk - LMK if I am missing anything here.

with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
# pytype: disable=attribute-error
compiled = p_train_step.lower(state, example_batch, nextrng).compile()
Expand Down
Loading