Skip to content

Commit

Permalink
add normal network training
Browse files Browse the repository at this point in the history
  • Loading branch information
YuliangXiu committed Aug 2, 2022
1 parent b571781 commit cf46a6b
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 105 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
<br />

## News :triangular_flag_on_post:
- [2022/07/30] <a href="https://huggingface.co/spaces/Yuliang/ICON" style='padding-left: 0.5rem;'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-orange'></a> <a href='https://colab.research.google.com/drive/1-AWeWhPvCTBX0KfMtgtMk10uPU05ihoA?usp=sharing' style='padding-left: 0.5rem;'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Google Colab'></a>
- [2022/07/30] <a href="https://huggingface.co/spaces/Yuliang/ICON" style='padding-left: 0.5rem;'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-orange'></a> <a href='https://colab.research.google.com/drive/1-AWeWhPvCTBX0KfMtgtMk10uPU05ihoA?usp=sharing' style='padding-left: 0.5rem;'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Google Colab'></a> are both available.
- [2022/07/26] New cloth-refinement module is released, try `-loop_cloth`.
- [2022/06/13] ETH Zürich students from 3DV course create an add-on for [garment-extraction](docs/garment-extraction.md).
- [2022/05/16] <a href="https://github.com/Arthur151/ROMP">BEV</a> is supported as optional HPS by <a href="https://scholar.google.com/citations?hl=en&user=fkGxgrsAAAAJ">Yu Sun</a>, see [commit #060e265](https://github.com/YuliangXiu/ICON/commit/060e265bd253c6a34e65c9d0a5288c6d7ffaf68e).
Expand Down
2 changes: 2 additions & 0 deletions apps/Normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
torch.backends.cudnn.benchmark = True

logging.getLogger("lightning").setLevel(logging.ERROR)
import warnings
warnings.filterwarnings("ignore")


class Normal(pl.LightningModule):
Expand Down
16 changes: 5 additions & 11 deletions apps/train-normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pytorch_lightning import loggers as pl_loggers
import pytorch_lightning as pl
from lib.common.config import get_cfg_defaults
import sys
import os
import os.path as osp
import argparse
Expand All @@ -21,8 +20,7 @@
parser.add_argument(
"-cfg", "--config_file", type=str, help="path of the yaml config file"
)
argv = sys.argv[1: sys.argv.index("--")]
args = parser.parse_args(argv)
args = parser.parse_args()
cfg = get_cfg_defaults()
cfg.merge_from_file(args.config_file)
cfg.freeze()
Expand Down Expand Up @@ -56,19 +54,14 @@
freq_eval = cfg.fast_dev

trainer_kwargs = {
# 'accelerator': 'dp',
# 'amp_level': 'O2',
# 'precision': 16,
# 'weights_summary': 'top',
# 'stochastic_weight_avg': False,
"gpus": cfg.gpus,
"auto_select_gpus": True,
"reload_dataloaders_every_epoch": True,
"sync_batchnorm": True,
"benchmark": True,
"automatic_optimization": False,
"logger": tb_logger,
"track_grad_norm": -1,
"automatic_optimization": False,
"num_sanity_val_steps": cfg.num_sanity_val_steps,
"checkpoint_callback": checkpoint,
"limit_train_batches": cfg.dataset.train_bsize,
Expand All @@ -94,14 +87,15 @@
else freq_eval,
}
)

if cfg.overfit:
cfg_show_list = ["freq_show_train", 200.0, "freq_show_val", 10.0]
else:
cfg_show_list = [
"freq_show_train",
cfg.freq_show_train * train_len / cfg.batch_size,
cfg.freq_show_train * train_len // cfg.batch_size,
"freq_show_val",
max(cfg.freq_show_val * val_len / cfg.batch_size, 1.0),
max(cfg.freq_show_val * val_len // cfg.batch_size, 1.0),
]

cfg.merge_from_list(cfg_show_list)
Expand Down
41 changes: 41 additions & 0 deletions configs/train/normal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: normal
ckpt_dir: "./data/ckpt/"
resume_path: "./data/ckpt/normal.ckpt"
results_path: "./results"

dataset:
root: "./data/"
rotation_num: 36
train_bsize: 1.0
val_bsize: 1.0
test_bsize: 1.0
types: ["thuman2"]
scales: [100.0]

net:
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))

