@@ -16,12 +16,17 @@ class SwAeController:
16
16
transform : Union [torchvision .transforms .Compose , None ] = None
17
17
global_sty : Union [torch .Tensor , None ] = None
18
18
global_tex : Union [torch .Tensor , None ] = None
19
- tex_path : str = None
19
+ structure_path : Union [ str , None ] = None
20
20
cache : Dict = {}
21
21
sty_argumentation : OrderedDict = OrderedDict ()
22
22
23
23
@timing
24
24
def __init__ (self , name : str ) -> None :
25
+ """Initilise the model and other options
26
+
27
+ Args:
28
+ name (str): [description]
29
+ """
25
30
self .opt = Global_config (isTrain = False , name = name )
26
31
self .model = SwappingAutoencoderModel (self .opt )
27
32
self .model .initialize ()
@@ -34,29 +39,43 @@ def _get_transform(self) -> torchvision.transforms.Compose:
34
39
return get_transform (self .opt , ** kwarg )
35
40
36
41
@timing
37
- def set_size (self , size : int ):
38
- if size < 0 :
39
- raise ValueError ("Can not set negetive size" )
42
+ def set_size (self , size : int ) -> None :
43
+ """Sets transform to load images with the `size`. Output is also of width `size`. It must be greater than 128 and must be a multiple of 4.
44
+
45
+
46
+ Args:
47
+ size (int): size of the ouput image.
48
+
49
+ Raises:
50
+ ValueError: if the size is not a valid integer.
51
+ """
52
+ if size < 0 or size % 2 == 1 or size < 128 :
53
+ raise ValueError ("invalid size" )
40
54
self .load_size = size
41
55
42
56
# need to reload transforms with new size
43
57
self .transform = self ._get_transform ()
44
58
45
59
@timing
46
- def load_image (self , path ) -> torch .Tensor :
60
+ def _load_image (self , path ) -> torch .Tensor :
47
61
img = Image .open (path ).convert ("RGB" )
48
62
if self .transform == None :
49
63
self .transform = self ._get_transform ()
50
64
tensor = self .transform (img ).unsqueeze (0 )
51
65
return tensor
52
66
53
67
@timing
54
- def set_tex (self , tex_path ):
55
- if tex_path == None and self .tex_path == tex_path :
68
+ def set_structure (self , structure_path : str ) -> None :
69
+ """set the structure, must be called before compute(). Doesn't cache the image. But, sets the noise input for the model
70
+
71
+ Args:
72
+ structure_path (str): path to the structure image
73
+ """
74
+ if structure_path == None and self .structure_path == structure_path :
56
75
return
57
- if self .tex_path == None :
58
- self .tex_path = tex_path
59
- source = self .load_image ( tex_path ).to ("cuda" )
76
+ if self .structure_path == None :
77
+ self .structure_path = structure_path
78
+ source = self ._load_image ( structure_path ).to ("cuda" )
60
79
with torch .no_grad ():
61
80
self .model (sample_image = source , command = "fix_noise" )
62
81
self .global_tex , self .global_sty = self .encode (source )
@@ -72,13 +91,19 @@ def encode(self, im: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
72
91
def load_encode_cache (self , path : str ) -> Tuple [torch .Tensor , torch .Tensor ]:
73
92
if path in self .cache :
74
93
return self .cache [path ]
75
- im = self .load_image (path )
94
+ im = self ._load_image (path )
76
95
tex , sty = self .encode (im )
77
96
self .cache [path ] = tex , sty
78
97
return tex , sty
79
98
80
99
@timing
81
- def mix_style (self , style_path , alpha ):
100
+ def mix_style (self , style_path : str , alpha : float ) -> None :
101
+ """Mixes the style of the image given with the current structure image by the factor of alpha. Caches the encoded image. actual mixing happens when `compute()` is called
102
+
103
+ Args:
104
+ style_path (str): Path to the image whose style you want to mix
105
+ alpha (float): Value of mix factor. 0 would remove this image from the mix, 1 implies using NONE of the original styles
106
+ """
82
107
if alpha == 0 :
83
108
if style_path in self .sty_argumentation :
84
109
del self .sty_argumentation [style_path ]
@@ -90,7 +115,12 @@ def mix_style(self, style_path, alpha):
90
115
self .sty_argumentation [style_path ] = alpha
91
116
92
117
@timing
93
- def compute (self ):
118
+ def compute (self ) -> torch .Tensor :
119
+ """Computes the output of the operations performed by the mix_style and gives the output image
120
+
121
+ Returns:
122
+ torch.Tensor: output tensor with the shape Tensor with shape (1, 3, h, w) where `h` and `w` are height and width.
123
+ """
94
124
assert self .global_sty != None and self .global_tex != None
95
125
torch .cuda .empty_cache ()
96
126
local_sty = self .global_sty .clone ()
0 commit comments