diff --git a/README.md b/README.md index 6d008b7..426cb88 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,22 @@ Feel free to open an issue in case there is any question. - `--style-dir ` the top-level directory of the style images (mandatory) - `--output-dir ` the directory where the stylized dataset will be stored (optional, default: `output/`) - `--num-styles ` number of stylizations to create for each content image (optional, default: `1`) + - `--style-map ` an explicit content style map json. If provided, num-styles will be ignored (optional, default: `None`) - `--alpha ` Weight that controls the strength of stylization, should be between 0 and 1 (optional, default: `1`) - `--extensions ...` list of image extensions to scan style and content directory for (optional, default: `png, jpeg, jpg`). Note: this is case sensitive, `--extensions jpg` will not scan for files ending on `.JPG`. Image types must be compatible with PIL's `Image.open()` ([Documentation](https://pillow.readthedocs.io/en/5.1.x/handbook/image-file-formats.html)) - `--content-size ` Minimum size for content images, resulting in scaling of the shorter side of the content image to `N` (optional, default: `0`). Set this to 0 to keep the original image dimensions. - `--style-size ` Minimum size for style images, resulting in scaling of the shorter side of the style image to `N` (optional, default: `512`). Set this to 0 to keep the original image dimensions (for large style images, this will result in high (GPU) memory consumption). - `--crop ` Size for the center crop applied to the content image in order to create a squared image (optional, default 0). Setting this to 0 will disable the cropping. +The chosen styles per content image will be saved in a `content_style_map.json`. This file can be used with the `--style-map` argument to reproduce a specific dataset or to create an explicit content style mapping manually. Keys and values must be existing file names in the `--content-dir` and `--style-dir` respectively. + +```json +{ + "content_image_1.jpg": ["style_1.jpg"], + "content_image_2.jpg": ["style_2.jpg", "style_3.jpg"] +} +``` + Here is an example call: ``` diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/models/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/stylize.py b/stylize.py index ce6b179..9329f2e 100644 --- a/stylize.py +++ b/stylize.py @@ -10,6 +10,7 @@ import torchvision.transforms from torchvision.utils import save_image from tqdm import tqdm +import json parser = argparse.ArgumentParser(description='This script applies the AdaIN style transfer method to arbitrary datasets.') parser.add_argument('--content-dir', type=str, @@ -20,6 +21,8 @@ help='Directory to save the output images') parser.add_argument('--num-styles', type=int, default=1, help='Number of styles to \ create for each image (default: 1)') +parser.add_argument('--style-map', type=str, default=None, help='An explicit content-style map \ + that is used instead of random sampling. Num styles is ignored in this case.') parser.add_argument('--alpha', type=float, default=1.0, help='The weight that controls the degree of \ stylization. Should be between 0 and 1') @@ -87,8 +90,42 @@ def main(): styles += list(style_dir.rglob('*.' + ext)) assert len(styles) > 0, 'No images with specified extensions found in style directory' + style_dir - styles = sorted(styles) - print('Found %d style images in %s' % (len(styles), style_dir)) + style_paths = sorted(styles) + print('Found %d style images in %s' % (len(style_paths), style_dir)) + + # convert to dicts so we can access the actual file paths by the file names + content_files = {} + style_files = {} + + for cp in content_paths: + content_files[cp.name] = cp + + for sp in style_paths: + style_files[sp.name] = sp + + content_filenames = list(content_files.keys()) + style_filenames = list(style_files.keys()) + + # create the content to style mappings + if args.style_map is None: + print('Create new content-style map with %d random styles per content image.' % args.num_styles) + style_map = {} + for c in content_filenames: + style_list = [] + for s in random.sample(style_filenames, args.num_styles): + style_list.append(s) + style_map[c] = style_list + else: + style_map_path = Path(args.style_map).resolve() + print('Load content-style map from %s' % style_map_path) + with open(style_map_path) as f: + style_map = json.load(f) + + # ensure that content and style files exist (e.g. when using an explicit content-style-map) + for c, s_list in style_map.items(): + assert c in content_filenames, 'Content file %s not found in content directory %s' % (c, content_dir) + for s in s_list: + assert s in style_filenames, 'Style file %s not found in style directory %s' % (s, style_dir) decoder = net.decoder vgg = net.vgg @@ -114,11 +151,13 @@ def main(): skipped_imgs = [] # actual style transfer as in AdaIN - with tqdm(total=len(content_paths)) as pbar: - for content_path in content_paths: + with tqdm(total=len(style_map.keys())) as pbar: + for c, style_list in style_map.items(): + content_path = content_files[c] try: content_img = Image.open(content_path).convert('RGB') - for style_path in random.sample(styles, args.num_styles): + for s in style_list: + style_path = style_files[s] style_img = Image.open(style_path).convert('RGB') content = content_tf(content_img) @@ -156,6 +195,9 @@ def main(): finally: pbar.update(1) + with open(output_dir.joinpath('content_style_map.json'), 'w') as f: + json.dump(style_map, f) + if(len(skipped_imgs) > 0): with open(output_dir.joinpath('skipped_imgs.txt'), 'w') as f: for item in skipped_imgs: