diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..594bf3a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,12 @@
+*.pyc
+.vscode
+output
+build
+diff_rasterization/diff_rast.egg-info
+diff_rasterization/dist
+tensorboard_3d
+screenshots
+debug
+wandb
+data
+*.txt
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..c13f991
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..ba8e45f
--- /dev/null
+++ b/README.md
@@ -0,0 +1,109 @@
+
+
+
Gaussian-SLAM: Photo-realistic Dense SLAM with Gaussian Splatting
+
+ Vladimir Yugay
+ ·
+ Yue Li
+ ·
+ Theo Gevers
+ ·
+ Martin Oswald
+
+
+
+
+
+
+
+
+
+
+
+## ⚙️ Setting Things Up
+
+Clone the repo:
+
+```
+git clone https://github.com/VladimirYugay/Gaussian-SLAM
+```
+
+Make sure that gcc and g++ paths on your system are exported:
+
+```
+export CC=
+export CXX=
+```
+
+To find the gcc path and g++ path on your machine you can use which gcc.
+
+
+Then setup environment from the provided conda environment file,
+
+```
+conda env create -f environment.yml
+conda activate gslam
+```
+We tested our code on RTX3090 and RTX A6000 GPUs respectively and Ubuntu22 and CentOS7.5.
+
+## 🔨 Running Gaussian-SLAM
+
+Here we elaborate on how to load the necessary data, configure Gaussian-SLAM for your use-case, debug it, and how to reproduce the results mentioned in the paper.
+
+
+ Downloading the Data
+ We tested our code on Replica, TUM_RGBD, ScanNet, and ScanNet++ datasets. We also provide scripts for downloading Replica nad TUM_RGBD.
+ For downloading ScanNet, follow the procedure described on here.
+ For downloading ScanNet++, follow the procedure described on here.
+ The config files are named after the sequences that we used for our method.
+
+
+
+ Running the code
+ Start the system with the command:
+
+ ```
+ python run_slam.py configs// --input_path --output_path
+ ```
+ For example:
+ ```
+ python run_slam.py configs/Replica/room0.yaml --input_path /home/datasets/Replica/room0 --output_path output/Replica/room0
+ ```
+ You can also configure input and output paths in the config yaml file.
+
+
+
+ Reproducing Results
+ While we made all parts of our code deterministic, differential rasterizer of Gaussian Splatting is not. The metrics can be slightly different from run to run. In the paper we report average metrics that were computed over three seeds: 0, 1, and 2.
+
+ You can reproduce the results for a single scene by running:
+
+ ```
+ python run_slam.py configs// --input_path --output_path
+ ```
+
+ If you are running on a SLURM cluster, you can reproduce the results for all scenes in a dataset by running the script:
+ ```
+ ./scripts/reproduce_sbatch.sh
+ ```
+ Please note the evaluation of ```depth_L1``` metric requires reconstruction of the mesh, which in turns requires headless installation of open3d if you are running on a cluster.
+
+
+
+ Demo
+ We used the camera path tool in gaussian-splatting-lightning repo to help make the fly-through video based on the reconstructed scenes. We thank its author for the great work.
+
+
+## 📌 Citation
+
+If you find our paper and code useful, please cite us:
+
+```bib
+@misc{yugay2023gaussianslam,
+ title={Gaussian-SLAM: Photo-realistic Dense SLAM with Gaussian Splatting},
+ author={Vladimir Yugay and Yue Li and Theo Gevers and Martin R. Oswald},
+ year={2023},
+ eprint={2312.10070},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
diff --git a/assets/gaussian_slam.gif b/assets/gaussian_slam.gif
new file mode 100644
index 0000000..61c5d6c
Binary files /dev/null and b/assets/gaussian_slam.gif differ
diff --git a/assets/gaussian_slam.mp4 b/assets/gaussian_slam.mp4
new file mode 100644
index 0000000..a91fe66
Binary files /dev/null and b/assets/gaussian_slam.mp4 differ
diff --git a/configs/Replica/office0.yaml b/configs/Replica/office0.yaml
new file mode 100644
index 0000000..f90f036
--- /dev/null
+++ b/configs/Replica/office0.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/Replica/replica.yaml
+data:
+ scene_name: office0
+ input_path: data/Replica-SLAM/Replica/office0/
+ output_path: output/Replica/office0/
diff --git a/configs/Replica/office1.yaml b/configs/Replica/office1.yaml
new file mode 100644
index 0000000..db44aa1
--- /dev/null
+++ b/configs/Replica/office1.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/Replica/replica.yaml
+data:
+ scene_name: office1
+ input_path: data/Replica-SLAM/Replica/office1/
+ output_path: output/Replica/office1/
diff --git a/configs/Replica/office2.yaml b/configs/Replica/office2.yaml
new file mode 100644
index 0000000..781c672
--- /dev/null
+++ b/configs/Replica/office2.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/Replica/replica.yaml
+data:
+ scene_name: office2
+ input_path: data/Replica-SLAM/Replica/office2/
+ output_path: output/Replica/office2/
diff --git a/configs/Replica/office3.yaml b/configs/Replica/office3.yaml
new file mode 100644
index 0000000..f2c6097
--- /dev/null
+++ b/configs/Replica/office3.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/Replica/replica.yaml
+data:
+ scene_name: office3
+ input_path: data/Replica-SLAM/Replica/office3/
+ output_path: output/Replica/office3/
diff --git a/configs/Replica/office4.yaml b/configs/Replica/office4.yaml
new file mode 100644
index 0000000..63ad6f1
--- /dev/null
+++ b/configs/Replica/office4.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/Replica/replica.yaml
+data:
+ scene_name: office4
+ input_path: data/Replica-SLAM/Replica/office4/
+ output_path: output/Replica/office4/
diff --git a/configs/Replica/replica.yaml b/configs/Replica/replica.yaml
new file mode 100644
index 0000000..57482a4
--- /dev/null
+++ b/configs/Replica/replica.yaml
@@ -0,0 +1,42 @@
+project_name: "Gaussian_SLAM_replica"
+dataset_name: "replica"
+checkpoint_path: null
+use_wandb: False
+frame_limit: -1 # for debugging, set to -1 to disable
+seed: 0
+mapping:
+ new_submap_every: 50
+ map_every: 5
+ iterations: 100
+ new_submap_iterations: 1000
+ new_submap_points_num: 600000
+ new_submap_gradient_points_num: 50000
+ new_frame_sample_size: -1
+ new_points_radius: 0.0000001
+ current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
+ alpha_thre: 0.6
+ pruning_thre: 0.1
+ submap_using_motion_heuristic: True
+tracking:
+ gt_camera: False
+ w_color_loss: 0.95
+ iterations: 60
+ cam_rot_lr: 0.0002
+ cam_trans_lr: 0.002
+ odometry_type: "odometer" # gt, const_speed, odometer
+ help_camera_initialization: False # temp option to help const_init
+ init_err_ratio: 5
+ odometer_method: "point_to_plane" # hybrid or point_to_plane
+ filter_alpha: False
+ filter_outlier_depth: True
+ alpha_thre: 0.98
+ soft_alpha: True
+ mask_invalid_depth: False
+cam:
+ H: 680
+ W: 1200
+ fx: 600.0
+ fy: 600.0
+ cx: 599.5
+ cy: 339.5
+ depth_scale: 6553.5
diff --git a/configs/Replica/room0.yaml b/configs/Replica/room0.yaml
new file mode 100644
index 0000000..6e2675c
--- /dev/null
+++ b/configs/Replica/room0.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/Replica/replica.yaml
+data:
+ scene_name: room0
+ input_path: data/Replica-SLAM/room0/
+ output_path: output/Replica/room0/
diff --git a/configs/Replica/room1.yaml b/configs/Replica/room1.yaml
new file mode 100644
index 0000000..6822604
--- /dev/null
+++ b/configs/Replica/room1.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/Replica/replica.yaml
+data:
+ scene_name: room1
+ input_path: data/Replica-SLAM/Replica/room1/
+ output_path: output/Replica/room1/
diff --git a/configs/Replica/room2.yaml b/configs/Replica/room2.yaml
new file mode 100644
index 0000000..272e63c
--- /dev/null
+++ b/configs/Replica/room2.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/Replica/replica.yaml
+data:
+ scene_name: room2
+ input_path: data/Replica-SLAM/Replica/room2/
+ output_path: output/Replica/room2/
diff --git a/configs/ScanNet/scannet.yaml b/configs/ScanNet/scannet.yaml
new file mode 100644
index 0000000..7808a63
--- /dev/null
+++ b/configs/ScanNet/scannet.yaml
@@ -0,0 +1,43 @@
+project_name: "Gaussian_SLAM_scannet"
+dataset_name: "scan_net"
+checkpoint_path: null
+use_wandb: False
+frame_limit: -1 # for debugging, set to -1 to disable
+seed: 0
+mapping:
+ new_submap_every: 50
+ map_every: 1
+ iterations: 100
+ new_submap_iterations: 100
+ new_submap_points_num: 100000
+ new_submap_gradient_points_num: 50000
+ new_frame_sample_size: 30000
+ new_points_radius: 0.0001
+ current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
+ alpha_thre: 0.6
+ pruning_thre: 0.5
+ submap_using_motion_heuristic: False
+tracking:
+ gt_camera: False
+ w_color_loss: 0.6
+ iterations: 200
+ cam_rot_lr: 0.002
+ cam_trans_lr: 0.01
+ odometry_type: "const_speed" # gt, const_speed, odometer
+ help_camera_initialization: False # temp option to help const_init
+ init_err_ratio: 5
+ odometer_method: "hybrid" # hybrid or point_to_plane
+ filter_alpha: True
+ filter_outlier_depth: True
+ alpha_thre: 0.98
+ soft_alpha: True
+ mask_invalid_depth: True
+cam:
+ H: 480
+ W: 640
+ fx: 577.590698
+ fy: 578.729797
+ cx: 318.905426
+ cy: 242.683609
+ depth_scale: 1000.
+ crop_edge: 12
\ No newline at end of file
diff --git a/configs/ScanNet/scene0000_00.yaml b/configs/ScanNet/scene0000_00.yaml
new file mode 100644
index 0000000..0f59159
--- /dev/null
+++ b/configs/ScanNet/scene0000_00.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/ScanNet/scannet.yaml
+data:
+ input_path: data/scannet/scans/scene0000_00
+ output_path: output/ScanNet/scene0000
+ scene_name: scene0000_00
\ No newline at end of file
diff --git a/configs/ScanNet/scene0059_00.yaml b/configs/ScanNet/scene0059_00.yaml
new file mode 100644
index 0000000..4595156
--- /dev/null
+++ b/configs/ScanNet/scene0059_00.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/ScanNet/scannet.yaml
+data:
+ input_path: data/scannet/scans/scene0059_00
+ output_path: output/ScanNet/scene0059
+ scene_name: scene0059_00
\ No newline at end of file
diff --git a/configs/ScanNet/scene0106_00.yaml b/configs/ScanNet/scene0106_00.yaml
new file mode 100644
index 0000000..5987977
--- /dev/null
+++ b/configs/ScanNet/scene0106_00.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/ScanNet/scannet.yaml
+data:
+ input_path: data/scannet/scans/scene0106_00
+ output_path: output/ScanNet/scene0106
+ scene_name: scene0106_00
\ No newline at end of file
diff --git a/configs/ScanNet/scene0169_00.yaml b/configs/ScanNet/scene0169_00.yaml
new file mode 100644
index 0000000..c5d001d
--- /dev/null
+++ b/configs/ScanNet/scene0169_00.yaml
@@ -0,0 +1,10 @@
+inherit_from: configs/ScanNet/scannet.yaml
+data:
+ input_path: data/scannet/scans/scene0169_00
+ output_path: output/ScanNet/scene0169
+ scene_name: scene0169_00
+cam:
+ fx: 574.540771
+ fy: 577.583740
+ cx: 322.522827
+ cy: 238.558853
\ No newline at end of file
diff --git a/configs/ScanNet/scene0181_00.yaml b/configs/ScanNet/scene0181_00.yaml
new file mode 100644
index 0000000..ab6509c
--- /dev/null
+++ b/configs/ScanNet/scene0181_00.yaml
@@ -0,0 +1,10 @@
+inherit_from: configs/ScanNet/scannet.yaml
+data:
+ input_path: data/scannet/scans/scene0181_00
+ output_path: output/ScanNet/scene0181
+ scene_name: scene0181_00
+cam:
+ fx: 575.547668
+ fy: 577.459778
+ cx: 323.171967
+ cy: 236.417465
diff --git a/configs/ScanNet/scene0207_00.yaml b/configs/ScanNet/scene0207_00.yaml
new file mode 100644
index 0000000..0622473
--- /dev/null
+++ b/configs/ScanNet/scene0207_00.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/ScanNet/scannet.yaml
+data:
+ input_path: data/scannet/scans/scene0207_00
+ output_path: output/ScanNet/scene0207
+ scene_name: scene0207_00
\ No newline at end of file
diff --git a/configs/TUM_RGBD/rgbd_dataset_freiburg1_desk.yaml b/configs/TUM_RGBD/rgbd_dataset_freiburg1_desk.yaml
new file mode 100644
index 0000000..7ab9b6d
--- /dev/null
+++ b/configs/TUM_RGBD/rgbd_dataset_freiburg1_desk.yaml
@@ -0,0 +1,14 @@
+inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
+data:
+ input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg1_desk
+ output_path: output/TUM_RGBD/rgbd_dataset_freiburg1_desk/
+ scene_name: rgbd_dataset_freiburg1_desk
+cam: #intrinsic is different per scene in TUM
+ H: 480
+ W: 640
+ fx: 517.3
+ fy: 516.5
+ cx: 318.6
+ cy: 255.3
+ crop_edge: 50
+ distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633]
\ No newline at end of file
diff --git a/configs/TUM_RGBD/rgbd_dataset_freiburg2_xyz.yaml b/configs/TUM_RGBD/rgbd_dataset_freiburg2_xyz.yaml
new file mode 100644
index 0000000..1206710
--- /dev/null
+++ b/configs/TUM_RGBD/rgbd_dataset_freiburg2_xyz.yaml
@@ -0,0 +1,14 @@
+inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
+data:
+ input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg2_xyz
+ output_path: output/TUM_RGBD/rgbd_dataset_freiburg2_xyz/
+ scene_name: rgbd_dataset_freiburg2_xyz
+cam: #intrinsic is different per scene in TUM
+ H: 480
+ W: 640
+ fx: 520.9
+ fy: 521.0
+ cx: 325.1
+ cy: 249.7
+ crop_edge: 8
+ distortion: [0.2312, -0.7849, -0.0033, -0.0001, 0.9172]
\ No newline at end of file
diff --git a/configs/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household.yaml b/configs/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household.yaml
new file mode 100644
index 0000000..d59a60a
--- /dev/null
+++ b/configs/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household.yaml
@@ -0,0 +1,14 @@
+inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
+data:
+ input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg3_long_office_household/
+ output_path: output/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household/
+ scene_name: rgbd_dataset_freiburg3_long_office_household
+cam: #intrinsic is different per scene in TUM
+ H: 480
+ W: 640
+ fx: 517.3
+ fy: 516.5
+ cx: 318.6
+ cy: 255.3
+ crop_edge: 50
+ distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633]
\ No newline at end of file
diff --git a/configs/TUM_RGBD/tum_rgbd.yaml b/configs/TUM_RGBD/tum_rgbd.yaml
new file mode 100644
index 0000000..398a548
--- /dev/null
+++ b/configs/TUM_RGBD/tum_rgbd.yaml
@@ -0,0 +1,37 @@
+project_name: "Gaussian_SLAM_tumrgbd"
+dataset_name: "tum_rgbd"
+checkpoint_path: null
+use_wandb: False
+frame_limit: -1 # for debugging, set to -1 to disable
+seed: 0
+mapping:
+ new_submap_every: 50
+ map_every: 1
+ iterations: 100
+ new_submap_iterations: 100
+ new_submap_points_num: 100000
+ new_submap_gradient_points_num: 50000
+ new_frame_sample_size: 30000
+ new_points_radius: 0.0001
+ current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
+ alpha_thre: 0.6
+ pruning_thre: 0.5
+ submap_using_motion_heuristic: True
+tracking:
+ gt_camera: False
+ w_color_loss: 0.6
+ iterations: 200
+ cam_rot_lr: 0.002
+ cam_trans_lr: 0.01
+ odometry_type: "const_speed" # gt, const_speed, odometer
+ help_camera_initialization: False # temp option to help const_init
+ init_err_ratio: 5
+ odometer_method: "hybrid" # hybrid or point_to_plane
+ filter_alpha: False
+ filter_outlier_depth: False
+ alpha_thre: 0.98
+ soft_alpha: True
+ mask_invalid_depth: True
+cam:
+ crop_edge: 16
+ depth_scale: 5000.0
\ No newline at end of file
diff --git a/configs/scannetpp/281bc17764.yaml b/configs/scannetpp/281bc17764.yaml
new file mode 100644
index 0000000..2632660
--- /dev/null
+++ b/configs/scannetpp/281bc17764.yaml
@@ -0,0 +1,14 @@
+inherit_from: configs/scannetpp/scannetpp.yaml
+data:
+ input_path: data/scannetpp/data/281bc17764
+ output_path: output/ScanNetPP/281bc17764
+ scene_name: "281bc17764"
+ use_train_split: True
+ frame_limit: 250
+cam:
+ H: 584
+ W: 876
+ fx: 312.79197434640764
+ fy: 313.48022477591036
+ cx: 438.0
+ cy: 292.0
\ No newline at end of file
diff --git a/configs/scannetpp/2e74812d00.yaml b/configs/scannetpp/2e74812d00.yaml
new file mode 100644
index 0000000..f84df89
--- /dev/null
+++ b/configs/scannetpp/2e74812d00.yaml
@@ -0,0 +1,14 @@
+inherit_from: configs/scannetpp/scannetpp.yaml
+data:
+ input_path: data/scannetpp/data/2e74812d00
+ output_path: output/ScanNetPP/2e74812d00
+ scene_name: "2e74812d00"
+ use_train_split: True
+ frame_limit: 250
+cam:
+ H: 584
+ W: 876
+ fx: 312.0984049779051
+ fy: 312.4823067146056
+ cx: 438.0
+ cy: 292.0
\ No newline at end of file
diff --git a/configs/scannetpp/8b5caf3398.yaml b/configs/scannetpp/8b5caf3398.yaml
new file mode 100644
index 0000000..fd24c20
--- /dev/null
+++ b/configs/scannetpp/8b5caf3398.yaml
@@ -0,0 +1,14 @@
+inherit_from: configs/scannetpp/scannetpp.yaml
+data:
+ input_path: data/scannetpp/data/8b5caf3398
+ output_path: output/ScanNetPP/8b5caf3398
+ scene_name: "8b5caf3398"
+ use_train_split: True
+ frame_limit: 250
+cam:
+ H: 584
+ W: 876
+ fx: 316.3837659917395
+ fy: 319.18649362678593
+ cx: 438.0
+ cy: 292.0
diff --git a/configs/scannetpp/b20a261fdf.yaml b/configs/scannetpp/b20a261fdf.yaml
new file mode 100644
index 0000000..5bb654f
--- /dev/null
+++ b/configs/scannetpp/b20a261fdf.yaml
@@ -0,0 +1,14 @@
+inherit_from: configs/scannetpp/scannetpp.yaml
+data:
+ input_path: data/scannetpp/data/b20a261fdf
+ output_path: output/ScanNetPP/b20a261fdf
+ scene_name: "b20a261fdf"
+ use_train_split: True
+ frame_limit: 250
+cam:
+ H: 584
+ W: 876
+ fx: 312.7099188244687
+ fy: 313.5121746848229
+ cx: 438.0
+ cy: 292.0
\ No newline at end of file
diff --git a/configs/scannetpp/fb05e13ad1.yaml b/configs/scannetpp/fb05e13ad1.yaml
new file mode 100644
index 0000000..ca5d944
--- /dev/null
+++ b/configs/scannetpp/fb05e13ad1.yaml
@@ -0,0 +1,14 @@
+inherit_from: configs/scannetpp/scannetpp.yaml
+data:
+ input_path: data/scannetpp/data/fb05e13ad1
+ output_path: output/ScanNetPP/fb05e13ad1
+ scene_name: "fb05e13ad1"
+ use_train_split: True
+ frame_limit: 250
+cam:
+ H: 584
+ W: 876
+ fx: 231.8197441948914
+ fy: 231.9980523882361
+ cx: 438.0
+ cy: 292.0
\ No newline at end of file
diff --git a/configs/scannetpp/scannetpp.yaml b/configs/scannetpp/scannetpp.yaml
new file mode 100644
index 0000000..6d94db2
--- /dev/null
+++ b/configs/scannetpp/scannetpp.yaml
@@ -0,0 +1,37 @@
+project_name: "Gaussian_SLAM_scannetpp"
+dataset_name: "scannetpp"
+checkpoint_path: null
+use_wandb: False
+frame_limit: -1 # set to -1 to disable
+seed: 0
+mapping:
+ new_submap_every: 100
+ map_every: 2
+ iterations: 500
+ new_submap_iterations: 500
+ new_submap_points_num: 400000
+ new_submap_gradient_points_num: 50000
+ new_frame_sample_size: 100000
+ new_points_radius: 0.00000001
+ current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
+ alpha_thre: 0.6
+ pruning_thre: 0.5
+ submap_using_motion_heuristic: False
+tracking:
+ gt_camera: False
+ w_color_loss: 0.5
+ iterations: 300
+ cam_rot_lr: 0.002
+ cam_trans_lr: 0.01
+ odometry_type: "const_speed" # gt, const_speed, odometer
+ help_camera_initialization: True
+ init_err_ratio: 50
+ odometer_method: "point_to_plane" # hybrid or point_to_plane
+ filter_alpha: True
+ filter_outlier_depth: True
+ alpha_thre: 0.98
+ soft_alpha: False
+ mask_invalid_depth: True
+cam:
+ crop_edge: 0
+ depth_scale: 1000.0
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000..9cbb192
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,28 @@
+name: gslam
+channels:
+ - pytorch
+ - nvidia
+ - nvidia/label/cuda-12.1.0
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.10
+ - faiss-gpu=1.8.0
+ - cuda-toolkit=12.1
+ - pytorch=2.1.2
+ - pytorch-cuda=12.1
+ - torchvision=0.16.2
+ - pip
+ - pip:
+ - open3d==0.18.0
+ - wandb
+ - trimesh
+ - pytorch_msssim
+ - torchmetrics
+ - tqdm
+ - imageio
+ - opencv-python
+ - plyfile
+ - git+https://github.com/eriksandstroem/evaluate_3d_reconstruction_lib.git@9b3cc08be5440db9c375cc21e3bd65bb4a337db7
+ - git+https://github.com/VladimirYugay/simple-knn.git@c7e51a06a4cd84c25e769fee29ab391fe5d5ff8d
+ - git+https://github.com/VladimirYugay/gaussian_rasterizer.git@9c40173fcc8d9b16778a1a8040295bc2f9ebf129
\ No newline at end of file
diff --git a/run_evaluation.py b/run_evaluation.py
new file mode 100644
index 0000000..33f9b5d
--- /dev/null
+++ b/run_evaluation.py
@@ -0,0 +1,19 @@
+import argparse
+from pathlib import Path
+from src.evaluation.evaluator import Evaluator
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='Arguments to compute the mesh')
+ parser.add_argument('--checkpoint_path', type=str, help='SLAM checkpoint path', default="output/slam/full_experiment/")
+ parser.add_argument('--config_path', type=str, help='Config path', default="")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ if args.config_path == "":
+ args.config_path = Path(args.checkpoint_path) / "config.yaml"
+
+ evaluator = Evaluator(Path(args.checkpoint_path), Path(args.config_path))
+ evaluator.run()
diff --git a/run_slam.py b/run_slam.py
new file mode 100644
index 0000000..1d26e3c
--- /dev/null
+++ b/run_slam.py
@@ -0,0 +1,118 @@
+import argparse
+import os
+import time
+import uuid
+
+import wandb
+
+from src.entities.gaussian_slam import GaussianSLAM
+from src.evaluation.evaluator import Evaluator
+from src.utils.io_utils import load_config, log_metrics_to_wandb
+from src.utils.utils import setup_seed
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ description='Arguments to compute the mesh')
+ parser.add_argument('config_path', type=str,
+ help='Path to the configuration yaml file')
+ parser.add_argument('--input_path', default="")
+ parser.add_argument('--output_path', default="")
+ parser.add_argument('--track_w_color_loss', type=float)
+ parser.add_argument('--track_alpha_thre', type=float)
+ parser.add_argument('--track_iters', type=int)
+ parser.add_argument('--track_filter_alpha', action='store_true')
+ parser.add_argument('--track_filter_outlier', action='store_true')
+ parser.add_argument('--track_wo_filter_alpha', action='store_true')
+ parser.add_argument("--track_wo_filter_outlier", action="store_true")
+ parser.add_argument("--track_cam_trans_lr", type=float)
+ parser.add_argument('--alpha_seeding_thre', type=float)
+ parser.add_argument('--map_every', type=int)
+ parser.add_argument("--map_iters", type=int)
+ parser.add_argument('--new_submap_every', type=int)
+ parser.add_argument('--project_name', type=str)
+ parser.add_argument('--group_name', type=str)
+ parser.add_argument('--gt_camera', action='store_true')
+ parser.add_argument('--help_camera_initialization', action='store_true')
+ parser.add_argument('--soft_alpha', action='store_true')
+ parser.add_argument('--seed', type=int)
+ parser.add_argument('--submap_using_motion_heuristic', action='store_true')
+ parser.add_argument('--new_submap_points_num', type=int)
+ return parser.parse_args()
+
+
+def update_config_with_args(config, args):
+ if args.input_path:
+ config["data"]["input_path"] = args.input_path
+ if args.output_path:
+ config["data"]["output_path"] = args.output_path
+ if args.track_w_color_loss is not None:
+ config["tracking"]["w_color_loss"] = args.track_w_color_loss
+ if args.track_iters is not None:
+ config["tracking"]["iterations"] = args.track_iterations
+ if args.track_filter_alpha:
+ config["tracking"]["filter_alpha"] = True
+ if args.track_wo_filter_alpha:
+ config["tracking"]["filter_alpha"] = False
+ if args.track_filter_outlier:
+ config["tracking"]["filter_outlier_depth"] = True
+ if args.track_wo_filter_outlier:
+ config["tracking"]["filter_outlier_depth"] = False
+ if args.track_alpha_thre is not None:
+ config["tracking"]["alpha_thre"] = args.track_alpha_thre
+ if args.map_every:
+ config["mapping"]["map_every"] = args.map_every
+ if args.map_iters:
+ config["mapping"]["iterations"] = args.map_iters
+ if args.new_submap_every:
+ config["mapping"]["new_submap_every"] = args.new_submap_every
+ if args.project_name:
+ config["project_name"] = args.project_name
+ if args.alpha_seeding_thre is not None:
+ config["mapping"]["alpha_thre"] = args.alpha_seeding_thre
+ if args.seed:
+ config["seed"] = args.seed
+ if args.help_camera_initialization:
+ config["tracking"]["help_camera_initialization"] = True
+ if args.soft_alpha:
+ config["tracking"]["soft_alpha"] = True
+ if args.submap_using_motion_heuristic:
+ config["mapping"]["submap_using_motion_heuristic"] = True
+ if args.new_submap_points_num:
+ config["mapping"]["new_submap_points_num"] = args.new_submap_points_num
+ if args.track_cam_trans_lr:
+ config["tracking"]["cam_trans_lr"] = args.track_cam_trans_lr
+ return config
+
+
+if __name__ == "__main__":
+ args = get_args()
+ config = load_config(args.config_path)
+ config = update_config_with_args(config, args)
+
+ if os.getenv('DISABLE_WANDB') == 'true':
+ config["use_wandb"] = False
+ if config["use_wandb"]:
+ wandb.init(
+ project=config["project_name"],
+ config=config,
+ dir="/home/yli3/scratch/outputs/slam/wandb",
+ group=config["data"]["scene_name"]
+ if not args.group_name
+ else args.group_name,
+ name=f'{config["data"]["scene_name"]}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}_{str(uuid.uuid4())[:5]}',
+ )
+ wandb.run.log_code(".", include_fn=lambda path: path.endswith(".py"))
+
+ setup_seed(config["seed"])
+ gslam = GaussianSLAM(config)
+ gslam.run()
+
+ evaluator = Evaluator(gslam.output_path, gslam.output_path / "config.yaml")
+ evaluator.run()
+ if config["use_wandb"]:
+ evals = ["rendering_metrics.json",
+ "reconstruction_metrics.json", "ate_aligned.json"]
+ log_metrics_to_wandb(evals, gslam.output_path, "Evaluation")
+ wandb.finish()
+ print("All done.✨")
diff --git a/scripts/download_replica.sh b/scripts/download_replica.sh
new file mode 100644
index 0000000..b225a79
--- /dev/null
+++ b/scripts/download_replica.sh
@@ -0,0 +1,3 @@
+mkdir -p data
+cd data
+git clone https://huggingface.co/datasets/voviktyl/Replica-SLAM
\ No newline at end of file
diff --git a/scripts/download_tum.sh b/scripts/download_tum.sh
new file mode 100644
index 0000000..db5e98b
--- /dev/null
+++ b/scripts/download_tum.sh
@@ -0,0 +1,3 @@
+mkdir -p data
+cd data
+git clone https://huggingface.co/datasets/voviktyl/TUM_RGBD-SLAM
\ No newline at end of file
diff --git a/scripts/reproduce_sbatch.sh b/scripts/reproduce_sbatch.sh
new file mode 100755
index 0000000..09062bc
--- /dev/null
+++ b/scripts/reproduce_sbatch.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+#SBATCH --output=output/logs/%A_%a.log # please change accordingly
+#SBATCH --error=output/logs/%A_%a.log # please change accordingly
+#SBATCH -N 1
+#SBATCH -n 1
+#SBATCH --gpus-per-node=1
+#SBATCH --partition=gpu
+#SBATCH --cpus-per-task=12
+#SBATCH --time=24:00:00
+#SBATCH --array=0-4 # number of scenes, 0-7 for Replica, 0-2 for TUM_RGBD, 0-5 for ScanNet, 0-4 for ScanNet++
+
+dataset="Replica" # set dataset
+if [ "$dataset" == "Replica" ]; then
+ scenes=("room0" "room1" "room2" "office0" "office1" "office2" "office3" "office4")
+ INPUT_PATH="data/Replica-SLAM"
+elif [ "$dataset" == "TUM_RGBD" ]; then
+ scenes=("rgbd_dataset_freiburg1_desk" "rgbd_dataset_freiburg2_xyz" "rgbd_dataset_freiburg3_long_office_household")
+ INPUT_PATH="data/TUM_RGBD-SLAM"
+elif [ "$dataset" == "ScanNet" ]; then
+ scenes=("scene0000_00" "scene0059_00" "scene0106_00" "scene0169_00" "scene0181_00" "scene0207_00")
+ INPUT_PATH="data/scannet/scans"
+elif [ "$dataset" == "ScanNetPP" ]; then
+ scenes=("b20a261fdf" "8b5caf3398" "fb05e13ad1" "2e74812d00" "281bc17764")
+ INPUT_PATH="data/scannetpp/data"
+else
+ echo "Dataset not recognized!"
+ exit 1
+fi
+
+OUTPUT_PATH="output"
+CONFIG_PATH="configs/${dataset}"
+EXPERIMENT_NAME="reproduce"
+SCENE_NAME=${scenes[$SLURM_ARRAY_TASK_ID]}
+
+source # please change accordingly
+conda activate gslam
+
+echo "Job for dataset: $dataset, scene: $SCENE_NAME"
+echo "Starting on: $(date)"
+echo "Running on node: $(hostname)"
+
+# Your command to run the experiment
+python run_slam.py "${CONFIG_PATH}/${SCENE_NAME}.yaml" \
+ --input_path "${INPUT_PATH}/${SCENE_NAME}" \
+ --output_path "${OUTPUT_PATH}/${dataset}/${EXPERIMENT_NAME}/${SCENE_NAME}" \
+ --group_name "${EXPERIMENT_NAME}" \
+
+echo "Job for scene $SCENE_NAME completed."
+echo "Started at: $START_TIME"
+echo "Finished at: $(date)"
\ No newline at end of file
diff --git a/src/entities/__init__.py b/src/entities/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/entities/arguments.py b/src/entities/arguments.py
new file mode 100644
index 0000000..2523cad
--- /dev/null
+++ b/src/entities/arguments.py
@@ -0,0 +1,94 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+import sys
+from argparse import ArgumentParser, Namespace
+
+
+class GroupParams:
+ pass
+
+
+class ParamGroup:
+ def __init__(self, parser: ArgumentParser, name: str, fill_none=False):
+ group = parser.add_argument_group(name)
+ for key, value in vars(self).items():
+ shorthand = False
+ if key.startswith("_"):
+ shorthand = True
+ key = key[1:]
+ t = type(value)
+ value = value if not fill_none else None
+ if shorthand:
+ if t == bool:
+ group.add_argument(
+ "--" + key, ("-" + key[0:1]), default=value, action="store_true")
+ else:
+ group.add_argument(
+ "--" + key, ("-" + key[0:1]), default=value, type=t)
+ else:
+ if t == bool:
+ group.add_argument(
+ "--" + key, default=value, action="store_true")
+ else:
+ group.add_argument("--" + key, default=value, type=t)
+
+ def extract(self, args):
+ group = GroupParams()
+ for arg in vars(args).items():
+ if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
+ setattr(group, arg[0], arg[1])
+ return group
+
+
+class OptimizationParams(ParamGroup):
+ def __init__(self, parser):
+ self.iterations = 30_000
+ self.position_lr_init = 0.0001
+ self.position_lr_final = 0.0000016
+ self.position_lr_delay_mult = 0.01
+ self.position_lr_max_steps = 30_000
+ self.feature_lr = 0.0025
+ self.opacity_lr = 0.05
+ self.scaling_lr = 0.005 # before 0.005
+ self.rotation_lr = 0.001
+ self.percent_dense = 0.01
+ self.lambda_dssim = 0.2
+ self.densification_interval = 100
+ self.opacity_reset_interval = 3000
+ self.densify_from_iter = 500
+ self.densify_until_iter = 15_000
+ self.densify_grad_threshold = 0.0002
+ super().__init__(parser, "Optimization Parameters")
+
+
+def get_combined_args(parser: ArgumentParser):
+ cmdlne_string = sys.argv[1:]
+ cfgfile_string = "Namespace()"
+ args_cmdline = parser.parse_args(cmdlne_string)
+
+ try:
+ cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
+ print("Looking for config file in", cfgfilepath)
+ with open(cfgfilepath) as cfg_file:
+ print("Config file found: {}".format(cfgfilepath))
+ cfgfile_string = cfg_file.read()
+ except TypeError:
+ print("Config file not found at")
+ pass
+ args_cfgfile = eval(cfgfile_string)
+
+ merged_dict = vars(args_cfgfile).copy()
+ for k, v in vars(args_cmdline).items():
+ if v is not None:
+ merged_dict[k] = v
+ return Namespace(**merged_dict)
diff --git a/src/entities/datasets.py b/src/entities/datasets.py
new file mode 100644
index 0000000..f54d464
--- /dev/null
+++ b/src/entities/datasets.py
@@ -0,0 +1,270 @@
+import math
+import os
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+import json
+import imageio
+
+
+class BaseDataset(torch.utils.data.Dataset):
+
+ def __init__(self, dataset_config: dict):
+ self.dataset_path = Path(dataset_config["input_path"])
+ self.frame_limit = dataset_config.get("frame_limit", -1)
+ self.dataset_config = dataset_config
+ self.height = dataset_config["H"]
+ self.width = dataset_config["W"]
+ self.fx = dataset_config["fx"]
+ self.fy = dataset_config["fy"]
+ self.cx = dataset_config["cx"]
+ self.cy = dataset_config["cy"]
+
+ self.depth_scale = dataset_config["depth_scale"]
+ self.distortion = np.array(
+ dataset_config['distortion']) if 'distortion' in dataset_config else None
+ self.crop_edge = dataset_config['crop_edge'] if 'crop_edge' in dataset_config else 0
+ if self.crop_edge:
+ self.height -= 2 * self.crop_edge
+ self.width -= 2 * self.crop_edge
+ self.cx -= self.crop_edge
+ self.cy -= self.crop_edge
+
+ self.fovx = 2 * math.atan(self.width / (2 * self.fx))
+ self.fovy = 2 * math.atan(self.height / (2 * self.fy))
+ self.intrinsics = np.array(
+ [[self.fx, 0, self.cx], [0, self.fy, self.cy], [0, 0, 1]])
+
+ self.color_paths = []
+ self.depth_paths = []
+
+ def __len__(self):
+ return len(self.color_paths) if self.frame_limit < 0 else int(self.frame_limit)
+
+
+class Replica(BaseDataset):
+
+ def __init__(self, dataset_config: dict):
+ super().__init__(dataset_config)
+ self.color_paths = sorted(
+ list((self.dataset_path / "results").glob("frame*.jpg")))
+ self.depth_paths = sorted(
+ list((self.dataset_path / "results").glob("depth*.png")))
+ self.load_poses(self.dataset_path / "traj.txt")
+ print(f"Loaded {len(self.color_paths)} frames")
+
+ def load_poses(self, path):
+ self.poses = []
+ with open(path, "r") as f:
+ lines = f.readlines()
+ for line in lines:
+ c2w = np.array(list(map(float, line.split()))).reshape(4, 4)
+ self.poses.append(c2w.astype(np.float32))
+
+ def __getitem__(self, index):
+ color_data = cv2.imread(str(self.color_paths[index]))
+ color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
+ depth_data = cv2.imread(
+ str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED)
+ depth_data = depth_data.astype(np.float32) / self.depth_scale
+ return index, color_data, depth_data, self.poses[index]
+
+
+class TUM_RGBD(BaseDataset):
+ def __init__(self, dataset_config: dict):
+ super().__init__(dataset_config)
+ self.color_paths, self.depth_paths, self.poses = self.loadtum(
+ self.dataset_path, frame_rate=32)
+
+ def parse_list(self, filepath, skiprows=0):
+ """ read list data """
+ return np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows)
+
+ def associate_frames(self, tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08):
+ """ pair images, depths, and poses """
+ associations = []
+ for i, t in enumerate(tstamp_image):
+ if tstamp_pose is None:
+ j = np.argmin(np.abs(tstamp_depth - t))
+ if (np.abs(tstamp_depth[j] - t) < max_dt):
+ associations.append((i, j))
+ else:
+ j = np.argmin(np.abs(tstamp_depth - t))
+ k = np.argmin(np.abs(tstamp_pose - t))
+ if (np.abs(tstamp_depth[j] - t) < max_dt) and (np.abs(tstamp_pose[k] - t) < max_dt):
+ associations.append((i, j, k))
+ return associations
+
+ def loadtum(self, datapath, frame_rate=-1):
+ """ read video data in tum-rgbd format """
+ if os.path.isfile(os.path.join(datapath, 'groundtruth.txt')):
+ pose_list = os.path.join(datapath, 'groundtruth.txt')
+ elif os.path.isfile(os.path.join(datapath, 'pose.txt')):
+ pose_list = os.path.join(datapath, 'pose.txt')
+
+ image_list = os.path.join(datapath, 'rgb.txt')
+ depth_list = os.path.join(datapath, 'depth.txt')
+
+ image_data = self.parse_list(image_list)
+ depth_data = self.parse_list(depth_list)
+ pose_data = self.parse_list(pose_list, skiprows=1)
+ pose_vecs = pose_data[:, 1:].astype(np.float64)
+
+ tstamp_image = image_data[:, 0].astype(np.float64)
+ tstamp_depth = depth_data[:, 0].astype(np.float64)
+ tstamp_pose = pose_data[:, 0].astype(np.float64)
+ associations = self.associate_frames(
+ tstamp_image, tstamp_depth, tstamp_pose)
+
+ indicies = [0]
+ for i in range(1, len(associations)):
+ t0 = tstamp_image[associations[indicies[-1]][0]]
+ t1 = tstamp_image[associations[i][0]]
+ if t1 - t0 > 1.0 / frame_rate:
+ indicies += [i]
+
+ images, poses, depths = [], [], []
+ inv_pose = None
+ for ix in indicies:
+ (i, j, k) = associations[ix]
+ images += [os.path.join(datapath, image_data[i, 1])]
+ depths += [os.path.join(datapath, depth_data[j, 1])]
+ c2w = self.pose_matrix_from_quaternion(pose_vecs[k])
+ if inv_pose is None:
+ inv_pose = np.linalg.inv(c2w)
+ c2w = np.eye(4)
+ else:
+ c2w = inv_pose@c2w
+ poses += [c2w.astype(np.float32)]
+
+ return images, depths, poses
+
+ def pose_matrix_from_quaternion(self, pvec):
+ """ convert 4x4 pose matrix to (t, q) """
+ from scipy.spatial.transform import Rotation
+
+ pose = np.eye(4)
+ pose[:3, :3] = Rotation.from_quat(pvec[3:]).as_matrix()
+ pose[:3, 3] = pvec[:3]
+ return pose
+
+ def __getitem__(self, index):
+ color_data = cv2.imread(str(self.color_paths[index]))
+ if self.distortion is not None:
+ color_data = cv2.undistort(
+ color_data, self.intrinsics, self.distortion)
+ color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
+
+ depth_data = cv2.imread(
+ str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED)
+ depth_data = depth_data.astype(np.float32) / self.depth_scale
+ edge = self.crop_edge
+ if edge > 0:
+ color_data = color_data[edge:-edge, edge:-edge]
+ depth_data = depth_data[edge:-edge, edge:-edge]
+ # Interpolate depth values for splatting
+ return index, color_data, depth_data, self.poses[index]
+
+
+class ScanNet(BaseDataset):
+ def __init__(self, dataset_config: dict):
+ super().__init__(dataset_config)
+ self.color_paths = sorted(list(
+ (self.dataset_path / "color").glob("*.jpg")), key=lambda x: int(os.path.basename(x)[:-4]))
+ self.depth_paths = sorted(list(
+ (self.dataset_path / "depth").glob("*.png")), key=lambda x: int(os.path.basename(x)[:-4]))
+ self.load_poses(self.dataset_path / "pose")
+
+ def load_poses(self, path):
+ self.poses = []
+ pose_paths = sorted(path.glob('*.txt'),
+ key=lambda x: int(os.path.basename(x)[:-4]))
+ for pose_path in pose_paths:
+ with open(pose_path, "r") as f:
+ lines = f.readlines()
+ ls = []
+ for line in lines:
+ ls.append(list(map(float, line.split(' '))))
+ c2w = np.array(ls).reshape(4, 4).astype(np.float32)
+ self.poses.append(c2w)
+
+ def __getitem__(self, index):
+ color_data = cv2.imread(str(self.color_paths[index]))
+ if self.distortion is not None:
+ color_data = cv2.undistort(
+ color_data, self.intrinsics, self.distortion)
+ color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
+ color_data = cv2.resize(color_data, (self.dataset_config["W"], self.dataset_config["H"]))
+
+ depth_data = cv2.imread(
+ str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED)
+ depth_data = depth_data.astype(np.float32) / self.depth_scale
+ edge = self.crop_edge
+ if edge > 0:
+ color_data = color_data[edge:-edge, edge:-edge]
+ depth_data = depth_data[edge:-edge, edge:-edge]
+ # Interpolate depth values for splatting
+ return index, color_data, depth_data, self.poses[index]
+
+
+class ScanNetPP(BaseDataset):
+ def __init__(self, dataset_config: dict):
+ super().__init__(dataset_config)
+ self.use_train_split = dataset_config["use_train_split"]
+ self.train_test_split = json.load(open(f"{self.dataset_path}/dslr/train_test_lists.json", "r"))
+ if self.use_train_split:
+ self.image_names = self.train_test_split["train"]
+ else:
+ self.image_names = self.train_test_split["test"]
+ self.load_data()
+
+ def load_data(self):
+ self.poses = []
+ cams_path = self.dataset_path / "dslr" / "nerfstudio" / "transforms_undistorted.json"
+ cams_metadata = json.load(open(str(cams_path), "r"))
+ frames_key = "frames" if self.use_train_split else "test_frames"
+ frames_metadata = cams_metadata[frames_key]
+ frame2idx = {frame["file_path"]: index for index, frame in enumerate(frames_metadata)}
+ P = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]).astype(np.float32)
+ for image_name in self.image_names:
+ frame_metadata = frames_metadata[frame2idx[image_name]]
+ # if self.ignore_bad and frame_metadata['is_bad']:
+ # continue
+ color_path = str(self.dataset_path / "dslr" / "undistorted_images" / image_name)
+ depth_path = str(self.dataset_path / "dslr" / "undistorted_depths" / image_name.replace('.JPG', '.png'))
+ self.color_paths.append(color_path)
+ self.depth_paths.append(depth_path)
+ c2w = np.array(frame_metadata["transform_matrix"]).astype(np.float32)
+ c2w = P @ c2w @ P.T
+ self.poses.append(c2w)
+
+ def __len__(self):
+ if self.use_train_split:
+ return len(self.image_names) if self.frame_limit < 0 else int(self.frame_limit)
+ else:
+ return len(self.image_names)
+
+ def __getitem__(self, index):
+
+ color_data = np.asarray(imageio.imread(self.color_paths[index]), dtype=float)
+ color_data = cv2.resize(color_data, (self.width, self.height), interpolation=cv2.INTER_LINEAR)
+ color_data = color_data.astype(np.uint8)
+
+ depth_data = np.asarray(imageio.imread(self.depth_paths[index]), dtype=np.int64)
+ depth_data = cv2.resize(depth_data.astype(float), (self.width, self.height), interpolation=cv2.INTER_NEAREST)
+ depth_data = depth_data.astype(np.float32) / self.depth_scale
+ return index, color_data, depth_data, self.poses[index]
+
+
+def get_dataset(dataset_name: str):
+ if dataset_name == "replica":
+ return Replica
+ elif dataset_name == "tum_rgbd":
+ return TUM_RGBD
+ elif dataset_name == "scan_net":
+ return ScanNet
+ elif dataset_name == "scannetpp":
+ return ScanNetPP
+ raise NotImplementedError(f"Dataset {dataset_name} not implemented")
diff --git a/src/entities/gaussian_model.py b/src/entities/gaussian_model.py
new file mode 100644
index 0000000..fa01f71
--- /dev/null
+++ b/src/entities/gaussian_model.py
@@ -0,0 +1,408 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+from pathlib import Path
+
+import numpy as np
+import open3d as o3d
+import torch
+from plyfile import PlyData, PlyElement
+from simple_knn._C import distCUDA2
+from torch import nn
+
+from src.utils.gaussian_model_utils import (RGB2SH, build_scaling_rotation,
+ get_expon_lr_func, inverse_sigmoid,
+ strip_symmetric)
+
+
+class GaussianModel:
+ def __init__(self, sh_degree: int = 3, isotropic=False):
+ self.gaussian_param_names = [
+ "active_sh_degree",
+ "xyz",
+ "features_dc",
+ "features_rest",
+ "scaling",
+ "rotation",
+ "opacity",
+ "max_radii2D",
+ "xyz_gradient_accum",
+ "denom",
+ "spatial_lr_scale",
+ "optimizer",
+ ]
+ self.max_sh_degree = sh_degree
+ self.active_sh_degree = sh_degree # temp
+ self._xyz = torch.empty(0).cuda()
+ self._features_dc = torch.empty(0).cuda()
+ self._features_rest = torch.empty(0).cuda()
+ self._scaling = torch.empty(0).cuda()
+ self._rotation = torch.empty(0, 4).cuda()
+ self._opacity = torch.empty(0).cuda()
+ self.max_radii2D = torch.empty(0)
+ self.xyz_gradient_accum = torch.empty(0)
+ self.denom = torch.empty(0)
+ self.optimizer = None
+ self.percent_dense = 0
+ self.spatial_lr_scale = 1
+ self.setup_functions()
+ self.isotropic = isotropic
+
+ def restore_from_params(self, params_dict, training_args):
+ self.training_setup(training_args)
+ self.densification_postfix(
+ params_dict["xyz"],
+ params_dict["features_dc"],
+ params_dict["features_rest"],
+ params_dict["opacity"],
+ params_dict["scaling"],
+ params_dict["rotation"])
+
+ def build_covariance_from_scaling_rotation(self, scaling, scaling_modifier, rotation):
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
+ actual_covariance = L @ L.transpose(1, 2)
+ symm = strip_symmetric(actual_covariance)
+ return symm
+
+ def setup_functions(self):
+ self.scaling_activation = torch.exp
+ self.scaling_inverse_activation = torch.log
+ self.opacity_activation = torch.sigmoid
+ self.inverse_opacity_activation = inverse_sigmoid
+ self.rotation_activation = torch.nn.functional.normalize
+
+ def capture_dict(self):
+ return {
+ "active_sh_degree": self.active_sh_degree,
+ "xyz": self._xyz.clone().detach().cpu(),
+ "features_dc": self._features_dc.clone().detach().cpu(),
+ "features_rest": self._features_rest.clone().detach().cpu(),
+ "scaling": self._scaling.clone().detach().cpu(),
+ "rotation": self._rotation.clone().detach().cpu(),
+ "opacity": self._opacity.clone().detach().cpu(),
+ "max_radii2D": self.max_radii2D.clone().detach().cpu(),
+ "xyz_gradient_accum": self.xyz_gradient_accum.clone().detach().cpu(),
+ "denom": self.denom.clone().detach().cpu(),
+ "spatial_lr_scale": self.spatial_lr_scale,
+ "optimizer": self.optimizer.state_dict(),
+ }
+
+ def get_size(self):
+ return self._xyz.shape[0]
+
+ def get_scaling(self):
+ if self.isotropic:
+ scale = self.scaling_activation(self._scaling)[:, 0:1] # Extract the first column
+ scales = scale.repeat(1, 3) # Replicate this column three times
+ return scales
+ return self.scaling_activation(self._scaling)
+
+ def get_rotation(self):
+ return self.rotation_activation(self._rotation)
+
+ def get_xyz(self):
+ return self._xyz
+
+ def get_features(self):
+ features_dc = self._features_dc
+ features_rest = self._features_rest
+ return torch.cat((features_dc, features_rest), dim=1)
+
+ def get_opacity(self):
+ return self.opacity_activation(self._opacity)
+
+ def get_active_sh_degree(self):
+ return self.active_sh_degree
+
+ def get_covariance(self, scaling_modifier=1):
+ return self.build_covariance_from_scaling_rotation(self.get_scaling(), scaling_modifier, self._rotation)
+
+ def add_points(self, pcd: o3d.geometry.PointCloud, global_scale_init=True):
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
+ fused_color = RGB2SH(torch.tensor(
+ np.asarray(pcd.colors)).float().cuda())
+ features = (torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda())
+ features[:, :3, 0] = fused_color
+ features[:, 3:, 1:] = 0.0
+ print("Number of added points: ", fused_point_cloud.shape[0])
+
+ if global_scale_init:
+ global_points = torch.cat((self.get_xyz(),torch.from_numpy(np.asarray(pcd.points)).float().cuda()))
+ dist2 = torch.clamp_min(distCUDA2(global_points), 0.0000001)
+ dist2 = dist2[self.get_size():]
+ else:
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
+ scales = torch.log(1.0 * torch.sqrt(dist2))[..., None].repeat(1, 3)
+ # scales = torch.log(0.001 * torch.ones_like(dist2))[..., None].repeat(1, 3)
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
+ rots[:, 0] = 1
+ opacities = inverse_sigmoid(0.5 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
+ new_xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
+ new_features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True))
+ new_features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True))
+ new_scaling = nn.Parameter(scales.requires_grad_(True))
+ new_rotation = nn.Parameter(rots.requires_grad_(True))
+ new_opacities = nn.Parameter(opacities.requires_grad_(True))
+ self.densification_postfix(
+ new_xyz,
+ new_features_dc,
+ new_features_rest,
+ new_opacities,
+ new_scaling,
+ new_rotation,
+ )
+
+ def training_setup(self, training_args):
+ self.percent_dense = training_args.percent_dense
+ self.xyz_gradient_accum = torch.zeros(
+ (self.get_xyz().shape[0], 1), device="cuda"
+ )
+ self.denom = torch.zeros((self.get_xyz().shape[0], 1), device="cuda")
+
+ params = [
+ {"params": [self._xyz], "lr": training_args.position_lr_init, "name": "xyz"},
+ {"params": [self._features_dc], "lr": training_args.feature_lr, "name": "f_dc"},
+ {"params": [self._features_rest], "lr": training_args.feature_lr / 20.0, "name": "f_rest"},
+ {"params": [self._opacity], "lr": training_args.opacity_lr, "name": "opacity"},
+ {"params": [self._scaling], "lr": training_args.scaling_lr, "name": "scaling"},
+ {"params": [self._rotation], "lr": training_args.rotation_lr, "name": "rotation"},
+ ]
+
+ self.optimizer = torch.optim.Adam(params, lr=0.0, eps=1e-15)
+ self.xyz_scheduler_args = get_expon_lr_func(
+ lr_init=training_args.position_lr_init * self.spatial_lr_scale,
+ lr_final=training_args.position_lr_final * self.spatial_lr_scale,
+ lr_delay_mult=training_args.position_lr_delay_mult,
+ max_steps=training_args.position_lr_max_steps,
+ )
+
+ def training_setup_camera(self, cam_rot, cam_trans, cfg):
+ self.xyz_gradient_accum = torch.zeros(
+ (self.get_xyz().shape[0], 1), device="cuda"
+ )
+ self.denom = torch.zeros((self.get_xyz().shape[0], 1), device="cuda")
+ params = [
+ {"params": [self._xyz], "lr": 0.0, "name": "xyz"},
+ {"params": [self._features_dc], "lr": 0.0, "name": "f_dc"},
+ {"params": [self._features_rest], "lr": 0.0, "name": "f_rest"},
+ {"params": [self._opacity], "lr": 0.0, "name": "opacity"},
+ {"params": [self._scaling], "lr": 0.0, "name": "scaling"},
+ {"params": [self._rotation], "lr": 0.0, "name": "rotation"},
+ {"params": [cam_rot], "lr": cfg["cam_rot_lr"],
+ "name": "cam_unnorm_rot"},
+ {"params": [cam_trans], "lr": cfg["cam_trans_lr"],
+ "name": "cam_trans"},
+ ]
+ self.optimizer = torch.optim.Adam(params, amsgrad=True)
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+ self.optimizer, "min", factor=0.98, patience=10, verbose=False)
+
+ def construct_list_of_attributes(self):
+ l = ["x", "y", "z", "nx", "ny", "nz"]
+ # All channels except the 3 DC
+ for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
+ l.append("f_dc_{}".format(i))
+ for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):
+ l.append("f_rest_{}".format(i))
+ l.append("opacity")
+ for i in range(self._scaling.shape[1]):
+ l.append("scale_{}".format(i))
+ for i in range(self._rotation.shape[1]):
+ l.append("rot_{}".format(i))
+ return l
+
+ def save_ply(self, path):
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
+
+ xyz = self._xyz.detach().cpu().numpy()
+ normals = np.zeros_like(xyz)
+ f_dc = (
+ self._features_dc.detach()
+ .transpose(1, 2)
+ .flatten(start_dim=1)
+ .contiguous()
+ .cpu()
+ .numpy())
+ f_rest = (
+ self._features_rest.detach()
+ .transpose(1, 2)
+ .flatten(start_dim=1)
+ .contiguous()
+ .cpu()
+ .numpy())
+ opacities = self._opacity.detach().cpu().numpy()
+ if self.isotropic:
+ # tile into shape (P, 3)
+ scale = np.tile(self._scaling.detach().cpu().numpy()[:, 0].reshape(-1, 1), (1, 3))
+ else:
+ scale = self._scaling.detach().cpu().numpy()
+ rotation = self._rotation.detach().cpu().numpy()
+
+ dtype_full = [(attribute, "f4") for attribute in self.construct_list_of_attributes()]
+
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
+ elements[:] = list(map(tuple, attributes))
+ el = PlyElement.describe(elements, "vertex")
+ PlyData([el]).write(path)
+
+ def load_ply(self, path):
+ plydata = PlyData.read(path)
+
+ xyz = np.stack((
+ np.asarray(plydata.elements[0]["x"]),
+ np.asarray(plydata.elements[0]["y"]),
+ np.asarray(plydata.elements[0]["z"])),
+ axis=1)
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
+
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
+
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
+ extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
+ assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
+ for idx, attr_name in enumerate(extra_f_names):
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
+
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
+ scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
+ for idx, attr_name in enumerate(scale_names):
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
+ rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
+ for idx, attr_name in enumerate(rot_names):
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
+ self._features_dc = nn.Parameter(
+ torch.tensor(features_dc, dtype=torch.float, device="cuda")
+ .transpose(1, 2).contiguous().requires_grad_(True))
+ self._features_rest = nn.Parameter(
+ torch.tensor(features_extra, dtype=torch.float, device="cuda")
+ .transpose(1, 2)
+ .contiguous()
+ .requires_grad_(True)
+ )
+ self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
+ self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
+ self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
+
+ self.active_sh_degree = self.max_sh_degree
+
+ def replace_tensor_to_optimizer(self, tensor, name):
+ optimizable_tensors = {}
+ for group in self.optimizer.param_groups:
+ if group["name"] == name:
+ stored_state = self.optimizer.state.get(group["params"][0], None)
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
+
+ del self.optimizer.state[group["params"][0]]
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
+ self.optimizer.state[group["params"][0]] = stored_state
+
+ optimizable_tensors[group["name"]] = group["params"][0]
+ return optimizable_tensors
+
+ def _prune_optimizer(self, mask):
+ optimizable_tensors = {}
+ for group in self.optimizer.param_groups:
+ stored_state = self.optimizer.state.get(group["params"][0], None)
+ if stored_state is not None:
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
+
+ del self.optimizer.state[group["params"][0]]
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
+ self.optimizer.state[group["params"][0]] = stored_state
+ optimizable_tensors[group["name"]] = group["params"][0]
+ else:
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
+ optimizable_tensors[group["name"]] = group["params"][0]
+ return optimizable_tensors
+
+ def prune_points(self, mask):
+ valid_points_mask = ~mask
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
+
+ self._xyz = optimizable_tensors["xyz"]
+ self._features_dc = optimizable_tensors["f_dc"]
+ self._features_rest = optimizable_tensors["f_rest"]
+ self._opacity = optimizable_tensors["opacity"]
+ self._scaling = optimizable_tensors["scaling"]
+ self._rotation = optimizable_tensors["rotation"]
+
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
+
+ self.denom = self.denom[valid_points_mask]
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
+
+ def cat_tensors_to_optimizer(self, tensors_dict):
+ optimizable_tensors = {}
+ for group in self.optimizer.param_groups:
+ assert len(group["params"]) == 1
+ extension_tensor = tensors_dict[group["name"]]
+ stored_state = self.optimizer.state.get(group["params"][0], None)
+ if stored_state is not None:
+ stored_state["exp_avg"] = torch.cat(
+ (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
+ stored_state["exp_avg_sq"] = torch.cat(
+ (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
+
+ del self.optimizer.state[group["params"][0]]
+ group["params"][0] = nn.Parameter(
+ torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
+ self.optimizer.state[group["params"][0]] = stored_state
+
+ optimizable_tensors[group["name"]] = group["params"][0]
+ else:
+ group["params"][0] = nn.Parameter(
+ torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
+ optimizable_tensors[group["name"]] = group["params"][0]
+
+ return optimizable_tensors
+
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest,
+ new_opacities, new_scaling, new_rotation):
+ d = {
+ "xyz": new_xyz,
+ "f_dc": new_features_dc,
+ "f_rest": new_features_rest,
+ "opacity": new_opacities,
+ "scaling": new_scaling,
+ "rotation": new_rotation,
+ }
+
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
+ self._xyz = optimizable_tensors["xyz"]
+ self._features_dc = optimizable_tensors["f_dc"]
+ self._features_rest = optimizable_tensors["f_rest"]
+ self._opacity = optimizable_tensors["opacity"]
+ self._scaling = optimizable_tensors["scaling"]
+ self._rotation = optimizable_tensors["rotation"]
+
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz().shape[0], 1), device="cuda")
+ self.denom = torch.zeros((self.get_xyz().shape[0], 1), device="cuda")
+ self.max_radii2D = torch.zeros(
+ (self.get_xyz().shape[0]), device="cuda")
+
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
+ self.xyz_gradient_accum[update_filter] += torch.norm(
+ viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True)
+ self.denom[update_filter] += 1
diff --git a/src/entities/gaussian_slam.py b/src/entities/gaussian_slam.py
new file mode 100644
index 0000000..974ff87
--- /dev/null
+++ b/src/entities/gaussian_slam.py
@@ -0,0 +1,159 @@
+""" This module includes the Gaussian-SLAM class, which is responsible for controlling Mapper and Tracker
+ It also decides when to start a new submap and when to update the estimated camera poses.
+"""
+import os
+import pprint
+from argparse import ArgumentParser
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from src.entities.arguments import OptimizationParams
+from src.entities.datasets import get_dataset
+from src.entities.gaussian_model import GaussianModel
+from src.entities.mapper import Mapper
+from src.entities.tracker import Tracker
+from src.entities.logger import Logger
+from src.utils.io_utils import save_dict_to_ckpt, save_dict_to_yaml
+from src.utils.mapper_utils import exceeds_motion_thresholds
+from src.utils.utils import np2torch, setup_seed, torch2np
+from src.utils.vis_utils import * # noqa - needed for debugging
+
+
+class GaussianSLAM(object):
+
+ def __init__(self, config: dict) -> None:
+
+ self._setup_output_path(config)
+ self.device = "cuda"
+ self.config = config
+
+ self.scene_name = config["data"]["scene_name"]
+ self.dataset_name = config["dataset_name"]
+ self.dataset = get_dataset(config["dataset_name"])({**config["data"], **config["cam"]})
+
+ n_frames = len(self.dataset)
+ frame_ids = list(range(n_frames))
+ self.mapping_frame_ids = frame_ids[::config["mapping"]["map_every"]] + [n_frames - 1]
+
+ self.estimated_c2ws = torch.empty(len(self.dataset), 4, 4)
+ self.estimated_c2ws[0] = torch.from_numpy(self.dataset[0][3])
+
+ save_dict_to_yaml(config, "config.yaml", directory=self.output_path)
+
+ self.submap_using_motion_heuristic = config["mapping"]["submap_using_motion_heuristic"]
+
+ self.keyframes_info = {}
+ self.opt = OptimizationParams(ArgumentParser(description="Training script parameters"))
+
+ if self.submap_using_motion_heuristic:
+ self.new_submap_frame_ids = [0]
+ else:
+ self.new_submap_frame_ids = frame_ids[::config["mapping"]["new_submap_every"]] + [n_frames - 1]
+ self.new_submap_frame_ids.pop(0)
+
+ self.logger = Logger(self.output_path, config["use_wandb"])
+ self.mapper = Mapper(config["mapping"], self.dataset, self.logger)
+ self.tracker = Tracker(config["tracking"], self.dataset, self.logger)
+
+ print('Tracking config')
+ pprint.PrettyPrinter().pprint(config["tracking"])
+ print('Mapping config')
+ pprint.PrettyPrinter().pprint(config["mapping"])
+
+ def _setup_output_path(self, config: dict) -> None:
+ """ Sets up the output path for saving results based on the provided configuration. If the output path is not
+ specified in the configuration, it creates a new directory with a timestamp.
+ Args:
+ config: A dictionary containing the experiment configuration including data and output path information.
+ """
+ if "output_path" not in config["data"]:
+ output_path = Path(config["data"]["output_path"])
+ self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ self.output_path = output_path / self.timestamp
+ else:
+ self.output_path = Path(config["data"]["output_path"])
+ self.output_path.mkdir(exist_ok=True, parents=True)
+ os.makedirs(self.output_path / "mapping_vis", exist_ok=True)
+ os.makedirs(self.output_path / "tracking_vis", exist_ok=True)
+
+ def should_start_new_submap(self, frame_id: int) -> bool:
+ """ Determines whether a new submap should be started based on the motion heuristic or specific frame IDs.
+ Args:
+ frame_id: The ID of the current frame being processed.
+ Returns:
+ A boolean indicating whether to start a new submap.
+ """
+ if self.submap_using_motion_heuristic:
+ if exceeds_motion_thresholds(
+ self.estimated_c2ws[frame_id], self.estimated_c2ws[self.new_submap_frame_ids[-1]],
+ rot_thre=50, trans_thre=0.5):
+ return True
+ elif frame_id in self.new_submap_frame_ids:
+ return True
+ return False
+
+ def start_new_submap(self, frame_id: int, gaussian_model: GaussianModel) -> None:
+ """ Initializes a new submap, saving the current submap's checkpoint and resetting the Gaussian model.
+ This function updates the submap count and optionally marks the current frame ID for new submap initiation.
+ Args:
+ frame_id: The ID of the current frame at which the new submap is started.
+ gaussian_model: The current GaussianModel instance to capture and reset for the new submap.
+ Returns:
+ A new, reset GaussianModel instance for the new submap.
+ """
+ gaussian_params = gaussian_model.capture_dict()
+ submap_ckpt_name = str(self.submap_id).zfill(6)
+ submap_ckpt = {
+ "gaussian_params": gaussian_params,
+ "submap_keyframes": sorted(list(self.keyframes_info.keys()))
+ }
+ save_dict_to_ckpt(
+ submap_ckpt, f"{submap_ckpt_name}.ckpt", directory=self.output_path / "submaps")
+ gaussian_model = GaussianModel(0)
+ gaussian_model.training_setup(self.opt)
+ self.mapper.keyframes = []
+ self.keyframes_info = {}
+ if self.submap_using_motion_heuristic:
+ self.new_submap_frame_ids.append(frame_id)
+ self.mapping_frame_ids.append(frame_id)
+ self.submap_id += 1
+ return gaussian_model
+
+ def run(self) -> None:
+ """ Starts the main program flow for Gaussian-SLAM, including tracking and mapping. """
+ setup_seed(self.config["seed"])
+ gaussian_model = GaussianModel(0)
+ gaussian_model.training_setup(self.opt)
+ self.submap_id = 0
+
+ for frame_id in range(len(self.dataset)):
+
+ if frame_id in [0, 1]:
+ estimated_c2w = self.dataset[frame_id][-1]
+ else:
+ estimated_c2w = self.tracker.track(
+ frame_id, gaussian_model,
+ torch2np(self.estimated_c2ws[torch.tensor([0, frame_id - 2, frame_id - 1])]))
+ self.estimated_c2ws[frame_id] = np2torch(estimated_c2w)
+
+ # Reinitialize gaussian model for new segment
+ if self.should_start_new_submap(frame_id):
+ save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path)
+ gaussian_model = self.start_new_submap(frame_id, gaussian_model)
+
+ if frame_id in self.mapping_frame_ids:
+ print("\nMapping frame", frame_id)
+ gaussian_model.training_setup(self.opt)
+ estimate_c2w = torch2np(self.estimated_c2ws[frame_id])
+ new_submap = not bool(self.keyframes_info)
+ opt_dict = self.mapper.map(frame_id, estimate_c2w, gaussian_model, new_submap)
+
+ # Keyframes info update
+ self.keyframes_info[frame_id] = {
+ "keyframe_id": len(self.keyframes_info.keys()),
+ "opt_dict": opt_dict
+ }
+ save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path)
diff --git a/src/entities/logger.py b/src/entities/logger.py
new file mode 100644
index 0000000..7ae2a6b
--- /dev/null
+++ b/src/entities/logger.py
@@ -0,0 +1,157 @@
+""" This module includes the Logger class, which is responsible for logging for both Mapper and the Tracker """
+from pathlib import Path
+from typing import Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import wandb
+
+
+class Logger(object):
+
+ def __init__(self, output_path: Union[Path, str], use_wandb=False) -> None:
+ self.output_path = Path(output_path)
+ (self.output_path / "mapping_vis").mkdir(exist_ok=True, parents=True)
+ self.use_wandb = use_wandb
+
+ def log_tracking_iteration(self, frame_id, cur_pose, gt_quat, gt_trans, total_loss,
+ color_loss, depth_loss, iter, num_iters,
+ wandb_output=False, print_output=False) -> None:
+ """ Logs tracking iteration metrics including pose error, losses, and optionally reports to Weights & Biases.
+ Logs the error between the current pose estimate and ground truth quaternion and translation,
+ as well as various loss metrics. Can output to wandb if enabled and specified, and print to console.
+ Args:
+ frame_id: Identifier for the current frame.
+ cur_pose: The current estimated pose as a tensor (quaternion + translation).
+ gt_quat: Ground truth quaternion.
+ gt_trans: Ground truth translation.
+ total_loss: Total computed loss for the current iteration.
+ color_loss: Computed color loss for the current iteration.
+ depth_loss: Computed depth loss for the current iteration.
+ iter: The current iteration number.
+ num_iters: The total number of iterations planned.
+ wandb_output: Whether to output the log to wandb.
+ print_output: Whether to print the log output.
+ """
+
+ quad_err = torch.abs(cur_pose[:4] - gt_quat).mean().item()
+ trans_err = torch.abs(cur_pose[4:] - gt_trans).mean().item()
+ if self.use_wandb and wandb_output:
+ wandb.log(
+ {
+ "Tracking/idx": frame_id,
+ "Tracking/cam_quad_err": quad_err,
+ "Tracking/cam_position_err": trans_err,
+ "Tracking/total_loss": total_loss.item(),
+ "Tracking/color_loss": color_loss.item(),
+ "Tracking/depth_loss": depth_loss.item(),
+ "Tracking/num_iters": num_iters,
+ })
+ if iter == num_iters - 1:
+ msg = f"frame_id: {frame_id}, cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f} "
+ else:
+ msg = f"iter: {iter}, color_loss: {color_loss.item():.5f}, depth_loss: {depth_loss.item():.5f} "
+ msg = msg + f", cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f}"
+ if print_output:
+ print(msg, flush=True)
+
+ def log_mapping_iteration(self, frame_id, new_pts_num, model_size, iter_opt_time, opt_dict: dict) -> None:
+ """ Logs mapping iteration metrics including the number of new points, model size, and optimization times,
+ and optionally reports to Weights & Biases (wandb).
+ Args:
+ frame_id: Identifier for the current frame.
+ new_pts_num: The number of new points added in the current mapping iteration.
+ model_size: The total size of the model after the current mapping iteration.
+ iter_opt_time: Time taken per optimization iteration.
+ opt_dict: A dictionary containing optimization metrics such as PSNR, color loss, and depth loss.
+ """
+ if self.use_wandb:
+ wandb.log({"Mapping/idx": frame_id,
+ "Mapping/num_total_gs": model_size,
+ "Mapping/num_new_gs": new_pts_num,
+ "Mapping/per_iteration_time": iter_opt_time,
+ "Mapping/psnr_render": opt_dict["psnr_render"],
+ "Mapping/color_loss": opt_dict[frame_id]["color_loss"],
+ "Mapping/depth_loss": opt_dict[frame_id]["depth_loss"]})
+
+ def vis_mapping_iteration(self, frame_id, iter, color, depth, gt_color, gt_depth, seeding_mask=None) -> None:
+ """
+ Visualization of depth, color images and save to file.
+
+ Args:
+ frame_id (int): current frame index.
+ iter (int): the iteration number.
+ save_rendered_image (bool): whether to save the rgb image in separate folder
+ img_dir (str): the directory to save the visualization.
+ seeding_mask: used in mapper when adding gaussians, if not none.
+ """
+ gt_depth_np = gt_depth.cpu().numpy()
+ gt_color_np = gt_color.cpu().numpy()
+
+ depth_np = depth.detach().cpu().numpy()
+ color = torch.round(color * 255.0) / 255.0
+ color_np = color.detach().cpu().numpy()
+ depth_residual = np.abs(gt_depth_np - depth_np)
+ depth_residual[gt_depth_np == 0.0] = 0.0
+ # make errors >=5cm noticeable
+ depth_residual = np.clip(depth_residual, 0.0, 0.05)
+
+ color_residual = np.abs(gt_color_np - color_np)
+ color_residual[np.squeeze(gt_depth_np == 0.0)] = 0.0
+
+ # Determine Aspect Ratio and Figure Size
+ aspect_ratio = color.shape[1] / color.shape[0]
+ fig_height = 8
+ # Adjust the multiplier as needed for better spacing
+ fig_width = fig_height * aspect_ratio * 1.2
+
+ fig, axs = plt.subplots(2, 3, figsize=(fig_width, fig_height))
+ axs[0, 0].imshow(gt_depth_np, cmap="jet", vmin=0, vmax=6)
+ axs[0, 0].set_title('Input Depth', fontsize=16)
+ axs[0, 0].set_xticks([])
+ axs[0, 0].set_yticks([])
+ axs[0, 1].imshow(depth_np, cmap="jet", vmin=0, vmax=6)
+ axs[0, 1].set_title('Rendered Depth', fontsize=16)
+ axs[0, 1].set_xticks([])
+ axs[0, 1].set_yticks([])
+ axs[0, 2].imshow(depth_residual, cmap="plasma")
+ axs[0, 2].set_title('Depth Residual', fontsize=16)
+ axs[0, 2].set_xticks([])
+ axs[0, 2].set_yticks([])
+ gt_color_np = np.clip(gt_color_np, 0, 1)
+ color_np = np.clip(color_np, 0, 1)
+ color_residual = np.clip(color_residual, 0, 1)
+ axs[1, 0].imshow(gt_color_np, cmap="plasma")
+ axs[1, 0].set_title('Input RGB', fontsize=16)
+ axs[1, 0].set_xticks([])
+ axs[1, 0].set_yticks([])
+ axs[1, 1].imshow(color_np, cmap="plasma")
+ axs[1, 1].set_title('Rendered RGB', fontsize=16)
+ axs[1, 1].set_xticks([])
+ axs[1, 1].set_yticks([])
+ if seeding_mask is not None:
+ axs[1, 2].imshow(seeding_mask, cmap="gray")
+ axs[1, 2].set_title('Densification Mask', fontsize=16)
+ axs[1, 2].set_xticks([])
+ axs[1, 2].set_yticks([])
+ else:
+ axs[1, 2].imshow(color_residual, cmap="plasma")
+ axs[1, 2].set_title('RGB Residual', fontsize=16)
+ axs[1, 2].set_xticks([])
+ axs[1, 2].set_yticks([])
+
+ for ax in axs.flatten():
+ ax.axis('off')
+ fig.tight_layout()
+ plt.subplots_adjust(top=0.90) # Adjust top margin
+ fig_name = str(self.output_path / "mapping_vis" / f'{frame_id:04d}_{iter:04d}.jpg')
+ fig_title = f"Mapper Color/Depth at frame {frame_id:04d} iters {iter:04d}"
+ plt.suptitle(fig_title, y=0.98, fontsize=20)
+ plt.savefig(fig_name, dpi=250, bbox_inches='tight')
+ plt.clf()
+ plt.close()
+ if self.use_wandb:
+ log_title = "Mapping_vis/" + f'{frame_id:04d}_{iter:04d}'
+ wandb.log({log_title: [wandb.Image(fig_name)]})
+ print(f"Saved rendering vis of color/depth at {frame_id:04d}_{iter:04d}.jpg")
diff --git a/src/entities/losses.py b/src/entities/losses.py
new file mode 100644
index 0000000..070ac3f
--- /dev/null
+++ b/src/entities/losses.py
@@ -0,0 +1,140 @@
+from math import exp
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+
+def l1_loss(network_output: torch.Tensor, gt: torch.Tensor, agg="mean") -> torch.Tensor:
+ """
+ Computes the L1 loss, which is the mean absolute error between the network output and the ground truth.
+
+ Args:
+ network_output: The output from the network.
+ gt: The ground truth tensor.
+ agg: The aggregation method to be used. Defaults to "mean".
+ Returns:
+ The computed L1 loss.
+ """
+ l1_loss = torch.abs(network_output - gt)
+ if agg == "mean":
+ return l1_loss.mean()
+ elif agg == "sum":
+ return l1_loss.sum()
+ elif agg == "none":
+ return l1_loss
+ else:
+ raise ValueError("Invalid aggregation method.")
+
+
+def gaussian(window_size: int, sigma: float) -> torch.Tensor:
+ """
+ Creates a 1D Gaussian kernel.
+
+ Args:
+ window_size: The size of the window for the Gaussian kernel.
+ sigma: The standard deviation of the Gaussian kernel.
+
+ Returns:
+ The 1D Gaussian kernel.
+ """
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 /
+ float(2 * sigma ** 2)) for x in range(window_size)])
+ return gauss / gauss.sum()
+
+
+def create_window(window_size: int, channel: int) -> Variable:
+ """
+ Creates a 2D Gaussian window/kernel for SSIM computation.
+
+ Args:
+ window_size: The size of the window to be created.
+ channel: The number of channels in the image.
+
+ Returns:
+ A 2D Gaussian window expanded to match the number of channels.
+ """
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(
+ _1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(_2D_window.expand(
+ channel, 1, window_size, window_size).contiguous())
+ return window
+
+
+def ssim(img1: torch.Tensor, img2: torch.Tensor, window_size: int = 11, size_average: bool = True) -> torch.Tensor:
+ """
+ Computes the Structural Similarity Index (SSIM) between two images.
+
+ Args:
+ img1: The first image.
+ img2: The second image.
+ window_size: The size of the window to be used in SSIM computation. Defaults to 11.
+ size_average: If True, averages the SSIM over all pixels. Defaults to True.
+
+ Returns:
+ The computed SSIM value.
+ """
+ channel = img1.size(-3)
+ window = create_window(window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ return _ssim(img1, img2, window, window_size, channel, size_average)
+
+
+def _ssim(img1: torch.Tensor, img2: torch.Tensor, window: Variable, window_size: int,
+ channel: int, size_average: bool = True) -> torch.Tensor:
+ """
+ Internal function to compute the Structural Similarity Index (SSIM) between two images.
+
+ Args:
+ img1: The first image.
+ img2: The second image.
+ window: The Gaussian window/kernel for SSIM computation.
+ window_size: The size of the window to be used in SSIM computation.
+ channel: The number of channels in the image.
+ size_average: If True, averages the SSIM over all pixels.
+
+ Returns:
+ The computed SSIM value.
+ """
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window,
+ padding=window_size // 2, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window,
+ padding=window_size // 2, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window,
+ padding=window_size // 2, groups=channel) - mu1_mu2
+
+ C1 = 0.01 ** 2
+ C2 = 0.03 ** 2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
+ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
+
+def isotropic_loss(scaling: torch.Tensor) -> torch.Tensor:
+ """
+ Computes loss enforcing isotropic scaling for the 3D Gaussians
+ Args:
+ scaling: scaling tensor of 3D Gaussians of shape (n, 3)
+ Returns:
+ The computed isotropic loss
+ """
+ mean_scaling = scaling.mean(dim=1, keepdim=True)
+ isotropic_diff = torch.abs(scaling - mean_scaling * torch.ones_like(scaling))
+ return isotropic_diff.mean()
diff --git a/src/entities/mapper.py b/src/entities/mapper.py
new file mode 100644
index 0000000..23fa990
--- /dev/null
+++ b/src/entities/mapper.py
@@ -0,0 +1,260 @@
+""" This module includes the Mapper class, which is responsible scene mapping: Paragraph 3.2 """
+import time
+from argparse import ArgumentParser
+
+import numpy as np
+import torch
+import torchvision
+
+from src.entities.arguments import OptimizationParams
+from src.entities.datasets import TUM_RGBD, BaseDataset, ScanNet
+from src.entities.gaussian_model import GaussianModel
+from src.entities.logger import Logger
+from src.entities.losses import isotropic_loss, l1_loss, ssim
+from src.utils.mapper_utils import (calc_psnr, compute_camera_frustum_corners,
+ compute_frustum_point_ids,
+ compute_new_points_ids,
+ compute_opt_views_distribution,
+ create_point_cloud, geometric_edge_mask,
+ sample_pixels_based_on_gradient)
+from src.utils.utils import (get_render_settings, np2ptcloud, np2torch,
+ render_gaussian_model, torch2np)
+from src.utils.vis_utils import * # noqa - needed for debugging
+
+
+class Mapper(object):
+ def __init__(self, config: dict, dataset: BaseDataset, logger: Logger) -> None:
+ """ Sets up the mapper parameters
+ Args:
+ config: configuration of the mapper
+ dataset: The dataset object used for extracting camera parameters and reading the data
+ logger: The logger object used for logging the mapping process and saving visualizations
+ """
+ self.config = config
+ self.logger = logger
+ self.dataset = dataset
+ self.iterations = config["iterations"]
+ self.new_submap_iterations = config["new_submap_iterations"]
+ self.new_submap_points_num = config["new_submap_points_num"]
+ self.new_submap_gradient_points_num = config["new_submap_gradient_points_num"]
+ self.new_frame_sample_size = config["new_frame_sample_size"]
+ self.new_points_radius = config["new_points_radius"]
+ self.alpha_thre = config["alpha_thre"]
+ self.pruning_thre = config["pruning_thre"]
+ self.current_view_opt_iterations = config["current_view_opt_iterations"]
+ self.opt = OptimizationParams(ArgumentParser(description="Training script parameters"))
+ self.keyframes = []
+
+ def compute_seeding_mask(self, gaussian_model: GaussianModel, keyframe: dict, new_submap: bool) -> np.ndarray:
+ """
+ Computes a binary mask to identify regions within a keyframe where new Gaussian models should be seeded
+ based on alpha masks or color gradient
+ Args:
+ gaussian_model: The current submap
+ keyframe (dict): Keyframe dict containing color, depth, and render settings
+ new_submap (bool): A boolean indicating whether the seeding is occurring in current submap or a new submap
+ Returns:
+ np.ndarray: A binary mask of shpae (H, W) indicates regions suitable for seeding new 3D Gaussian models
+ """
+ seeding_mask = None
+ if new_submap:
+ color_for_mask = (torch2np(keyframe["color"].permute(1, 2, 0)) * 255).astype(np.uint8)
+ seeding_mask = geometric_edge_mask(color_for_mask, RGB=True)
+ else:
+ render_dict = render_gaussian_model(gaussian_model, keyframe["render_settings"])
+ alpha_mask = (render_dict["alpha"] < self.alpha_thre)
+ gt_depth_tensor = keyframe["depth"][None]
+ depth_error = torch.abs(gt_depth_tensor - render_dict["depth"]) * (gt_depth_tensor > 0)
+ depth_error_mask = (render_dict["depth"] > gt_depth_tensor) * (depth_error > 40 * depth_error.median())
+ seeding_mask = alpha_mask | depth_error_mask
+ seeding_mask = torch2np(seeding_mask[0])
+ return seeding_mask
+
+ def seed_new_gaussians(self, gt_color: np.ndarray, gt_depth: np.ndarray, intrinsics: np.ndarray,
+ estimate_c2w: np.ndarray, seeding_mask: np.ndarray, is_new_submap: bool) -> np.ndarray:
+ """
+ Seeds means for the new 3D Gaussian based on ground truth color and depth, camera intrinsics,
+ estimated camera-to-world transformation, a seeding mask, and a flag indicating whether this is a new submap.
+ Args:
+ gt_color: The ground truth color image as a numpy array with shape (H, W, 3).
+ gt_depth: The ground truth depth map as a numpy array with shape (H, W).
+ intrinsics: The camera intrinsics matrix as a numpy array with shape (3, 3).
+ estimate_c2w: The estimated camera-to-world transformation matrix as a numpy array with shape (4, 4).
+ seeding_mask: A binary mask indicating where to seed new Gaussians, with shape (H, W).
+ is_new_submap: Flag indicating whether the seeding is for a new submap (True) or an existing submap (False).
+ Returns:
+ np.ndarray: An array of 3D points where new Gaussians will be initialized, with shape (N, 3)
+
+ """
+ pts = create_point_cloud(gt_color, 1.005 * gt_depth, intrinsics, estimate_c2w)
+ flat_gt_depth = gt_depth.flatten()
+ non_zero_depth_mask = flat_gt_depth > 0. # need filter if zero depth pixels in gt_depth
+ valid_ids = np.flatnonzero(seeding_mask)
+ if is_new_submap:
+ if self.new_submap_points_num < 0:
+ uniform_ids = np.arange(pts.shape[0])
+ else:
+ uniform_ids = np.random.choice(pts.shape[0], self.new_submap_points_num, replace=False)
+ gradient_ids = sample_pixels_based_on_gradient(gt_color, self.new_submap_gradient_points_num)
+ combined_ids = np.concatenate((uniform_ids, gradient_ids))
+ combined_ids = np.concatenate((combined_ids, valid_ids))
+ sample_ids = np.unique(combined_ids)
+ else:
+ if self.new_frame_sample_size < 0 or len(valid_ids) < self.new_frame_sample_size:
+ sample_ids = valid_ids
+ else:
+ sample_ids = np.random.choice(valid_ids, size=self.new_frame_sample_size, replace=False)
+ sample_ids = sample_ids[non_zero_depth_mask[sample_ids]]
+ return pts[sample_ids, :].astype(np.float32)
+
+ def optimize_submap(self, keyframes: list, gaussian_model: GaussianModel, iterations: int = 100) -> dict:
+ """
+ Optimizes the submap by refining the parameters of the 3D Gaussian based on the observations
+ from keyframes observing the submap.
+ Args:
+ keyframes: A list of tuples consisting of frame id and keyframe dictionary
+ gaussian_model: An instance of the GaussianModel class representing the initial state
+ of the Gaussian model to be optimized.
+ iterations: The number of iterations to perform the optimization process. Defaults to 100.
+ Returns:
+ losses_dict: Dictionary with the optimization statistics
+ """
+
+ iteration = 0
+ losses_dict = {}
+
+ current_frame_iters = self.current_view_opt_iterations * iterations
+ distribution = compute_opt_views_distribution(len(keyframes), iterations, current_frame_iters)
+ start_time = time.time()
+ while iteration < iterations + 1:
+ gaussian_model.optimizer.zero_grad(set_to_none=True)
+ keyframe_id = np.random.choice(np.arange(len(keyframes)), p=distribution)
+
+ frame_id, keyframe = keyframes[keyframe_id]
+ render_pkg = render_gaussian_model(gaussian_model, keyframe["render_settings"])
+
+ image, depth = render_pkg["color"], render_pkg["depth"]
+ gt_image = keyframe["color"]
+ gt_depth = keyframe["depth"]
+
+ mask = (gt_depth > 0) & (~torch.isnan(depth)).squeeze(0)
+ color_loss = (1.0 - self.opt.lambda_dssim) * l1_loss(
+ image[:, mask], gt_image[:, mask]) + self.opt.lambda_dssim * (1.0 - ssim(image, gt_image))
+
+ depth_loss = l1_loss(depth[:, mask], gt_depth[mask])
+ reg_loss = isotropic_loss(gaussian_model.get_scaling())
+ total_loss = color_loss + depth_loss + reg_loss
+ total_loss.backward()
+
+ losses_dict[frame_id] = {"color_loss": color_loss.item(),
+ "depth_loss": depth_loss.item(),
+ "total_loss": total_loss.item()}
+
+ with torch.no_grad():
+
+ if iteration == iterations // 2 or iteration == iterations:
+ prune_mask = (gaussian_model.get_opacity()
+ < self.pruning_thre).squeeze()
+ gaussian_model.prune_points(prune_mask)
+
+ # Optimizer step
+ if iteration < iterations:
+ gaussian_model.optimizer.step()
+ gaussian_model.optimizer.zero_grad(set_to_none=True)
+
+ iteration += 1
+ optimization_time = time.time() - start_time
+ losses_dict["optimization_time"] = optimization_time
+ losses_dict["optimization_iter_time"] = optimization_time / iterations
+ return losses_dict
+
+ def grow_submap(self, gt_depth: np.ndarray, estimate_c2w: np.ndarray, gaussian_model: GaussianModel,
+ pts: np.ndarray, filter_cloud: bool) -> int:
+ """
+ Expands the submap by integrating new points from the current keyframe
+ Args:
+ gt_depth: The ground truth depth map for the current keyframe, as a 2D numpy array.
+ estimate_c2w: The estimated camera-to-world transformation matrix for the current keyframe of shape (4x4)
+ gaussian_model (GaussianModel): The Gaussian model representing the current state of the submap.
+ pts: The current set of 3D points in the keyframe of shape (N, 3)
+ filter_cloud: A boolean flag indicating whether to apply filtering to the point cloud to remove
+ outliers or noise before integrating it into the map.
+ Returns:
+ int: The number of points added to the submap
+ """
+ gaussian_points = gaussian_model.get_xyz()
+ camera_frustum_corners = compute_camera_frustum_corners(gt_depth, estimate_c2w, self.dataset.intrinsics)
+ reused_pts_ids = compute_frustum_point_ids(
+ gaussian_points, np2torch(camera_frustum_corners), device="cuda")
+ new_pts_ids = compute_new_points_ids(gaussian_points[reused_pts_ids], np2torch(pts[:, :3]).contiguous(),
+ radius=self.new_points_radius, device="cuda")
+ new_pts_ids = torch2np(new_pts_ids)
+ if new_pts_ids.shape[0] > 0:
+ cloud_to_add = np2ptcloud(pts[new_pts_ids, :3], pts[new_pts_ids, 3:] / 255.0)
+ if filter_cloud:
+ cloud_to_add, _ = cloud_to_add.remove_statistical_outlier(nb_neighbors=40, std_ratio=2.0)
+ gaussian_model.add_points(cloud_to_add)
+ gaussian_model._features_dc.requires_grad = False
+ gaussian_model._features_rest.requires_grad = False
+ print("Gaussian model size", gaussian_model.get_size())
+ return new_pts_ids.shape[0]
+
+ def map(self, frame_id: int, estimate_c2w: np.ndarray, gaussian_model: GaussianModel, is_new_submap: bool) -> dict:
+ """ Calls out the mapping process described in paragraph 3.2
+ The process goes as follows: seed new gaussians -> add to the submap -> optimize the submap
+ Args:
+ frame_id: current keyframe id
+ estimate_c2w (np.ndarray): The estimated camera-to-world transformation matrix of shape (4x4)
+ gaussian_model (GaussianModel): The current Gaussian model of the submap
+ is_new_submap (bool): A boolean flag indicating whether the current frame initiates a new submap
+ Returns:
+ opt_dict: Dictionary with statistics about the optimization process
+ """
+
+ _, gt_color, gt_depth, _ = self.dataset[frame_id]
+ estimate_w2c = np.linalg.inv(estimate_c2w)
+
+ color_transform = torchvision.transforms.ToTensor()
+ keyframe = {
+ "color": color_transform(gt_color).cuda(),
+ "depth": np2torch(gt_depth, device="cuda"),
+ "render_settings": get_render_settings(
+ self.dataset.width, self.dataset.height, self.dataset.intrinsics, estimate_w2c)}
+
+ seeding_mask = self.compute_seeding_mask(gaussian_model, keyframe, is_new_submap)
+ pts = self.seed_new_gaussians(
+ gt_color, gt_depth, self.dataset.intrinsics, estimate_c2w, seeding_mask, is_new_submap)
+
+ filter_cloud = isinstance(self.dataset, (TUM_RGBD, ScanNet)) and not is_new_submap
+
+ new_pts_num = self.grow_submap(gt_depth, estimate_c2w, gaussian_model, pts, filter_cloud)
+
+ max_iterations = self.iterations
+ if is_new_submap:
+ max_iterations = self.new_submap_iterations
+ start_time = time.time()
+ opt_dict = self.optimize_submap([(frame_id, keyframe)] + self.keyframes, gaussian_model, max_iterations)
+ optimization_time = time.time() - start_time
+ print("Optimization time: ", optimization_time)
+
+ self.keyframes.append((frame_id, keyframe))
+
+ # Visualise the mapping for the current frame
+ with torch.no_grad():
+ render_pkg_vis = render_gaussian_model(gaussian_model, keyframe["render_settings"])
+ image_vis, depth_vis = render_pkg_vis["color"], render_pkg_vis["depth"]
+ psnr_value = calc_psnr(image_vis, keyframe["color"]).mean().item()
+ opt_dict["psnr_render"] = psnr_value
+ print(f"PSNR this frame: {psnr_value}")
+ self.logger.vis_mapping_iteration(
+ frame_id, max_iterations,
+ image_vis.clone().detach().permute(1, 2, 0),
+ depth_vis.clone().detach().permute(1, 2, 0),
+ keyframe["color"].permute(1, 2, 0),
+ keyframe["depth"].unsqueeze(-1),
+ seeding_mask=seeding_mask)
+
+ # Log the mapping numbers for the current frame
+ self.logger.log_mapping_iteration(frame_id, new_pts_num, gaussian_model.get_size(),
+ optimization_time/max_iterations, opt_dict)
+ return opt_dict
diff --git a/src/entities/tracker.py b/src/entities/tracker.py
new file mode 100644
index 0000000..7c75c13
--- /dev/null
+++ b/src/entities/tracker.py
@@ -0,0 +1,215 @@
+""" This module includes the Mapper class, which is responsible scene mapping: Paper Section 3.4 """
+from argparse import ArgumentParser
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+from scipy.spatial.transform import Rotation as R
+
+from src.entities.arguments import OptimizationParams
+from src.entities.losses import l1_loss
+from src.entities.gaussian_model import GaussianModel
+from src.entities.logger import Logger
+from src.entities.datasets import BaseDataset
+from src.entities.visual_odometer import VisualOdometer
+from src.utils.gaussian_model_utils import build_rotation
+from src.utils.tracker_utils import (compute_camera_opt_params,
+ interpolate_poses, multiply_quaternions,
+ transformation_to_quaternion)
+from src.utils.utils import (get_render_settings, np2torch,
+ render_gaussian_model, torch2np)
+
+
+class Tracker(object):
+ def __init__(self, config: dict, dataset: BaseDataset, logger: Logger) -> None:
+ """ Initializes the Tracker with a given configuration, dataset, and logger.
+ Args:
+ config: Configuration dictionary specifying hyperparameters and operational settings.
+ dataset: The dataset object providing access to the sequence of frames.
+ logger: Logger object for logging the tracking process.
+ """
+ self.dataset = dataset
+ self.logger = logger
+ self.config = config
+ self.filter_alpha = self.config["filter_alpha"]
+ self.filter_outlier_depth = self.config["filter_outlier_depth"]
+ self.alpha_thre = self.config["alpha_thre"]
+ self.soft_alpha = self.config["soft_alpha"]
+ self.mask_invalid_depth_in_color_loss = self.config["mask_invalid_depth"]
+ self.w_color_loss = self.config["w_color_loss"]
+ self.transform = torchvision.transforms.ToTensor()
+ self.opt = OptimizationParams(ArgumentParser(description="Training script parameters"))
+ self.frame_depth_loss = []
+ self.frame_color_loss = []
+ self.odometry_type = self.config["odometry_type"]
+ self.help_camera_initialization = self.config["help_camera_initialization"]
+ self.init_err_ratio = self.config["init_err_ratio"]
+ self.odometer = VisualOdometer(self.dataset.intrinsics, self.config["odometer_method"])
+
+ def compute_losses(self, gaussian_model: GaussianModel, render_settings: dict,
+ opt_cam_rot: torch.Tensor, opt_cam_trans: torch.Tensor,
+ gt_color: torch.Tensor, gt_depth: torch.Tensor, depth_mask: torch.Tensor) -> tuple:
+ """ Computes the tracking losses with respect to ground truth color and depth.
+ Args:
+ gaussian_model: The current state of the Gaussian model of the scene.
+ render_settings: Dictionary containing rendering settings such as image dimensions and camera intrinsics.
+ opt_cam_rot: Optimizable tensor representing the camera's rotation.
+ opt_cam_trans: Optimizable tensor representing the camera's translation.
+ gt_color: Ground truth color image tensor.
+ gt_depth: Ground truth depth image tensor.
+ depth_mask: Binary mask indicating valid depth values in the ground truth depth image.
+ Returns:
+ A tuple containing losses and renders
+ """
+ rel_transform = torch.eye(4).cuda().float()
+ rel_transform[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None]))[0]
+ rel_transform[:3, 3] = opt_cam_trans
+
+ pts = gaussian_model.get_xyz()
+ pts_ones = torch.ones(pts.shape[0], 1).cuda().float()
+ pts4 = torch.cat((pts, pts_ones), dim=1)
+ transformed_pts = (rel_transform @ pts4.T).T[:, :3]
+
+ quat = F.normalize(opt_cam_rot[None])
+ _rotations = multiply_quaternions(gaussian_model.get_rotation(), quat.unsqueeze(0)).squeeze(0)
+
+ render_dict = render_gaussian_model(gaussian_model, render_settings,
+ override_means_3d=transformed_pts, override_rotations=_rotations)
+ rendered_color, rendered_depth = render_dict["color"], render_dict["depth"]
+ alpha_mask = render_dict["alpha"] > self.alpha_thre
+
+ tracking_mask = torch.ones_like(alpha_mask).bool()
+ tracking_mask &= depth_mask
+ depth_err = torch.abs(rendered_depth - gt_depth) * depth_mask
+
+ if self.filter_alpha:
+ tracking_mask &= alpha_mask
+ if self.filter_outlier_depth and torch.median(depth_err) > 0:
+ tracking_mask &= depth_err < 50 * torch.median(depth_err)
+
+ color_loss = l1_loss(rendered_color, gt_color, agg="none")
+ depth_loss = l1_loss(rendered_depth, gt_depth, agg="none") * tracking_mask
+
+ if self.soft_alpha:
+ alpha = render_dict["alpha"] ** 3
+ color_loss *= alpha
+ depth_loss *= alpha
+ if self.mask_invalid_depth_in_color_loss:
+ color_loss *= tracking_mask
+ else:
+ color_loss *= tracking_mask
+
+ color_loss = color_loss.sum()
+ depth_loss = depth_loss.sum()
+
+ return color_loss, depth_loss, rendered_color, rendered_depth, alpha_mask
+
+ def track(self, frame_id: int, gaussian_model: GaussianModel, prev_c2ws: np.ndarray) -> np.ndarray:
+ """
+ Updates the camera pose estimation for the current frame based on the provided image and depth, using either ground truth poses,
+ constant speed assumption, or visual odometry.
+ Args:
+ frame_id: Index of the current frame being processed.
+ gaussian_model: The current Gaussian model of the scene.
+ prev_c2ws: Array containing the camera-to-world transformation matrices for the frames (0, i - 2, i - 1)
+ Returns:
+ The updated camera-to-world transformation matrix for the current frame.
+ """
+ _, image, depth, gt_c2w = self.dataset[frame_id]
+
+ if (self.help_camera_initialization or self.odometry_type == "odometer") and self.odometer.last_rgbd is None:
+ _, last_image, last_depth, _ = self.dataset[frame_id - 1]
+ self.odometer.update_last_rgbd(last_image, last_depth)
+
+ if self.odometry_type == "gt":
+ return gt_c2w
+ elif self.odometry_type == "const_speed":
+ init_c2w = interpolate_poses(prev_c2ws[1:])
+ elif self.odometry_type == "odometer":
+ odometer_rel = self.odometer.estimate_rel_pose(image, depth)
+ init_c2w = prev_c2ws[-1] @ odometer_rel
+
+ last_c2w = prev_c2ws[-1]
+ last_w2c = np.linalg.inv(last_c2w)
+ init_rel = init_c2w @ np.linalg.inv(last_c2w)
+ init_rel_w2c = np.linalg.inv(init_rel)
+ reference_w2c = last_w2c
+ render_settings = get_render_settings(
+ self.dataset.width, self.dataset.height, self.dataset.intrinsics, reference_w2c)
+ opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c)
+ gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config)
+
+ gt_color = self.transform(image).cuda()
+ gt_depth = np2torch(depth, "cuda")
+ depth_mask = gt_depth > 0.0
+ gt_trans = np2torch(gt_c2w[:3, 3])
+ gt_quat = np2torch(R.from_matrix(gt_c2w[:3, :3]).as_quat(canonical=True)[[3, 0, 1, 2]])
+ num_iters = self.config["iterations"]
+ current_min_loss = float("inf")
+
+ print(f"\nTracking frame {frame_id}")
+ # Initial loss check
+ color_loss, depth_loss, _, _, _ = self.compute_losses(gaussian_model, render_settings, opt_cam_rot,
+ opt_cam_trans, gt_color, gt_depth, depth_mask)
+ if len(self.frame_color_loss) > 0 and (
+ color_loss.item() > self.init_err_ratio * np.median(self.frame_color_loss)
+ or depth_loss.item() > self.init_err_ratio * np.median(self.frame_depth_loss)
+ ):
+ num_iters *= 2
+ print(f"Higher initial loss, increasing num_iters to {num_iters}")
+ if self.help_camera_initialization and self.odometry_type != "odometer":
+ _, last_image, last_depth, _ = self.dataset[frame_id - 1]
+ self.odometer.update_last_rgbd(last_image, last_depth)
+ odometer_rel = self.odometer.estimate_rel_pose(image, depth)
+ init_c2w = last_c2w @ odometer_rel
+ init_rel = init_c2w @ np.linalg.inv(last_c2w)
+ init_rel_w2c = np.linalg.inv(init_rel)
+ opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c)
+ gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config)
+ render_settings = get_render_settings(
+ self.dataset.width, self.dataset.height, self.dataset.intrinsics, last_w2c)
+ print(f"re-init with odometer for frame {frame_id}")
+
+ for iter in range(num_iters):
+ color_loss, depth_loss, _, _, _, = self.compute_losses(
+ gaussian_model, render_settings, opt_cam_rot, opt_cam_trans, gt_color, gt_depth, depth_mask)
+
+ total_loss = (self.w_color_loss * color_loss + (1 - self.w_color_loss) * depth_loss)
+ total_loss.backward()
+ gaussian_model.optimizer.step()
+ gaussian_model.optimizer.zero_grad(set_to_none=True)
+
+ with torch.no_grad():
+ if total_loss.item() < current_min_loss:
+ current_min_loss = total_loss.item()
+ best_w2c = torch.eye(4)
+ best_w2c[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None].clone().detach().cpu()))[0]
+ best_w2c[:3, 3] = opt_cam_trans.clone().detach().cpu()
+
+ cur_quat, cur_trans = F.normalize(opt_cam_rot[None].clone().detach()), opt_cam_trans.clone().detach()
+ cur_rel_w2c = torch.eye(4)
+ cur_rel_w2c[:3, :3] = build_rotation(cur_quat)[0]
+ cur_rel_w2c[:3, 3] = cur_trans
+ if iter == num_iters - 1:
+ cur_w2c = torch.from_numpy(reference_w2c) @ best_w2c
+ else:
+ cur_w2c = torch.from_numpy(reference_w2c) @ cur_rel_w2c
+ cur_c2w = torch.inverse(cur_w2c)
+ cur_cam = transformation_to_quaternion(cur_c2w)
+ if (gt_quat * cur_cam[:4]).sum() < 0: # for logging purpose
+ gt_quat *= -1
+ if iter == num_iters - 1:
+ self.frame_color_loss.append(color_loss.item())
+ self.frame_depth_loss.append(depth_loss.item())
+ self.logger.log_tracking_iteration(
+ frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters,
+ wandb_output=True, print_output=True)
+ elif iter % 20 == 0:
+ self.logger.log_tracking_iteration(
+ frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters,
+ wandb_output=False, print_output=True)
+
+ final_c2w = torch.inverse(torch.from_numpy(reference_w2c) @ best_w2c)
+ final_c2w[-1, :] = torch.tensor([0., 0., 0., 1.], dtype=final_c2w.dtype, device=final_c2w.device)
+ return torch2np(final_c2w)
diff --git a/src/entities/visual_odometer.py b/src/entities/visual_odometer.py
new file mode 100644
index 0000000..ede9567
--- /dev/null
+++ b/src/entities/visual_odometer.py
@@ -0,0 +1,76 @@
+""" This module includes the Odometer class, which is allows for fast pose estimation from RGBD neighbor frames """
+import numpy as np
+import open3d as o3d
+import open3d.core as o3c
+
+
+class VisualOdometer(object):
+
+ def __init__(self, intrinsics: np.ndarray, method_name="hybrid", device="cuda"):
+ """ Initializes the visual odometry system with specified intrinsics, method, and device.
+ Args:
+ intrinsics: Camera intrinsic parameters.
+ method_name: The name of the odometry computation method to use ('hybrid' or 'point_to_plane').
+ device: The computation device ('cuda' or 'cpu').
+ """
+ device = "CUDA:0" if device == "cuda" else "CPU:0"
+ self.device = o3c.Device(device)
+ self.intrinsics = o3d.core.Tensor(intrinsics, o3d.core.Dtype.Float64)
+ self.last_abs_pose = None
+ self.last_frame = None
+ self.criteria_list = [
+ o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500),
+ o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500),
+ o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500)]
+ self.setup_method(method_name)
+ self.max_depth = 10.0
+ self.depth_scale = 1.0
+ self.last_rgbd = None
+
+ def setup_method(self, method_name: str) -> None:
+ """ Sets up the odometry computation method based on the provided method name.
+ Args:
+ method_name: The name of the odometry method to use ('hybrid' or 'point_to_plane').
+ """
+ if method_name == "hybrid":
+ self.method = o3d.t.pipelines.odometry.Method.Hybrid
+ elif method_name == "point_to_plane":
+ self.method = o3d.t.pipelines.odometry.Method.PointToPlane
+ else:
+ raise ValueError("Odometry method does not exist!")
+
+ def update_last_rgbd(self, image: np.ndarray, depth: np.ndarray) -> None:
+ """ Updates the last RGB-D frame stored in the system with a new RGB-D frame constructed from provided image and depth.
+ Args:
+ image: The new RGB image as a numpy ndarray.
+ depth: The new depth image as a numpy ndarray.
+ """
+ self.last_rgbd = o3d.t.geometry.RGBDImage(
+ o3d.t.geometry.Image(np.ascontiguousarray(
+ image).astype(np.float32)).to(self.device),
+ o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device))
+
+ def estimate_rel_pose(self, image: np.ndarray, depth: np.ndarray, init_transform=np.eye(4)):
+ """ Estimates the relative pose of the current frame with respect to the last frame using RGB-D odometry.
+ Args:
+ image: The current RGB image as a numpy ndarray.
+ depth: The current depth image as a numpy ndarray.
+ init_transform: An initial transformation guess as a numpy ndarray. Defaults to the identity matrix.
+ Returns:
+ The relative transformation matrix as a numpy ndarray.
+ """
+ rgbd = o3d.t.geometry.RGBDImage(
+ o3d.t.geometry.Image(np.ascontiguousarray(image).astype(np.float32)).to(self.device),
+ o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device))
+ rel_transform = o3d.t.pipelines.odometry.rgbd_odometry_multi_scale(
+ self.last_rgbd, rgbd, self.intrinsics, o3c.Tensor(init_transform),
+ self.depth_scale, self.max_depth, self.criteria_list, self.method)
+ self.last_rgbd = rgbd.clone()
+
+ # Adjust for the coordinate system difference
+ rel_transform = rel_transform.transformation.cpu().numpy()
+ rel_transform[0, [1, 2, 3]] *= -1
+ rel_transform[1, [0, 2, 3]] *= -1
+ rel_transform[2, [0, 1, 3]] *= -1
+
+ return rel_transform
diff --git a/src/evaluation/__init__.py b/src/evaluation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/evaluation/evaluate_merged_map.py b/src/evaluation/evaluate_merged_map.py
new file mode 100644
index 0000000..6db536f
--- /dev/null
+++ b/src/evaluation/evaluate_merged_map.py
@@ -0,0 +1,141 @@
+""" This module is responsible for merging submaps. """
+from argparse import ArgumentParser
+
+import faiss
+import numpy as np
+import open3d as o3d
+import torch
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+from src.entities.arguments import OptimizationParams
+from src.entities.gaussian_model import GaussianModel
+from src.entities.losses import isotropic_loss, l1_loss, ssim
+from src.utils.utils import (batch_search_faiss, get_render_settings,
+ np2ptcloud, render_gaussian_model, torch2np)
+
+
+class RenderFrames(Dataset):
+ """A dataset class for loading keyframes along with their estimated camera poses and render settings."""
+ def __init__(self, dataset, render_poses: np.ndarray, height: int, width: int, fx: float, fy: float):
+ self.dataset = dataset
+ self.render_poses = render_poses
+ self.height = height
+ self.width = width
+ self.fx = fx
+ self.fy = fy
+ self.device = "cuda"
+ self.stride = 1
+ if len(dataset) > 1000:
+ self.stride = len(dataset) // 1000
+
+ def __len__(self) -> int:
+ return len(self.dataset) // self.stride
+
+ def __getitem__(self, idx):
+ idx = idx * self.stride
+ color = (torch.from_numpy(
+ self.dataset[idx][1]) / 255.0).float().to(self.device)
+ depth = torch.from_numpy(self.dataset[idx][2]).float().to(self.device)
+ estimate_c2w = self.render_poses[idx]
+ estimate_w2c = np.linalg.inv(estimate_c2w)
+ frame = {
+ "frame_id": idx,
+ "color": color,
+ "depth": depth,
+ "render_settings": get_render_settings(
+ self.width, self.height, self.dataset.intrinsics, estimate_w2c)
+ }
+ return frame
+
+
+def merge_submaps(submaps_paths: list, radius: float = 0.0001, device: str = "cuda") -> o3d.geometry.PointCloud:
+ """ Merge submaps into a single point cloud, which is then used for global map refinement.
+ Args:
+ segments_paths (list): Folder path of the submaps.
+ radius (float, optional): Nearest neighbor distance threshold for adding a point. Defaults to 0.0001.
+ device (str, optional): Defaults to "cuda".
+
+ Returns:
+ o3d.geometry.PointCloud: merged point cloud
+ """
+ pts_index = faiss.IndexFlatL2(3)
+ if device == "cuda":
+ pts_index = faiss.index_cpu_to_gpu(
+ faiss.StandardGpuResources(),
+ 0,
+ faiss.IndexIVFFlat(faiss.IndexFlatL2(3), 3, 500, faiss.METRIC_L2))
+ pts_index.nprobe = 5
+ merged_pts = []
+ print("Merging segments")
+ for submap_path in tqdm(submaps_paths):
+ gaussian_params = torch.load(submap_path)["gaussian_params"]
+ current_pts = gaussian_params["xyz"].to(device).float()
+ pts_index.train(current_pts)
+ distances, _ = batch_search_faiss(pts_index, current_pts, 8)
+ neighbor_num = (distances < radius).sum(axis=1).int()
+ ids_to_include = torch.where(neighbor_num == 0)[0]
+ pts_index.add(current_pts[ids_to_include])
+ merged_pts.append(current_pts[ids_to_include])
+ pts = torch2np(torch.vstack(merged_pts))
+ pt_cloud = np2ptcloud(pts, np.zeros_like(pts))
+
+ # Downsampling if the total number of points is too large
+ if len(pt_cloud.points) > 1_000_000:
+ voxel_size = 0.02
+ pt_cloud = pt_cloud.voxel_down_sample(voxel_size)
+ print(f"Downsampled point cloud to {len(pt_cloud.points)} points")
+ filtered_pt_cloud, _ = pt_cloud.remove_statistical_outlier(nb_neighbors=40, std_ratio=3.0)
+ del pts_index
+ return filtered_pt_cloud
+
+
+def refine_global_map(pt_cloud: o3d.geometry.PointCloud, training_frames: list, max_iterations: int) -> GaussianModel:
+ """Refines a global map based on the merged point cloud and training keyframes frames.
+ Args:
+ pt_cloud (o3d.geometry.PointCloud): The merged point cloud used for refinement.
+ training_frames (list): A list of training frames for map refinement.
+ max_iterations (int): The maximum number of iterations to perform for refinement.
+ Returns:
+ GaussianModel: The refined global map as a Gaussian model.
+ """
+ opt_params = OptimizationParams(ArgumentParser(description="Training script parameters"))
+
+ gaussian_model = GaussianModel(3)
+ gaussian_model.active_sh_degree = 3
+ gaussian_model.training_setup(opt_params)
+ gaussian_model.add_points(pt_cloud)
+
+ iteration = 0
+ for iteration in tqdm(range(max_iterations), desc="Refinement"):
+ training_frame = next(training_frames)
+ gt_color, gt_depth, render_settings = (
+ training_frame["color"].squeeze(0),
+ training_frame["depth"].squeeze(0),
+ training_frame["render_settings"])
+
+ render_dict = render_gaussian_model(gaussian_model, render_settings)
+ rendered_color, rendered_depth = (render_dict["color"].permute(1, 2, 0), render_dict["depth"])
+
+ reg_loss = isotropic_loss(gaussian_model.get_scaling())
+ depth_mask = (gt_depth > 0)
+ color_loss = (1.0 - opt_params.lambda_dssim) * l1_loss(
+ rendered_color[depth_mask, :], gt_color[depth_mask, :]
+ ) + opt_params.lambda_dssim * (1.0 - ssim(rendered_color, gt_color))
+ depth_loss = l1_loss(
+ rendered_depth[:, depth_mask], gt_depth[depth_mask])
+
+ total_loss = color_loss + depth_loss + reg_loss
+ total_loss.backward()
+
+ with torch.no_grad():
+ if iteration % 500 == 0:
+ prune_mask = (gaussian_model.get_opacity() < 0.005).squeeze()
+ gaussian_model.prune_points(prune_mask)
+
+ # Optimizer step
+ gaussian_model.optimizer.step()
+ gaussian_model.optimizer.zero_grad(set_to_none=True)
+ iteration += 1
+
+ return gaussian_model
diff --git a/src/evaluation/evaluate_reconstruction.py b/src/evaluation/evaluate_reconstruction.py
new file mode 100644
index 0000000..cac9b2b
--- /dev/null
+++ b/src/evaluation/evaluate_reconstruction.py
@@ -0,0 +1,289 @@
+import json
+import random
+from pathlib import Path
+
+import numpy as np
+import open3d as o3d
+import torch
+import trimesh
+from evaluate_3d_reconstruction import run_evaluation
+from tqdm import tqdm
+
+
+def normalize(x):
+ return x / np.linalg.norm(x)
+
+
+def get_align_transformation(rec_meshfile, gt_meshfile):
+ """
+ Get the transformation matrix to align the reconstructed mesh to the ground truth mesh.
+ """
+ o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
+ o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
+ o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices)
+ o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices)
+ trans_init = np.eye(4)
+ threshold = 0.1
+ reg_p2p = o3d.pipelines.registration.registration_icp(
+ o3d_rec_pc,
+ o3d_gt_pc,
+ threshold,
+ trans_init,
+ o3d.pipelines.registration.TransformationEstimationPointToPoint(),
+ )
+ transformation = reg_p2p.transformation
+ return transformation
+
+
+def check_proj(points, W, H, fx, fy, cx, cy, c2w):
+ """
+ Check if points can be projected into the camera view.
+
+ Returns:
+ bool: True if there are points can be projected
+
+ """
+ c2w = c2w.copy()
+ c2w[:3, 1] *= -1.0
+ c2w[:3, 2] *= -1.0
+ points = torch.from_numpy(points).cuda().clone()
+ w2c = np.linalg.inv(c2w)
+ w2c = torch.from_numpy(w2c).cuda().float()
+ K = torch.from_numpy(
+ np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]]).reshape(3, 3)
+ ).cuda()
+ ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda()
+ homo_points = (
+ torch.cat([points, ones], dim=1).reshape(-1, 4, 1).cuda().float()
+ ) # (N, 4)
+ cam_cord_homo = w2c @ homo_points # (N, 4, 1)=(4,4)*(N, 4, 1)
+ cam_cord = cam_cord_homo[:, :3] # (N, 3, 1)
+ cam_cord[:, 0] *= -1
+ uv = K.float() @ cam_cord.float()
+ z = uv[:, -1:] + 1e-5
+ uv = uv[:, :2] / z
+ uv = uv.float().squeeze(-1).cpu().numpy()
+ edge = 0
+ mask = (
+ (0 <= -z[:, 0, 0].cpu().numpy())
+ & (uv[:, 0] < W - edge)
+ & (uv[:, 0] > edge)
+ & (uv[:, 1] < H - edge)
+ & (uv[:, 1] > edge)
+ )
+ return mask.sum() > 0
+
+
+def get_cam_position(gt_meshfile):
+ mesh_gt = trimesh.load(gt_meshfile)
+ to_origin, extents = trimesh.bounds.oriented_bounds(mesh_gt)
+ extents[2] *= 0.7
+ extents[1] *= 0.7
+ extents[0] *= 0.3
+ transform = np.linalg.inv(to_origin)
+ transform[2, 3] += 0.4
+ return extents, transform
+
+
+def viewmatrix(z, up, pos):
+ vec2 = normalize(z)
+ vec1_avg = up
+ vec0 = normalize(np.cross(vec1_avg, vec2))
+ vec1 = normalize(np.cross(vec2, vec0))
+ m = np.stack([vec0, vec1, vec2, pos], 1)
+ return m
+
+
+def calc_2d_metric(
+ rec_meshfile, gt_meshfile, unseen_gt_pointcloud_file, align=True, n_imgs=1000
+):
+ """
+ 2D reconstruction metric, depth L1 loss.
+
+ """
+ H = 500
+ W = 500
+ focal = 300
+ fx = focal
+ fy = focal
+ cx = H / 2.0 - 0.5
+ cy = W / 2.0 - 0.5
+
+ gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
+ rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
+ pc_unseen = np.load(unseen_gt_pointcloud_file)
+ if align:
+ transformation = get_align_transformation(rec_meshfile, gt_meshfile)
+ rec_mesh = rec_mesh.transform(transformation)
+
+ # get vacant area inside the room
+ extents, transform = get_cam_position(gt_meshfile)
+
+ vis = o3d.visualization.Visualizer()
+ vis.create_window(width=W, height=H, visible=False)
+ vis.get_render_option().mesh_show_back_face = True
+ errors = []
+ for i in tqdm(range(n_imgs)):
+ while True:
+ # sample view, and check if unseen region is not inside the camera view
+ # if inside, then needs to resample
+ up = [0, 0, -1]
+ origin = trimesh.sample.volume_rectangular(
+ extents, 1, transform=transform)
+ origin = origin.reshape(-1)
+ tx = round(random.uniform(-10000, +10000), 2)
+ ty = round(random.uniform(-10000, +10000), 2)
+ tz = round(random.uniform(-10000, +10000), 2)
+ # will be normalized, so sample from range [0.0,1.0]
+ target = [tx, ty, tz]
+ target = np.array(target) - np.array(origin)
+ c2w = viewmatrix(target, up, origin)
+ tmp = np.eye(4)
+ tmp[:3, :] = c2w # sample translations
+ c2w = tmp
+ # if unseen points are projected into current view (c2w)
+ seen = check_proj(pc_unseen, W, H, fx, fy, cx, cy, c2w)
+ if ~seen:
+ break
+
+ param = o3d.camera.PinholeCameraParameters()
+ param.extrinsic = np.linalg.inv(c2w) # 4x4 numpy array
+
+ param.intrinsic = o3d.camera.PinholeCameraIntrinsic(
+ W, H, fx, fy, cx, cy)
+
+ ctr = vis.get_view_control()
+ ctr.set_constant_z_far(20)
+ ctr.convert_from_pinhole_camera_parameters(param)
+
+ vis.add_geometry(
+ gt_mesh,
+ reset_bounding_box=True,
+ )
+ ctr.convert_from_pinhole_camera_parameters(param)
+ vis.poll_events()
+ vis.update_renderer()
+ gt_depth = vis.capture_depth_float_buffer(True)
+ gt_depth = np.asarray(gt_depth)
+ vis.remove_geometry(
+ gt_mesh,
+ reset_bounding_box=True,
+ )
+
+ vis.add_geometry(
+ rec_mesh,
+ reset_bounding_box=True,
+ )
+ ctr.convert_from_pinhole_camera_parameters(param)
+ vis.poll_events()
+ vis.update_renderer()
+ ours_depth = vis.capture_depth_float_buffer(True)
+ ours_depth = np.asarray(ours_depth)
+ vis.remove_geometry(
+ rec_mesh,
+ reset_bounding_box=True,
+ )
+
+ # filter missing surfaces where depth is 0
+ if (ours_depth > 0).sum() > 0:
+ errors += [
+ np.abs(gt_depth[ours_depth > 0] -
+ ours_depth[ours_depth > 0]).mean()
+ ]
+ else:
+ continue
+
+ errors = np.array(errors)
+ return {"depth l1": errors.mean() * 100}
+
+
+def clean_mesh(mesh):
+ mesh_tri = trimesh.Trimesh(
+ vertices=np.asarray(mesh.vertices),
+ faces=np.asarray(mesh.triangles),
+ vertex_colors=np.asarray(mesh.vertex_colors),
+ )
+ components = trimesh.graph.connected_components(
+ edges=mesh_tri.edges_sorted)
+
+ min_len = 200
+ components_to_keep = [c for c in components if len(c) >= min_len]
+
+ new_vertices = []
+ new_faces = []
+ new_colors = []
+ vertex_count = 0
+ for component in components_to_keep:
+ vertices = mesh_tri.vertices[component]
+ colors = mesh_tri.visual.vertex_colors[component]
+
+ # Create a mapping from old vertex indices to new vertex indices
+ index_mapping = {
+ old_idx: vertex_count + new_idx for new_idx, old_idx in enumerate(component)
+ }
+ vertex_count += len(vertices)
+
+ # Select faces that are part of the current connected component and update vertex indices
+ faces_in_component = mesh_tri.faces[
+ np.any(np.isin(mesh_tri.faces, component), axis=1)
+ ]
+ reindexed_faces = np.vectorize(index_mapping.get)(faces_in_component)
+
+ new_vertices.extend(vertices)
+ new_faces.extend(reindexed_faces)
+ new_colors.extend(colors)
+
+ cleaned_mesh_tri = trimesh.Trimesh(vertices=new_vertices, faces=new_faces)
+ cleaned_mesh_tri.visual.vertex_colors = np.array(new_colors)
+
+ cleaned_mesh_tri.update_faces(cleaned_mesh_tri.nondegenerate_faces())
+ cleaned_mesh_tri.update_faces(cleaned_mesh_tri.unique_faces())
+ print(
+ f"Mesh cleaning (before/after), vertices: {len(mesh_tri.vertices)}/{len(cleaned_mesh_tri.vertices)}, faces: {len(mesh_tri.faces)}/{len(cleaned_mesh_tri.faces)}")
+
+ cleaned_mesh = o3d.geometry.TriangleMesh(
+ o3d.utility.Vector3dVector(cleaned_mesh_tri.vertices),
+ o3d.utility.Vector3iVector(cleaned_mesh_tri.faces),
+ )
+ vertex_colors = np.asarray(cleaned_mesh_tri.visual.vertex_colors)[
+ :, :3] / 255.0
+ cleaned_mesh.vertex_colors = o3d.utility.Vector3dVector(
+ vertex_colors.astype(np.float64)
+ )
+
+ return cleaned_mesh
+
+
+def evaluate_reconstruction(
+ mesh_path: Path,
+ gt_mesh_path: Path,
+ unseen_pc_path: Path,
+ output_path: Path,
+ to_clean=True,
+):
+ if to_clean:
+ mesh = o3d.io.read_triangle_mesh(str(mesh_path))
+ print(mesh)
+ cleaned_mesh = clean_mesh(mesh)
+ cleaned_mesh_path = output_path / "mesh" / "cleaned_mesh.ply"
+ o3d.io.write_triangle_mesh(str(cleaned_mesh_path), cleaned_mesh)
+ mesh_path = cleaned_mesh_path
+
+ result_3d = run_evaluation(
+ str(mesh_path.parts[-1]),
+ str(mesh_path.parent),
+ str(gt_mesh_path).split("/")[-1].split(".")[0],
+ distance_thresh=0.01,
+ full_path_to_gt_ply=gt_mesh_path,
+ icp_align=True,
+ )
+
+ try:
+ result_2d = calc_2d_metric(str(mesh_path), str(gt_mesh_path), str(unseen_pc_path), align=True, n_imgs=1000)
+ except Exception as e:
+ print(e)
+ result_2d = {"depth l1": None}
+
+ result = {**result_3d, **result_2d}
+ with open(str(output_path / "reconstruction_metrics.json"), "w") as f:
+ json.dump(result, f)
diff --git a/src/evaluation/evaluate_trajectory.py b/src/evaluation/evaluate_trajectory.py
new file mode 100644
index 0000000..f557fbc
--- /dev/null
+++ b/src/evaluation/evaluate_trajectory.py
@@ -0,0 +1,130 @@
+import json
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+class NumpyFloatValuesEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, np.float32):
+ return float(obj)
+ return JSONEncoder.default(self, obj)
+
+
+def align(model, data):
+ """Align two trajectories using the method of Horn (closed-form).
+
+ Input:
+ model -- first trajectory (3xn)
+ data -- second trajectory (3xn)
+
+ Output:
+ rot -- rotation matrix (3x3)
+ trans -- translation vector (3x1)
+ trans_error -- translational error per point (1xn)
+
+ """
+ np.set_printoptions(precision=3, suppress=True)
+ model_zerocentered = model - model.mean(1)
+ data_zerocentered = data - data.mean(1)
+
+ W = np.zeros((3, 3))
+ for column in range(model.shape[1]):
+ W += np.outer(model_zerocentered[:,
+ column], data_zerocentered[:, column])
+ U, d, Vh = np.linalg.linalg.svd(W.transpose())
+ S = np.matrix(np.identity(3))
+ if (np.linalg.det(U) * np.linalg.det(Vh) < 0):
+ S[2, 2] = -1
+ rot = U * S * Vh
+ trans = data.mean(1) - rot * model.mean(1)
+
+ model_aligned = rot * model + trans
+ alignment_error = model_aligned - data
+
+ trans_error = np.sqrt(
+ np.sum(np.multiply(alignment_error, alignment_error), 0)).A[0]
+
+ return rot, trans, trans_error
+
+
+def align_trajectories(t_pred: np.ndarray, t_gt: np.ndarray):
+ """
+ Args:
+ t_pred: (n, 3) translations
+ t_gt: (n, 3) translations
+ Returns:
+ t_align: (n, 3) aligned translations
+ """
+ t_align = np.matrix(t_pred).transpose()
+ R, t, _ = align(t_align, np.matrix(t_gt).transpose())
+ t_align = R * t_align + t
+ t_align = np.asarray(t_align).T
+ return t_align
+
+
+def pose_error(t_pred: np.ndarray, t_gt: np.ndarray, align=False):
+ """
+ Args:
+ t_pred: (n, 3) translations
+ t_gt: (n, 3) translations
+ Returns:
+ dict: error dict
+ """
+ n = t_pred.shape[0]
+ trans_error = np.linalg.norm(t_pred - t_gt, axis=1)
+ return {
+ "compared_pose_pairs": n,
+ "rmse": np.sqrt(np.dot(trans_error, trans_error) / n),
+ "mean": np.mean(trans_error),
+ "median": np.median(trans_error),
+ "std": np.std(trans_error),
+ "min": np.min(trans_error),
+ "max": np.max(trans_error)
+ }
+
+
+def plot_2d(pts, ax=None, color="green", label="None", title="3D Trajectory in 2D"):
+ if ax is None:
+ _, ax = plt.subplots()
+ ax.scatter(pts[:, 0], pts[:, 1], color=color, label=label, s=0.7)
+ ax.set_xlabel('X')
+ ax.set_ylabel('Y')
+ ax.set_title(title)
+ return ax
+
+
+def evaluate_trajectory(estimated_poses: np.ndarray, gt_poses: np.ndarray, output_path: Path):
+ output_path.mkdir(exist_ok=True, parents=True)
+ # Truncate the ground truth trajectory if needed
+ if gt_poses.shape[0] > estimated_poses.shape[0]:
+ gt_poses = gt_poses[:estimated_poses.shape[0]]
+ valid = ~np.any(np.isnan(gt_poses) |
+ np.isinf(gt_poses), axis=(1, 2))
+ gt_poses = gt_poses[valid]
+ estimated_poses = estimated_poses[valid]
+
+ gt_t = gt_poses[:, :3, 3]
+ estimated_t = estimated_poses[:, :3, 3]
+ estimated_t_aligned = align_trajectories(estimated_t, gt_t)
+ ate = pose_error(estimated_t, gt_t)
+ ate_aligned = pose_error(estimated_t_aligned, gt_t)
+
+ with open(str(output_path / "ate.json"), "w") as f:
+ f.write(json.dumps(ate, cls=NumpyFloatValuesEncoder))
+
+ with open(str(output_path / "ate_aligned.json"), "w") as f:
+ f.write(json.dumps(ate_aligned, cls=NumpyFloatValuesEncoder))
+
+ ate_rmse, ate_rmse_aligned = ate["rmse"], ate_aligned["rmse"]
+ ax = plot_2d(
+ estimated_t, label=f"ate-rmse: {round(ate_rmse * 100, 2)} cm", color="orange")
+ ax = plot_2d(estimated_t_aligned, ax,
+ label=f"ate-rsme (aligned): {round(ate_rmse_aligned * 100, 2)} cm", color="lightskyblue")
+ ax = plot_2d(gt_t, ax, label="GT", color="green")
+ ax.legend()
+ plt.savefig(str(output_path / "eval_trajectory.png"), dpi=300)
+ plt.close()
+ print(
+ f"ATE-RMSE: {ate_rmse * 100:.2f} cm, ATE-RMSE (aligned): {ate_rmse_aligned * 100:.2f} cm")
diff --git a/src/evaluation/evaluator.py b/src/evaluation/evaluator.py
new file mode 100644
index 0000000..7170a95
--- /dev/null
+++ b/src/evaluation/evaluator.py
@@ -0,0 +1,276 @@
+""" This module is responsible for evaluating rendering, trajectory and reconstruction metrics"""
+import traceback
+from argparse import ArgumentParser
+from copy import deepcopy
+from itertools import cycle
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import open3d as o3d
+import torch
+import torchvision
+from pytorch_msssim import ms_ssim
+from scipy.ndimage import median_filter
+from torch.utils.data import DataLoader
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+from torchvision.utils import save_image
+from tqdm import tqdm
+
+from src.entities.arguments import OptimizationParams
+from src.entities.datasets import get_dataset
+from src.entities.gaussian_model import GaussianModel
+from src.evaluation.evaluate_merged_map import (RenderFrames, merge_submaps,
+ refine_global_map)
+from src.evaluation.evaluate_reconstruction import evaluate_reconstruction
+from src.evaluation.evaluate_trajectory import evaluate_trajectory
+from src.utils.io_utils import load_config, save_dict_to_json
+from src.utils.mapper_utils import calc_psnr
+from src.utils.utils import (get_render_settings, np2torch,
+ render_gaussian_model, setup_seed, torch2np)
+
+
+def filter_depth_outliers(depth_map, kernel_size=3, threshold=1.0):
+ median_filtered = median_filter(depth_map, size=kernel_size)
+ abs_diff = np.abs(depth_map - median_filtered)
+ outlier_mask = abs_diff > threshold
+ depth_map_filtered = np.where(outlier_mask, median_filtered, depth_map)
+ return depth_map_filtered
+
+
+class Evaluator(object):
+
+ def __init__(self, checkpoint_path, config_path, config=None, save_render=False) -> None:
+ if config is None:
+ self.config = load_config(config_path)
+ else:
+ self.config = config
+ setup_seed(self.config["seed"])
+
+ self.checkpoint_path = Path(checkpoint_path)
+ self.device = "cuda"
+ self.dataset = get_dataset(self.config["dataset_name"])({**self.config["data"], **self.config["cam"]})
+ self.scene_name = self.config["data"]["scene_name"]
+ self.dataset_name = self.config["dataset_name"]
+ self.gt_poses = np.array(self.dataset.poses)
+ self.fx, self.fy = self.dataset.intrinsics[0, 0], self.dataset.intrinsics[1, 1]
+ self.cx, self.cy = self.dataset.intrinsics[0,
+ 2], self.dataset.intrinsics[1, 2]
+ self.width, self.height = self.dataset.width, self.dataset.height
+ self.save_render = save_render
+ if self.save_render:
+ self.render_path = self.checkpoint_path / "rendered_imgs"
+ self.render_path.mkdir(exist_ok=True, parents=True)
+
+ self.estimated_c2w = torch2np(torch.load(self.checkpoint_path / "estimated_c2w.ckpt", map_location=self.device))
+ self.submaps_paths = sorted(list((self.checkpoint_path / "submaps").glob('*')))
+
+ def run_trajectory_eval(self):
+ """ Evaluates the estimated trajectory """
+ print("Running trajectory evaluation...")
+ evaluate_trajectory(self.estimated_c2w, self.gt_poses, self.checkpoint_path)
+
+ def run_rendering_eval(self):
+ """ Renderes the submaps and evaluates the PSNR, LPIPS, SSIM and depth L1 metrics."""
+ print("Running rendering evaluation...")
+ psnr, lpips, ssim, depth_l1 = [], [], [], []
+ color_transform = torchvision.transforms.ToTensor()
+ lpips_model = LearnedPerceptualImagePatchSimilarity(
+ net_type='alex', normalize=True).to(self.device)
+ opt_settings = OptimizationParams(ArgumentParser(
+ description="Training script parameters"))
+
+ submaps_paths = sorted(
+ list((self.checkpoint_path / "submaps").glob('*.ckpt')))
+ for submap_path in tqdm(submaps_paths):
+ submap = torch.load(submap_path, map_location=self.device)
+ gaussian_model = GaussianModel()
+ gaussian_model.training_setup(opt_settings)
+ gaussian_model.restore_from_params(
+ submap["gaussian_params"], opt_settings)
+
+ for keyframe_id in submap["submap_keyframes"]:
+
+ _, gt_color, gt_depth, _ = self.dataset[keyframe_id]
+ gt_color = color_transform(gt_color).to(self.device)
+ gt_depth = np2torch(gt_depth).to(self.device)
+
+ estimate_c2w = self.estimated_c2w[keyframe_id]
+ estimate_w2c = np.linalg.inv(estimate_c2w)
+ render_dict = render_gaussian_model(
+ gaussian_model, get_render_settings(self.width, self.height, self.dataset.intrinsics, estimate_w2c))
+ rendered_color, rendered_depth = render_dict["color"].detach(
+ ), render_dict["depth"][0].detach()
+ rendered_color = torch.clamp(rendered_color, min=0.0, max=1.0)
+ if self.save_render:
+ torchvision.utils.save_image(
+ rendered_color, self.render_path / f"{keyframe_id:05d}.png")
+
+ mse_loss = torch.nn.functional.mse_loss(
+ rendered_color, gt_color)
+ psnr_value = (-10. * torch.log10(mse_loss)).item()
+ lpips_value = lpips_model(
+ rendered_color[None], gt_color[None]).item()
+ ssim_value = ms_ssim(
+ rendered_color[None], gt_color[None], data_range=1.0, size_average=True).item()
+ depth_l1_value = torch.abs(
+ (rendered_depth - gt_depth)).mean().item()
+
+ psnr.append(psnr_value)
+ lpips.append(lpips_value)
+ ssim.append(ssim_value)
+ depth_l1.append(depth_l1_value)
+
+ num_frames = len(psnr)
+ metrics = {
+ "psnr": sum(psnr) / num_frames,
+ "lpips": sum(lpips) / num_frames,
+ "ssim": sum(ssim) / num_frames,
+ "depth_l1_train_view": sum(depth_l1) / num_frames,
+ "num_renders": num_frames
+ }
+ save_dict_to_json(metrics, "rendering_metrics.json",
+ directory=self.checkpoint_path)
+
+ x = list(range(len(psnr)))
+ fig, axs = plt.subplots(1, 3, figsize=(12, 4))
+ axs[0].plot(x, psnr, label="PSNR")
+ axs[0].legend()
+ axs[0].set_title("PSNR")
+ axs[1].plot(x, ssim, label="SSIM")
+ axs[1].legend()
+ axs[1].set_title("SSIM")
+ axs[2].plot(x, depth_l1, label="Depth L1 (Train view)")
+ axs[2].legend()
+ axs[2].set_title("Depth L1 Render")
+ plt.tight_layout()
+ plt.savefig(str(self.checkpoint_path /
+ "rendering_metrics.png"), dpi=300)
+ print(metrics)
+
+ def run_reconstruction_eval(self):
+ """ Reconstructs the mesh, evaluates it, render novel view depth maps from it, and evaluates them as well """
+ print("Running reconstruction evaluation...")
+ if self.config["dataset_name"] != "replica":
+ print("dataset is not supported, skipping reconstruction eval")
+ return
+ (self.checkpoint_path / "mesh").mkdir(exist_ok=True, parents=True)
+ opt_settings = OptimizationParams(ArgumentParser(
+ description="Training script parameters"))
+ intrinsic = o3d.camera.PinholeCameraIntrinsic(
+ self.width, self.height, self.fx, self.fy, self.cx, self.cy)
+ scale = 1.0
+ volume = o3d.pipelines.integration.ScalableTSDFVolume(
+ voxel_length=5.0 * scale / 512.0,
+ sdf_trunc=0.04 * scale,
+ color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8)
+
+ submaps_paths = sorted(list((self.checkpoint_path / "submaps").glob('*.ckpt')))
+ for submap_path in tqdm(submaps_paths):
+ submap = torch.load(submap_path, map_location=self.device)
+ gaussian_model = GaussianModel()
+ gaussian_model.training_setup(opt_settings)
+ gaussian_model.restore_from_params(
+ submap["gaussian_params"], opt_settings)
+
+ for keyframe_id in submap["submap_keyframes"]:
+ estimate_c2w = self.estimated_c2w[keyframe_id]
+ estimate_w2c = np.linalg.inv(estimate_c2w)
+ render_dict = render_gaussian_model(
+ gaussian_model, get_render_settings(self.width, self.height, self.dataset.intrinsics, estimate_w2c))
+ rendered_color, rendered_depth = render_dict["color"].detach(
+ ), render_dict["depth"][0].detach()
+ rendered_color = torch.clamp(rendered_color, min=0.0, max=1.0)
+
+ rendered_color = (
+ torch2np(rendered_color.permute(1, 2, 0)) * 255).astype(np.uint8)
+ rendered_depth = torch2np(rendered_depth)
+ rendered_depth = filter_depth_outliers(
+ rendered_depth, kernel_size=20, threshold=0.1)
+
+ rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
+ o3d.geometry.Image(np.ascontiguousarray(rendered_color)),
+ o3d.geometry.Image(rendered_depth),
+ depth_scale=scale,
+ depth_trunc=30,
+ convert_rgb_to_intensity=False)
+ volume.integrate(rgbd, intrinsic, estimate_w2c)
+
+ o3d_mesh = volume.extract_triangle_mesh()
+ compensate_vector = (-0.0 * scale / 512.0, 2.5 *
+ scale / 512.0, -2.5 * scale / 512.0)
+ o3d_mesh = o3d_mesh.translate(compensate_vector)
+ file_name = self.checkpoint_path / "mesh" / "final_mesh.ply"
+ o3d.io.write_triangle_mesh(str(file_name), o3d_mesh)
+ evaluate_reconstruction(file_name,
+ f"data/Replica-SLAM/cull_replica/{self.scene_name}.ply",
+ f"data/Replica-SLAM/cull_replica/{self.scene_name}_pc_unseen.npy",
+ self.checkpoint_path)
+
+ def run_global_map_eval(self):
+ """ Merges the map, evaluates it over training and novel views """
+ print("Running global map evaluation...")
+
+ training_frames = RenderFrames(self.dataset, self.estimated_c2w, self.height, self.width, self.fx, self.fy)
+ training_frames = DataLoader(training_frames, batch_size=1, shuffle=True)
+ training_frames = cycle(training_frames)
+ merged_cloud = merge_submaps(self.submaps_paths)
+ refined_merged_gaussian_model = refine_global_map(merged_cloud, training_frames, 10000)
+ ply_path = self.checkpoint_path / f"{self.config['data']['scene_name']}_global_map.ply"
+ refined_merged_gaussian_model.save_ply(ply_path)
+ print(f"Refined global map saved to {ply_path}")
+
+ if self.config["dataset_name"] != "scannetpp":
+ return # "NVS evaluation only supported for scannetpp"
+
+ eval_config = deepcopy(self.config)
+ print(f"✨ Eval NVS for scene {self.config['data']['scene_name']}...")
+ (self.checkpoint_path / "nvs_eval").mkdir(exist_ok=True, parents=True)
+ eval_config["data"]["use_train_split"] = False
+ test_set = get_dataset(eval_config["dataset_name"])({**eval_config["data"], **eval_config["cam"]})
+ test_poses = torch.stack([torch.from_numpy(test_set[i][3]) for i in range(len(test_set))], dim=0)
+ test_frames = RenderFrames(test_set, test_poses, self.height, self.width, self.fx, self.fy)
+
+ psnr_list = []
+ for i in tqdm(range(len(test_set))):
+ gt_color, _, render_settings = (
+ test_frames[i]["color"],
+ test_frames[i]["depth"],
+ test_frames[i]["render_settings"])
+ render_dict = render_gaussian_model(refined_merged_gaussian_model, render_settings)
+ rendered_color, _ = (render_dict["color"].permute(1, 2, 0), render_dict["depth"],)
+ rendered_color = torch.clip(rendered_color, 0, 1)
+ save_image(rendered_color.permute(2, 0, 1), self.checkpoint_path / f"nvs_eval/{i:04d}.jpg")
+ psnr = calc_psnr(gt_color, rendered_color).mean()
+ psnr_list.append(psnr.item())
+ print(f"PSNR List: {psnr_list}")
+ print(f"Avg. NVS PSNR: {np.array(psnr_list).mean()}")
+
+ def run(self):
+ """ Runs the general evaluation flow """
+
+ print("Starting evaluation...🍺")
+
+ try:
+ self.run_trajectory_eval()
+ except Exception:
+ print("Could not run trajectory eval")
+ traceback.print_exc()
+
+ try:
+ self.run_rendering_eval()
+ except Exception:
+ print("Could not run rendering eval")
+ traceback.print_exc()
+
+ try:
+ self.run_reconstruction_eval()
+ except Exception:
+ print("Could not run reconstruction eval")
+ traceback.print_exc()
+
+ try:
+ self.run_global_map_eval()
+ except Exception:
+ print("Could not run global map eval")
+ traceback.print_exc()
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/utils/gaussian_model_utils.py b/src/utils/gaussian_model_utils.py
new file mode 100644
index 0000000..e05ba07
--- /dev/null
+++ b/src/utils/gaussian_model_utils.py
@@ -0,0 +1,212 @@
+# Copyright 2021 The PlenOctree Authors.
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+C0 = 0.28209479177387814
+C1 = 0.4886025119029199
+C2 = [
+ 1.0925484305920792,
+ -1.0925484305920792,
+ 0.31539156525252005,
+ -1.0925484305920792,
+ 0.5462742152960396
+]
+C3 = [
+ -0.5900435899266435,
+ 2.890611442640554,
+ -0.4570457994644658,
+ 0.3731763325901154,
+ -0.4570457994644658,
+ 1.445305721320277,
+ -0.5900435899266435
+]
+C4 = [
+ 2.5033429417967046,
+ -1.7701307697799304,
+ 0.9461746957575601,
+ -0.6690465435572892,
+ 0.10578554691520431,
+ -0.6690465435572892,
+ 0.47308734787878004,
+ -1.7701307697799304,
+ 0.6258357354491761,
+]
+
+
+def eval_sh(deg, sh, dirs):
+ """
+ Evaluate spherical harmonics at unit directions
+ using hardcoded SH polynomials.
+ Works with torch/np/jnp.
+ ... Can be 0 or more batch dimensions.
+ Args:
+ deg: int SH deg. Currently, 0-3 supported
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
+ dirs: jnp.ndarray unit directions [..., 3]
+ Returns:
+ [..., C]
+ """
+ assert deg <= 4 and deg >= 0
+ coeff = (deg + 1) ** 2
+ assert sh.shape[-1] >= coeff
+
+ result = C0 * sh[..., 0]
+ if deg > 0:
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
+ result = (result -
+ C1 * y * sh[..., 1] +
+ C1 * z * sh[..., 2] -
+ C1 * x * sh[..., 3])
+
+ if deg > 1:
+ xx, yy, zz = x * x, y * y, z * z
+ xy, yz, xz = x * y, y * z, x * z
+ result = (result +
+ C2[0] * xy * sh[..., 4] +
+ C2[1] * yz * sh[..., 5] +
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
+ C2[3] * xz * sh[..., 7] +
+ C2[4] * (xx - yy) * sh[..., 8])
+
+ if deg > 2:
+ result = (result +
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
+ C3[1] * xy * z * sh[..., 10] +
+ C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] +
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
+ C3[5] * z * (xx - yy) * sh[..., 14] +
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
+
+ if deg > 3:
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
+ return result
+
+
+def RGB2SH(rgb):
+ return (rgb - 0.5) / C0
+
+
+def SH2RGB(sh):
+ return sh * C0 + 0.5
+
+
+def inverse_sigmoid(x):
+ return torch.log(x/(1-x))
+
+
+def get_expon_lr_func(
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
+):
+ """
+ Copied from Plenoxels
+
+ Continuous learning rate decay function. Adapted from JaxNeRF
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
+ function of lr_delay_mult, such that the initial learning rate is
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
+ to the normal learning rate when steps>lr_delay_steps.
+ :param conf: config subtree 'lr' or similar
+ :param max_steps: int, the number of steps during optimization.
+ :return HoF which takes step as input
+ """
+
+ def helper(step):
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
+ # Disable this parameter
+ return 0.0
+ if lr_delay_steps > 0:
+ # A kind of reverse cosine decay.
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
+ )
+ else:
+ delay_rate = 1.0
+ t = np.clip(step / max_steps, 0, 1)
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
+ return delay_rate * log_lerp
+
+ return helper
+
+
+def strip_lowerdiag(L):
+ uncertainty = torch.zeros(
+ (L.shape[0], 6), dtype=torch.float, device="cuda")
+
+ uncertainty[:, 0] = L[:, 0, 0]
+ uncertainty[:, 1] = L[:, 0, 1]
+ uncertainty[:, 2] = L[:, 0, 2]
+ uncertainty[:, 3] = L[:, 1, 1]
+ uncertainty[:, 4] = L[:, 1, 2]
+ uncertainty[:, 5] = L[:, 2, 2]
+ return uncertainty
+
+
+def strip_symmetric(sym):
+ return strip_lowerdiag(sym)
+
+
+def build_rotation(r):
+
+ q = F.normalize(r, p=2, dim=1)
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
+
+ r = q[:, 0]
+ x = q[:, 1]
+ y = q[:, 2]
+ z = q[:, 3]
+
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
+ R[:, 0, 1] = 2 * (x*y - r*z)
+ R[:, 0, 2] = 2 * (x*z + r*y)
+ R[:, 1, 0] = 2 * (x*y + r*z)
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
+ R[:, 1, 2] = 2 * (y*z - r*x)
+ R[:, 2, 0] = 2 * (x*z - r*y)
+ R[:, 2, 1] = 2 * (y*z + r*x)
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
+ return R
+
+
+def build_scaling_rotation(s, r):
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
+ R = build_rotation(r)
+
+ L[:, 0, 0] = s[:, 0]
+ L[:, 1, 1] = s[:, 1]
+ L[:, 2, 2] = s[:, 2]
+
+ L = R @ L
+ return L
diff --git a/src/utils/io_utils.py b/src/utils/io_utils.py
new file mode 100644
index 0000000..07ad4fd
--- /dev/null
+++ b/src/utils/io_utils.py
@@ -0,0 +1,149 @@
+import json
+import os
+from pathlib import Path
+from typing import Union
+
+import open3d as o3d
+import torch
+import wandb
+import yaml
+
+
+def mkdir_decorator(func):
+ """A decorator that creates the directory specified in the function's 'directory' keyword
+ argument before calling the function.
+ Args:
+ func: The function to be decorated.
+ Returns:
+ The wrapper function.
+ """
+ def wrapper(*args, **kwargs):
+ output_path = Path(kwargs["directory"])
+ output_path.mkdir(parents=True, exist_ok=True)
+ return func(*args, **kwargs)
+ return wrapper
+
+
+@mkdir_decorator
+def save_clouds(clouds: list, cloud_names: list, *, directory: Union[str, Path]) -> None:
+ """ Saves a list of point clouds to the specified directory, creating the directory if it does not exist.
+ Args:
+ clouds: A list of point cloud objects to be saved.
+ cloud_names: A list of filenames for the point clouds, corresponding by index to the clouds.
+ directory: The directory where the point clouds will be saved.
+ """
+ for cld_name, cloud in zip(cloud_names, clouds):
+ o3d.io.write_point_cloud(str(directory / cld_name), cloud)
+
+
+@mkdir_decorator
+def save_dict_to_ckpt(dictionary, file_name: str, *, directory: Union[str, Path]) -> None:
+ """ Saves a dictionary to a checkpoint file in the specified directory, creating the directory if it does not exist.
+ Args:
+ dictionary: The dictionary to be saved.
+ file_name: The name of the checkpoint file.
+ directory: The directory where the checkpoint file will be saved.
+ """
+ torch.save(dictionary, directory / file_name,
+ _use_new_zipfile_serialization=False)
+
+
+@mkdir_decorator
+def save_dict_to_yaml(dictionary, file_name: str, *, directory: Union[str, Path]) -> None:
+ """ Saves a dictionary to a YAML file in the specified directory, creating the directory if it does not exist.
+ Args:
+ dictionary: The dictionary to be saved.
+ file_name: The name of the YAML file.
+ directory: The directory where the YAML file will be saved.
+ """
+ with open(directory / file_name, "w") as f:
+ yaml.dump(dictionary, f)
+
+
+@mkdir_decorator
+def save_dict_to_json(dictionary, file_name: str, *, directory: Union[str, Path]) -> None:
+ """ Saves a dictionary to a JSON file in the specified directory, creating the directory if it does not exist.
+ Args:
+ dictionary: The dictionary to be saved.
+ file_name: The name of the JSON file.
+ directory: The directory where the JSON file will be saved.
+ """
+ with open(directory / file_name, "w") as f:
+ json.dump(dictionary, f)
+
+
+def load_config(path: str, default_path: str = None) -> dict:
+ """
+ Loads a configuration file and optionally merges it with a default configuration file.
+
+ This function loads a configuration from the given path. If the configuration specifies an inheritance
+ path (`inherit_from`), or if a `default_path` is provided, it loads the base configuration and updates it
+ with the specific configuration.
+
+ Args:
+ path: The path to the specific configuration file.
+ default_path: An optional path to a default configuration file that is loaded if the specific configuration
+ does not specify an inheritance or as a base for the inheritance.
+
+ Returns:
+ A dictionary containing the merged configuration.
+ """
+ # load configuration from per scene/dataset cfg.
+ with open(path, 'r') as f:
+ cfg_special = yaml.full_load(f)
+ inherit_from = cfg_special.get('inherit_from')
+ cfg = dict()
+ if inherit_from is not None:
+ cfg = load_config(inherit_from, default_path)
+ elif default_path is not None:
+ with open(default_path, 'r') as f:
+ cfg = yaml.full_load(f)
+ update_recursive(cfg, cfg_special)
+ return cfg
+
+
+def update_recursive(dict1: dict, dict2: dict) -> None:
+ """ Recursively updates the first dictionary with the contents of the second dictionary.
+
+ This function iterates through `dict2` and updates `dict1` with its contents. If a key from `dict2`
+ exists in `dict1` and its value is also a dictionary, the function updates the value recursively.
+ Otherwise, it overwrites the value in `dict1` with the value from `dict2`.
+
+ Args:
+ dict1: The dictionary to be updated.
+ dict2: The dictionary whose entries are used to update `dict1`.
+
+ Returns:
+ None: The function modifies `dict1` in place.
+ """
+ for k, v in dict2.items():
+ if k not in dict1:
+ dict1[k] = dict()
+ if isinstance(v, dict):
+ update_recursive(dict1[k], v)
+ else:
+ dict1[k] = v
+
+
+def log_metrics_to_wandb(json_files: list, output_path: str, section: str = "Evaluation") -> None:
+ """ Logs metrics from JSON files to Weights & Biases under a specified section.
+
+ This function reads metrics from a list of JSON files and logs them to Weights & Biases (wandb).
+ Each metric is prefixed with a specified section name for organized logging.
+
+ Args:
+ json_files: A list of filenames for JSON files containing metrics to be logged.
+ output_path: The directory path where the JSON files are located.
+ section: The section under which to log the metrics in wandb. Defaults to "Evaluation".
+
+ Returns:
+ None: Metrics are logged to wandb and the function does not return a value.
+ """
+ for json_file in json_files:
+ file_path = os.path.join(output_path, json_file)
+ if os.path.exists(file_path):
+ with open(file_path, 'r') as file:
+ metrics = json.load(file)
+ prefixed_metrics = {
+ f"{section}/{key}": value for key, value in metrics.items()}
+ wandb.log(prefixed_metrics)
diff --git a/src/utils/mapper_utils.py b/src/utils/mapper_utils.py
new file mode 100644
index 0000000..7ff5131
--- /dev/null
+++ b/src/utils/mapper_utils.py
@@ -0,0 +1,336 @@
+
+import cv2
+import faiss
+import faiss.contrib.torch_utils
+import numpy as np
+import torch
+
+
+def compute_opt_views_distribution(keyframes_num, iterations_num, current_frame_iter) -> np.ndarray:
+ """ Computes the probability distribution for selecting views based on the current iteration.
+ Args:
+ keyframes_num: The total number of keyframes.
+ iterations_num: The total number of iterations planned.
+ current_frame_iter: The current iteration number.
+ Returns:
+ An array representing the probability distribution of keyframes.
+ """
+ if keyframes_num == 1:
+ return np.array([1.0])
+ prob = np.full(keyframes_num, (iterations_num - current_frame_iter) / (keyframes_num - 1))
+ prob[0] = current_frame_iter
+ prob /= prob.sum()
+ return prob
+
+
+def compute_camera_frustum_corners(depth_map: np.ndarray, pose: np.ndarray, intrinsics: np.ndarray) -> np.ndarray:
+ """ Computes the 3D coordinates of the camera frustum corners based on the depth map, pose, and intrinsics.
+ Args:
+ depth_map: The depth map of the scene.
+ pose: The camera pose matrix.
+ intrinsics: The camera intrinsic matrix.
+ Returns:
+ An array of 3D coordinates for the frustum corners.
+ """
+ height, width = depth_map.shape
+ depth_map = depth_map[depth_map > 0]
+ min_depth, max_depth = depth_map.min(), depth_map.max()
+ corners = np.array(
+ [
+ [0, 0, min_depth],
+ [width, 0, min_depth],
+ [0, height, min_depth],
+ [width, height, min_depth],
+ [0, 0, max_depth],
+ [width, 0, max_depth],
+ [0, height, max_depth],
+ [width, height, max_depth],
+ ]
+ )
+ x = (corners[:, 0] - intrinsics[0, 2]) * corners[:, 2] / intrinsics[0, 0]
+ y = (corners[:, 1] - intrinsics[1, 2]) * corners[:, 2] / intrinsics[1, 1]
+ z = corners[:, 2]
+ corners_3d = np.vstack((x, y, z, np.ones(x.shape[0]))).T
+ corners_3d = pose @ corners_3d.T
+ return corners_3d.T[:, :3]
+
+
+def compute_camera_frustum_planes(frustum_corners: np.ndarray) -> torch.Tensor:
+ """ Computes the planes of the camera frustum from its corners.
+ Args:
+ frustum_corners: An array of 3D coordinates representing the corners of the frustum.
+
+ Returns:
+ A tensor of frustum planes.
+ """
+ # near, far, left, right, top, bottom
+ planes = torch.stack(
+ [
+ torch.cross(
+ frustum_corners[2] - frustum_corners[0],
+ frustum_corners[1] - frustum_corners[0],
+ ),
+ torch.cross(
+ frustum_corners[6] - frustum_corners[4],
+ frustum_corners[5] - frustum_corners[4],
+ ),
+ torch.cross(
+ frustum_corners[4] - frustum_corners[0],
+ frustum_corners[2] - frustum_corners[0],
+ ),
+ torch.cross(
+ frustum_corners[7] - frustum_corners[3],
+ frustum_corners[1] - frustum_corners[3],
+ ),
+ torch.cross(
+ frustum_corners[5] - frustum_corners[1],
+ frustum_corners[3] - frustum_corners[1],
+ ),
+ torch.cross(
+ frustum_corners[6] - frustum_corners[2],
+ frustum_corners[0] - frustum_corners[2],
+ ),
+ ]
+ )
+ D = torch.stack([-torch.dot(plane, frustum_corners[i]) for i, plane in enumerate(planes)])
+ return torch.cat([planes, D[:, None]], dim=1).float()
+
+
+def compute_frustum_aabb(frustum_corners: torch.Tensor):
+ """ Computes a mask indicating which points lie inside a given axis-aligned bounding box (AABB).
+ Args:
+ points: An array of 3D points.
+ min_corner: The minimum corner of the AABB.
+ max_corner: The maximum corner of the AABB.
+ Returns:
+ A boolean array indicating whether each point lies inside the AABB.
+ """
+ return torch.min(frustum_corners, axis=0).values, torch.max(frustum_corners, axis=0).values
+
+
+def points_inside_aabb_mask(points: np.ndarray, min_corner: np.ndarray, max_corner: np.ndarray) -> np.ndarray:
+ """ Computes a mask indicating which points lie inside the camera frustum.
+ Args:
+ points: A tensor of 3D points.
+ frustum_planes: A tensor representing the planes of the frustum.
+ Returns:
+ A boolean tensor indicating whether each point lies inside the frustum.
+ """
+ return (
+ (points[:, 0] >= min_corner[0])
+ & (points[:, 0] <= max_corner[0])
+ & (points[:, 1] >= min_corner[1])
+ & (points[:, 1] <= max_corner[1])
+ & (points[:, 2] >= min_corner[2])
+ & (points[:, 2] <= max_corner[2]))
+
+
+def points_inside_frustum_mask(points: torch.Tensor, frustum_planes: torch.Tensor) -> torch.Tensor:
+ """ Computes a mask indicating which points lie inside the camera frustum.
+ Args:
+ points: A tensor of 3D points.
+ frustum_planes: A tensor representing the planes of the frustum.
+ Returns:
+ A boolean tensor indicating whether each point lies inside the frustum.
+ """
+ num_pts = points.shape[0]
+ ones = torch.ones(num_pts, 1).to(points.device)
+ plane_product = torch.cat([points, ones], axis=1) @ frustum_planes.T
+ return torch.all(plane_product <= 0, axis=1)
+
+
+def compute_frustum_point_ids(pts: torch.Tensor, frustum_corners: torch.Tensor, device: str = "cuda"):
+ """ Identifies points within the camera frustum, optimizing for computation on a specified device.
+ Args:
+ pts: A tensor of 3D points.
+ frustum_corners: A tensor of 3D coordinates representing the corners of the frustum.
+ device: The computation device ("cuda" or "cpu").
+ Returns:
+ Indices of points lying inside the frustum.
+ """
+ if pts.shape[0] == 0:
+ return torch.tensor([], dtype=torch.int64, device=device)
+ # Broad phase
+ pts = pts.to(device)
+ frustum_corners = frustum_corners.to(device)
+
+ min_corner, max_corner = compute_frustum_aabb(frustum_corners)
+ inside_aabb_mask = points_inside_aabb_mask(pts, min_corner, max_corner)
+
+ # Narrow phase
+ frustum_planes = compute_camera_frustum_planes(frustum_corners)
+ frustum_planes = frustum_planes.to(device)
+ inside_frustum_mask = points_inside_frustum_mask(pts[inside_aabb_mask], frustum_planes)
+
+ inside_aabb_mask[inside_aabb_mask == 1] = inside_frustum_mask
+ return torch.where(inside_aabb_mask)[0]
+
+
+def sample_pixels_based_on_gradient(image: np.ndarray, num_samples: int) -> np.ndarray:
+ """ Samples pixel indices based on the gradient magnitude of an image.
+ Args:
+ image: The image from which to sample pixels.
+ num_samples: The number of pixels to sample.
+ Returns:
+ Indices of the sampled pixels.
+ """
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ grad_magnitude = cv2.magnitude(grad_x, grad_y)
+
+ # Normalize the gradient magnitude to create a probability map
+ prob_map = grad_magnitude / np.sum(grad_magnitude)
+
+ # Flatten the probability map
+ prob_map_flat = prob_map.flatten()
+
+ # Sample pixel indices based on the probability map
+ sampled_indices = np.random.choice(prob_map_flat.size, size=num_samples, p=prob_map_flat)
+ return sampled_indices.T
+
+
+def compute_new_points_ids(frustum_points: torch.Tensor, new_pts: torch.Tensor,
+ radius: float = 0.03, device: str = "cpu") -> torch.Tensor:
+ """ Having newly initialized points, decides which of them should be added to the submap.
+ For every new point, if there are no neighbors within the radius in the frustum points,
+ it is added to the submap.
+ Args:
+ frustum_points: Point within a current frustum of the active submap of shape (N, 3)
+ new_pts: New 3D Gaussian means which are about to be added to the submap of shape (N, 3)
+ radius: Radius whithin which the points are considered to be neighbors
+ device: Execution device
+ Returns:
+ Indicies of the new points that should be added to the submap of shape (N)
+ """
+ if frustum_points.shape[0] == 0:
+ return torch.arange(new_pts.shape[0])
+ if device == "cpu":
+ pts_index = faiss.IndexFlatL2(3)
+ else:
+ pts_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss.IndexFlatL2(3))
+ frustum_points = frustum_points.to(device)
+ new_pts = new_pts.to(device)
+ pts_index.add(frustum_points)
+
+ split_pos = torch.split(new_pts, 65535, dim=0)
+ distances, ids = [], []
+ for split_p in split_pos:
+ distance, id = pts_index.search(split_p.float(), 8)
+ distances.append(distance)
+ ids.append(id)
+ distances = torch.cat(distances, dim=0)
+ ids = torch.cat(ids, dim=0)
+ neighbor_num = (distances < radius).sum(axis=1).int()
+ pts_index.reset()
+ return torch.where(neighbor_num == 0)[0]
+
+
+def rotation_to_euler(R: torch.Tensor) -> torch.Tensor:
+ """
+ Converts a rotation matrix to Euler angles.
+ Args:
+ R: A rotation matrix.
+ Returns:
+ Euler angles corresponding to the rotation matrix.
+ """
+ sy = torch.sqrt(R[0, 0] ** 2 + R[1, 0] ** 2)
+ singular = sy < 1e-6
+
+ if not singular:
+ x = torch.atan2(R[2, 1], R[2, 2])
+ y = torch.atan2(-R[2, 0], sy)
+ z = torch.atan2(R[1, 0], R[0, 0])
+ else:
+ x = torch.atan2(-R[1, 2], R[1, 1])
+ y = torch.atan2(-R[2, 0], sy)
+ z = 0
+
+ return torch.tensor([x, y, z]) * (180 / np.pi)
+
+
+def exceeds_motion_thresholds(current_c2w: torch.Tensor, last_submap_c2w: torch.Tensor,
+ rot_thre: float = 50, trans_thre: float = 0.5) -> bool:
+ """ Checks if a camera motion exceeds certain rotation and translation thresholds
+ Args:
+ current_c2w: The current camera-to-world transformation matrix.
+ last_submap_c2w: The last submap's camera-to-world transformation matrix.
+ rot_thre: The rotation threshold for triggering a new submap.
+ trans_thre: The translation threshold for triggering a new submap.
+
+ Returns:
+ A boolean indicating whether a new submap is required.
+ """
+ delta_pose = torch.matmul(torch.linalg.inv(last_submap_c2w).float(), current_c2w.float())
+ translation_diff = torch.norm(delta_pose[:3, 3])
+ rot_euler_diff_deg = torch.abs(rotation_to_euler(delta_pose[:3, :3]))
+ exceeds_thresholds = (translation_diff > trans_thre) or torch.any(rot_euler_diff_deg > rot_thre)
+ return exceeds_thresholds.item()
+
+
+def geometric_edge_mask(rgb_image: np.ndarray, dilate: bool = True, RGB: bool = False) -> np.ndarray:
+ """ Computes an edge mask for an RGB image using geometric edges.
+ Args:
+ rgb_image: The RGB image.
+ dilate: Whether to dilate the edges.
+ RGB: Indicates if the image format is RGB (True) or BGR (False).
+ Returns:
+ An edge mask of the input image.
+ """
+ # Convert the image to grayscale as Canny edge detection requires a single channel image
+ gray_image = cv2.cvtColor(
+ rgb_image, cv2.COLOR_BGR2GRAY if not RGB else cv2.COLOR_RGB2GRAY)
+ if gray_image.dtype != np.uint8:
+ gray_image = gray_image.astype(np.uint8)
+ edges = cv2.Canny(gray_image, threshold1=100, threshold2=200, apertureSize=3, L2gradient=True)
+ # Define the structuring element for dilation, you can change the size for a thicker/thinner mask
+ if dilate:
+ kernel = np.ones((2, 2), np.uint8)
+ edges = cv2.dilate(edges, kernel, iterations=1)
+ return edges
+
+
+def calc_psnr(img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
+ """ Calculates the Peak Signal-to-Noise Ratio (PSNR) between two images.
+ Args:
+ img1: The first image.
+ img2: The second image.
+ Returns:
+ The PSNR value.
+ """
+ mse = ((img1 - img2) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
+
+
+def create_point_cloud(image: np.ndarray, depth: np.ndarray, intrinsics: np.ndarray, pose: np.ndarray) -> np.ndarray:
+ """
+ Creates a point cloud from an image, depth map, camera intrinsics, and pose.
+
+ Args:
+ image: The RGB image of shape (H, W, 3)
+ depth: The depth map of shape (H, W)
+ intrinsics: The camera intrinsic parameters of shape (3, 3)
+ pose: The camera pose of shape (4, 4)
+ Returns:
+ A point cloud of shape (N, 6) with last dimension representing (x, y, z, r, g, b)
+ """
+ height, width = depth.shape
+ # Create a mesh grid of pixel coordinates
+ u, v = np.meshgrid(np.arange(width), np.arange(height))
+ # Convert pixel coordinates to camera coordinates
+ x = (u - intrinsics[0, 2]) * depth / intrinsics[0, 0]
+ y = (v - intrinsics[1, 2]) * depth / intrinsics[1, 1]
+ z = depth
+ # Stack the coordinates together
+ points = np.stack((x, y, z, np.ones_like(z)), axis=-1)
+ # Reshape the coordinates for matrix multiplication
+ points = points.reshape(-1, 4)
+ # Transform points to world coordinates
+ posed_points = pose @ points.T
+ posed_points = posed_points.T[:, :3]
+ # Flatten the image to get colors for each point
+ colors = image.reshape(-1, 3)
+ # Concatenate posed points with their corresponding color
+ point_cloud = np.concatenate((posed_points, colors), axis=-1)
+
+ return point_cloud
diff --git a/src/utils/tracker_utils.py b/src/utils/tracker_utils.py
new file mode 100644
index 0000000..b1092b9
--- /dev/null
+++ b/src/utils/tracker_utils.py
@@ -0,0 +1,93 @@
+import numpy as np
+import torch
+from scipy.spatial.transform import Rotation
+from typing import Union
+from src.utils.utils import np2torch
+
+
+def multiply_quaternions(q: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
+ """Performs batch-wise quaternion multiplication.
+
+ Given two quaternions, this function computes their product. The operation is
+ vectorized and can be performed on batches of quaternions.
+
+ Args:
+ q: A tensor representing the first quaternion or a batch of quaternions.
+ Expected shape is (... , 4), where the last dimension contains quaternion components (w, x, y, z).
+ r: A tensor representing the second quaternion or a batch of quaternions with the same shape as q.
+ Returns:
+ A tensor of the same shape as the input tensors, representing the product of the input quaternions.
+ """
+ w0, x0, y0, z0 = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
+ w1, x1, y1, z1 = r[..., 0], r[..., 1], r[..., 2], r[..., 3]
+
+ w = -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0
+ x = x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0
+ y = -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0
+ z = x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0
+ return torch.stack((w, x, y, z), dim=-1)
+
+
+def transformation_to_quaternion(RT: Union[torch.Tensor, np.ndarray]):
+ """ Converts a rotation-translation matrix to a tensor representing quaternion and translation.
+
+ This function takes a 3x4 transformation matrix (rotation and translation) and converts it
+ into a tensor that combines the quaternion representation of the rotation and the translation vector.
+
+ Args:
+ RT: A 3x4 matrix representing the rotation and translation. This can be a NumPy array
+ or a torch.Tensor. If it's a torch.Tensor and resides on a GPU, it will be moved to CPU.
+
+ Returns:
+ A tensor combining the quaternion (in w, x, y, z order) and translation vector. The tensor
+ will be moved to the original device if the input was a GPU tensor.
+ """
+ gpu_id = -1
+ if isinstance(RT, torch.Tensor):
+ if RT.get_device() != -1:
+ RT = RT.detach().cpu()
+ gpu_id = RT.get_device()
+ RT = RT.numpy()
+ R, T = RT[:3, :3], RT[:3, 3]
+
+ rot = Rotation.from_matrix(R)
+ quad = rot.as_quat(canonical=True)
+ quad = np.roll(quad, 1)
+ tensor = np.concatenate([quad, T], 0)
+ tensor = torch.from_numpy(tensor).float()
+ if gpu_id != -1:
+ tensor = tensor.to(gpu_id)
+ return tensor
+
+
+def interpolate_poses(poses: np.ndarray) -> np.ndarray:
+ """ Generates an interpolated pose based on the first two poses in the given array.
+ Args:
+ poses: An array of poses, where each pose is represented by a 4x4 transformation matrix.
+ Returns:
+ A 4x4 numpy ndarray representing the interpolated transformation matrix.
+ """
+ quat_poses = Rotation.from_matrix(poses[:, :3, :3]).as_quat()
+ init_rot = quat_poses[1] + (quat_poses[1] - quat_poses[0])
+ init_trans = poses[1, :3, 3] + (poses[1, :3, 3] - poses[0, :3, 3])
+ init_transformation = np.eye(4)
+ init_transformation[:3, :3] = Rotation.from_quat(init_rot).as_matrix()
+ init_transformation[:3, 3] = init_trans
+ return init_transformation
+
+
+def compute_camera_opt_params(estimate_rel_w2c: np.ndarray) -> tuple:
+ """ Computes the camera's rotation and translation parameters from an world-to-camera transformation matrix.
+ This function extracts the rotation component of the transformation matrix, converts it to a quaternion,
+ and reorders it to match a specific convention. Both rotation and translation parameters are converted
+ to torch Parameters and intended to be optimized in a PyTorch model.
+ Args:
+ estimate_rel_w2c: A 4x4 numpy ndarray representing the estimated world-to-camera transformation matrix.
+ Returns:
+ A tuple containing two torch.nn.Parameters: camera's rotation and camera's translation.
+ """
+ quaternion = Rotation.from_matrix(estimate_rel_w2c[:3, :3]).as_quat(canonical=True)
+ quaternion = quaternion[[3, 0, 1, 2]]
+ opt_cam_rot = torch.nn.Parameter(np2torch(quaternion, "cuda"))
+ opt_cam_trans = torch.nn.Parameter(np2torch(estimate_rel_w2c[:3, 3], "cuda"))
+ return opt_cam_rot, opt_cam_trans
diff --git a/src/utils/utils.py b/src/utils/utils.py
new file mode 100644
index 0000000..631f676
--- /dev/null
+++ b/src/utils/utils.py
@@ -0,0 +1,209 @@
+import os
+import random
+
+import numpy as np
+import open3d as o3d
+import torch
+from gaussian_rasterizer import GaussianRasterizationSettings, GaussianRasterizer
+
+
+def setup_seed(seed: int) -> None:
+ """ Sets the seed for generating random numbers to ensure reproducibility across multiple runs.
+ Args:
+ seed: The seed value to set for random number generators in torch, numpy, and random.
+ """
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def torch2np(tensor: torch.Tensor) -> np.ndarray:
+ """ Converts a PyTorch tensor to a NumPy ndarray.
+ Args:
+ tensor: The PyTorch tensor to convert.
+ Returns:
+ A NumPy ndarray with the same data and dtype as the input tensor.
+ """
+ return tensor.detach().cpu().numpy()
+
+
+def np2torch(array: np.ndarray, device: str = "cpu") -> torch.Tensor:
+ """Converts a NumPy ndarray to a PyTorch tensor.
+ Args:
+ array: The NumPy ndarray to convert.
+ device: The device to which the tensor is sent. Defaults to 'cpu'.
+
+ Returns:
+ A PyTorch tensor with the same data as the input array.
+ """
+ return torch.from_numpy(array).float().to(device)
+
+
+def np2ptcloud(pts: np.ndarray, rgb=None) -> o3d.geometry.PointCloud:
+ """converts numpy array to point cloud
+ Args:
+ pts (ndarray): point cloud
+ Returns:
+ (PointCloud): resulting point cloud
+ """
+ cloud = o3d.geometry.PointCloud()
+ cloud.points = o3d.utility.Vector3dVector(pts)
+ if rgb is not None:
+ cloud.colors = o3d.utility.Vector3dVector(rgb)
+ return cloud
+
+
+def dict2device(dict: dict, device: str = "cpu") -> dict:
+ """Sends all tensors in a dictionary to a specified device.
+ Args:
+ dict: The dictionary containing tensors.
+ device: The device to send the tensors to. Defaults to 'cpu'.
+ Returns:
+ The dictionary with all tensors sent to the specified device.
+ """
+ for k, v in dict.items():
+ if isinstance(v, torch.Tensor):
+ dict[k] = v.to(device)
+ return dict
+
+
+def get_render_settings(w, h, intrinsics, w2c, near=0.01, far=100, sh_degree=0):
+ """
+ Constructs and returns a GaussianRasterizationSettings object for rendering,
+ configured with given camera parameters.
+
+ Args:
+ width (int): The width of the image.
+ height (int): The height of the image.
+ intrinsic (array): 3*3, Intrinsic camera matrix.
+ w2c (array): World to camera transformation matrix.
+ near (float, optional): The near plane for the camera. Defaults to 0.01.
+ far (float, optional): The far plane for the camera. Defaults to 100.
+
+ Returns:
+ GaussianRasterizationSettings: Configured settings for Gaussian rasterization.
+ """
+ fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1,
+ 1], intrinsics[0, 2], intrinsics[1, 2]
+ w2c = torch.tensor(w2c).cuda().float()
+ cam_center = torch.inverse(w2c)[:3, 3]
+ viewmatrix = w2c.transpose(0, 1)
+ opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0],
+ [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0],
+ [0.0, 0.0, far /
+ (far - near), -(far * near) / (far - near)],
+ [0.0, 0.0, 1.0, 0.0]], device='cuda').float().transpose(0, 1)
+ full_proj_matrix = viewmatrix.unsqueeze(
+ 0).bmm(opengl_proj.unsqueeze(0)).squeeze(0)
+ return GaussianRasterizationSettings(
+ image_height=h,
+ image_width=w,
+ tanfovx=w / (2 * fx),
+ tanfovy=h / (2 * fy),
+ bg=torch.tensor([0, 0, 0], device='cuda').float(),
+ scale_modifier=1.0,
+ viewmatrix=viewmatrix,
+ projmatrix=full_proj_matrix,
+ sh_degree=sh_degree,
+ campos=cam_center,
+ prefiltered=False,
+ debug=False)
+
+
+def render_gaussian_model(gaussian_model, render_settings,
+ override_means_3d=None, override_means_2d=None,
+ override_scales=None, override_rotations=None,
+ override_opacities=None, override_colors=None):
+ """
+ Renders a Gaussian model with specified rendering settings, allowing for
+ optional overrides of various model parameters.
+
+ Args:
+ gaussian_model: A Gaussian model object that provides methods to get
+ various properties like xyz coordinates, opacity, features, etc.
+ render_settings: Configuration settings for the GaussianRasterizer.
+ override_means_3d (Optional): If provided, these values will override
+ the 3D mean values from the Gaussian model.
+ override_means_2d (Optional): If provided, these values will override
+ the 2D mean values. Defaults to zeros if not provided.
+ override_scales (Optional): If provided, these values will override the
+ scale values from the Gaussian model.
+ override_rotations (Optional): If provided, these values will override
+ the rotation values from the Gaussian model.
+ override_opacities (Optional): If provided, these values will override
+ the opacity values from the Gaussian model.
+ override_colors (Optional): If provided, these values will override the
+ color values from the Gaussian model.
+ Returns:
+ A dictionary containing the rendered color, depth, radii, and 2D means
+ of the Gaussian model. The keys of this dictionary are 'color', 'depth',
+ 'radii', and 'means2D', each mapping to their respective rendered values.
+ """
+ renderer = GaussianRasterizer(raster_settings=render_settings)
+
+ if override_means_3d is None:
+ means3D = gaussian_model.get_xyz()
+ else:
+ means3D = override_means_3d
+
+ if override_means_2d is None:
+ means2D = torch.zeros_like(
+ means3D, dtype=means3D.dtype, requires_grad=True, device="cuda")
+ means2D.retain_grad()
+ else:
+ means2D = override_means_2d
+
+ if override_opacities is None:
+ opacities = gaussian_model.get_opacity()
+ else:
+ opacities = override_opacities
+
+ shs, colors_precomp = None, None
+ if override_colors is not None:
+ colors_precomp = override_colors
+ else:
+ shs = gaussian_model.get_features()
+
+ render_args = {
+ "means3D": means3D,
+ "means2D": means2D,
+ "opacities": opacities,
+ "colors_precomp": colors_precomp,
+ "shs": shs,
+ "scales": gaussian_model.get_scaling() if override_scales is None else override_scales,
+ "rotations": gaussian_model.get_rotation() if override_rotations is None else override_rotations,
+ "cov3D_precomp": None
+ }
+ color, depth, alpha, radii = renderer(**render_args)
+
+ return {"color": color, "depth": depth, "radii": radii, "means2D": means2D, "alpha": alpha}
+
+
+def batch_search_faiss(indexer, query_points, k):
+ """
+ Perform a batch search on a IndexIVFFlat indexer to circumvent the search size limit of 65535.
+
+ Args:
+ indexer: The FAISS indexer object.
+ query_points: A tensor of query points.
+ k (int): The number of nearest neighbors to find.
+
+ Returns:
+ distances (torch.Tensor): The distances of the nearest neighbors.
+ ids (torch.Tensor): The indices of the nearest neighbors.
+ """
+ split_pos = torch.split(query_points, 65535, dim=0)
+ distances_list, ids_list = [], []
+
+ for split_p in split_pos:
+ distance, id = indexer.search(split_p.float(), k)
+ distances_list.append(distance.clone())
+ ids_list.append(id.clone())
+ distances = torch.cat(distances_list, dim=0)
+ ids = torch.cat(ids_list, dim=0)
+
+ return distances, ids
diff --git a/src/utils/vis_utils.py b/src/utils/vis_utils.py
new file mode 100644
index 0000000..78c0cb0
--- /dev/null
+++ b/src/utils/vis_utils.py
@@ -0,0 +1,112 @@
+from collections import OrderedDict
+from copy import deepcopy
+from typing import List, Union
+
+import numpy as np
+import open3d as o3d
+from matplotlib import colors
+
+COLORS_ANSI = OrderedDict({
+ "blue": "\033[94m",
+ "orange": "\033[93m",
+ "green": "\033[92m",
+ "red": "\033[91m",
+ "purple": "\033[95m",
+ "brown": "\033[93m", # No exact match, using yellow
+ "pink": "\033[95m",
+ "gray": "\033[90m",
+ "olive": "\033[93m", # No exact match, using yellow
+ "cyan": "\033[96m",
+ "end": "\033[0m", # Reset color
+})
+
+
+COLORS_MATPLOTLIB = OrderedDict({
+ 'blue': '#1f77b4',
+ 'orange': '#ff7f0e',
+ 'green': '#2ca02c',
+ 'red': '#d62728',
+ 'purple': '#9467bd',
+ 'brown': '#8c564b',
+ 'pink': '#e377c2',
+ 'gray': '#7f7f7f',
+ 'yellow-green': '#bcbd22',
+ 'cyan': '#17becf'
+})
+
+
+COLORS_MATPLOTLIB_RGB = OrderedDict({
+ 'blue': np.array([31, 119, 180]) / 255.0,
+ 'orange': np.array([255, 127, 14]) / 255.0,
+ 'green': np.array([44, 160, 44]) / 255.0,
+ 'red': np.array([214, 39, 40]) / 255.0,
+ 'purple': np.array([148, 103, 189]) / 255.0,
+ 'brown': np.array([140, 86, 75]) / 255.0,
+ 'pink': np.array([227, 119, 194]) / 255.0,
+ 'gray': np.array([127, 127, 127]) / 255.0,
+ 'yellow-green': np.array([188, 189, 34]) / 255.0,
+ 'cyan': np.array([23, 190, 207]) / 255.0
+})
+
+
+def get_color(color_name: str):
+ """ Returns the RGB values of a given color name as a normalized numpy array.
+ Args:
+ color_name: The name of the color. Can be any color name from CSS4_COLORS.
+ Returns:
+ A numpy array representing the RGB values of the specified color, normalized to the range [0, 1].
+ """
+ if color_name == "custom_yellow":
+ return np.asarray([255.0, 204.0, 102.0]) / 255.0
+ if color_name == "custom_blue":
+ return np.asarray([102.0, 153.0, 255.0]) / 255.0
+ assert color_name in colors.CSS4_COLORS
+ return np.asarray(colors.to_rgb(colors.CSS4_COLORS[color_name]))
+
+
+def plot_ptcloud(point_clouds: Union[List, o3d.geometry.PointCloud], show_frame: bool = True):
+ """ Visualizes one or more point clouds, optionally showing the coordinate frame.
+ Args:
+ point_clouds: A single point cloud or a list of point clouds to be visualized.
+ show_frame: If True, displays the coordinate frame in the visualization. Defaults to True.
+ """
+ # rotate down up
+ if not isinstance(point_clouds, list):
+ point_clouds = [point_clouds]
+ if show_frame:
+ mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1, origin=[0, 0, 0])
+ point_clouds = point_clouds + [mesh_frame]
+ o3d.visualization.draw_geometries(point_clouds)
+
+
+def draw_registration_result_original_color(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud,
+ transformation: np.ndarray):
+ """ Visualizes the result of a point cloud registration, keeping the original color of the source point cloud.
+ Args:
+ source: The source point cloud.
+ target: The target point cloud.
+ transformation: The transformation matrix applied to the source point cloud.
+ """
+ source_temp = deepcopy(source)
+ source_temp.transform(transformation)
+ o3d.visualization.draw_geometries([source_temp, target])
+
+
+def draw_registration_result(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud,
+ transformation: np.ndarray, source_color: str = "blue", target_color: str = "orange"):
+ """ Visualizes the result of a point cloud registration, coloring the source and target point clouds.
+ Args:
+ source: The source point cloud.
+ target: The target point cloud.
+ transformation: The transformation matrix applied to the source point cloud.
+ source_color: The color to apply to the source point cloud. Defaults to "blue".
+ target_color: The color to apply to the target point cloud. Defaults to "orange".
+ """
+ source_temp = deepcopy(source)
+ source_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[source_color])
+
+ target_temp = deepcopy(target)
+ target_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[target_color])
+
+ source_temp.transform(transformation)
+ o3d.visualization.draw_geometries([source_temp, target_temp])