diff --git a/docs/img/GraphCompNet/bar_chamber.png b/docs/img/GraphCompNet/bar_chamber.png
new file mode 100644
index 0000000000..2edbd22c3c
Binary files /dev/null and b/docs/img/GraphCompNet/bar_chamber.png differ
diff --git a/docs/img/GraphCompNet/dl_comp_test-2.png b/docs/img/GraphCompNet/dl_comp_test-2.png
new file mode 100644
index 0000000000..5d48893da8
Binary files /dev/null and b/docs/img/GraphCompNet/dl_comp_test-2.png differ
diff --git a/docs/img/GraphCompNet/figure_dl_predict.png b/docs/img/GraphCompNet/figure_dl_predict.png
new file mode 100644
index 0000000000..e2aa6fd4e3
Binary files /dev/null and b/docs/img/GraphCompNet/figure_dl_predict.png differ
diff --git a/docs/img/GraphCompNet/overall_arch-2.png b/docs/img/GraphCompNet/overall_arch-2.png
new file mode 100644
index 0000000000..0323893d90
Binary files /dev/null and b/docs/img/GraphCompNet/overall_arch-2.png differ
diff --git a/docs/img/GraphCompNet/table1_fig-2.png b/docs/img/GraphCompNet/table1_fig-2.png
new file mode 100644
index 0000000000..1104cd7672
Binary files /dev/null and b/docs/img/GraphCompNet/table1_fig-2.png differ
diff --git a/examples/additive_manufacturing/compensation/README.md b/examples/additive_manufacturing/compensation/README.md
new file mode 100644
index 0000000000..e7451a2b96
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/README.md
@@ -0,0 +1,130 @@
+
+
+# PyTorch version of deformation predictor & compensation
+
+## Introduction
+
+This work addresses shape deviation modeling and compensation in additive manufacturing (AM) to improve geometric accuracy for industrial-scale production. While traditional methods laid the groundwork, recent machine learning (ML) advancements offer better precision. However, challenges remain in generalizing across complex geometries and adapting to position-dependent variations in batch production. We introduce GraphCompNet, a novel framework combining graph-based neural networks with GAN-inspired training to model geometries and incorporate position-specific thermal and mechanical variations. Through a two-stage adversarial process, the framework refines compensated designs, improving accuracy by 35-65% across the print space. This approach enhances AM's real-time, scalable compensation capabilities, paving the way for high-precision, automated manufacturing systems.
+
+
+[//]: # (
)
+
+[//]: # (
)
+
+[//]: # (
)
+
+## Sample results
+
+Prediction & compensation accuracy (mm) to be updated
+
+[//]: # ()
+
+[//]: # (
)
+
+[//]: # (
)
+
+[//]: # (Compensation on Molded fiber dataset:)
+
+[//]: # ()
+[//]: # (Comparison of four sample parts in one print run, the top row illustrates the difference between the design CAD file and the scanned printed part geometry before applying compensation, the bottom row shows the difference between the design CAD file and the scanned printed part geometry after applying compensation using our trained prediction and compensation engine.)
+
+[//]: # ()
+[//]: # ()
+
+[//]: # (
)
+
+[//]: # (
)
+
+## Key requirments
+
+1. ``Torch_Geometric 2.5.1 or above``: PyTorch based geometric/graph neural network library
+
+ - https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html#installation-via-anaconda
+
+ - conda install pyg=*=*cu* -c pyg
+
+2. ``pip install trimesh``
+
+3. ``pip install matplotlib``
+
+4. ``pip install pandas``
+
+5. ``pip install hydra-core --upgrade --pre``
+
+6. ``PyTorch3D``: PyTorch based 3D computer vision library
+
+ - Check requirements from official install page: https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md
+ - when tested, Pytorch3D requires Python 3.8, 3.9 or 3.10
+
+ - ``pip install -U iopath``
+
+ - Install directly from the source ``pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" ``
+
+7. ``pip install torch-cluster``
+
+To test in customized CUDA environment, install compatible torch version compatible with cudatoolkit, i.e.
+
+``pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121``
+
+ Refer to:
+https://pytorch.org/get-started/previous-versions/
+
+Other dependencies for development:
+
+- ``open3d``: pip install open3d, tested version 0.18.0
+- ``torch-cluster``: conda install pytorch-cluster -c pyg
+
+
+
+## Dataset
+- Currently available:
+ - Bar repository [link not working yet](https://drive.google.com/file/d/1inUN4KIg8NOtuwaJa2d1j3tssRGUxgAQ/view?usp=sharing)
+ - Molded-fiber repository [Download sample data](https://drive.google.com/file/d/1inUN4KIg8NOtuwaJa2d1j3tssRGUxgAQ/view?usp=sharing)
+
+- Sample input data folder format:
+
+ - input_data.txt: logs for each row, the build geometry folder
+
+ - /part_folder_i:
+
+ - cad_.txt: contains 3 columns for point location
+
+ - scan_red.csv: contains 3 columns for point location
+
+[//]: # (- Post-processing: )
+
+[//]: # ( )
+[//]: # ( - https://github.azc.ext.hp.com/Shape-Compensation/Shape_compensator)
+
+
+## Training
+
+- To test running with cpu ``Connfig.yaml`` setting (not recommended):
+
+ - `` cuda: False ``
+ - ``use_distributed: False``
+ - ``use_multigpu: False``
+- Gpu training: set params listed above to True
+
+- There are two training codes that need to run in sequential manner.
+1. ``train_dis.py``: This code trains the discriminator (predict part deformations with its position and geometry)
+2. ``train_gen.py``: This code trains the generator (compensate part geometry)
+
+## Inference
+
+[Download pre-trained model checkpoint](https://drive.google.com/file/d/1Htd7MLGgvjmidIGyYquDtLkZe0gSEqRu/view?usp=drive_link)
+
+- Supported 3D formats:
+ - Stereolitography (STL)
+ - Wavefront file (OBJ)
+- How to run:
+ - ``python inference.py``
+
+
+## References
+
+[GraphCompNet: A Position-Aware Model for Predicting and Compensating Shape Deviations in 3D Printing](to be added)
+
+```text
+
+```
diff --git a/examples/additive_manufacturing/compensation/conf/config.yaml b/examples/additive_manufacturing/compensation/conf/config.yaml
new file mode 100644
index 0000000000..337851ffbc
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/conf/config.yaml
@@ -0,0 +1,64 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+general:
+ seed: 1234
+ random_sample: True
+ cuda: True
+ use_distributed: False
+ sync_batch: True
+ use_multigpu: False # Default: False for D, True for G
+
+train_dis_options:
+ model_path: './pretrained/11parts_lr-3/pred_model_0000.pth'
+ log_dir: './pretrained/11parts_lr-3/'
+ save_path: './pretrained/11parts_lr-3/'
+ num_epoch: 15001
+ num_batch: 2
+ learning_rate: 0.0001
+ pretrain: False
+ num_points: 190000
+
+train_gen_options:
+ num_points: 190000
+ pred_model_path: "./pretrained/ocardo_iso_p500k/pred_model_3000.pth" # For D
+ gen_model_path: "./pretrained/ocardo_iso_p500k/pred_model_3000.pth" # For G
+ log_dir: './pretrained/ocardo_iso_p500k/'
+ save_path: './pretrained/ocardo_iso_p500k/'
+ num_epoch: 50001
+ num_batch: 1
+ learning_rate: 0.001
+
+inference_options:
+ seed: 1234
+ num_points: 20000 # 'Num of points to use'
+ data_path: '/home/chenle/codes/DL_prediction_compensation-master/data/molded_fiber/10/cad' # for other dataset: 'input_data_bar_sample','molded_fiber'
+ discriminator_path: './pretrained/pretrained_os/pred_model_3000.pth' # 'discriminator model path'
+ generator_path: './pretrained/pretrained_os/gen_model_46500.pth' # 'generator model path'
+ save_path: './output/test' # 'save output path'
+ save_extra: True # 'exports prediction and additional data csv'
+
+
+data_options:
+ dataset_name: "Ocardo" # choices=['Ocardo', 'Bar']
+ data_path: '/home/chenle/codes/DL_prediction_compensation-master/data/molded_fiber' # for other dataset: 'input_data_bar_sample','molded_fiber'
+ cad_format: 'txt'
+ scan_format: 'csv'
diff --git a/examples/additive_manufacturing/compensation/data_proces/ply_sampling.py b/examples/additive_manufacturing/compensation/data_proces/ply_sampling.py
new file mode 100644
index 0000000000..ec774c5b23
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/data_proces/ply_sampling.py
@@ -0,0 +1,148 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+import numpy as np
+import open3d as o3d
+import pandas as pd
+import torch
+import torch_geometric
+import trimesh
+
+
+def generate_mesh_train(
+ ply_path,
+ scan_pcd_path,
+ save_csv=True,
+ save_mesh_path=None,
+ part_name="bar",
+ part_id="3",
+ export_format="pth",
+ filter_dist=False,
+):
+ """
+ A PLY file is a computer file format for storing 3D data as a collection of polygons.
+ PLY stands for Polygon File Format, and it's also known as the Stanford Triangle Format.
+ PLY files are used to store 3D data from 3D scanners.
+
+ This function load a CAD file in PLY format, or STL format with trimesh:
+ i.e.
+
+ Load the raw scan file sampled points in PCD format, then save the updated scan mesh in OBJ format.
+
+ Parameters:
+ - ply_path = os.path.join(root_data_path, "data_pipeline_bar/remesh98.ply")
+ - scan_pcd_path = os.path.join(root_data_path, "data_pipeline_bar/bar_98/scan/98_SAMPLED_POINTS_aligned.pcd")
+ - save_mesh_path = "test_data_pipeline"
+
+ Return:
+ Saved scan mesh path
+ """
+ os.makedirs(save_mesh_path, exist_ok=True)
+
+ # Load cad mesh from PLY file
+ cad_mesh = trimesh.load(ply_path)
+
+ # Centralize the coordinates
+ cad_pts = torch.FloatTensor(np.asarray(cad_mesh.vertices)) - torch.FloatTensor(
+ cad_mesh.bounds.mean(0)
+ )
+
+ # Load raw scan file in PCD, o3d function to read PointCloud from file
+ scan_pts = o3d.io.read_point_cloud(scan_pcd_path)
+
+ # Centralize the coordinates
+ scan_pts = torch.FloatTensor(np.asarray(scan_pts.points)) - torch.FloatTensor(
+ cad_mesh.bounds.mean(0)
+ )
+
+ # Fined one-to-one matching
+ idx1, idx2 = torch_geometric.nn.knn(scan_pts, cad_pts, 1)
+ new_vert = scan_pts[idx2]
+
+ if filter_dist:
+ dist = torch.sqrt(torch.sum(torch.pow(cad_pts - new_vert, 2), 1))
+ filt = dist > 1.2
+ new_vert[filt] = cad_pts[filt]
+
+ # Updates the scan coordinates to the original CAD mesh
+ scan_mesh = cad_mesh
+ vertices = new_vert + torch.FloatTensor(cad_mesh.bounds.mean(0))
+ scan_mesh.vertices = vertices
+
+ if export_format == "obj":
+ scan_mesh.export(os.path.join(save_mesh_path, "data_out.obj"))
+ elif export_format == "pth":
+ torch.save(vertices, os.path.join(save_mesh_path, f"{part_id}/{part_name}.pth"))
+ else:
+ print("Export format should be OBJ or PTH")
+ exit()
+
+ if save_csv:
+ # save the original CAD points with centralize coordinates
+ np.savetxt(
+ os.path.join(save_mesh_path, f"{part_id}/{part_name}_cad.csv"), cad_pts
+ )
+ # save the mapped scan_pts points with centralize coordinates
+ np.savetxt(
+ os.path.join(save_mesh_path, f"{part_id}/{part_name}_scan.csv"), new_vert
+ )
+
+ return os.path.join(save_mesh_path, "data_out.obj")
+
+
+def generate_mesh_eval(cad_path, comp_out_path, export_path, view=False):
+ """
+ Function to load a 3D object pair (Original design file v.s. Scanned printed / Compensated part),
+ - CAD design in format of OBJ or STL
+ - Scanned printed, or compensated part points, in CSV or TXT
+ Export the Scanned in mesh, OBJ format
+
+ Parameters:
+ - object_name = "bar"
+ - part_id = 5
+ - cad_path = "%s_%d/cad/%s_%d_uptess.obj" % (object_name, part_id, object_name, part_id)
+ - comp_out_path = comp/out__%02d.csv" % (part_id)
+
+ Return:
+ Saved scan mesh path
+ """
+ os.makedirs(export_path, exist_ok=True)
+
+ # Sample design CAD name
+ cad_mesh = trimesh.load(cad_path)
+
+ # Sample scanned printed file, or generated compensated file, in CSV or TXT
+ # change the reading format, if data was saved with other separators, " ", ","
+ scan_pts = pd.read_csv(comp_out_path, sep=",").values
+
+ # Define the new vertices as the scanned printed points coordinates
+ new_vert = torch.FloatTensor(scan_pts)
+
+ # Define the mesh from the Design CAD
+ scan_mesh = cad_mesh
+
+ # Export new mesh
+ scan_mesh.vertices = new_vert
+ scan_mesh.export(os.path.join(export_path, "export_out.obj"))
+ if view:
+ scan_mesh.show()
diff --git a/examples/additive_manufacturing/compensation/dataloader.py b/examples/additive_manufacturing/compensation/dataloader.py
new file mode 100644
index 0000000000..13133417c0
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/dataloader.py
@@ -0,0 +1,241 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+import numpy as np
+import pandas as pd
+import torch
+import torch_geometric
+
+# import open3d as o3d
+from utils import log_string
+
+torch.manual_seed(0)
+
+
+class Bar(torch.utils.data.Dataset):
+ """
+ To import the dataset, you can use files in either .txt or .csv format. Below is the folder structure for sample input data:
+ - input_data.txt: Contains logs, with each row corresponding to a build part geometry
+ - /part_folder_i (aligned with logs in input_data.txt):
+ - cad_.txt: Contains 3 columns, each representing a point's coordinates.
+ - scan_red.csv: Includes 3 columns representing point locations.
+ """
+
+ def __init__(
+ self,
+ data_path="insert data path, default in cfg.data_options.data_path",
+ num_points=50000,
+ partition="train",
+ random_sample=False,
+ transform=None,
+ LOG_FOUT=None,
+ ):
+ self.num_points = num_points
+ self.data_path = data_path
+ log_string(LOG_FOUT, f"Process from data_path: {data_path}")
+
+ self.random_sample = random_sample
+ self.partition = partition
+ if self.partition == "train":
+ lists = [
+ line.rstrip() for line in open(self.data_path + "/24hrs.txt")
+ ] # [28:]
+ elif self.partition == "val":
+ lists = [line.rstrip() for line in open(self.data_path + "/24hrs_val.txt")]
+ self.items = []
+ len_ds = len(lists)
+
+ print("total data_size = %02d" % len_ds)
+ for i in range(len_ds): # load all CAD & scan pairs
+ tag = lists[i].split("/")[-2][2:]
+ cad = torch.FloatTensor(
+ np.loadtxt(lists[i] + "cad%s.txt" % (tag), delimiter="\t")
+ )[:, :3]
+ scan = torch.FloatTensor(
+ np.loadtxt(lists[i] + "scan_res%s.csv" % (tag), delimiter=",")
+ )[:, :3]
+ self.items.append((i + 1, cad, scan))
+
+ def __len__(self):
+ return len(self.items)
+
+ def __getitem__(self, idx):
+ part_id, mesh, scan = self.items[idx]
+ m = torch.mean(mesh)
+
+ # random sampling for 50k points
+ if self.random_sample and (
+ self.partition == "train" or self.partition == "val"
+ ):
+ sample = torch.randint(
+ low=0, high=min(len(mesh), len(scan)) - 1, size=(self.num_points,)
+ )
+
+ # find correspondence between CAD - scan points
+ pts1 = mesh[sample]
+ pts2 = scan[sample]
+ else:
+ pts1 = mesh
+ pts2 = scan
+ with torch.no_grad():
+ edge_index = torch_geometric.nn.knn_graph(torch.FloatTensor(pts1), 20)
+
+ s = pts1.std(0)
+ # output in torch_geometric format
+ out = torch_geometric.data.Data(
+ x=pts1,
+ y=pts2,
+ edge_index=edge_index,
+ m=m,
+ s=s,
+ part_id=part_id,
+ )
+ return out
+
+
+class Ocardo(torch.utils.data.Dataset):
+ """
+
+ :param data_path:
+ # contains the list of paths for dataset
+ # in the input_data.txt file, each row contains the input part data folder
+ # i.e.
+ # /home/DL_engine/input_data/1/
+ # /home/DL_engine/input_data/2/
+ Under each data folder:
+ i.e.
+ - /home/DL_engine/input_data/1/cad{tag_id}.csv
+ - /home/DL_engine/input_data/1/scan_res{tag_id}.csv
+ # for complete description of sample data format, refer to README.md
+ # each part shape scan/ cad: i.e. torch.Size([12775, 3])
+ :param num_points:
+ :param partition:
+ :param random_sample:
+ :param transform:
+ """
+
+ def __init__(
+ self,
+ data_path="./input_data/",
+ num_points=50000,
+ partition="train",
+ random_sample=True,
+ transform=None,
+ LOG_FOUT=None,
+ ):
+ self.num_points = num_points
+ self.data_path = data_path
+ log_string(LOG_FOUT, f"Process from data_path: {data_path}")
+
+ self.random_sample = random_sample
+ # Read each row as the input part data ID
+ lists = [
+ line.rstrip()
+ for line in open(os.path.join(self.data_path, "input_data.txt"))
+ ]
+ log_string(LOG_FOUT, f"read data folder name lists: {lists}")
+
+ self.items = []
+ # Initialize the tranform function to: Converts mesh faces [3, num_faces] to edge indices [2, num_edges] (functional name: face_to_edge).
+ # https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html
+ self.transform = torch_geometric.transforms.FaceToEdge()
+ for i in range(len(lists)):
+ # process for each data build in the input_data file
+ # read the build id
+ tag = lists[i].split("/")[-2]
+ # Read each row from the cad_.txt, store as torch.FloatTensor the point coordinates
+ # cad = torch.FloatTensor(np.loadtxt(f"{self.data_path}/{lists[i]}cad{tag}.txt", delimiter='\t'))[:,:3] #input_data_bar_sample
+
+ log_string(LOG_FOUT, f"{self.data_path}/{lists[i]}scan_res{tag}.csv")
+
+ cad = torch.FloatTensor(
+ pd.read_csv(f"{self.data_path}/{lists[i]}cad{tag}.csv", sep=" ").values
+ ) # molded_fiber
+
+ # Read each row from the scan_res.csv, store as torch.FloatTensor the point coordinates
+ # scan = torch.FloatTensor(np.loadtxt(f"{self.data_path}/{lists[i]}scan_res{tag}.csv", delimiter=','))[:,:3] #input_data_bar_sample
+ scan = torch.FloatTensor(
+ pd.read_csv(
+ f"{self.data_path}/{lists[i]}scan_res{tag}.csv", sep=" "
+ ).values
+ ) # molded_fiber
+
+ self.items.append((i + 1, cad, scan))
+ log_string(LOG_FOUT, f"loaded scan {scan.shape}")
+
+ if not cad.shape == scan.shape:
+ raise Exception("Part CAD and Scan files not match ")
+
+ def __len__(self):
+ return len(self.items)
+
+ def __getitem__(self, idx):
+ """
+ For each item that contains the dataset
+ - part_id (i.e. 1)
+ - mesh(the original cad of size: pt_cnt, 3),
+ - scan(the scan of the printed part of size: pt_cnt, 3)
+ i.e. torch.Size([650, 3]) torch.Size([650, 3])
+
+
+ """
+ part_id, mesh, scan = self.items[idx]
+
+ # todo: reason to compute mean/ what means for mean < 0?
+ # torch.mean(mesh): tensor(-0.8895)
+ # torch.mean(mesh): tensor(4.4421)
+ m = torch.mean(mesh)
+
+ # find correspondence between CAD - scan points
+ if self.random_sample:
+ # Get random sampling index from 0 to self.num_points
+ # i.e. sample id: tensor([ 1653, 27927, 3942, ..., 24202, 1684, 23686])
+ # resulting pts1, pts2 size: [self.num_points, 3], i.e. torch.Size([190000, 3])
+ # todo: if meaningful with the sample >> the pcloud original scanning/ sampling density -> los of duplicated samples ?
+ sample = torch.randint(
+ low=0, high=min(len(mesh), len(scan)) - 1, size=(self.num_points,)
+ )
+
+ # find correspondence between CAD - scan points
+ pts1 = mesh[sample]
+ pts2 = scan[sample]
+ else:
+ pts1 = mesh
+ pts2 = scan
+
+ with torch.no_grad():
+ # taking 10 nodes for nearest neighbors, this lead to the edge numbers to be ~ 10 x sample number,
+ # i.e. sample#=190k, neighbor#=10, edge_index.shape=[2, ~1900k]
+ # i.e.torch.Size([2, 2082861]) / torch.Size([2, 1911460])
+ # knn compueted edge index: tensor([[ 45007, 130923, 79760, ..., 147219, 146399, 132629],
+ # [ 0, 0, 0, ..., 189999, 189999, 189999]])
+ edge_index = torch_geometric.nn.knn_graph(torch.FloatTensor(pts1), 10)
+
+ out = torch_geometric.data.Data(
+ x=pts1,
+ y=pts2,
+ edge_index=edge_index,
+ m=m,
+ part_id=part_id,
+ )
+ return out
diff --git a/examples/additive_manufacturing/compensation/inference.py b/examples/additive_manufacturing/compensation/inference.py
new file mode 100644
index 0000000000..8f0019384f
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/inference.py
@@ -0,0 +1,225 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import glob
+import os
+
+import numpy as np
+import torch
+import torch_geometric
+import trimesh
+from hydra import compose, initialize
+
+# from omegaconf import DictConfig, OmegaConf
+from utils import log_string, tic, toc
+
+from physicsnemo.models.dgcnn.dgcnn_compensation import DGCNN
+
+# from physicsnemo.models.dgcnn.dgcnn_compensation import DGCNN_ocardo
+
+
+def main():
+ """
+ With the trained ckpt for both the prediction engine, and the compensation engine, config in cfg.inference_options
+ 1. Load the parts to compensate from path:
+ cfg.inference_options.data_path
+ data path should contain stl / obj files to compensate for
+ 2. From the trained compensation model, compensate the input CAD
+ 3. From the trained deviation prediction model, predict the potential deviation of the compensated CAD, to compare with the original design
+ second:
+ """
+ # Read the configs
+ with initialize(config_path="conf", job_name="test_app"):
+ cfg = compose(config_name="config", overrides=["+db=mysql", "+db.user=me"])
+
+ LOG_FOUT = open(os.path.join(cfg.inference_options.save_path, "log_inf.txt"), "a")
+
+ torch.backends.cudnn.deterministic = True
+ torch.manual_seed(cfg.inference_options.seed)
+ torch.cuda.manual_seed_all(cfg.inference_options.seed)
+ np.random.seed(cfg.inference_options.seed)
+
+ # model setting
+ device = torch.device("cuda" if cfg.general.cuda else "cpu")
+
+ # Initialize compensator
+ compensator = DGCNN().to(device)
+ compensator.load_state_dict(
+ torch.load(cfg.inference_options.generator_path, map_location="cpu"),
+ strict=False,
+ )
+ for g in compensator.parameters():
+ g.requires_grad = False # to avoid computation
+ log_string(LOG_FOUT, "Trained Shape Compensation model loaded")
+
+ # Initialize predictor
+ discriminator = DGCNN().to(device)
+ discriminator.load_state_dict(
+ torch.load(cfg.inference_options.discriminator_path, map_location="cpu"),
+ strict=False,
+ )
+ for p in discriminator.parameters():
+ p.requires_grad = False # to avoid computation
+ log_string(LOG_FOUT, "Trained Shape Prediction model loaded")
+
+ save_path = cfg.inference_options.save_path
+ os.makedirs(save_path, exist_ok=True)
+
+ # Load parts data, in STL or OBJ format
+ # todo: change the path to search all subfolders
+ data_path = cfg.inference_options.data_path
+ parts_ds = glob.glob(data_path + "/*.stl") + glob.glob(
+ data_path + "/*.obj"
+ ) # main processing
+ log_string(LOG_FOUT, f"\n\nStart processing data .. {parts_ds}")
+
+ for part_ in parts_ds:
+ tic()
+ log_string(LOG_FOUT, f"Processing Part {part_}")
+
+ # load data as mesh
+ cad = trimesh.load_mesh(part_)
+ # convert mesh to points by taking vertices
+ pts1 = torch.FloatTensor(np.asarray(cad.vertices))
+
+ # randomize point indices
+ rand = np.arange(len(pts1))
+ np.random.shuffle(rand)
+ pts1_full = pts1[rand]
+ log_string(LOG_FOUT, f"Part full shape: {pts1_full.shape}")
+
+ subsample = cfg.inference_options.num_points
+ n = len(pts1_full) // subsample
+ pred_res = []
+ comp_res = []
+ out_res = []
+ for i in range(n + 1):
+ log_string(LOG_FOUT, f"batching {i}")
+ if i == n and len(pts1_full) % subsample == 0:
+ log_string(LOG_FOUT, "no need to process")
+ continue
+ elif i == n:
+ pts1 = pts1_full[-subsample:]
+ else:
+ pts1 = pts1_full[subsample * i : subsample * (i + 1)]
+
+ edge_index = torch_geometric.nn.knn_graph(pts1, 20)
+ if cfg.general.cuda:
+ log_string(LOG_FOUT, "Use Cuda")
+ pts1 = pts1.squeeze(0).to(device).squeeze(0)
+ edge_index = edge_index.squeeze().to(device)
+
+ # Compensate the original CAD, then predict the deviation from the compensated CAD
+ com = compensator(torch_geometric.data.Data(x=pts1, edge_index=edge_index))
+ out = discriminator(torch_geometric.data.Data(x=com, edge_index=edge_index))
+
+ # Directly predict the deviation from the original CAD
+ pre = discriminator(
+ torch_geometric.data.Data(x=pts1, edge_index=edge_index)
+ )
+
+ if i == n:
+ valid = len(pts1_full) - n * subsample
+ com = com[-valid:]
+ out = out[-valid:]
+ pre = pre[-valid:]
+
+ comp_res.append(com.detach().cpu().numpy())
+ pred_res.append(pre.detach().cpu().numpy())
+ out_res.append(out.detach().cpu().numpy())
+
+ # Concatenate batches data to the final full part
+ cad_comp_pts = np.concatenate(comp_res, 0)
+ cad_pred_pts = np.concatenate(pred_res, 0)
+ cad_outp_pts = np.concatenate(out_res, 0)
+
+ # note that all above results are shuffled one! back to the original order
+ pts1_res = np.zeros((len(pts1_full), 3))
+ cad_comp_res = np.zeros((len(cad_comp_pts), 3))
+ cad_pred_res = np.zeros((len(cad_pred_pts), 3))
+ cad_outp_res = np.zeros((len(cad_outp_pts), 3))
+
+ for i in range(len(cad_comp_pts)):
+ pts1_res[rand[i]] = pts1_full[i]
+ cad_comp_res[rand[i]] = cad_comp_pts[i]
+ cad_pred_res[rand[i]] = cad_pred_pts[i]
+ cad_outp_res[rand[i]] = cad_outp_pts[i]
+
+ # Get input STL model name
+ part_name = os.path.basename(part_)
+ subfolder = save_path + part_name[:-4] # + "_compensated/"
+
+ if not os.path.exists(subfolder):
+ os.mkdir(subfolder)
+
+ np.savetxt(
+ os.path.join(subfolder, part_name[:-4] + "_cad.csv"),
+ pts1_full.cpu().numpy(),
+ fmt="%.8f",
+ delimiter=",",
+ )
+ log_string(LOG_FOUT, "Wrote to CAD mesh")
+
+ np.savetxt(
+ os.path.join(subfolder, part_name[:-4] + "_comp.csv"),
+ cad_comp_res,
+ fmt="%.8f",
+ delimiter=",",
+ )
+ log_string(
+ LOG_FOUT, f"Wrote to Compensated pointcloud: {part_name[:-4]}_comp.csv"
+ )
+
+ cad.vertices = cad_comp_res
+ cad.export(os.path.join(subfolder, part_name[:-4] + "_comp.stl"))
+ log_string(LOG_FOUT, f"Wrote to Compensated mesh: {part_name[:-4]}_comp.stl")
+
+ if cfg.inference_options.save_extra:
+ log_string(LOG_FOUT, "Saving extra files...")
+ np.savetxt(
+ os.path.join(subfolder, part_name[:-4] + "_pred.csv"),
+ cad_pred_res,
+ fmt="%.8f",
+ delimiter=",",
+ )
+ np.savetxt(
+ os.path.join(subfolder, part_name[:-4] + "_outp.csv"),
+ cad_outp_res,
+ fmt="%.8f",
+ delimiter=",",
+ )
+ log_string(
+ LOG_FOUT,
+ "Wrote to deviation from the original CAD, and deviation from the compensated CAD",
+ )
+
+ toc()
+
+
+if __name__ == "__main__":
+
+ with initialize(config_path="conf", job_name="test_app"):
+ cfg = compose(config_name="config", overrides=["+db=mysql", "+db.user=me"])
+
+ distributed_option = cfg.general.use_distributed
+ device = torch.device("cuda" if cfg.general.cuda else "cpu")
+
+ os.makedirs(cfg.inference_options.save_path, exist_ok=True)
+ main()
diff --git a/examples/additive_manufacturing/compensation/losses.py b/examples/additive_manufacturing/compensation/losses.py
new file mode 100644
index 0000000000..253edfd4a4
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/losses.py
@@ -0,0 +1,33 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+
+def l2_dist(pts1, pts2, reduction="mean"):
+ """
+ L2-loss compute, mean of all points' difference
+ """
+ l2_per_batch = torch.mean(torch.sum(torch.pow(pts1 - pts2, 2), -1), -1)
+ if reduction == "mean":
+ return torch.mean(l2_per_batch)
+ else:
+ return l2_per_batch
diff --git a/examples/additive_manufacturing/compensation/train_dis.py b/examples/additive_manufacturing/compensation/train_dis.py
new file mode 100644
index 0000000000..67bf4358c3
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/train_dis.py
@@ -0,0 +1,371 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+# test diff number of devices
+# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
+import numpy as np
+import torch
+import torch.distributed as distributed
+import torch.multiprocessing as mp
+import torch_geometric
+from dataloader import Bar, Ocardo
+from hydra import compose, initialize
+from losses import l2_dist
+from omegaconf import OmegaConf
+from pytorch3d.loss import chamfer_distance
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data.distributed import DistributedSampler
+from utils import log_string, tic, toc
+
+from physicsnemo.models.dgcnn.dgcnn_compensation import DGCNN, DGCNN_ocardo
+
+
+# @hydra.main(version_base=None, config_path="conf", config_name="conf")
+def main(rank):
+ """
+ :param rank: id of visible cuda devices, from 0, 1, ... for distributed training,
+ for each parallel run, i.e. rank: 0 ; rank: 1 ; ... etc.
+ :return:
+ """
+
+ # Read the configs
+ global dataset
+ with initialize(config_path="conf", job_name="test_app"):
+ cfg = compose(config_name="config", overrides=["+db=mysql", "+db.user=me"])
+ # define gpu id, dtype:int
+ device = rank
+ world_size = torch.cuda.device_count()
+ print("rank: ", device)
+
+ # Initialize and open the log file
+ LOG_FOUT = open(
+ os.path.join(cfg.train_dis_options.log_dir, "log_train_dis.txt"), "a"
+ )
+ # log_string(LOG_FOUT, OmegaConf.to_yaml(cfg))
+
+ # load data
+ log_string(LOG_FOUT, "Loading data: note it takes time")
+ # todo: optimize the dataloader to potentially one
+ if cfg.data_options.dataset_name == "Ocardo":
+ dataset = Ocardo(
+ data_path=cfg.data_options.data_path,
+ num_points=cfg.train_dis_options.num_points,
+ partition="train",
+ LOG_FOUT=LOG_FOUT,
+ )
+ elif cfg.data_options.dataset_name == "Bar":
+ dataset = Bar(
+ data_path=cfg.data_options.data_path,
+ num_points=cfg.train_dis_options.num_points,
+ partition="train",
+ LOG_FOUT=LOG_FOUT,
+ )
+ log_string(
+ LOG_FOUT, f"Complete data loading, size of the parts read: {len(dataset)}"
+ )
+ # todo: dataset not yet normailzed
+
+ # set up distributed training
+ if cfg.general.use_distributed:
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12345"
+ distributed.init_process_group("nccl", rank=device, world_size=world_size)
+
+ torch.backends.cudnn.deterministic = True
+ torch.manual_seed(cfg.general.seed)
+ torch.cuda.manual_seed_all(cfg.general.seed)
+ np.random.seed(cfg.general.seed)
+
+ tic()
+ # Initialize and open the log file
+ LOG_FOUT = open(
+ os.path.join(cfg.train_dis_options.log_dir, "log_train_dis.txt"), "a"
+ )
+
+ # dataset for train
+ train_dataset = dataset
+
+ # model initialization
+ model = DGCNN_ocardo() if cfg.data_options.dataset_name == "Ocardo" else DGCNN()
+ log_string(LOG_FOUT, "Initialize model .... \n\n")
+
+ if cfg.general.use_distributed:
+ print("use distributed multi-gpu train")
+ model = model.to(device)
+ model = DistributedDataParallel(model, device_ids=[device])
+
+ if cfg.general.sync_batch:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
+ # initialize the Sampler that restricts data loading to a subset of the dataset.
+ train_sampler = DistributedSampler(
+ train_dataset,
+ num_replicas=world_size,
+ rank=device,
+ shuffle=False,
+ drop_last=False,
+ )
+ train_loader = torch_geometric.loader.DataLoader(
+ train_dataset, sampler=train_sampler
+ )
+
+ elif cfg.general.use_multigpu:
+ # todo: define the difference between use_distributed and use_multigpu
+ print("use multi-gpus")
+ # dataloader must be a PyTorch_Geometric list loader
+ train_loader = torch_geometric.loader.DataListLoader(
+ train_dataset,
+ batch_size=cfg.train_dis_options.num_batch,
+ shuffle=True,
+ drop_last=True,
+ )
+ model = torch_geometric.nn.DataParallel(model).to(device)
+ else:
+ # Single Gpu training, or CPU training
+ train_loader = torch_geometric.loader.DataLoader(
+ train_dataset,
+ batch_size=cfg.train_dis_options.num_batch,
+ shuffle=True,
+ drop_last=True,
+ )
+ # todo: test single gpu working
+ # model = model.cuda()
+ model = model.to(device)
+
+ # In case of we have pre-trained setup
+ if cfg.train_dis_options.pretrain:
+ log_string(LOG_FOUT, "Update pre-trained model")
+ if cfg.general.use_distributed:
+ map_location = {"cuda:%d" % 0: "cuda:%d" % device}
+ model.load_state_dict(
+ torch.load(cfg.train_dis_options.model_path, map_location=map_location)
+ )
+ else:
+ model.load_state_dict(torch.load(cfg.train_dis_options.model_path))
+ # optimiser setting
+ optimizer = torch.optim.Adam(
+ model.parameters(), lr=cfg.train_dis_options.learning_rate
+ )
+ # todo: check what the scheduler does, whether can move milestone to config
+ # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1001,2001,3001],gamma=0.5)
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer, milestones=[50, 100, 200, 500, 1000, 1500, 2001, 3001], gamma=0.5
+ )
+
+ log_string(LOG_FOUT, "Start training ....... ")
+ for ep in range(cfg.train_dis_options.num_epoch):
+ # todo: log the epoch number
+ model.train()
+ # list of log parameters
+ total_train_loss = 0
+ total_chamfer_loss = 0
+ total_oloss = 0
+ total_ocham = 0
+ if device == 0 or not cfg.general.use_distributed:
+ tic()
+
+ # train
+ d_cnt = 0
+ for data in train_loader:
+ # the train_loader length is the length of the data samples (parts), i.e. 11 parts in training
+ d_cnt += 1
+ optimizer.zero_grad()
+ if cfg.general.cuda and not cfg.general.use_multigpu:
+ data = data.to(device)
+ pts1 = data.x
+ elif cfg.general.use_distributed:
+ data = data.to(device)
+ pts1 = data.x
+ elif cfg.general.cuda and cfg.general.use_multigpu:
+ pts1 = torch.cat([d.x for d in data]).reshape(
+ cfg.train_dis_options.num_batch, -1, 3
+ )
+ pts1 = pts1.to(device)
+ optimizer.zero_grad()
+
+ # model ouput
+ out = model(data)
+
+ if cfg.general.use_distributed:
+ pts2 = data.y
+ elif cfg.general.use_multigpu:
+ pts2 = (
+ torch.cat([d.y for d in data])
+ .to(out.device)
+ .reshape(cfg.train_dis_options.num_batch, -1, 3)
+ )
+ else:
+ pts2 = data.y
+ # shape consistency loss
+ # get the predicted distance v.s. the original distance
+ chamfer, _ = chamfer_distance(
+ out.reshape(cfg.train_dis_options.num_batch, -1, 3),
+ pts2.reshape(cfg.train_dis_options.num_batch, -1, 3),
+ )
+ o_chamfer, _ = chamfer_distance(
+ pts1.reshape(cfg.train_dis_options.num_batch, -1, 3),
+ pts2.reshape(cfg.train_dis_options.num_batch, -1, 3),
+ )
+
+ # L1 or L2 distance - dimensional free errors for Naive torch implementation
+ o_loss = l2_dist(
+ pts1.reshape(cfg.train_dis_options.num_batch, -1, 3),
+ pts2.reshape(cfg.train_dis_options.num_batch, -1, 3),
+ )
+ l2_loss = l2_dist(
+ out.reshape(cfg.train_dis_options.num_batch, -1, 3),
+ pts2.reshape(cfg.train_dis_options.num_batch, -1, 3),
+ )
+
+ # loss to backpropagate (weighted to chamfer)
+ loss = l2_loss + chamfer
+ loss.backward()
+
+ # tracking loss
+ total_train_loss += loss.item() - chamfer.item()
+ total_chamfer_loss += chamfer.item()
+ total_oloss += o_loss
+ total_ocham += o_chamfer
+ optimizer.step()
+
+ # syncronise after
+ if cfg.general.use_distributed:
+ distributed.barrier()
+
+ total_avg_train_loss = total_train_loss / (
+ cfg.train_dis_options.num_batch * len(train_loader)
+ )
+ total_avg_chamfer_loss = total_chamfer_loss / (
+ cfg.train_dis_options.num_batch * len(train_loader)
+ )
+ total_avg_oloss = total_oloss / (
+ cfg.train_dis_options.num_batch * len(train_loader)
+ )
+ total_avg_ocham = total_ocham / (
+ cfg.train_dis_options.num_batch * len(train_loader)
+ )
+ if device == 0 or not cfg.general.use_distributed:
+ log_string(
+ LOG_FOUT,
+ "[Epoch %03d] training loss: %.6f, chamfer loss: %.6f, reference1: %.6f, reference2: %.6f"
+ % (
+ ep,
+ total_avg_train_loss,
+ total_avg_chamfer_loss,
+ total_avg_oloss,
+ total_avg_ocham,
+ ),
+ )
+ toc()
+ tic()
+
+ # data save
+ if device == 0 and ep % cfg.train_dis_options.saving_ep_step == 0:
+ print("save weights at epoch %03d" % ep)
+ os.makedirs(cfg.train_gen_options.save_path, exist_ok=True)
+
+ # save
+ if cfg.general.use_distributed:
+ torch.save(
+ model.state_dict(),
+ cfg.train_dis_options.save_path + "pred_model_%04d.pth" % ep,
+ )
+ elif not cfg.general.use_distributed and cfg.general.use_multigpu:
+ torch.save(
+ model.module.state_dict(),
+ cfg.train_dis_options.save_path + "pred_model_%04d.pth" % ep,
+ )
+ else:
+ torch.save(
+ model.state_dict(),
+ cfg.train_dis_options.save_path + "pred_model_%04d.pth" % ep,
+ )
+
+ if not os.path.exists(
+ os.path.join(cfg.train_dis_options.save_path, "results")
+ ):
+ os.mkdir(os.path.join(cfg.train_dis_options.save_path, "results"))
+ np.savetxt(
+ os.path.join(
+ cfg.train_dis_options.save_path, "results/dis_cad__%02d.csv" % ep
+ ),
+ pts1.cpu().reshape(cfg.train_dis_options.num_batch, -1, 3).numpy()[0],
+ fmt="%.8f",
+ delimiter=",",
+ )
+ np.savetxt(
+ os.path.join(
+ cfg.train_dis_options.save_path, "results/dis_scan_%02d.csv" % ep
+ ),
+ pts2.cpu().reshape(cfg.train_dis_options.num_batch, -1, 3).numpy()[0],
+ fmt="%.8f",
+ delimiter=",",
+ )
+ np.savetxt(
+ os.path.join(
+ cfg.train_dis_options.save_path, "results/dis_out__%02d.csv" % ep
+ ),
+ out.detach()
+ .cpu()
+ .reshape(cfg.train_dis_options.num_batch, -1, 3)
+ .numpy()[0],
+ fmt="%.8f",
+ delimiter=",",
+ )
+ scheduler.step()
+ # end training
+ LOG_FOUT.close()
+ if cfg.general.use_distributed:
+ distributed.destroy_process_group()
+
+
+if __name__ == "__main__":
+ with initialize(config_path="conf", job_name="test_app"):
+ cfg = compose(config_name="config", overrides=["+db=mysql", "+db.user=me"])
+
+ distributed_option = cfg.general.use_distributed
+ device = torch.device("cuda" if cfg.general.cuda else "cpu")
+ os.makedirs(cfg.train_dis_options.log_dir, exist_ok=True)
+
+ # todo: add test case, if the log already exist, exit, and remind to rename
+ LOG_FOUT = open(
+ os.path.join(cfg.train_dis_options.log_dir, "log_train_dis.txt"), "a"
+ )
+ log_string(LOG_FOUT, OmegaConf.to_yaml(cfg))
+
+ # run model based on single // data parallel // distributed data parallel
+ if distributed_option:
+ if torch.cuda.is_available():
+ print("A GPU is available!")
+ else:
+ print("No GPU available.")
+ # Get the number of available GPUs
+ world_size = torch.cuda.device_count()
+ log_string(
+ LOG_FOUT,
+ f"distributed_option: data parallel / Cuda device cnt: {world_size}",
+ )
+ mp.spawn(main, nprocs=world_size, join=True)
+ else:
+ log_string(LOG_FOUT, "distributed_option: false\n")
+ main(device)
diff --git a/examples/additive_manufacturing/compensation/train_gen.py b/examples/additive_manufacturing/compensation/train_gen.py
new file mode 100644
index 0000000000..c6fe0e6c9b
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/train_gen.py
@@ -0,0 +1,411 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
+import numpy as np
+import torch
+import torch.distributed as distributed
+import torch.multiprocessing as mp
+import torch_geometric
+from dataloader import Bar, Ocardo
+from hydra import compose, initialize
+from losses import l2_dist
+from omegaconf import OmegaConf
+from pytorch3d.loss import chamfer_distance
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data.distributed import DistributedSampler
+from utils import log_string, tic, toc
+
+from physicsnemo.models.dgcnn.dgcnn_compensation import DGCNN, DGCNN_ocardo
+
+
+def main(rank):
+ # def main(rank, world_size, dataset,args):
+ """
+
+ :param rank: number of visible cuda devices, from 0, 1, .. for distributed training
+ :return:
+ """
+
+ # Read the configs
+ with initialize(config_path="conf", job_name="test_app"):
+ cfg = compose(config_name="config", overrides=["+db=mysql", "+db.user=me"])
+ # define gpu id, dtype:int
+ device = rank
+ world_size = torch.cuda.device_count()
+ print("rank: ", device)
+
+ # Initialize and open the log file
+ LOG_FOUT = open(
+ os.path.join(cfg.train_gen_options.log_dir, "log_train_gen.txt"), "a"
+ )
+ log_string(LOG_FOUT, OmegaConf.to_yaml(cfg))
+
+ # load data
+ log_string(LOG_FOUT, "load data: note it takes time")
+ if cfg.data_options.dataset_name == "Ocardo":
+ dataset = Ocardo(
+ data_path=cfg.data_options.data_path,
+ num_points=cfg.train_dis_options.num_points,
+ partition="train",
+ )
+ elif cfg.data_options.dataset_name == "Bar":
+ dataset = Bar(
+ data_path=cfg.data_options.data_path,
+ num_points=cfg.train_dis_options.num_points,
+ partition="train",
+ )
+ print("size of the data %d" % len(dataset))
+
+ # set up distributed training
+ if cfg.general.use_distributed:
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12355"
+ distributed.init_process_group("nccl", rank=device, world_size=world_size)
+
+ torch.backends.cudnn.deterministic = True
+ torch.manual_seed(cfg.general.seed)
+ torch.cuda.manual_seed_all(cfg.general.seed)
+ np.random.seed(cfg.general.seed)
+
+ # generator
+ train_dataset = dataset
+
+ if cfg.data_options.dataset_name == "Ocardo":
+ generator = DGCNN_ocardo()
+ discriminator = DGCNN_ocardo()
+ else:
+ generator = DGCNN()
+ discriminator = DGCNN()
+ log_string(LOG_FOUT, "Initialize model .... \n\n")
+ # names are analogous to generative adversarial network
+ # note that IT IS NOT A GAN!!!!! it is just an analogy!
+
+ if cfg.general.use_distributed:
+ print("use distributed multi-gpu train")
+ # dataloader
+ train_sampler = DistributedSampler(
+ train_dataset,
+ num_replicas=world_size,
+ rank=rank,
+ shuffle=True,
+ drop_last=True,
+ )
+ train_loader = torch_geometric.loader.DataLoader(
+ train_dataset, sampler=train_sampler
+ )
+
+ # generator
+ generator = generator.to(rank)
+ generator = DistributedDataParallel(generator, device_ids=[rank])
+ map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
+ generator.load_state_dict(
+ torch.load(cfg.train_gen_options.gen_model_path, map_location=map_location)
+ )
+
+ # discriminator
+ discriminator = discriminator.to(rank)
+ discriminator = DistributedDataParallel(discriminator, device_ids=[rank])
+ map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
+ # Load model
+ discriminator.load_state_dict(
+ torch.load(cfg.train_gen_options.pred_model_path, map_location=map_location)
+ )
+ elif cfg.general.use_multigpu:
+ print("use multi-gpus")
+ # todo: check elif , else conditions are same
+ # dataloader init
+ train_loader = torch_geometric.loader.DataListLoader(
+ train_dataset,
+ batch_size=cfg.train_gen_options.num_batch,
+ shuffle=True,
+ drop_last=True,
+ )
+
+ # generator
+ generator.load_state_dict(
+ torch.load(cfg.train_gen_options.gen_model_path, map_location="cpu")
+ )
+ generator = torch_geometric.nn.DataParallel(generator).cuda()
+
+ # discriminator
+ discriminator.load_state_dict(
+ torch.load(cfg.train_gen_options.pred_model_path, map_location="cpu")
+ )
+ discriminator = torch_geometric.nn.DataParallel(discriminator).cuda()
+ else:
+ # dataloader
+ train_loader = torch_geometric.data.DataLoader(
+ train_dataset,
+ batch_size=cfg.train_gen_options.num_batch,
+ shuffle=True,
+ drop_last=True,
+ )
+
+ # generator
+ generator.load_state_dict(
+ torch.load(cfg.train_gen_options.gen_model_path, map_location="cpu")
+ )
+ # todo: test single gpu working
+ # generator = generator.cuda()
+ generator = generator.to(device)
+
+ # discriminator
+ discriminator.load_state_dict(
+ torch.load(cfg.train_gen_options.pred_model_path, map_location="cpu")
+ )
+ # discriminator = discriminator.cuda()
+ discriminator = discriminator.to(device)
+
+ # freeze weights for discriminator
+ for p in discriminator.parameters():
+ p.requires_grad = False # to avoid computation
+
+ optimizer = torch.optim.Adam(
+ generator.parameters(), lr=cfg.train_gen_options.learning_rate
+ )
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer, milestones=[400, 800, 1200, 1600, 2500], gamma=0.5
+ )
+ steps = 250
+
+ log_string(LOG_FOUT, "Start training ....... ")
+
+ for ep in range(cfg.train_gen_options.num_epoch):
+ total_train_loss = 0
+ total_chamfer_loss = 0
+ total_oloss = 0
+ total_ocham = 0
+ if rank == 0 or not cfg.general.use_distributed:
+ tic()
+
+ # train
+ for data in train_loader:
+ if cfg.general.use_distributed:
+ data = data.to(device)
+ pts1 = data.x
+ pts2 = data.y.cpu()
+ edge_index = data.edge_index
+ elif cfg.general.cuda and not cfg.general.use_multigpu:
+ data = data.to(device)
+ pts1 = data.x
+ # todo: why need to load to cpu
+ pts2 = data.y.cpu()
+ edge_index = data.edge_index
+ elif cfg.general.use_multigpu:
+ pts1 = (
+ torch.cat([d.x for d in data])
+ .reshape(cfg.train_gen_options.num_batch, -1, 3)
+ .reshape(cfg.train_gen_options.num_batch, -1, 3)
+ )
+ pts2 = (
+ torch.cat([d.y for d in data])
+ .reshape(cfg.train_gen_options.num_batch, -1, 3)
+ .reshape(cfg.train_gen_options.num_batch, -1, 3)
+ )
+ pts1 = pts1.to(device)
+ pts2 = pts1.to(device)
+
+ optimizer.zero_grad()
+
+ # compensation
+ com = generator(data)
+ if cfg.general.use_distributed:
+ compensated_data = torch_geometric.data.Data(
+ x=com, edge_index=edge_index
+ )
+ elif cfg.general.use_multigpu:
+ # it has be a list of graph data
+ compensated_data = []
+ tmp = com.reshape(cfg.train_gen_options.num_batch, -1, 3)
+ for ii in range(cfg.train_gen_options.num_batch):
+ d = torch_geometric.data.Data(
+ x=tmp[ii], edge_index=data[ii].edge_index.cuda()
+ )
+ compensated_data.append(d)
+ else:
+ compensated_data = torch_geometric.data.Data(
+ x=com, edge_index=edge_index
+ )
+ # evaluate deformation
+ out = discriminator(compensated_data)
+
+ # reshape for metric computation
+ if cfg.general.cuda and not cfg.general.use_multigpu:
+ pts1 = pts1.reshape(cfg.train_gen_options.num_batch, -1, 3)
+ pts2 = pts2.reshape(cfg.train_gen_options.num_batch, -1, 3)
+
+ # metric (loss fun)
+ # Compute chamfer_distance of the input CAD - D(G(compensated))
+ chamfer, _ = chamfer_distance(
+ out.reshape(cfg.train_gen_options.num_batch, -1, 3),
+ pts1.reshape(cfg.train_gen_options.num_batch, -1, 3),
+ )
+ # Compute chamfer_distance of the input CAD - G(compensated)
+ o_chamfer, _ = chamfer_distance(
+ com.data.reshape(cfg.train_gen_options.num_batch, -1, 3),
+ pts1.reshape(cfg.train_gen_options.num_batch, -1, 3),
+ )
+ # Compute l2_dist
+ l2_loss = l2_dist(
+ out.reshape(cfg.train_gen_options.num_batch, -1, 3),
+ pts1.reshape(cfg.train_gen_options.num_batch, -1, 3),
+ )
+ o_loss = (
+ l2_dist(
+ com.data.reshape(cfg.train_gen_options.num_batch, -1, 3),
+ pts1.reshape(cfg.train_gen_options.num_batch, -1, 3),
+ )
+ .cpu()
+ .numpy()
+ )
+
+ # Min the loss as input CAD - D(G(compensated))
+ loss = l2_loss + chamfer # *2
+ loss.backward()
+
+ total_train_loss += loss.item() - chamfer.item() # *2
+ total_chamfer_loss += chamfer.item()
+ total_oloss += o_loss
+ total_ocham += o_chamfer
+
+ optimizer.step()
+
+ # syncronise after
+ if cfg.general.use_distributed:
+ distributed.barrier()
+ total_avg_train_loss = total_train_loss / (
+ cfg.train_gen_options.num_batch * len(train_loader)
+ )
+ total_avg_chamfer_loss = total_chamfer_loss / (
+ cfg.train_gen_options.num_batch * len(train_loader)
+ )
+ total_avg_oloss = total_oloss / (
+ cfg.train_gen_options.num_batch * len(train_loader)
+ )
+ total_avg_ocham = total_ocham / (
+ cfg.train_gen_options.num_batch * len(train_loader)
+ )
+ if rank == 0 or not cfg.general.use_distributed:
+ log_string(
+ LOG_FOUT,
+ "[Epoch %03d] training loss: %.6f, chamfer loss: %.6f, reference1: %.6f, reference2: %.6f"
+ % (
+ ep,
+ total_avg_train_loss,
+ total_avg_chamfer_loss,
+ total_avg_oloss,
+ total_avg_ocham,
+ ),
+ )
+ toc()
+ tic()
+ # data save
+ if rank == 0 and ep % steps == 0:
+ log_string(LOG_FOUT, "save weights at epoch %03d" % ep)
+ os.makedirs(cfg.train_gen_options.save_path, exist_ok=True)
+
+ # save
+ if cfg.general.use_distributed:
+ torch.save(
+ generator.state_dict(),
+ cfg.train_gen_options.save_path + "gen_model_%04d.pth" % ep,
+ )
+ elif cfg.general.use_multigpu:
+ torch.save(
+ generator.module.state_dict(),
+ cfg.train_gen_options.save_path + "gen_model_%04d.pth" % ep,
+ )
+ else:
+ torch.save(
+ generator.state_dict(),
+ cfg.train_gen_options.save_path + "gen_model_%04d.pth" % ep,
+ )
+
+ os.makedirs(
+ os.path.join(cfg.train_gen_options.save_path, "results2"), exist_ok=True
+ )
+
+ np.savetxt(
+ os.path.join(
+ cfg.train_gen_options.save_path, "results2/cad__%02d.csv" % ep
+ ),
+ pts1.cpu().reshape(cfg.train_gen_options.num_batch, -1, 3).numpy()[0],
+ fmt="%.8f",
+ delimiter=",",
+ )
+ np.savetxt(
+ os.path.join(
+ cfg.train_gen_options.save_path, "results2/scan_%02d.csv" % ep
+ ),
+ pts2.cpu().reshape(cfg.train_gen_options.num_batch, -1, 3).numpy()[0],
+ fmt="%.8f",
+ delimiter=",",
+ )
+ np.savetxt(
+ os.path.join(
+ cfg.train_gen_options.save_path, "results2/comp_%02d.csv" % ep
+ ),
+ com.detach()
+ .cpu()
+ .reshape(cfg.train_gen_options.num_batch, -1, 3)
+ .numpy()[0],
+ fmt="%.8f",
+ delimiter=",",
+ )
+ np.savetxt(
+ os.path.join(
+ cfg.train_gen_options.save_path, "results2/out_%02d.csv" % ep
+ ),
+ out.detach()
+ .cpu()
+ .reshape(cfg.train_gen_options.num_batch, -1, 3)
+ .numpy()[0],
+ fmt="%.8f",
+ delimiter=",",
+ )
+ scheduler.step()
+
+ # end training
+ LOG_FOUT.close()
+ if cfg.general.use_distributed:
+ distributed.destroy_process_group()
+
+
+if __name__ == "__main__":
+ with initialize(config_path="conf", job_name="test_app"):
+ cfg = compose(config_name="config", overrides=["+db=mysql", "+db.user=me"])
+
+ distributed_option = cfg.general.use_distributed
+ device = torch.device("cuda" if cfg.general.cuda else "cpu")
+
+ # run model based on single // data parallel // distributed data parallel
+ if distributed_option:
+ print("distributed data parallel ")
+ world_size = torch.cuda.device_count()
+ print("Cuda device cnt: ", world_size)
+ # mp.spawn(main, args=(world_size, dataset, param), nprocs=world_size, join=True)
+ mp.spawn(main, nprocs=world_size, join=True)
+ else:
+ # main(device,world_size,dataset,param)
+ main(device, cfg)
diff --git a/examples/additive_manufacturing/compensation/utils.py b/examples/additive_manufacturing/compensation/utils.py
new file mode 100644
index 0000000000..6fcaad5585
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/utils.py
@@ -0,0 +1,56 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import time
+
+
+# time measure
+def TicTocGenerator():
+ # Generator that returns time differences
+ ti = 0 # initial time
+ tf = time.time() # final time
+ while True:
+ ti = tf
+ tf = time.time()
+ yield tf - ti # returns the time difference
+
+
+TicToc = TicTocGenerator() # create an instance of the TicTocGen generator
+
+# This will be the main function through which Convolutional filters (Translation invariance+Self-similarity)
+
+
+def toc(tempBool=True):
+ # Prints the time difference yielded by generator instance TicToc
+ tempTimeInterval = next(TicToc)
+ if tempBool:
+ print("Elapsed time: %f seconds.\n" % tempTimeInterval)
+
+
+def tic():
+ # Records a time in TicToc, marks the beginning of a time interval
+ toc(False)
+
+
+def log_string(LOG_FOUT, out_str):
+ LOG_FOUT.write(out_str + "\n")
+ LOG_FOUT.flush()
+ print(out_str)
diff --git a/examples/additive_manufacturing/compensation/visualization.py b/examples/additive_manufacturing/compensation/visualization.py
new file mode 100644
index 0000000000..a65b39be4c
--- /dev/null
+++ b/examples/additive_manufacturing/compensation/visualization.py
@@ -0,0 +1,142 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import matplotlib.pyplot as plt
+import numpy as np
+import open3d as o3d
+import torch
+from mpl_toolkits.mplot3d.art3d import Poly3DCollection
+from pytorch3d.loss import chamfer_distance
+
+
+def stl_to_vertices_and_faces(file_path):
+ """
+ Function to load STL and convert to vertices and triangles
+ """
+ mesh = o3d.io.read_triangle_mesh(file_path)
+ if mesh.is_empty():
+ raise ValueError(f"Mesh at {file_path} is empty or invalid.")
+ verts = np.asarray(mesh.vertices)
+ faces = np.asarray(mesh.triangles)
+ return verts, faces
+
+
+def plot_mesh(mesh, ax, face_color, edge_color, label):
+ """
+ Function to plot meshes
+ """
+ vertices = np.asarray(mesh.vertices)
+ triangles = np.asarray(mesh.triangles)
+ # Add faces with transparency
+ ax.add_collection3d(
+ Poly3DCollection(
+ vertices[triangles], facecolor=face_color, edgecolor=edge_color, alpha=0.3
+ )
+ )
+ ax.scatter(
+ vertices[:, 0], vertices[:, 1], vertices[:, 2], color=edge_color, s=0.1
+ ) # Points
+ ax.set_title(label, fontsize=12)
+ ax.set_box_aspect([1, 1, 1]) # Equal scaling
+ ax.grid(True)
+
+
+# RMS Calculation
+def calculate_rms(source, target):
+ """Calculate Root Mean Square (RMS) error."""
+ diff = source - target
+ return torch.sqrt(torch.mean(diff**2))
+
+
+# Sample usage: Load STL files
+target_mesh_o3d = o3d.io.read_triangle_mesh("/content/cad_4.stl")
+uncompensated_mesh_o3d = o3d.io.read_triangle_mesh("/content/cad_4.stl")
+compensated_mesh_o3d = o3d.io.read_triangle_mesh("/content/cad_4.stl")
+
+# Load STL files for Chamfer Distance and RMS calculation
+target_verts, _ = stl_to_vertices_and_faces("/content/cad_4.stl")
+uncompensated_verts, _ = stl_to_vertices_and_faces("/content/cad_4.stl")
+compensated_verts, _ = stl_to_vertices_and_faces("/content/cad_4.stl")
+
+# Convert vertices to PyTorch tensors
+target_verts_tensor = torch.tensor(target_verts, dtype=torch.float32).unsqueeze(0)
+uncompensated_verts_tensor = torch.tensor(
+ uncompensated_verts, dtype=torch.float32
+).unsqueeze(0)
+compensated_verts_tensor = torch.tensor(
+ compensated_verts, dtype=torch.float32
+).unsqueeze(0)
+
+# Chamfer Distance
+cd_uncomp, _ = chamfer_distance(uncompensated_verts_tensor, target_verts_tensor)
+cd_comp, _ = chamfer_distance(compensated_verts_tensor, target_verts_tensor)
+
+# RMS Error
+min_len = min(len(uncompensated_verts), len(target_verts))
+uncomp_rms = calculate_rms(
+ uncompensated_verts_tensor[:, :min_len, :], target_verts_tensor[:, :min_len, :]
+)
+comp_rms = calculate_rms(
+ compensated_verts_tensor[:, :min_len, :], target_verts_tensor[:, :min_len, :]
+)
+
+# Fitness (normalized metric based on distances)
+fitness_uncomp = 1 - cd_uncomp.item()
+fitness_comp = 1 - cd_comp.item()
+
+# Print evaluation metrics
+print(f"Chamfer Distance (Uncompensated): {cd_uncomp.item()}")
+print(f"Chamfer Distance (Compensated): {cd_comp.item()}")
+print(f"RMS Error (Uncompensated): {uncomp_rms.item()}")
+print(f"RMS Error (Compensated): {comp_rms.item()}")
+print(f"Fitness (Uncompensated): {fitness_uncomp}")
+print(f"Fitness (Compensated): {fitness_comp}")
+
+# Visualization with enhanced plot
+fig = plt.figure(figsize=(18, 6))
+ax1 = fig.add_subplot(131, projection="3d")
+ax2 = fig.add_subplot(132, projection="3d")
+ax3 = fig.add_subplot(133, projection="3d")
+
+plot_mesh(
+ target_mesh_o3d,
+ ax1,
+ face_color="lightcoral",
+ edge_color="red",
+ label="Target\n(Desired Shape)",
+)
+plot_mesh(
+ uncompensated_mesh_o3d,
+ ax2,
+ face_color="lightgreen",
+ edge_color="green",
+ label=f"Uncompensated\nCD: {cd_uncomp.item():.4f}\nRMS: {uncomp_rms.item():.4f}",
+)
+plot_mesh(
+ compensated_mesh_o3d,
+ ax3,
+ face_color="lightblue",
+ edge_color="blue",
+ label=f"Compensated\nCD: {cd_comp.item():.4f}\nRMS: {comp_rms.item():.4f}",
+)
+
+plt.tight_layout()
+plt.show()
diff --git a/examples/additive_manufacturing/sintering_physics/README.md b/examples/additive_manufacturing/sintering_physics/README.md
index 44eb57854f..dcfb602379 100644
--- a/examples/additive_manufacturing/sintering_physics/README.md
+++ b/examples/additive_manufacturing/sintering_physics/README.md
@@ -125,6 +125,10 @@ Then run:
python train.py
```
+### Access to the pre-trained model checkpoint
+
+[Download the pre-trained checkpoint in the paper](https://drive.google.com/file/d/1vxgigx0jz81EhD97uFDZWIdnwnr9IkBp/view?usp=drive_link)
+
## Visualize test result
Change the params in conf/config.yaml:
@@ -158,6 +162,8 @@ python inference.py
## Data
+[To download the pre-processed deformation simulation data in the paper](https://drive.google.com/drive/folders/1TRbj9vg1095aKIcr1izVfXT2DJSjb8jA?usp=drive_link)
+
- Test data
- Same voxel resolution as train
@@ -203,9 +209,26 @@ and scaling for different process parameter configurations.
## Reference
+[Virtual Foundry Graphnet
+for Predicting Metal Sintering Deformation](https://sensors.myu-group.co.jp/sm_pdf/SM3704.pdf)
+
[Learning to Simulate Complex Physics with Graph Networks](https://arxiv.org/abs/2002.09405)
```text
+@article{Chen_2024,
+ title={Virtual Foundry Graphnet for Predicting Metal Sintering Deformation},
+ volume={36},
+ ISSN={2435-0869},
+ url={http://dx.doi.org/10.18494/SAM4883},
+ DOI={10.18494/sam4883},
+ number={7},
+ journal={Sensors and Materials},
+ publisher={MYU K.K.},
+ author={Chen, Rachel (Lei) and Gan, Chuang and Lee, Juheon and Yang,
+ Zijiang and Nabian, Mohammad Amin and Zeng, Jun},
+ year={2024},
+ month=jul, pages={2835} }
+
@inproceedings{sanchezgonzalez2020learning,
title={Learning to Simulate Complex Physics with Graph Networks},
author={Alvaro Sanchez-Gonzalez and
diff --git a/physicsnemo/models/dgcnn/__init__.py b/physicsnemo/models/dgcnn/__init__.py
new file mode 100644
index 0000000000..aa32b1307f
--- /dev/null
+++ b/physicsnemo/models/dgcnn/__init__.py
@@ -0,0 +1,17 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .dgcnn_compensation import DGCNN, DGCNN_ocardo
diff --git a/physicsnemo/models/dgcnn/dgcnn_compensation.py b/physicsnemo/models/dgcnn/dgcnn_compensation.py
new file mode 100644
index 0000000000..2b520c79b2
--- /dev/null
+++ b/physicsnemo/models/dgcnn/dgcnn_compensation.py
@@ -0,0 +1,212 @@
+# ignore_header_test
+# ruff: noqa: E402
+
+# © Copyright 2023 HP Development Company, L.P.
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+from torch.nn import Linear as Lin
+from torch.nn import Sequential as Seq
+from torch_geometric.nn import EdgeConv, knn_graph
+
+import physicsnemo # noqa: F401 for docs
+
+from ..meta import ModelMetaData
+from ..module import Module
+
+
+@dataclass
+class MetaData(ModelMetaData):
+ name: str = "GraphComPNet"
+ # Optimization
+ jit: bool = False
+ cuda_graphs: bool = True
+ amp_cpu: bool = False # Reflect padding not supported in bfloat16
+ amp_gpu: bool = False
+ # Inference
+ onnx_cpu: bool = False
+ onnx_gpu: bool = False
+ onnx_runtime: bool = False
+ # Physics informed
+ var_dim: int = 1
+ func_torch: bool = False
+ auto_grad: bool = False
+
+
+def MLP(channels, batch_norm=True):
+ """
+ Set up the MLP layer with NN.linear
+
+ :param channels: channel[0]:in_features
+ channel[1]:out_features
+ :param batch_norm:
+ :return: nn.Sequentially structured MLP model
+ """
+ return nn.Sequential(
+ *[
+ nn.Sequential(
+ nn.Linear(channels[i - 1], channels[i]), nn.ReLU()
+ ) # , nn.BatchNorm1d(channels[i]))
+ for i in range(1, len(channels))
+ ]
+ )
+
+
+class DynamicEdgeConv2(EdgeConv):
+ """
+ A modified pytorch-geometric implementation of EdgeConv:
+ https://pytorch-geometric.readthedocs.io/en/2.6.0/generated/torch_geometric.nn.conv.EdgeConv.html
+ Original Paper: https://arxiv.org/abs/1801.07829
+ """
+
+ def __init__(self, nn, k, aggr="max", **kwargs):
+ """
+
+ :param nn: network architecture
+ :param k:
+ :param aggr:
+ :param kwargs:
+ """
+ super(DynamicEdgeConv2, self).__init__(nn=nn, aggr=aggr, **kwargs)
+
+ if knn_graph is None:
+ raise ImportError("`DynamicEdgeConv` requires `torch-cluster`.")
+
+ self.k = k
+
+ def forward(self, x, edge_index):
+ # edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow, cosine=True)
+ return super(DynamicEdgeConv2, self).forward(x, edge_index)
+
+ def __repr__(self):
+ return "{}(nn={}, k={})".format(self.__class__.__name__, self.nn, self.k)
+
+
+class DGCNN(Module):
+ """
+ A modified of EdgeConv blocks with the DGCNN backbone.
+ Applies convolution to the edge features.
+
+ Parameters
+ ----------
+ 1st EdgeConv MLP: input feature channels=3*2, output feature channels=64
+ 2nd EdgeConv MLP: input feature channels=64*2, output feature channels=128
+ 3rd EdgeConv MLP: input feature channels=128*2, output feature channels=512
+
+ Example
+ -------
+ >>> model = modulus.models.dgcnn.DGCNN()
+ >>> sample_pts = torch.randn(10000, 3)
+ >>> edge_index = torch_geometric.nn.knn_graph(torch.FloatTensor(sample_pts), 10)
+ >>> sample_pts_data = torch_geometric.data.Data(x=sample_pts, edge_index=edge_index)
+ >>> compensated_output = model(sample_pts_data)
+ >>> compensated_output.size()
+
+ Note
+ ----
+ Reference of DGCNN backbone: https://arxiv.org/pdf/1801.07829
+ """
+
+ def __init__(
+ self,
+ k: int = 20,
+ aggr: str = "max",
+ ):
+ if not (k >= 0):
+ raise ValueError("Invalid arch params")
+ super().__init__(meta=MetaData(name="dgcnn"))
+
+ self.conv1 = DynamicEdgeConv2(MLP([3 * 2, 64]), k, aggr)
+ self.conv2 = DynamicEdgeConv2(MLP([64 * 2, 128]), k, aggr)
+ self.conv3 = DynamicEdgeConv2(MLP([128 * 2, 512]), k, aggr)
+
+ self.lin1 = Seq(
+ MLP([512, 256]), # Dropout(0.2), #MLP([512,256]), Dropout(0.2),
+ Lin(256, 3),
+ )
+
+ def forward(self, data):
+ x, edge_index = data.x, data.edge_index
+ x1 = self.conv1(x, edge_index)
+ x2 = self.conv2(x1, edge_index)
+ x3 = self.conv3(x2, edge_index)
+ x4 = self.lin1(x3)
+ return x + x4
+
+
+class DGCNN_ocardo(Module):
+ """
+ Variation of EdgeConv blocks with the DGCNN backbone: https://arxiv.org/pdf/1801.07829
+ Model architecture tuned for optimal performance on the Orcardo dataset
+
+ Parameters
+ ----------
+ 1st EdgeConv MLP: input feature channels=3*2, output feature channels=64
+ 2nd EdgeConv MLP: input feature channels=64*2, output feature channels=64
+ 3rd EdgeConv MLP: input feature channels=64*2, output feature channels=64
+ 4th EdgeConv MLP: input feature channels=64*2, output feature channels=64
+ 5th EdgeConv MLP: input feature channels=64*2, output feature channels=64
+
+ Aggregation func of the last layer: Max
+
+ Note
+ ----
+ Reference of DGCNN backbone: https://arxiv.org/pdf/1801.07829
+ """
+
+ def __init__(
+ self,
+ k: int = 5,
+ aggr: str = "max",
+ ):
+ if not (k >= 0):
+ raise ValueError("Invalid arch params")
+ super().__init__(meta=MetaData(name="dgcnn_orcardo"))
+
+ self.conv1 = DynamicEdgeConv2(MLP([3 * 2, 64]), k, aggr)
+ self.conv2 = DynamicEdgeConv2(MLP([64 * 2, 64]), k, aggr)
+ self.conv3 = DynamicEdgeConv2(MLP([64 * 2, 64]), k, aggr)
+ self.conv4 = DynamicEdgeConv2(MLP([64 * 2, 64]), k, aggr)
+ self.conv5 = DynamicEdgeConv2(MLP([64 * 2, 64]), k, aggr)
+
+ self.lin1 = Seq(
+ MLP([128, 128]), # Dropout(0.2), #MLP([512,256]), Dropout(0.2),
+ # MLP([256, 256]),# Dropout(0.2), #MLP([512,256]), Dropout(0.2),
+ # Lin(256, 3)
+ Lin(128, 3),
+ )
+
+ def forward(self, data):
+ x, edge_index = data.x, data.edge_index
+ n_pts, _ = x.shape
+ x1 = self.conv1(x, edge_index)
+ x2 = self.conv2(x1, edge_index)
+ x3 = self.conv3(x2, edge_index)
+ x4 = self.conv4(x3, edge_index)
+ x5 = self.conv5(x4, edge_index)
+ # extract global feature
+ globals = torch.max(x5, 0, keepdim=True)[0]
+ # concat local and global features
+ feat = torch.cat([x5, globals.repeat(n_pts, 1)], 1)
+ # MLP feature
+ x6 = self.lin1(feat)
+ # residual connection
+ return x + x6
diff --git a/test/models/test_compensation_net.py b/test/models/test_compensation_net.py
new file mode 100644
index 0000000000..088fcdf668
--- /dev/null
+++ b/test/models/test_compensation_net.py
@@ -0,0 +1,88 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+
+import pytest
+import torch
+import torch_geometric
+
+from physicsnemo.models.dgcnn.dgcnn_compensation import DGCNN
+
+from . import common
+
+
+@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
+@pytest.mark.parametrize("knn_cnt", [5, 20])
+@pytest.mark.parametrize("sample_pts", [100, 1000, 10000])
+def test_dgcnn_forward(device, knn_cnt, sample_pts):
+ """Test model forward pass"""
+ torch.manual_seed(0)
+ # Construct dgcnn model
+ model = DGCNN(k=knn_cnt, aggr="max").to(device)
+
+ bsize = 2
+ in_pts = torch.randn(bsize, sample_pts, 3).to(device)
+ edge_index = torch_geometric.nn.knn_graph(torch.FloatTensor(in_pts), knn_cnt)
+ invar = torch_geometric.data.Data(x=in_pts, edge_index=edge_index)
+ assert common.validate_forward_accuracy(
+ model,
+ (invar,),
+ file_name=f"dgcnn_k{knn_cnt}_pts{sample_pts}_output.pth",
+ atol=1e-4,
+ )
+
+
+@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
+@pytest.mark.parametrize("knn_cnt", [5, 20])
+@pytest.mark.parametrize("sample_pts", [100, 1000, 10000])
+def test_dgcnn_checkpoint(device, knn_cnt, sample_pts):
+ """Test model checkpoint save/load"""
+ torch.manual_seed(0)
+ # Construct dgcnn model
+ model_1 = DGCNN(k=knn_cnt, aggr="max").to(device)
+
+ model_2 = DGCNN(k=knn_cnt, aggr="max").to(device)
+
+ bsize = random.randint(1, 2)
+ in_pts = torch.randn(bsize, sample_pts, 3).to(device)
+ edge_index = torch_geometric.nn.knn_graph(torch.FloatTensor(in_pts), knn_cnt)
+ invar = torch_geometric.data.Data(x=in_pts, edge_index=edge_index)
+
+ assert common.validate_checkpoint(model_1, model_2, (invar,))
+
+
+@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
+@pytest.mark.parametrize("knn_cnt", [5, 20])
+@pytest.mark.parametrize("sample_pts", [100, 1000, 10000])
+def test_dgcnn_optimizations(device, knn_cnt, sample_pts):
+ """Test model optimizations"""
+
+ def setup_model():
+ "Sets up fresh model for each optimization test"
+ # Construct dgcnn model
+ model = DGCNN(k=knn_cnt, aggr="max").to(device)
+
+ bsize = 2
+ in_pts = torch.randn(bsize, sample_pts, 3).to(device)
+ edge_index = torch_geometric.nn.knn_graph(torch.FloatTensor(in_pts), knn_cnt)
+ invar = torch_geometric.data.Data(x=in_pts, edge_index=edge_index)
+
+ return model, invar
+
+ # Check AMP
+ model, invar = setup_model()
+ assert common.validate_amp(model, (invar,))