Skip to content

Commit

Permalink
Upgraded distributed tensorflow sample to 2.16 (Azure#3428)
Browse files Browse the repository at this point in the history
* Upgraded distributed tensorflow sample to 2.16
  • Loading branch information
jeff-shepherd authored Oct 29, 2024
1 parent 498445d commit b34bafb
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
2 changes: 1 addition & 1 deletion cli/jobs/single-step/tensorflow/mnist-distributed/job.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ command: >-
inputs:
epochs: 1
model_dir: outputs/keras-model
environment: azureml:AzureML-tensorflow-2.12-cuda11@latest
environment: azureml:AzureML-tensorflow-2.16-cuda11@latest
compute: azureml:gpu-cluster
resources:
instance_count: 2
Expand Down
15 changes: 13 additions & 2 deletions cli/jobs/single-step/tensorflow/mnist-distributed/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ def write_filepath(filepath, task_type, task_id):
return os.path.join(dirpath, base)


def fix_tf_config():
# This is necessary for TensorFlow 2.13 and later
tf_config = json.loads(os.environ["TF_CONFIG"])
if "cluster" in tf_config:
cluster = tf_config["cluster"]
if "ps" in cluster and len(cluster["ps"]) == 0:
cluster.pop("ps")
os.environ["TF_CONFIG"] = json.dumps(tf_config)
return tf_config


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=3)
Expand All @@ -93,10 +104,10 @@ def main():

args = parser.parse_args()

tf_config = json.loads(os.environ["TF_CONFIG"])
tf_config = fix_tf_config()
num_workers = len(tf_config["cluster"]["worker"])

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
strategy = tf.distribute.MultiWorkerMirroredStrategy()

# Here the batch size scales up by number of workers since
# `tf.data.Dataset.batch` expects the global batch size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ def write_filepath(filepath, task_type, task_id):
return os.path.join(dirpath, base)


def fix_tf_config():
# This is necessary for TensorFlow 2.13 and later
tf_config = json.loads(os.environ["TF_CONFIG"])
if "cluster" in tf_config:
cluster = tf_config["cluster"]
if "ps" in cluster and len(cluster["ps"]) == 0:
cluster.pop("ps")
os.environ["TF_CONFIG"] = json.dumps(tf_config)
return tf_config


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=3)
Expand All @@ -93,10 +104,11 @@ def main():

args = parser.parse_args()

tf_config = json.loads(os.environ["TF_CONFIG"])
tf_config = fix_tf_config()

num_workers = len(tf_config["cluster"]["worker"])

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
strategy = tf.distribute.MultiWorkerMirroredStrategy()

# Here the batch size scales up by number of workers since
# `tf.data.Dataset.batch` expects the global batch size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
" code=\"./src\", # local path where the code is stored\n",
" command=\"python main.py --epochs ${{inputs.epochs}} --model-dir ${{inputs.model_dir}}\",\n",
" inputs={\"epochs\": 1, \"model_dir\": \"outputs/keras-model\"},\n",
" environment=\"AzureML-tensorflow-2.12-cuda11@latest\",\n",
" environment=\"AzureML-tensorflow-2.16-cuda11@latest\",\n",
" compute=\"cpu-cluster\",\n",
" instance_count=2,\n",
" # distribution = {\"type\": \"mpi\", \"process_count_per_instance\": 1},\n",
Expand Down

0 comments on commit b34bafb

Please sign in to comment.