Skip to content

Commit

Permalink
v0.2 release (#7)
Browse files Browse the repository at this point in the history
* add new requirements

* v0.2 code dump

* add colab requirements file

* add wandb to requirements_colab.txt

* division by zero fix for device memory profiling

* update README

* change default of save_dir
  • Loading branch information
mar-muel authored Oct 27, 2022
1 parent 2ec896a commit 7e9f3da
Show file tree
Hide file tree
Showing 18 changed files with 1,100 additions and 485 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# vim
*.swp
*.swo

# other
generated_images
logging
Expand All @@ -169,3 +173,6 @@ data
train.sh
wandb
test_*.py
profile
jax_cache
mem.prof
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,36 @@ Specifically, the features we've added allow for better scaling of [StyleGAN2](h

</details>

## 🏗 Changelog
<details>
<summary>v0.2</summary>

* Better support for class-conditional training, adding per-class moving average statistics to generator
* Training data can now be split into multiple tfrecord files (can be either in `--data_dir` or in a subdirectory `tfrecords`). Still requires `dataset_info.json` in `--data_dir` location (containing `width`, `heigh`, `num_examples`, and list of `classes` if class-conditional).
* Renaming arg `--load_from_pkl` => `--load_from_ckpt`
* Added `--num_steps` argument to specify a fixed number of steps to run
* Added `--early_stopping_after_steps` argument to stop after n steps of no FID improvement
* Removal of `--bf16` flag and consolidation with `--mixed_precision`.
* Allow layer freezing with `--freeze_g` and `--freeze_d` arguments
* Add `--fmap_max` argument, in order to have better control over feature map dimensions
* Allow disabling of generator and discriminator regularization
* Change checkpointing behaviour from saving every 2k steps to saving every 10k steps and keeping 2 best checkpoints (see `--save_every` and `--keep_n_checkpoints`)
* Add `--metric_cache_location` in order to cache dataset statistics (currently for FID only)
* Log TPU memory usage, shoutout to ayaka14732 for help (see also https://github.com/ayaka14732/jax-smi)
* Visualise model architecture & parameters on startup
* Improve W&B logging (e.g. adding eval snapshots with fixed latents)
* Experimental: Add jax profiling

</details>
<details>
<summary>v0.1</summary>

* Enable training on TPUs
* Google Cloud Storage (GCS) integration
* Several quality-of-life improvements

</details>

## 🧑‍🔧 Install
1. Clone the repository:
```sh
Expand Down
Binary file added bin/pprof
Binary file not shown.
34 changes: 23 additions & 11 deletions checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import builtins
from jax._src.lib import xla_client
import tensorflow as tf
import logging


logger = logging.getLogger(__name__)


# Hack: this is the module reported by this object.
Expand All @@ -24,7 +28,7 @@ def pickle_load(filename):
return pickled


def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep=2):
def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep_best=2, is_best=False):
"""
Saves checkpoint.
Expand All @@ -38,7 +42,8 @@ def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, s
step (int): Current step.
epoch (int): Current epoch.
fid_score (float): FID score corresponding to the checkpoint.
keep (int): Number of checkpoints to keep.
keep_best (int): Number of best checkpoints to keep.
is_best (bool): Whether this is a new best model
"""
state_dict = {'state_G': flax.jax_utils.unreplicate(state_G),
'state_D': flax.jax_utils.unreplicate(state_D),
Expand All @@ -49,15 +54,22 @@ def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, s
'step': step,
'epoch': epoch}

pickle_dump(state_dict, os.path.join(ckpt_dir, f'ckpt_{step}.pickle'))
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
if len(ckpts) > keep:
modified_times = {}
for ckpt in ckpts:
stats = tf.io.gfile.stat(ckpt)
modified_times[ckpt] = stats.mtime_nsec
oldest_ckpt = sorted(modified_times, key=modified_times.get)[0]
tf.io.gfile.remove(oldest_ckpt)
if is_best:
f_name = f'ckpt_{step}_best.pickle'
else:
f_name = f'ckpt_{step}.pickle'
f_path = os.path.join(ckpt_dir, f_name)
logger.info(f'Saving checkpoint for step {step:,} to {f_path}')
pickle_dump(state_dict, f_path)
if is_best:
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*_best.pickle'))
if len(ckpts) > keep_best:
modified_times = {}
for ckpt in ckpts:
stats = tf.io.gfile.stat(ckpt)
modified_times[ckpt] = stats.mtime_nsec
oldest_ckpt = sorted(modified_times, key=modified_times.get)[0]
tf.io.gfile.remove(oldest_ckpt)


def load_checkpoint(filename):
Expand Down
30 changes: 19 additions & 11 deletions data_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import flax
import numpy as np
from PIL import Image
import os
from typing import Sequence
from tqdm import tqdm
import json
from tqdm import tqdm
import logging

logger = logging.getLogger(__name__)
Expand All @@ -24,7 +19,7 @@ def prefetch(dataset, n_prefetch):
return ds_iter


def get_data(data_dir, img_size, img_channels, num_classes, num_local_devices, batch_size, shuffle_buffer=1000):
def get_data(data_dir, img_size, img_channels, num_classes, num_local_devices, batch_size, allow_resolution_mismatch=False, shuffle_buffer=1000):
"""
Args:
Expand All @@ -51,16 +46,13 @@ def pre_process(serialized_example):
height = tf.cast(example['height'], dtype=tf.int64)
width = tf.cast(example['width'], dtype=tf.int64)
channels = tf.cast(example['channels'], dtype=tf.int64)

image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
image = tf.reshape(image, shape=[height, width, channels])

image = tf.cast(image, dtype='float32')

image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
image = tf.image.random_flip_left_right(image)

image = (image - 127.5) / 127.5

label = tf.one_hot(example['label'], num_classes)
return {'image': image, 'label': label}

Expand All @@ -75,7 +67,23 @@ def shard(data):
with tf.io.gfile.GFile(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
dataset_info = json.load(fin)

ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
# check resolution mismatch
if not allow_resolution_mismatch:
if 'width' in dataset_info and 'height' in dataset_info:
msg = 'Requested resolution {img_size} is different from input data {input_size}.' \
' Provide the flag --allow_resolution_mismatch in order to allow this behaviour.'
assert dataset_info['width'] == img_size, msg.format(img_size=img_size, input_size=dataset_info['width'])
assert dataset_info['height'] == img_size, msg.format(img_size=img_size, input_size=dataset_info['height'])
else:
raise Exception(f'dataset_info.json does not contain keys "height" or "width". Ignore by providing --allow_resolution_mismatch.')

for folder in [data_dir, os.path.join(data_dir, 'tfrecords')]:
ckpt_files = tf.io.gfile.glob(os.path.join(folder, '*.tfrecords'))
if len(ckpt_files) > 0:
break
else:
raise FileNotFoundError(f'Could not find any tfrecord files in {data_dir}')
ds = tf.data.TFRecordDataset(filenames=ckpt_files)
ds = ds.shard(jax.process_count(), jax.process_index())
ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
ds = ds.map(pre_process, tf.data.AUTOTUNE)
Expand Down
Loading

0 comments on commit 7e9f3da

Please sign in to comment.