Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit content style maps for reproducing datasets and manual control #18

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,22 @@ Feel free to open an issue in case there is any question.
- `--style-dir <STLYE>` the top-level directory of the style images (mandatory)
- `--output-dir <OUTPUT>` the directory where the stylized dataset will be stored (optional, default: `output/`)
- `--num-styles <N>` number of stylizations to create for each content image (optional, default: `1`)
- `--style-map <STYLE_MAP>` an explicit content style map json. If provided, num-styles will be ignored (optional, default: `None`)
- `--alpha <A>` Weight that controls the strength of stylization, should be between 0 and 1 (optional, default: `1`)
- `--extensions <EX0> <EX1> ...` list of image extensions to scan style and content directory for (optional, default: `png, jpeg, jpg`). Note: this is case sensitive, `--extensions jpg` will not scan for files ending on `.JPG`. Image types must be compatible with PIL's `Image.open()` ([Documentation](https://pillow.readthedocs.io/en/5.1.x/handbook/image-file-formats.html))
- `--content-size <N>` Minimum size for content images, resulting in scaling of the shorter side of the content image to `N` (optional, default: `0`). Set this to 0 to keep the original image dimensions.
- `--style-size <N>` Minimum size for style images, resulting in scaling of the shorter side of the style image to `N` (optional, default: `512`). Set this to 0 to keep the original image dimensions (for large style images, this will result in high (GPU) memory consumption).
- `--crop <N>` Size for the center crop applied to the content image in order to create a squared image (optional, default 0). Setting this to 0 will disable the cropping.

The chosen styles per content image will be saved in a `content_style_map.json`. This file can be used with the `--style-map` argument to reproduce a specific dataset or to create an explicit content style mapping manually. Keys and values must be existing file names in the `--content-dir` and `--style-dir` respectively.

```json
{
"content_image_1.jpg": ["style_1.jpg"],
"content_image_2.jpg": ["style_2.jpg", "style_3.jpg"]
}
```

Here is an example call:

```
Expand Down
2 changes: 2 additions & 0 deletions models/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
52 changes: 47 additions & 5 deletions stylize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torchvision.transforms
from torchvision.utils import save_image
from tqdm import tqdm
import json

parser = argparse.ArgumentParser(description='This script applies the AdaIN style transfer method to arbitrary datasets.')
parser.add_argument('--content-dir', type=str,
Expand All @@ -20,6 +21,8 @@
help='Directory to save the output images')
parser.add_argument('--num-styles', type=int, default=1, help='Number of styles to \
create for each image (default: 1)')
parser.add_argument('--style-map', type=str, default=None, help='An explicit content-style map \
that is used instead of random sampling. Num styles is ignored in this case.')
parser.add_argument('--alpha', type=float, default=1.0,
help='The weight that controls the degree of \
stylization. Should be between 0 and 1')
Expand Down Expand Up @@ -87,8 +90,42 @@ def main():
styles += list(style_dir.rglob('*.' + ext))

assert len(styles) > 0, 'No images with specified extensions found in style directory' + style_dir
styles = sorted(styles)
print('Found %d style images in %s' % (len(styles), style_dir))
style_paths = sorted(styles)
print('Found %d style images in %s' % (len(style_paths), style_dir))

# convert to dicts so we can access the actual file paths by the file names
content_files = {}
style_files = {}

for cp in content_paths:
content_files[cp.name] = cp

for sp in style_paths:
style_files[sp.name] = sp

content_filenames = list(content_files.keys())
style_filenames = list(style_files.keys())

# create the content to style mappings
if args.style_map is None:
print('Create new content-style map with %d random styles per content image.' % args.num_styles)
style_map = {}
for c in content_filenames:
style_list = []
for s in random.sample(style_filenames, args.num_styles):
style_list.append(s)
style_map[c] = style_list
else:
style_map_path = Path(args.style_map).resolve()
print('Load content-style map from %s' % style_map_path)
with open(style_map_path) as f:
style_map = json.load(f)

# ensure that content and style files exist (e.g. when using an explicit content-style-map)
for c, s_list in style_map.items():
assert c in content_filenames, 'Content file %s not found in content directory %s' % (c, content_dir)
for s in s_list:
assert s in style_filenames, 'Style file %s not found in style directory %s' % (s, style_dir)

decoder = net.decoder
vgg = net.vgg
Expand All @@ -114,11 +151,13 @@ def main():
skipped_imgs = []

# actual style transfer as in AdaIN
with tqdm(total=len(content_paths)) as pbar:
for content_path in content_paths:
with tqdm(total=len(style_map.keys())) as pbar:
for c, style_list in style_map.items():
content_path = content_files[c]
try:
content_img = Image.open(content_path).convert('RGB')
for style_path in random.sample(styles, args.num_styles):
for s in style_list:
style_path = style_files[s]
style_img = Image.open(style_path).convert('RGB')

content = content_tf(content_img)
Expand Down Expand Up @@ -156,6 +195,9 @@ def main():
finally:
pbar.update(1)

with open(output_dir.joinpath('content_style_map.json'), 'w') as f:
json.dump(style_map, f)

if(len(skipped_imgs) > 0):
with open(output_dir.joinpath('skipped_imgs.txt'), 'w') as f:
for item in skipped_imgs:
Expand Down