Skip to content

Commit 9f2553f

Browse files
committed
refactor: misc varible changes and docstrings
1 parent 81bf988 commit 9f2553f

File tree

3 files changed

+50
-17
lines changed

3 files changed

+50
-17
lines changed

api/__init__.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,17 @@ class SwAeController:
1616
transform: Union[torchvision.transforms.Compose, None] = None
1717
global_sty: Union[torch.Tensor, None] = None
1818
global_tex: Union[torch.Tensor, None] = None
19-
tex_path: str = None
19+
structure_path: Union[str, None] = None
2020
cache: Dict = {}
2121
sty_argumentation: OrderedDict = OrderedDict()
2222

2323
@timing
2424
def __init__(self, name: str) -> None:
25+
"""Initilise the model and other options
26+
27+
Args:
28+
name (str): [description]
29+
"""
2530
self.opt = Global_config(isTrain=False, name=name)
2631
self.model = SwappingAutoencoderModel(self.opt)
2732
self.model.initialize()
@@ -34,29 +39,43 @@ def _get_transform(self) -> torchvision.transforms.Compose:
3439
return get_transform(self.opt, **kwarg)
3540

3641
@timing
37-
def set_size(self, size: int):
38-
if size < 0:
39-
raise ValueError("Can not set negetive size")
42+
def set_size(self, size: int) -> None:
43+
"""Sets transform to load images with the `size`. Output is also of width `size`. It must be greater than 128 and must be a multiple of 4.
44+
45+
46+
Args:
47+
size (int): size of the ouput image.
48+
49+
Raises:
50+
ValueError: if the size is not a valid integer.
51+
"""
52+
if size < 0 or size % 2 == 1 or size < 128:
53+
raise ValueError("invalid size")
4054
self.load_size = size
4155

4256
# need to reload transforms with new size
4357
self.transform = self._get_transform()
4458

4559
@timing
46-
def load_image(self, path) -> torch.Tensor:
60+
def _load_image(self, path) -> torch.Tensor:
4761
img = Image.open(path).convert("RGB")
4862
if self.transform == None:
4963
self.transform = self._get_transform()
5064
tensor = self.transform(img).unsqueeze(0)
5165
return tensor
5266

5367
@timing
54-
def set_tex(self, tex_path):
55-
if tex_path == None and self.tex_path == tex_path:
68+
def set_structure(self, structure_path: str) -> None:
69+
"""set the structure, must be called before compute(). Doesn't cache the image. But, sets the noise input for the model
70+
71+
Args:
72+
structure_path (str): path to the structure image
73+
"""
74+
if structure_path == None and self.structure_path == structure_path:
5675
return
57-
if self.tex_path == None:
58-
self.tex_path = tex_path
59-
source = self.load_image(tex_path).to("cuda")
76+
if self.structure_path == None:
77+
self.structure_path = structure_path
78+
source = self._load_image(structure_path).to("cuda")
6079
with torch.no_grad():
6180
self.model(sample_image=source, command="fix_noise")
6281
self.global_tex, self.global_sty = self.encode(source)
@@ -72,13 +91,19 @@ def encode(self, im: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
7291
def load_encode_cache(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]:
7392
if path in self.cache:
7493
return self.cache[path]
75-
im = self.load_image(path)
94+
im = self._load_image(path)
7695
tex, sty = self.encode(im)
7796
self.cache[path] = tex, sty
7897
return tex, sty
7998

8099
@timing
81-
def mix_style(self, style_path, alpha):
100+
def mix_style(self, style_path: str, alpha: float) -> None:
101+
"""Mixes the style of the image given with the current structure image by the factor of alpha. Caches the encoded image. actual mixing happens when `compute()` is called
102+
103+
Args:
104+
style_path (str): Path to the image whose style you want to mix
105+
alpha (float): Value of mix factor. 0 would remove this image from the mix, 1 implies using NONE of the original styles
106+
"""
82107
if alpha == 0:
83108
if style_path in self.sty_argumentation:
84109
del self.sty_argumentation[style_path]
@@ -90,7 +115,12 @@ def mix_style(self, style_path, alpha):
90115
self.sty_argumentation[style_path] = alpha
91116

92117
@timing
93-
def compute(self):
118+
def compute(self) -> torch.Tensor:
119+
"""Computes the output of the operations performed by the mix_style and gives the output image
120+
121+
Returns:
122+
torch.Tensor: output tensor with the shape Tensor with shape (1, 3, h, w) where `h` and `w` are height and width.
123+
"""
94124
assert self.global_sty != None and self.global_tex != None
95125
torch.cuda.empty_cache()
96126
local_sty = self.global_sty.clone()

api/__main__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
parser = argparse.ArgumentParser(
99
description="Process some images", prog="api", formatter_class=argparse.ArgumentDefaultsHelpFormatter
1010
)
11-
parser.add_argument("--version", action="version", version="%(prog)s 1.0")
11+
12+
__version__ = "1.0"
13+
14+
parser.add_argument("--version", action="version", version=__version__)
1215

1316
parser.add_argument("img1", metavar="Structure", help="Path to Structure one")
1417
parser.add_argument("img2", metavar="Style", help="Path to Style two")
@@ -25,7 +28,7 @@
2528

2629
SAE = SwAeController(args.model)
2730
SAE.set_size(512)
28-
SAE.set_tex(args.img1)
31+
SAE.set_structure(args.img1)
2932
SAE.mix_style(args.img2, args.alpha)
3033

3134
output_image = tensor_to_PIL(SAE.compute()[0])

streamlit_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def set_images():
4040
"Size:", ("128", "256", "512", "640"), format_func=lambda x: f"{x} px", help="Size of the ouput image"
4141
)
4242
st.session_state.SAE.set_size(int(size))
43-
st.session_state.SAE.set_tex(st.session_state.STR)
43+
st.session_state.SAE.set_structure(st.session_state.STR)
4444
opt = st.sidebar.slider(
4545
"Options to load", 3, len(st.session_state.images), help="No. of option images to load for style mix", step=2
4646
)
@@ -57,7 +57,7 @@ def set_images():
5757
help="Choose the structure image from the options below",
5858
)
5959
st.session_state.STR = IM
60-
st.session_state.SAE.set_tex(IM)
60+
st.session_state.SAE.set_structure(IM)
6161
st.image(st.session_state.STR, "Orignal Structure Image")
6262

6363

0 commit comments

Comments
 (0)