From d25e7de841299c2e8c58a978e1b7510576531682 Mon Sep 17 00:00:00 2001 From: leeyunjai82 Date: Fri, 10 Mar 2023 13:01:27 +0900 Subject: [PATCH 1/2] update --- UGATIT.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/UGATIT.py b/UGATIT.py index c1859a1c..8be9bddf 100644 --- a/UGATIT.py +++ b/UGATIT.py @@ -153,13 +153,13 @@ def train(self): real_A, _ = trainA_iter.next() except: trainA_iter = iter(self.trainA_loader) - real_A, _ = trainA_iter.next() + real_A, _ = next(trainA_iter) try: real_B, _ = trainB_iter.next() except: trainB_iter = iter(self.trainB_loader) - real_B, _ = trainB_iter.next() + real_B, _ = next(trainB_iter) real_A, real_B = real_A.to(self.device), real_B.to(self.device) @@ -254,13 +254,13 @@ def train(self): real_A, _ = trainA_iter.next() except: trainA_iter = iter(self.trainA_loader) - real_A, _ = trainA_iter.next() + real_A, _ = next(trainA_iter) try: real_B, _ = trainB_iter.next() except: trainB_iter = iter(self.trainB_loader) - real_B, _ = trainB_iter.next() + real_B, _ = next(trainB_iter) real_A, real_B = real_A.to(self.device), real_B.to(self.device) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) @@ -293,13 +293,13 @@ def train(self): real_A, _ = testA_iter.next() except: testA_iter = iter(self.testA_loader) - real_A, _ = testA_iter.next() + real_A, _ = next(testA_iter) try: real_B, _ = testB_iter.next() except: testB_iter = iter(self.testB_loader) - real_B, _ = testB_iter.next() + real_B, _ = next(testB_iter) real_A, real_B = real_A.to(self.device), real_B.to(self.device) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) From 87b85dff4bde1cff309c06a34e043cf09a127664 Mon Sep 17 00:00:00 2001 From: leeyunjai82 Date: Thu, 16 Mar 2023 11:38:00 +0900 Subject: [PATCH 2/2] update --- UGATIT.py | 27 +++++++++++++++++++++++++++ main.py | 17 ++++++++++++++--- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/UGATIT.py b/UGATIT.py index 8be9bddf..4ad3ca46 100644 --- a/UGATIT.py +++ b/UGATIT.py @@ -5,6 +5,8 @@ from networks import * from utils import * from glob import glob +from PIL import Image +import cv2 class UGATIT(object) : def __init__(self, args): @@ -363,8 +365,33 @@ def load(self, dir, step): self.disLA.load_state_dict(params['disLA']) self.disLB.load_state_dict(params['disLB']) + def build_model_for_demo(self): + """ Define Generator, Discriminator """ + self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device) + # self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device) + # self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) + # self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) + # self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) + # self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) + params = torch.load('results/selfie2anime_params_latest.pt') + self.genA2B.load_state_dict(params['genA2B']) + + def inference(self, d): + #d = cv2.imread("/home/circulus/api-test/vision-test/face3.jpg") + h, w, _ = d.shape + d = cv2.resize(d, (256,256)) + d = (d)/127.5 -1 + d = np.transpose(d[np.newaxis,:,:,:], (0,3,1,2)).astype(np.float32) + d = torch.from_numpy(d).to(self.device) + + fake_A2B, _, fake_A2B_heatmap = self.genA2B(d) + img = cv2.resize(RGB2BGR(tensor2numpy(denorm(fake_A2B[0])))*255.0, (w,h)) + print(" [*] Load SUCCESS") + return img + def test(self): model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) + if not len(model_list) == 0: model_list.sort() iter = int(model_list[-1].split('_')[-1].split('.')[0]) diff --git a/main.py b/main.py index 1bdcfd21..072cf83e 100644 --- a/main.py +++ b/main.py @@ -68,16 +68,27 @@ def main(): # open session gan = UGATIT(args) - # build graph - gan.build_model() - if args.phase == 'train' : + # build graph + gan.build_model() gan.train() print(" [*] Training finished!") if args.phase == 'test' : + gan.build_model() gan.test() print(" [*] Test finished!") + if args.phase == 'demo': + gan.build_model_for_demo() + #cap = cv2.VideoCapture(0) + #_, img = cap.read() + img = cv2.imread("/home/circulus/project/Face-Attribute-Classification/img_align_celeba/000156.jpg") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + res = gan.inference(img) + cv2.imwrite("AAA.jpg", res) + print(" [*] Demo finished!") + if __name__ == '__main__': main()