Skip to content

Commit

Permalink
Refactor out shared LSTM/GGNN training loop.
Browse files Browse the repository at this point in the history
github.com//issues/69
  • Loading branch information
ChrisCummins committed Aug 30, 2020
1 parent 9c4c442 commit 9b06bd6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 89 deletions.
49 changes: 48 additions & 1 deletion programl/task/dataflow/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import warnings
from typing import Tuple

from labm8.py import app, pbutil
from labm8.py import app, humanize, pbutil
from sklearn.exceptions import UndefinedMetricWarning

from programl.proto import checkpoint_pb2, epoch_pb2
Expand Down Expand Up @@ -208,3 +208,50 @@ def CreateLoggingDirectories(
(log_dir / "checkpoints").mkdir()
(log_dir / "graph_loader").mkdir()
return log_dir


def run_training_loop(log_dir, epochs, start_epoch_step, model):
for (
epoch_step,
(train_graph_count, train_graph_cumsum, train_batches),
) in enumerate(epochs, start=start_epoch_step):
start_time = time.time()
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"

train_results = model.RunBatches(
epoch_pb2.TRAIN,
train_batches,
log_prefix=f"Train to {hr_graph_cumsum}",
total_graph_count=train_graph_count,
)
val_results = model.RunBatches(
epoch_pb2.VAL,
val_batches.batches,
log_prefix=f"Val at {hr_graph_cumsum}",
total_graph_count=FLAGS.val_graph_count,
)

# Write the epoch to file as an epoch list. This may seem redundant since
# epoch list contains a single item, but it means that we can easily
# concatenate a sequence of these epoch protos to produce a valid epoch
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
epoch = epoch_pb2.EpochList(
epoch=[
epoch_pb2.Epoch(
walltime_seconds=time.time() - start_time,
epoch_num=epoch_step,
train_results=train_results,
val_results=val_results,
)
]
)
print(epoch, end="")

epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
pbutil.ToFile(epoch, epoch_path)
app.Log(1, "Wrote %s", epoch_path)

checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)

return log_dir
45 changes: 1 addition & 44 deletions programl/task/dataflow/ggnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,50 +173,7 @@ def TrainDataflowGGNN(
)
)

for (
epoch_step,
(train_graph_count, train_graph_cumsum, train_batches),
) in enumerate(epochs, start=start_epoch_step):
start_time = time.time()
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"

train_results = model.RunBatches(
epoch_pb2.TRAIN,
train_batches,
log_prefix=f"Train to {hr_graph_cumsum}",
total_graph_count=train_graph_count,
)
val_results = model.RunBatches(
epoch_pb2.VAL,
val_batches.batches,
log_prefix=f"Val at {hr_graph_cumsum}",
total_graph_count=val_graph_count,
)

# Write the epoch to file as an epoch list. This may seem redundant since
# epoch list contains a single item, but it means that we can easily
# concatenate a sequence of these epoch protos to produce a valid epoch
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
epoch = epoch_pb2.EpochList(
epoch=[
epoch_pb2.Epoch(
walltime_seconds=time.time() - start_time,
epoch_num=epoch_step,
train_results=train_results,
val_results=val_results,
)
]
)
print(epoch, end="")

epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
pbutil.ToFile(epoch, epoch_path)
app.Log(1, "Wrote %s", epoch_path)

checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)

return log_dir
return dataflow.run_training_loop(log_dir, epochs, start_epoch_step, model)


def TestDataflowGGNN(
Expand Down
45 changes: 1 addition & 44 deletions programl/task/dataflow/train_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,50 +160,7 @@ def TrainDataflowLSTM(
)
)

for (
epoch_step,
(train_graph_count, train_graph_cumsum, train_batches),
) in enumerate(epochs, start=start_epoch_step):
start_time = time.time()
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"

train_results = model.RunBatches(
epoch_pb2.TRAIN,
train_batches,
log_prefix=f"Train to {hr_graph_cumsum}",
total_graph_count=train_graph_count,
)
val_results = model.RunBatches(
epoch_pb2.VAL,
val_batches.batches,
log_prefix=f"Val at {hr_graph_cumsum}",
total_graph_count=FLAGS.val_graph_count,
)

# Write the epoch to file as an epoch list. This may seem redundant since
# epoch list contains a single item, but it means that we can easily
# concatenate a sequence of these epoch protos to produce a valid epoch
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
epoch = epoch_pb2.EpochList(
epoch=[
epoch_pb2.Epoch(
walltime_seconds=time.time() - start_time,
epoch_num=epoch_step,
train_results=train_results,
val_results=val_results,
)
]
)
print(epoch, end="")

epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
pbutil.ToFile(epoch, epoch_path)
app.Log(1, "Wrote %s", epoch_path)

checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)

return log_dir
return dataflow.run_training_loop(log_dir, epochs, start_epoch_step, model)


def TestDataflowLSTM(
Expand Down

0 comments on commit 9b06bd6

Please sign in to comment.