lr_N: 1e-4
weight_decay: 0.0
momentum: 0.0
batch_size: 4
num_threads: 4
gpus: [0]
test_gpus: [0]

fast_dev: 0
resume: False
test_mode: False
num_sanity_val_steps: 1

momentum: 0.0
optim: Adam

# training (batch=4, set=agora, rot-6)
overfit: False
num_epoch: 20
freq_show_train: 0.1
freq_show_val: 0.01
freq_plot: 0.01
freq_eval: 0.1
schedule: [18]
3 changes: 3 additions & 0 deletions docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ CUDA_VISIBLE_DEVICES=0 python -m apps.train -cfg ./configs/train/pifu.yaml

# PaMIR (name: pamir)
CUDA_VISIBLE_DEVICES=0 python -m apps.train -cfg ./configs/train/pamir.yaml

# Normal network Training
CUDA_VISIBLE_DEVICES=0 python -m apps.train-normal -cfg ./configs/train/normal.yaml
```

## Tensorboard
Expand Down
144 changes: 53 additions & 91 deletions lib/dataset/NormalDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
#
# Contact: [email protected]

import random
import os.path as osp
import numpy as np
from PIL import Image
from termcolor import colored
import torchvision.transforms as transforms


Expand All @@ -26,14 +28,13 @@ def __init__(self, cfg, split='train'):

self.split = split
self.root = cfg.root
self.bsize = cfg.batch_size
self.overfit = cfg.overfit

self.opt = cfg.dataset
self.datasets = self.opt.types
self.input_size = self.opt.input_size
self.set_splits = self.opt.set_splits
self.scales = self.opt.scales
self.pifu = self.opt.pifu

# input data types and dimensions
self.in_nml = [item[0] for item in cfg.net.in_nml]
Expand All @@ -44,19 +45,18 @@ def __init__(self, cfg, split='train'):
if self.split != 'train':
self.rotations = range(0, 360, 120)
else:
self.rotations = np.arange(0, 360, 360 /
self.opt.rotation_num).astype(np.int)
self.rotations = np.arange(
0, 360, 360//self.opt.rotation_num).astype(np.int)

self.datasets_dict = {}

for dataset_id, dataset in enumerate(self.datasets):
dataset_dir = osp.join(self.root, dataset, "smplx")

dataset_dir = osp.join(self.root, dataset)

self.datasets_dict[dataset] = {
"subjects":
np.loadtxt(osp.join(self.root, dataset, "all.txt"), dtype=str),
"path":
dataset_dir,
"scale":
self.scales[dataset_id]
"subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str),
"scale": self.scales[dataset_id]
}

self.subject_list = self.get_subject_list(split)
Expand All @@ -81,65 +81,35 @@ def get_subject_list(self, split):

for dataset in self.datasets:

if self.pifu:
txt = osp.join(self.root, dataset, f'{split}_pifu.txt')
else:
txt = osp.join(self.root, dataset, f'{split}.txt')

if osp.exists(txt):
print(f"load from {txt}")
subject_list += sorted(np.loadtxt(txt, dtype=str).tolist())

if self.pifu:
miss_pifu = sorted(
np.loadtxt(osp.join(self.root, dataset,
"miss_pifu.txt"),
dtype=str).tolist())
subject_list = [
subject for subject in subject_list
if subject not in miss_pifu
]
subject_list = [
"renderpeople/" + subject for subject in subject_list
]
split_txt = osp.join(self.root, dataset, f'{split}.txt')

if osp.exists(split_txt):
print(f"load from {split_txt}")
subject_list += np.loadtxt(split_txt, dtype=str).tolist()
else:
train_txt = osp.join(self.root, dataset, 'train.txt')
val_txt = osp.join(self.root, dataset, 'val.txt')
test_txt = osp.join(self.root, dataset, 'test.txt')

print(
f"generate lists of [train, val, test] \n {train_txt} \n {val_txt} \n {test_txt} \n"
)

split_txt = osp.join(self.root, dataset, f'{split}.txt')

subjects = self.datasets_dict[dataset]['subjects']
train_split = int(len(subjects) * self.set_splits[0])
val_split = int(
len(subjects) * self.set_splits[1]) + train_split

with open(train_txt, "w") as f:
f.write("\n".join(dataset + "/" + item
for item in subjects[:train_split]))
with open(val_txt, "w") as f:
f.write("\n".join(
dataset + "/" + item
for item in subjects[train_split:val_split]))
with open(test_txt, "w") as f:
f.write("\n".join(dataset + "/" + item
for item in subjects[val_split:]))

subject_list += sorted(
np.loadtxt(split_txt, dtype=str).tolist())

bug_list = sorted(
np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())

subject_list = [
subject for subject in subject_list if (subject not in bug_list)
]
full_txt = osp.join(self.root, dataset, 'all.txt')
print(f"split {full_txt} into train/val/test")

full_lst = np.loadtxt(full_txt, dtype=str)
full_lst = [dataset+"/"+item for item in full_lst]
[train_lst, test_lst, val_lst] = np.split(
full_lst, [500, 500+5, ])

np.savetxt(full_txt.replace(
"all", "train"), train_lst, fmt="%s")
np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s")
np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s")

print(f"load from {split_txt}")
subject_list += np.loadtxt(split_txt, dtype=str).tolist()

if self.split != 'test':
subject_list += subject_list[:self.bsize -
len(subject_list) % self.bsize]
print(colored(f"total: {len(subject_list)}", "yellow"))
random.shuffle(subject_list)

# subject_list = ["thuman2/0008"]
return subject_list

def __len__(self):
Expand All @@ -155,46 +125,38 @@ def __getitem__(self, index):
mid = index // len(self.rotations)

rotation = self.rotations[rid]

# choose specific test sets
subject = self.subject_list[mid]

subject_render = "/".join(
[subject.split("/")[0] + "_12views",
subject.split("/")[1]])
subject = self.subject_list[mid].split("/")[1]
dataset = self.subject_list[mid].split("/")[0]
render_folder = "/".join([dataset +
f"_{self.opt.rotation_num}views", subject])

# setup paths
data_dict = {
'dataset':
subject.split("/")[0],
'subject':
subject,
'rotation':
rotation,
'image_path':
osp.join(self.root, subject_render, 'render',
f'{rotation:03d}.png')
'dataset': dataset,
'subject': subject,
'rotation': rotation,
'scale': self.datasets_dict[dataset]["scale"],
'image_path': osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png')
}

# image/normal/depth loader
for name, channel in zip(self.in_total, self.in_total_dim):

if name != 'image':
if f'{name}_path' not in data_dict.keys():
data_dict.update({
f'{name}_path':
osp.join(self.root, subject_render, name,
f'{rotation:03d}.png')
f'{name}_path': osp.join(self.root, render_folder, name, f'{rotation:03d}.png')
})

# tensor update
data_dict.update({
name:
self.imagepath2tensor(data_dict[f'{name}_path'],
channel,
inv='depth_B' in name)
name: self.imagepath2tensor(
data_dict[f'{name}_path'], channel, inv=False)
})

path_keys = [
key for key in data_dict.keys() if '_path' in key or '_dir' in key
]

for key in path_keys:
del data_dict[key]

Expand Down
18 changes: 18 additions & 0 deletions lib/dataset/PIFuDataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@

# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]


from lib.renderer.mesh import load_fit_body
from lib.dataset.hoppeMesh import HoppeMesh
from lib.dataset.body_model import TetraSMPLModel
Expand Down
4 changes: 2 additions & 2 deletions lib/net/NormalNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def forward(self, in_tensor):
nmlB = self.netB(torch.cat(inB_list, dim=1))

# ||normal|| == 1
nmlF /= torch.norm(nmlF, dim=1)
nmlB /= torch.norm(nmlB, dim=1)
nmlF = nmlF / torch.norm(nmlF, dim=1, keepdim=True)
nmlB = nmlB / torch.norm(nmlB, dim=1, keepdim=True)

# output: float_arr [-1,1] with [B, C, H, W]

Expand Down

0 comments on commit cf46a6b

Please sign in to comment.