Skip to content

Commit

Permalink
gradually increase density threshold to remove floater
Browse files Browse the repository at this point in the history
  • Loading branch information
kwea123 committed Jul 4, 2022
1 parent cf7b4db commit c53b459
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
13 changes: 11 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import numpy as np
import cv2

# data
from torch.utils.data import DataLoader
from datasets import dataset_dict

# models
from kornia.utils.grid import create_meshgrid3d
from models.networks import NGP
from models.rendering import render
from models.rendering import render, MAX_SAMPLES

# optimizer, losses
from apex.optimizers import FusedAdam
Expand All @@ -26,6 +27,8 @@
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from utils import slim_ckpt

import warnings; warnings.filterwarnings("ignore")


Expand Down Expand Up @@ -103,7 +106,9 @@ def val_dataloader(self):

def training_step(self, batch, batch_nb):
if self.global_step%self.S == 0:
self.model.update_density_grid(warmup=self.global_step<256)
a_thr = min(self.current_epoch+1, 25)/50 # alpha threshold
self.model.update_density_grid(a_thr*MAX_SAMPLES/(2*3**0.5),
warmup=self.global_step<256)

rays, rgb = batch['rays'], batch['rgb']
results = self(rays, split='train')
Expand Down Expand Up @@ -179,3 +184,7 @@ def validation_epoch_end(self, outputs):
precision=16)

trainer.fit(system, ckpt_path=hparams.ckpt_path)

# save slimmed ckpt for the last epoch
ckpt_ = slim_ckpt(f'ckpts/{hparams.exp_name}/epoch={hparams.num_epochs-1}.ckpt')
torch.save(f'ckpts/{hparams.exp_name}/epoch={hparams.num_epochs-1}_slim.ckpt', ckpt_)
11 changes: 10 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,13 @@ def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
model_dict = model.state_dict()
checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore)
model_dict.update(checkpoint_)
model.load_state_dict(model_dict)
model.load_state_dict(model_dict)


def slim_ckpt(ckpt_path):
ckpt = torch.load(ckpt_path)
# pop unused parameters
ckpt['state_dict'].pop('weights', None)
ckpt['state_dict'].pop('model.density_grid', None)
ckpt['state_dict'].pop('model.grid_coords', None)
return ckpt['state_dict']

0 comments on commit c53b459

Please sign in to comment.