From c53b4590789698a2271d226ee897473fe4fc6481 Mon Sep 17 00:00:00 2001 From: kwea123 Date: Mon, 4 Jul 2022 16:07:45 +0900 Subject: [PATCH] gradually increase density threshold to remove floater --- train.py | 13 +++++++++++-- utils.py | 11 ++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 8c3ba7bb..0d4197fc 100644 --- a/train.py +++ b/train.py @@ -5,13 +5,14 @@ import numpy as np import cv2 +# data from torch.utils.data import DataLoader from datasets import dataset_dict # models from kornia.utils.grid import create_meshgrid3d from models.networks import NGP -from models.rendering import render +from models.rendering import render, MAX_SAMPLES # optimizer, losses from apex.optimizers import FusedAdam @@ -26,6 +27,8 @@ from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger +from utils import slim_ckpt + import warnings; warnings.filterwarnings("ignore") @@ -103,7 +106,9 @@ def val_dataloader(self): def training_step(self, batch, batch_nb): if self.global_step%self.S == 0: - self.model.update_density_grid(warmup=self.global_step<256) + a_thr = min(self.current_epoch+1, 25)/50 # alpha threshold + self.model.update_density_grid(a_thr*MAX_SAMPLES/(2*3**0.5), + warmup=self.global_step<256) rays, rgb = batch['rays'], batch['rgb'] results = self(rays, split='train') @@ -179,3 +184,7 @@ def validation_epoch_end(self, outputs): precision=16) trainer.fit(system, ckpt_path=hparams.ckpt_path) + + # save slimmed ckpt for the last epoch + ckpt_ = slim_ckpt(f'ckpts/{hparams.exp_name}/epoch={hparams.num_epochs-1}.ckpt') + torch.save(f'ckpts/{hparams.exp_name}/epoch={hparams.num_epochs-1}_slim.ckpt', ckpt_) diff --git a/utils.py b/utils.py index 12fd5a88..e5c10ebc 100644 --- a/utils.py +++ b/utils.py @@ -22,4 +22,13 @@ def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): model_dict = model.state_dict() checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore) model_dict.update(checkpoint_) - model.load_state_dict(model_dict) \ No newline at end of file + model.load_state_dict(model_dict) + + +def slim_ckpt(ckpt_path): + ckpt = torch.load(ckpt_path) + # pop unused parameters + ckpt['state_dict'].pop('weights', None) + ckpt['state_dict'].pop('model.density_grid', None) + ckpt['state_dict'].pop('model.grid_coords', None) + return ckpt['state_dict']