diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a6c5c7b --- /dev/null +++ b/.gitignore @@ -0,0 +1,99 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# DotEnv configuration +.env + +# Database +*.db +*.rdb + +# Pycharm +.idea + +# VS Code +.vscode/ + +# Spyder +.spyproject/ + +# Jupyter NB Checkpoints +.ipynb_checkpoints/ + +# exclude data from source control by default +io +io/ + +# Mac OS-specific storage files +.DS_Store + +# vim +*.swp +*.swo + +# Mypy cache +.mypy_cache/ + +# Other +.history +.history/ +logs/ +venv*/ +venv_gpu/ +lightning_logs +.DS_Store \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..7d3619b --- /dev/null +++ b/README.md @@ -0,0 +1,83 @@ +# Geometric GNN Dojo + +*The Geometric GNN Dojo* is a pedagogical resource for beginners and experts to explore the design space of **Graph Neural Networks for geometric graphs**. + + + +
+ +Check out the accompanying paper ['On the Expressive Power of Geometric Graph Neural Networks'](https://www.chaitjo.com/publication/joshi-2022-expressive/), which characterises the expressive power and theoretical limitations of geometric GNNs through the lens of geometric graph isomorphism. +> Chaitanya K. Joshi*, Cristian Bodnar*, Simon V. Mathis, Taco Cohen, and Pietro Liò. On the Expressive Power of Geometric Graph Neural Networks. *NeurIPS 2022 Workshop on Symmetry and Geometry in Neural Representations.* +> +>[PDF]() | [Slides](https://www.chaitjo.com/publication/joshi-2022-expressive/Geometric_GNNs_Slides.pdf) | [Video](https://youtu.be/VKj5wzZsoK4) + + +## Architectures + +The `/src` directory provides unified implementations of several popular geometric GNN architectures: +- Invariant GNNs: [SchNet](https://arxiv.org/abs/1706.08566), [DimeNet](https://arxiv.org/abs/2003.03123) +- Equivariant GNNs using cartesian vectors: [E(n) Equivariant GNN](https://proceedings.mlr.press/v139/satorras21a.html), [GVP-GNN](https://arxiv.org/abs/2009.01411) +- Equivariant GNNs using spherical tensors: [Tensor Field Network](https://arxiv.org/abs/1802.08219), [MACE](http://arxiv.org/abs/2206.07697) + +## Experiments + +The `/experiments` directory contains notebooks with synthetic experiments to highlight practical challenges in building powerful geometric GNNs: +- `kchains.ipynb`: Distinguishing k-chains, which test a model's ability to propagate geometric information non-locally and demonstrate oversquashing with increased depth. +- `rotsym.ipynb`: Rotationally symmetric structures, which test a layer's ability to identify neighbourhood orientation and highlight the utility of higher order tensors in equivariant GNNs. +- `incompleteness.ipynb`: Counterexamples from [Pozdnyakov et al.](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.125.166001), which test a layer's ability to create distinguishing fingerprints for local neighbourhoods and highlight the need for higher order scalarisation. + + + +## Installation + +```bash +# Create new conda environment +conda create -n pyg python=3.8 + +# Install PyTorch (Check CUDA version!) +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch + +# Install PyG +conda install pyg -c pyg -c conda-forge + +# Install other dependencies +pip3 install e3nn==0.4.4 +conda install matplotlib pandas networkx +pip3 install ipdb ase +conda install jupyterlab -c conda-forge +``` + + + +## Directory Structure and Usage + +``` +. +├── README.md +| +├── experiments # Synthetic experiments +│ ├── incompleteness.ipynb # Experiment on counterexamples from Pozdnyakov et al. +│ ├── kchains.ipynb # Experiment on k-chains +│ └── rotsym.ipynb # Experiment on rotationally symmetric structures +| +└── src # Geometric GNN models library + ├── models.py # Models built using layers + ├── gvp_layers.py # Layers for GVP-GNN + ├── egnn_layers.py # Layers for E(n) Equivariant GNN + ├── tfn_layers.py # Layers for Tensor Field Networks + ├── modules # Layers for MACE + └── utils # Helper functions for training, plotting, etc. +``` + + + +## Citation + +``` +@article{joshi2022expressive, + title={On the Expressive Power of Geometric Graph Neural Networks}, + author={Joshi, Chaitanya K. and Bodnar, Cristian and Mathis, Simon V. and Cohen, Taco and Liò, Pietro}, + journal={NeurIPS Workshop on Symmetry and Geometry in Neural Representations}, + year={2022}, +} +``` diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/fig/incompleteness.png b/experiments/fig/incompleteness.png new file mode 100644 index 0000000..8a01efe Binary files /dev/null and b/experiments/fig/incompleteness.png differ diff --git a/experiments/fig/kchains.png b/experiments/fig/kchains.png new file mode 100644 index 0000000..7706110 Binary files /dev/null and b/experiments/fig/kchains.png differ diff --git a/experiments/fig/rotsym.png b/experiments/fig/rotsym.png new file mode 100644 index 0000000..c20fffa Binary files /dev/null and b/experiments/fig/rotsym.png differ diff --git a/experiments/incompleteness.ipynb b/experiments/incompleteness.ipynb new file mode 100644 index 0000000..67ecdbc --- /dev/null +++ b/experiments/incompleteness.ipynb @@ -0,0 +1,504 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Identifying neighbourhood fingerprints: counterexamples from [Pozdnyakov et al., 2020](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.125.166001)\n", + "\n", + "*Background:*\n", + "Geometric GNNs identify local neighbourhoods around nodes via **'neighbourhood finderprints'** or scalarisations, where local geometric information from subsets of neighbours is aggregated to compute invariant scalars. The number of neighbours involved in computing the scalars is termed the **body order**.\n", + "The ideal neighbourhood fingerprint would perfectly identify neighbourhoods, which requires arbitrarily high body order.\n", + "\n", + "*Experiment:*\n", + "To demonstrate the practical implications of scalarisation body order, we evaluate geometric GNN layers on their ability to discriminate counterexamples from [Pozdnyakov et al., 2020](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.125.166001).\n", + "Each counterexample consists of a pair of local neighbourhoods that are **indistinguishable** when comparing their set of $k$-body scalars, i.e. geometric GNN layers with body order $k$ cannot distinguish the neighbourhoods.\n", + "The 3-body counterexample corresponds to Fig.1(b) in Pozdnyakov et al., 2020, 4-body chiral to Fig.2(e), and 4-body non-chiral to Fig.2(f); the 2-body counterexample is based on the two local neighbourhoods in our running example.\n", + "In this notebook, we train single layer geometric GNNs to distinguish the counterexamples using updated scalar features. \n", + "\n", + "![Counterexamples from Pozdnyakov et al., 2020](fig/incompleteness.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import random\n", + "import numpy as np\n", + "import torch\n", + "from torch.nn import functional as F\n", + "import torch_geometric\n", + "from torch_geometric.data import Data, Batch\n", + "from torch_geometric.loader import DataLoader\n", + "from torch_geometric.utils import is_undirected, to_undirected, remove_self_loops, to_dense_adj, dense_to_sparse\n", + "import e3nn\n", + "from e3nn import o3\n", + "from functools import partial\n", + "\n", + "print(\"PyTorch version {}\".format(torch.__version__))\n", + "print(\"PyG version {}\".format(torch_geometric.__version__))\n", + "print(\"e3nn version {}\".format(e3nn.__version__))\n", + "\n", + "from src.utils.plot_utils import plot_2d, plot_3d\n", + "from src.utils.train_utils import run_experiment\n", + "from src.models import MPNNModel, EGNNModel, GVPGNNModel, TFNModel, SchNetModel, DimeNetPPModel, MACEModel\n", + "\n", + "# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)\n", + "# print(f\"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}\")\n", + "# print(f\"Is MPS available? {torch.backends.mps.is_available()}\")\n", + "\n", + "# Set the device\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "# device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n", + "# device = torch.device(\"cpu\")\n", + "print(f\"Using device: {device}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Two-body counterexample\n", + "\n", + "Pair of local neighbourhoods that are indistinguishable when comparing their set of $2$-body scalars, i.e. the unordered set of pairwise distances." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_two_body_envs():\n", + " dataset = []\n", + "\n", + " # Environment 0\n", + " # atoms = torch.LongTensor([ 0, 1, 2 ])\n", + " atoms = torch.LongTensor([ 0, 0, 0 ])\n", + " edge_index = torch.LongTensor([ [0, 0], [1, 2] ])\n", + " pos = torch.FloatTensor([ \n", + " [0, 0, 0],\n", + " [5, 0, 0],\n", + " [3, 0, 4]\n", + " ])\n", + " y = torch.LongTensor([0]) # Label 0\n", + " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data1.edge_index = to_undirected(data1.edge_index)\n", + " dataset.append(data1)\n", + " \n", + " # Environment 1\n", + " # atoms = torch.LongTensor([ 0, 1, 2 ])\n", + " atoms = torch.LongTensor([ 0, 0, 0 ])\n", + " edge_index = torch.LongTensor([ [0, 0], [1, 2] ])\n", + " pos = torch.FloatTensor([ \n", + " [0, 0, 0],\n", + " [5, 0, 0],\n", + " [-5, 0, 0]\n", + " ])\n", + " y = torch.LongTensor([1]) # Label 1\n", + " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data2.edge_index = to_undirected(data2.edge_index)\n", + " dataset.append(data2)\n", + " \n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create dataset\n", + "dataset = create_two_body_envs()\n", + "for data in dataset:\n", + " plot_3d(data, lim=5)\n", + "\n", + "# Set model\n", + "model_name = \"egnn\"\n", + "\n", + "# Create dataloaders\n", + "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", + "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", + "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", + "\n", + "num_layers = 1\n", + "correlation = 2\n", + "model = {\n", + " \"mpnn\": MPNNModel,\n", + " \"schnet\": SchNetModel,\n", + " \"dimenet\": DimeNetPPModel,\n", + " \"egnn\": EGNNModel,\n", + " \"gvp\": GVPGNNModel,\n", + " \"tfn\": TFNModel,\n", + " \"mace\": partial(MACEModel, correlation=correlation),\n", + "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n", + "\n", + "best_val_acc, test_acc, train_time = run_experiment(\n", + " model, \n", + " dataloader,\n", + " val_loader, \n", + " test_loader,\n", + " n_epochs=100,\n", + " n_times=10,\n", + " device=device,\n", + " verbose=False\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Three-body counterexample\n", + "\n", + "Pair of local neighbourhoods that are indistinguishable when comparing their set of $3$-body scalars, i.e. the unordered set of pairwise distances as well as angles." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_three_body_envs():\n", + " dataset = []\n", + "\n", + " a_x, a_y, a_z = 5, 0, 5\n", + " b_x, b_y, b_z = 5, 5, 5\n", + " c_x, c_y, c_z = 0, 5, 5\n", + " \n", + " # Environment 0\n", + " # atoms = torch.LongTensor([ 0, 1, 2, 3, 4 ])\n", + " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n", + " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n", + " pos = torch.FloatTensor([ \n", + " [0, 0, 0],\n", + " [a_x, a_y, a_z],\n", + " [+b_x, +b_y, b_z],\n", + " [-b_x, -b_y, b_z],\n", + " [c_x, +c_y, c_z],\n", + " ])\n", + " y = torch.LongTensor([0]) # Label 0\n", + " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data1.edge_index = to_undirected(data1.edge_index)\n", + " dataset.append(data1)\n", + " \n", + " # Environment 1\n", + " # atoms = torch.LongTensor([ 0, 1, 2, 3, 4 ])\n", + " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n", + " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n", + " pos = torch.FloatTensor([ \n", + " [0, 0, 0],\n", + " [a_x, a_y, a_z],\n", + " [+b_x, +b_y, b_z],\n", + " [-b_x, -b_y, b_z],\n", + " [c_x, -c_y, c_z],\n", + " ])\n", + " y = torch.LongTensor([1]) # Label 1\n", + " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data2.edge_index = to_undirected(data2.edge_index)\n", + " dataset.append(data2)\n", + " \n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create dataset\n", + "dataset = create_three_body_envs()\n", + "for data in dataset:\n", + " plot_3d(data, lim=5)\n", + "\n", + "# Set model\n", + "model_name = \"egnn\"\n", + "\n", + "# Create dataloaders\n", + "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", + "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", + "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", + "\n", + "num_layers = 1\n", + "correlation = 3\n", + "model = {\n", + " \"mpnn\": MPNNModel,\n", + " \"schnet\": SchNetModel,\n", + " \"dimenet\": DimeNetPPModel,\n", + " \"egnn\": EGNNModel,\n", + " \"gvp\": GVPGNNModel,\n", + " \"tfn\": TFNModel,\n", + " \"mace\": partial(MACEModel, correlation=correlation),\n", + "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n", + "\n", + "best_val_acc, test_acc, train_time = run_experiment(\n", + " model, \n", + " dataloader,\n", + " val_loader, \n", + " test_loader,\n", + " n_epochs=100,\n", + " n_times=10,\n", + " device=device,\n", + " verbose=False\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Four-body non-chiral counterexample\n", + "\n", + "Pair of local neighbourhoods that are indistinguishable when comparing their set of $4$-body scalars without considering chirality/handedness, i.e. the unordered set of pairwise distances, angles, and quadruplet scalars." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_four_body_nonchiral_envs():\n", + " dataset = []\n", + "\n", + " a1_x, a1_y, a1_z = 3, 2, -4\n", + " a2_x, a2_y, a2_z = 0, 2, 5\n", + " a3_x, a3_y, a3_z = -3, 2, -4\n", + " b1_x, b1_y, b1_z = 3, -2, -4\n", + " b2_x, b2_y, b2_z = 0, -2, 5\n", + " b3_x, b3_y, b3_z = -3, -2, -4\n", + " c_x, c_y, c_z = 0, 5, 0\n", + "\n", + " angle = 2 * torch.pi / 10 # random angle\n", + " Q = o3.matrix_y(torch.tensor(angle)).numpy()\n", + "\n", + " # Environment 0\n", + " # atoms = torch.LongTensor([ 0, 1, 1, 1, 1, 1, 1, 2 ])\n", + " atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n", + " edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n", + " pos = torch.FloatTensor([ \n", + " [0, 0, 0],\n", + " [a1_x, a1_y, a1_z],\n", + " [a2_x, a2_y, a2_z],\n", + " [a3_x, a3_y, a3_z],\n", + " [b1_x, b1_y, b1_z] @ Q,\n", + " [b2_x, b2_y, b2_z] @ Q,\n", + " [b3_x, b3_y, b3_z] @ Q,\n", + " [c_x, +c_y, c_z],\n", + " ])\n", + " y = torch.LongTensor([0]) # Label 0\n", + " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data1.edge_index = to_undirected(data1.edge_index)\n", + " dataset.append(data1)\n", + " \n", + " # Environment 1\n", + " # atoms = torch.LongTensor([ 0, 1, 1, 1, 1, 1, 1, 2 ])\n", + " atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n", + " edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n", + " pos = torch.FloatTensor([ \n", + " [0, 0, 0],\n", + " [a1_x, a1_y, a1_z],\n", + " [a2_x, a2_y, a2_z],\n", + " [a3_x, a3_y, a3_z],\n", + " [b1_x, b1_y, b1_z] @ Q,\n", + " [b2_x, b2_y, b2_z] @ Q,\n", + " [b3_x, b3_y, b3_z] @ Q,\n", + " [c_x, -c_y, c_z],\n", + " ])\n", + " y = torch.LongTensor([1]) # Label 1\n", + " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data2.edge_index = to_undirected(data2.edge_index)\n", + " dataset.append(data2)\n", + " \n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create dataset\n", + "dataset = create_four_body_nonchiral_envs()\n", + "for data in dataset:\n", + " plot_3d(data, lim=5)\n", + "\n", + "# Set model\n", + "model_name = \"mace\"\n", + "\n", + "# Create dataloaders\n", + "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", + "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", + "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", + "\n", + "num_layers = 1\n", + "correlation = 4\n", + "model = {\n", + " \"mpnn\": MPNNModel,\n", + " \"schnet\": SchNetModel,\n", + " \"dimenet\": DimeNetPPModel,\n", + " \"egnn\": EGNNModel,\n", + " \"gvp\": GVPGNNModel,\n", + " \"tfn\": TFNModel,\n", + " \"mace\": partial(MACEModel, max_ell=3, correlation=correlation),\n", + "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n", + "\n", + "best_val_acc, test_acc, train_time = run_experiment(\n", + " model, \n", + " dataloader,\n", + " val_loader, \n", + " test_loader,\n", + " n_epochs=100,\n", + " n_times=1,\n", + " device=device,\n", + " verbose=False\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Four-body chiral counterexample\n", + "\n", + "Pair of local neighbourhoods that are indistinguishable when comparing their set of $4$-body scalars when considering chirality/handedness, i.e. the unordered set of pairwise distances, angles, and quadruplet scalars." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_four_body_chiral_envs():\n", + " dataset = []\n", + "\n", + " a1_x, a1_y, a1_z = 3, 0, -4\n", + " a2_x, a2_y, a2_z = 0, 0, 5\n", + " a3_x, a3_y, a3_z = -3, 0, -4\n", + " c_x, c_y, c_z = 0, 5, 0\n", + "\n", + " # Environment 0\n", + " # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])\n", + " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n", + " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n", + " pos = torch.FloatTensor([ \n", + " [0, 0, 0],\n", + " [a1_x, a1_y, a1_z],\n", + " [a2_x, a2_y, a2_z],\n", + " [a3_x, a3_y, a3_z],\n", + " [c_x, +c_y, c_z],\n", + " ])\n", + " y = torch.LongTensor([0]) # Label 0\n", + " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data1.edge_index = to_undirected(data1.edge_index)\n", + " dataset.append(data1)\n", + " \n", + " # Environment 1\n", + " # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])\n", + " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n", + " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n", + " pos = torch.FloatTensor([ \n", + " [0, 0, 0],\n", + " [a1_x, a1_y, a1_z],\n", + " [a2_x, a2_y, a2_z],\n", + " [a3_x, a3_y, a3_z],\n", + " [c_x, -c_y, c_z],\n", + " ])\n", + " y = torch.LongTensor([1]) # Label 1\n", + " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data2.edge_index = to_undirected(data2.edge_index)\n", + " dataset.append(data2)\n", + " \n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create dataset\n", + "dataset = create_four_body_chiral_envs()\n", + "for data in dataset:\n", + " plot_3d(data, lim=5)\n", + "\n", + "# Set model\n", + "model_name = \"egnn\"\n", + "\n", + "# Create dataloaders\n", + "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", + "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n", + "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n", + "\n", + "num_layers = 1\n", + "correlation = 4\n", + "model = {\n", + " \"mpnn\": MPNNModel,\n", + " \"schnet\": SchNetModel,\n", + " \"dimenet\": DimeNetPPModel,\n", + " \"egnn\": EGNNModel,\n", + " \"gvp\": GVPGNNModel,\n", + " \"tfn\": TFNModel,\n", + " \"mace\": partial(MACEModel, correlation=correlation),\n", + "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n", + "\n", + "best_val_acc, test_acc, train_time = run_experiment(\n", + " model, \n", + " dataloader,\n", + " val_loader, \n", + " test_loader,\n", + " n_epochs=100,\n", + " n_times=10,\n", + " device=device,\n", + " verbose=False\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiments/kchains.ipynb b/experiments/kchains.ipynb new file mode 100644 index 0000000..65945cc --- /dev/null +++ b/experiments/kchains.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Propogating geometric information: $k$-chains\n", + "\n", + "*Background:*\n", + "In geometric GNNs, **geometric information**, such as the relative orientation of local neighbourhoods, is propogated via summing features from multiple layers in fixed dimensional spaces. \n", + "The ideal architecture can be run for any number of layers to perfectly propogate geometric information without loss of information.\n", + "In practice, stacking geometric GNN layers may lead to distortion or **loss of information from distant nodes**.\n", + "\n", + "*Experiment:*\n", + "To study the practical implications of depth in propagating geometric information beyond local neighbourhoods, we consider **$k$-chain geometric graphs** which generalise the examples from [Schütt et al., 2021](https://arxiv.org/abs/2102.03150). \n", + "Each pair of $k$-chains consists of $k+2$ nodes with $k$ nodes arranged in a line and differentiated by the orientation of the $2$ end points.\n", + "Thus, $k$-chain graphs are $(\\lfloor \\frac{k}{2} \\rfloor + 1)$-hop distinguishable, and $(\\lfloor \\frac{k}{2} \\rfloor + 1)$ geometric GNN iterations should be theoretically sufficient to distinguish them.\n", + "In this notebook, we train equivariant and invariant geometric GNNs with an increasing number of layers to distinguish $k$-chains.\n", + "\n", + "![k-chains](fig/kchains.png)\n", + "\n", + "*Results:*\n", + "- Despite the supposed simplicity of the task, especially for small chain lengths, we find that popular equivariant GNNs such as E-GNN and TFN may require **more iterations** than theoretically sufficient.\n", + "- Notably, as the length of the chain gets larger than $k=4$, all equivariant GNNs tended to lose performance and required more than $(\\lfloor \\frac{k}{2} \\rfloor + 1)$ iterations to solve the task.\n", + "- Invariant GNNs are **unable** to distinguish $k$-chains.\n", + "\n", + "These results points to preliminary evidence of the **oversquashing** phenomenon for geometric GNNs.\n", + "These issues are most evident for E-GNN, which uses a single vector feature to aggregate and propogate geometric information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import random\n", + "import numpy as np\n", + "import torch\n", + "from torch.nn import functional as F\n", + "import torch_geometric\n", + "from torch_geometric.data import Data, Batch\n", + "from torch_geometric.loader import DataLoader\n", + "from torch_geometric.utils import is_undirected, to_undirected, remove_self_loops, to_dense_adj, dense_to_sparse\n", + "import e3nn\n", + "from e3nn import o3\n", + "from functools import partial\n", + "\n", + "print(\"PyTorch version {}\".format(torch.__version__))\n", + "print(\"PyG version {}\".format(torch_geometric.__version__))\n", + "print(\"e3nn version {}\".format(e3nn.__version__))\n", + "\n", + "from src.utils.plot_utils import plot_2d, plot_3d\n", + "from src.utils.train_utils import run_experiment\n", + "from src.models import MPNNModel, EGNNModel, GVPGNNModel, TFNModel, SchNetModel, DimeNetPPModel, MACEModel\n", + "\n", + "# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)\n", + "# print(f\"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}\")\n", + "# print(f\"Is MPS available? {torch.backends.mps.is_available()}\")\n", + "\n", + "# Set the device\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "# device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n", + "# device = torch.device(\"cpu\")\n", + "print(f\"Using device: {device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_kchains(k):\n", + " assert k >= 2\n", + " \n", + " dataset = []\n", + "\n", + " # Graph 0\n", + " atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n", + " edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n", + " pos = torch.FloatTensor(\n", + " [[-4, -3, 0]] + \n", + " [[0, 5*i , 0] for i in range(k)] + \n", + " [[4, 5*(k-1) + 3, 0]]\n", + " )\n", + " center_of_mass = torch.mean(pos, dim=0)\n", + " pos = pos - center_of_mass\n", + " y = torch.LongTensor([0]) # Label 0\n", + " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data1.edge_index = to_undirected(data1.edge_index)\n", + " dataset.append(data1)\n", + " \n", + " # Graph 1\n", + " atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n", + " edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n", + " pos = torch.FloatTensor(\n", + " [[4, -3, 0]] + \n", + " [[0, 5*i , 0] for i in range(k)] + \n", + " [[4, 5*(k-1) + 3, 0]]\n", + " )\n", + " center_of_mass = torch.mean(pos, dim=0)\n", + " pos = pos - center_of_mass\n", + " y = torch.LongTensor([1]) # Label 1\n", + " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data2.edge_index = to_undirected(data2.edge_index)\n", + " dataset.append(data2)\n", + " \n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "k = 4\n", + "\n", + "# Create dataset\n", + "dataset = create_kchains(k=k)\n", + "for data in dataset:\n", + " # plot_2d(data, lim=5*k)\n", + " plot_3d(data, lim=5*k)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set model\n", + "model_name = \"tfn\"\n", + "\n", + "# Create dataloaders\n", + "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", + "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n", + "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n", + "\n", + "for num_layers in range(k // 2 , k + 3):\n", + "\n", + " print(f\"\\nNumber of layers: {num_layers}\")\n", + " \n", + " model = {\n", + " \"mpnn\": MPNNModel,\n", + " \"schnet\": SchNetModel,\n", + " \"dimenet\": DimeNetPPModel,\n", + " \"egnn\": EGNNModel,\n", + " \"gvp\": GVPGNNModel,\n", + " \"tfn\": TFNModel,\n", + " \"mace\": partial(MACEModel, correlation=2),\n", + " }[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n", + " \n", + " best_val_acc, test_acc, train_time = run_experiment(\n", + " model, \n", + " dataloader,\n", + " val_loader, \n", + " test_loader,\n", + " n_epochs=100,\n", + " n_times=10,\n", + " device=device,\n", + " verbose=False\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiments/rotsym.ipynb b/experiments/rotsym.ipynb new file mode 100644 index 0000000..97ad8a4 --- /dev/null +++ b/experiments/rotsym.ipynb @@ -0,0 +1,184 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Identifying neighbourhood orientation: rotationally symmetric structures\n", + "\n", + "*Background:*\n", + "Rotationally equivariant geometric GNNs aggregate local geometric information via summing together the neighbourhood geometric features, which are either **cartesian vectors** or **higher order spherical tensors**. \n", + "The ideal geometric GNN would injectively aggregate local geometric infromation to perfectly identify neighbourhood identities, orientations, etc.\n", + "In practice, the choice of basis (cartesian vs. spherical) comes with tradeoffs between tractability and empirical performance.\n", + "\n", + "*Experiment:*\n", + "In this notebook, we study how rotational symmetries interact with tensor order in equivariant GNNs. \n", + "We evaluate equivariant layers on their ability to distinguish the orientation of **structures with rotational symmetry**. \n", + "An [$L$-fold symmetric structure](https://en.wikipedia.org/wiki/Rotational_symmetry) does not change when rotated by an angle $\\frac{2\\pi}{L}$ around a point (in 2D) or axis (3D).\n", + "We consider two *distinct* rotated versions of each $L$-fold symmetric structure and train single layer equivariant GNNs to classify the two orientations using the updated geometric features.\n", + "\n", + "![Rotationally symmetric structures](fig/rotsym.png)\n", + "\n", + "*Result:*\n", + "- **We find that layers using order $L$ tensors are unable to identify the orientation of structures with rotation symmetry higher than $L$-fold.** This observation may be attributed to **spherical harmonics**, which are used as the underlying orthonormal basis and are rotationally symmetric themselves.\n", + "- Layers such as E-GNN and GVP-GNN using **cartesian vectors** (corresponding to tensor order 1) are popular as working with higher order tensors can be computationally intractable for many applications. However, E-GNN and GVP-GNN are particularly poor at disciminating orientation of rotationally symmetric structures. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import random\n", + "import math\n", + "import numpy as np\n", + "import torch\n", + "from torch.nn import functional as F\n", + "import torch_geometric\n", + "from torch_geometric.data import Data, Batch\n", + "from torch_geometric.loader import DataLoader\n", + "from torch_geometric.utils import is_undirected, to_undirected, remove_self_loops, to_dense_adj, dense_to_sparse\n", + "import e3nn\n", + "from e3nn import o3\n", + "from functools import partial\n", + "\n", + "print(\"PyTorch version {}\".format(torch.__version__))\n", + "print(\"PyG version {}\".format(torch_geometric.__version__))\n", + "print(\"e3nn version {}\".format(e3nn.__version__))\n", + "\n", + "from src.utils.plot_utils import plot_2d, plot_3d\n", + "from src.utils.train_utils import run_experiment\n", + "from src.models import MPNNModel, EGNNModel, GVPGNNModel, TFNModel, SchNetModel, DimeNetPPModel, MACEModel\n", + "\n", + "# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)\n", + "# print(f\"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}\")\n", + "# print(f\"Is MPS available? {torch.backends.mps.is_available()}\")\n", + "\n", + "# Set the device\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "# device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n", + "# device = torch.device(\"cpu\")\n", + "print(f\"Using device: {device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_rotsym_envs(fold=3):\n", + " dataset = []\n", + "\n", + " # Environment 0\n", + " atoms = torch.LongTensor([ 0 ] + [ 0 ] * fold)\n", + " edge_index = torch.LongTensor( [ [0] * fold, [i for i in range(1, fold+1)] ] )\n", + " x = torch.Tensor([1,0,0])\n", + " pos = [\n", + " torch.Tensor([0,0,0]), # origin\n", + " x, # first spoke \n", + " ]\n", + " for count in range(1, fold):\n", + " R = o3.matrix_z(torch.Tensor([2*math.pi/fold * count])).squeeze(0)\n", + " pos.append(x @ R.T)\n", + " pos = torch.stack(pos)\n", + " y = torch.LongTensor([0]) # Label 0\n", + " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data1.edge_index = to_undirected(data1.edge_index)\n", + " dataset.append(data1)\n", + " \n", + " # Environment 1\n", + " q = 2*math.pi/(fold + random.randint(1, fold))\n", + " assert q < 2*math.pi/fold\n", + " Q = o3.matrix_z(torch.Tensor([q])).squeeze(0)\n", + " pos = pos @ Q.T\n", + " y = torch.LongTensor([1]) # Label 1\n", + " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", + " data2.edge_index = to_undirected(data2.edge_index)\n", + " dataset.append(data2)\n", + " \n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set parameters\n", + "model_name = \"tfn\"\n", + "correlation = 2\n", + "max_ell = 5\n", + "fold = 3\n", + "\n", + "# Create dataset\n", + "dataset = create_rotsym_envs(fold)\n", + "for data in dataset:\n", + " plot_2d(data, lim=1)\n", + "\n", + "# Create dataloaders\n", + "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", + "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", + "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", + "\n", + "num_layers = 1\n", + "model = {\n", + " \"mpnn\": MPNNModel,\n", + " \"schnet\": SchNetModel,\n", + " \"dimenet\": DimeNetPPModel,\n", + " \"egnn\": EGNNModel,\n", + " \"gvp\": GVPGNNModel,\n", + " \"tfn\": partial(TFNModel, max_ell=max_ell, scalar_pred=False),\n", + " \"mace\": partial(MACEModel, max_ell=max_ell, correlation=correlation, scalar_pred=False),\n", + "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n", + "\n", + "best_val_acc, test_acc, train_time = run_experiment(\n", + " model, \n", + " dataloader,\n", + " val_loader, \n", + " test_loader,\n", + " n_epochs=100,\n", + " n_times=1,\n", + " device=device,\n", + " verbose=False\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/egnn_layers.py b/src/egnn_layers.py new file mode 100644 index 0000000..e98b051 --- /dev/null +++ b/src/egnn_layers.py @@ -0,0 +1,152 @@ +import torch +from torch.nn import Linear, ReLU, SiLU, Sequential +from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool +from torch_scatter import scatter + + +class EGNNLayer(MessagePassing): + def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"): + """E(n) Equivariant GNN Layer + + Paper: E(n) Equivariant Graph Neural Networks, Satorras et al. + + Args: + emb_dim: (int) - hidden dimension `d` + activation: (str) - non-linearity within MLPs (swish/relu) + norm: (str) - normalisation layer (layer/batch) + aggr: (str) - aggregation function `\oplus` (sum/mean/max) + """ + # Set the aggregation function + super().__init__(aggr=aggr) + + self.emb_dim = emb_dim + self.activation = {"swish": SiLU(), "relu": ReLU()}[activation] + self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm] + + # MLP `\psi_h` for computing messages `m_ij` + self.mlp_msg = Sequential( + Linear(2 * emb_dim + 1, emb_dim), + self.norm(emb_dim), + self.activation, + Linear(emb_dim, emb_dim), + self.norm(emb_dim), + self.activation, + ) + # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij` + self.mlp_pos = Sequential( + Linear(emb_dim, emb_dim), self.norm(emb_dim), self.activation, Linear(emb_dim, 1) + ) + # MLP `\phi` for computing updated node features `h_i^{l+1}` + self.mlp_upd = Sequential( + Linear(2 * emb_dim, emb_dim), + self.norm(emb_dim), + self.activation, + Linear(emb_dim, emb_dim), + self.norm(emb_dim), + self.activation, + ) + + def forward(self, h, pos, edge_index): + """ + Args: + h: (n, d) - initial node features + pos: (n, 3) - initial node coordinates + edge_index: (e, 2) - pairs of edges (i, j) + Returns: + out: [(n, d),(n,3)] - updated node features + """ + out = self.propagate(edge_index, h=h, pos=pos) + return out + + def message(self, h_i, h_j, pos_i, pos_j): + # Compute messages + pos_diff = pos_i - pos_j + dists = torch.norm(pos_diff, dim=-1).unsqueeze(1) + msg = torch.cat([h_i, h_j, dists], dim=-1) + msg = self.mlp_msg(msg) + # Scale magnitude of displacement vector + pos_diff = pos_diff * self.mlp_pos(msg) # torch.clamp(updates, min=-100, max=100) + return msg, pos_diff + + def aggregate(self, inputs, index): + msgs, pos_diffs = inputs + # Aggregate messages + msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr) + # Aggregate displacement vectors + pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean") + return msg_aggr, pos_aggr + + def update(self, aggr_out, h, pos): + msg_aggr, pos_aggr = aggr_out + upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1)) + upd_pos = pos + pos_aggr + return upd_out, upd_pos + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})" + + +class MPNNLayer(MessagePassing): + def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"): + """Vanilla Message Passing GNN layer + + Args: + emb_dim: (int) - hidden dimension `d` + activation: (str) - non-linearity within MLPs (swish/relu) + norm: (str) - normalisation layer (layer/batch) + aggr: (str) - aggregation function `\oplus` (sum/mean/max) + """ + # Set the aggregation function + super().__init__(aggr=aggr) + + self.emb_dim = emb_dim + self.activation = {"swish": SiLU(), "relu": ReLU()}[activation] + self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm] + + # MLP `\psi_h` for computing messages `m_ij` + self.mlp_msg = Sequential( + Linear(2 * emb_dim, emb_dim), + self.norm(emb_dim), + self.activation, + Linear(emb_dim, emb_dim), + self.norm(emb_dim), + self.activation, + ) + # MLP `\phi` for computing updated node features `h_i^{l+1}` + self.mlp_upd = Sequential( + Linear(2 * emb_dim, emb_dim), + self.norm(emb_dim), + self.activation, + Linear(emb_dim, emb_dim), + self.norm(emb_dim), + self.activation, + ) + + def forward(self, h, edge_index): + """ + Args: + h: (n, d) - initial node features + edge_index: (e, 2) - pairs of edges (i, j) + Returns: + out: (n, d) - updated node features + """ + out = self.propagate(edge_index, h=h) + return out + + def message(self, h_i, h_j): + # Compute messages + msg = torch.cat([h_i, h_j], dim=-1) + msg = self.mlp_msg(msg) + return msg + + def aggregate(self, inputs, index): + # Aggregate messages + msg_aggr = scatter(inputs, index, dim=self.node_dim, reduce=self.aggr) + return msg_aggr + + def update(self, aggr_out, h): + upd_out = self.mlp_upd(torch.cat([h, aggr_out], dim=-1)) + return upd_out + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})" diff --git a/src/gvp_layers.py b/src/gvp_layers.py new file mode 100644 index 0000000..e8f2916 --- /dev/null +++ b/src/gvp_layers.py @@ -0,0 +1,388 @@ +########################################################################################### +# Implementation of Geometric Vector Perceptron layers +# +# Papers: +# (1) Learning from Protein Structure with Geometric Vector Perceptrons, +# by B Jing, S Eismann, P Suriana, RJL Townshend, and RO Dror +# (2) Equivariant Graph Neural Networks for 3D Macromolecular Structure, +# by B Jing, S Eismann, P Soni, and RO Dror +# +# Orginal repository: https://github.com/drorlab/gvp-pytorch +########################################################################################### + +import torch, functools +from torch import nn +import torch.nn.functional as F +from torch_geometric.nn import MessagePassing +from torch_scatter import scatter_add + +def tuple_sum(*args): + ''' + Sums any number of tuples (s, V) elementwise. + ''' + return tuple(map(sum, zip(*args))) + +def tuple_cat(*args, dim=-1): + ''' + Concatenates any number of tuples (s, V) elementwise. + + :param dim: dimension along which to concatenate when viewed + as the `dim` index for the scalar-channel tensors. + This means that `dim=-1` will be applied as + `dim=-2` for the vector-channel tensors. + ''' + dim %= len(args[0][0].shape) + s_args, v_args = list(zip(*args)) + return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) + +def tuple_index(x, idx): + ''' + Indexes into a tuple (s, V) along the first dimension. + + :param idx: any object which can be used to index into a `torch.Tensor` + ''' + return x[0][idx], x[1][idx] + +def randn(n, dims, device="cpu"): + ''' + Returns random tuples (s, V) drawn elementwise from a normal distribution. + + :param n: number of data points + :param dims: tuple of dimensions (n_scalar, n_vector) + + :return: (s, V) with s.shape = (n, n_scalar) and + V.shape = (n, n_vector, 3) + ''' + return torch.randn(n, dims[0], device=device), \ + torch.randn(n, dims[1], 3, device=device) + +def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): + ''' + L2 norm of tensor clamped above a minimum value `eps`. + + :param sqrt: if `False`, returns the square of the L2 norm + ''' + out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) + return torch.sqrt(out) if sqrt else out + +def _split(x, nv): + ''' + Splits a merged representation of (s, V) back into a tuple. + Should be used only with `_merge(s, V)` and only if the tuple + representation cannot be used. + + :param x: the `torch.Tensor` returned from `_merge` + :param nv: the number of vector channels in the input to `_merge` + ''' + v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3)) + s = x[..., :-3*nv] + return s, v + +def _merge(s, v): + ''' + Merges a tuple (s, V) into a single `torch.Tensor`, where the + vector channels are flattened and appended to the scalar channels. + Should be used only if the tuple representation cannot be used. + Use `_split(x, nv)` to reverse. + ''' + v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],)) + return torch.cat([s, v], -1) + +class GVP(nn.Module): + ''' + Geometric Vector Perceptron. See manuscript and README.md + for more details. + + :param in_dims: tuple (n_scalar, n_vector) + :param out_dims: tuple (n_scalar, n_vector) + :param h_dim: intermediate number of vector channels, optional + :param activations: tuple of functions (scalar_act, vector_act) + :param vector_gate: whether to use vector gating. + (vector_act will be used as sigma^+ in vector gating if `True`) + ''' + def __init__(self, in_dims, out_dims, h_dim=None, + activations=(F.relu, torch.sigmoid), vector_gate=False): + super(GVP, self).__init__() + self.si, self.vi = in_dims + self.so, self.vo = out_dims + self.vector_gate = vector_gate + if self.vi: + self.h_dim = h_dim or max(self.vi, self.vo) + self.wh = nn.Linear(self.vi, self.h_dim, bias=False) + self.ws = nn.Linear(self.h_dim + self.si, self.so) + if self.vo: + self.wv = nn.Linear(self.h_dim, self.vo, bias=False) + if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo) + else: + self.ws = nn.Linear(self.si, self.so) + + self.scalar_act, self.vector_act = activations + self.dummy_param = nn.Parameter(torch.empty(0)) + + def forward(self, x): + ''' + :param x: tuple (s, V) of `torch.Tensor`, + or (if vectors_in is 0), a single `torch.Tensor` + :return: tuple (s, V) of `torch.Tensor`, + or (if vectors_out is 0), a single `torch.Tensor` + ''' + if self.vi: + s, v = x + v = torch.transpose(v, -1, -2) + vh = self.wh(v) + vn = _norm_no_nan(vh, axis=-2) + s = self.ws(torch.cat([s, vn], -1)) + if self.vo: + v = self.wv(vh) + v = torch.transpose(v, -1, -2) + if self.vector_gate: + if self.vector_act: + gate = self.wsv(self.vector_act(s)) + else: + gate = self.wsv(s) + v = v * torch.sigmoid(gate).unsqueeze(-1) + elif self.vector_act: + v = v * self.vector_act( + _norm_no_nan(v, axis=-1, keepdims=True)) + else: + s = self.ws(x) + if self.vo: + v = torch.zeros(s.shape[0], self.vo, 3, + device=self.dummy_param.device) + if self.scalar_act: + s = self.scalar_act(s) + + return (s, v) if self.vo else s + +class _VDropout(nn.Module): + ''' + Vector channel dropout where the elements of each + vector channel are dropped together. + ''' + def __init__(self, drop_rate): + super(_VDropout, self).__init__() + self.drop_rate = drop_rate + self.dummy_param = nn.Parameter(torch.empty(0)) + + def forward(self, x): + ''' + :param x: `torch.Tensor` corresponding to vector channels + ''' + device = self.dummy_param.device + if not self.training: + return x + mask = torch.bernoulli( + (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) + ).unsqueeze(-1) + x = mask * x / (1 - self.drop_rate) + return x + +class Dropout(nn.Module): + ''' + Combined dropout for tuples (s, V). + Takes tuples (s, V) as input and as output. + ''' + def __init__(self, drop_rate): + super(Dropout, self).__init__() + self.sdropout = nn.Dropout(drop_rate) + self.vdropout = _VDropout(drop_rate) + + def forward(self, x): + ''' + :param x: tuple (s, V) of `torch.Tensor`, + or single `torch.Tensor` + (will be assumed to be scalar channels) + ''' + if type(x) is torch.Tensor: + return self.sdropout(x) + s, v = x + return self.sdropout(s), self.vdropout(v) + +class LayerNorm(nn.Module): + ''' + Combined LayerNorm for tuples (s, V). + Takes tuples (s, V) as input and as output. + ''' + def __init__(self, dims): + super(LayerNorm, self).__init__() + self.s, self.v = dims + self.scalar_norm = nn.LayerNorm(self.s) + + def forward(self, x): + ''' + :param x: tuple (s, V) of `torch.Tensor`, + or single `torch.Tensor` + (will be assumed to be scalar channels) + ''' + if not self.v: + return self.scalar_norm(x) + s, v = x + vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False) + vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) + return self.scalar_norm(s), v / vn + +class GVPConv(MessagePassing): + ''' + Graph convolution / message passing with Geometric Vector Perceptrons. + Takes in a graph with node and edge embeddings, + and returns new node embeddings. + + This does NOT do residual updates and pointwise feedforward layers + ---see `GVPConvLayer`. + + :param in_dims: input node embedding dimensions (n_scalar, n_vector) + :param out_dims: output node embedding dimensions (n_scalar, n_vector) + :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) + :param n_layers: number of GVPs in the message function + :param module_list: preconstructed message function, overrides n_layers + :param aggr: should be "add" if some incoming edges are masked, as in + a masked autoregressive decoder architecture, otherwise "mean" + :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs + :param vector_gate: whether to use vector gating. + (vector_act will be used as sigma^+ in vector gating if `True`) + ''' + def __init__(self, in_dims, out_dims, edge_dims, + n_layers=3, module_list=None, aggr="mean", + activations=(F.relu, torch.sigmoid), vector_gate=False): + super(GVPConv, self).__init__(aggr=aggr) + self.si, self.vi = in_dims + self.so, self.vo = out_dims + self.se, self.ve = edge_dims + + GVP_ = functools.partial(GVP, + activations=activations, vector_gate=vector_gate) + + module_list = module_list or [] + if not module_list: + if n_layers == 1: + module_list.append( + GVP_((2*self.si + self.se, 2*self.vi + self.ve), + (self.so, self.vo), activations=(None, None))) + else: + module_list.append( + GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims) + ) + for i in range(n_layers - 2): + module_list.append(GVP_(out_dims, out_dims)) + module_list.append(GVP_(out_dims, out_dims, + activations=(None, None))) + self.message_func = nn.Sequential(*module_list) + + def forward(self, x, edge_index, edge_attr): + ''' + :param x: tuple (s, V) of `torch.Tensor` + :param edge_index: array of shape [2, n_edges] + :param edge_attr: tuple (s, V) of `torch.Tensor` + ''' + x_s, x_v = x + message = self.propagate(edge_index, + s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]), + edge_attr=edge_attr) + return _split(message, self.vo) + + def message(self, s_i, v_i, s_j, v_j, edge_attr): + v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3) + v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3) + message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) + message = self.message_func(message) + return _merge(*message) + + +class GVPConvLayer(nn.Module): + ''' + Full graph convolution / message passing layer with + Geometric Vector Perceptrons. Residually updates node embeddings with + aggregated incoming messages, applies a pointwise feedforward + network to node embeddings, and returns updated node embeddings. + + To only compute the aggregated messages, see `GVPConv`. + + :param node_dims: node embedding dimensions (n_scalar, n_vector) + :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) + :param n_message: number of GVPs to use in message function + :param n_feedforward: number of GVPs to use in feedforward function + :param drop_rate: drop probability in all dropout layers + :param autoregressive: if `True`, this `GVPConvLayer` will be used + with a different set of input node embeddings for messages + where src >= dst + :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs + :param vector_gate: whether to use vector gating. + (vector_act will be used as sigma^+ in vector gating if `True`) + ''' + def __init__(self, node_dims, edge_dims, + n_message=3, n_feedforward=2, drop_rate=.1, + autoregressive=False, + activations=(F.relu, torch.sigmoid), vector_gate=False, + residual=True): + + super(GVPConvLayer, self).__init__() + self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message, + aggr="add" if autoregressive else "mean", + activations=activations, vector_gate=vector_gate) + GVP_ = functools.partial(GVP, + activations=activations, vector_gate=vector_gate) + self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)]) + self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) + + ff_func = [] + if n_feedforward == 1: + ff_func.append(GVP_(node_dims, node_dims, activations=(None, None))) + else: + hid_dims = 4*node_dims[0], 2*node_dims[1] + ff_func.append(GVP_(node_dims, hid_dims)) + for i in range(n_feedforward-2): + ff_func.append(GVP_(hid_dims, hid_dims)) + ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None))) + self.ff_func = nn.Sequential(*ff_func) + self.residual = residual + + def forward(self, x, edge_index, edge_attr, + autoregressive_x=None, node_mask=None): + ''' + :param x: tuple (s, V) of `torch.Tensor` + :param edge_index: array of shape [2, n_edges] + :param edge_attr: tuple (s, V) of `torch.Tensor` + :param autoregressive_x: tuple (s, V) of `torch.Tensor`. + If not `None`, will be used as src node embeddings + for forming messages where src >= dst. The corrent node + embeddings `x` will still be the base of the update and the + pointwise feedforward. + :param node_mask: array of type `bool` to index into the first + dim of node embeddings (s, V). If not `None`, only + these nodes will be updated. + ''' + + if autoregressive_x is not None: + src, dst = edge_index + mask = src < dst + edge_index_forward = edge_index[:, mask] + edge_index_backward = edge_index[:, ~mask] + edge_attr_forward = tuple_index(edge_attr, mask) + edge_attr_backward = tuple_index(edge_attr, ~mask) + + dh = tuple_sum( + self.conv(x, edge_index_forward, edge_attr_forward), + self.conv(autoregressive_x, edge_index_backward, edge_attr_backward) + ) + + count = scatter_add(torch.ones_like(dst), dst, + dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1) + + dh = dh[0] / count, dh[1] / count.unsqueeze(-1) + + else: + dh = self.conv(x, edge_index, edge_attr) + + if node_mask is not None: + x_ = x + x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) + + x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) if self.residual else dh + + dh = self.ff_func(x) + x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) if self.residual else dh + + if node_mask is not None: + x_[0][node_mask], x_[1][node_mask] = x[0], x[1] + x = x_ + return x \ No newline at end of file diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000..0875fc9 --- /dev/null +++ b/src/models.py @@ -0,0 +1,521 @@ +from typing import Callable, Optional, Union +import torch +from torch.nn import functional as F +import torch_geometric +from torch_geometric.nn import SchNet, DimeNetPlusPlus, global_add_pool, global_mean_pool +import torch_scatter +from torch_scatter import scatter +from e3nn import o3 + +from src.modules.blocks import ( + EquivariantProductBasisBlock, + RadialEmbeddingBlock, +) +from src.modules.irreps_tools import reshape_irreps + +from src.egnn_layers import MPNNLayer, EGNNLayer +from src.tfn_layers import TensorProductConvLayer +import src.gvp_layers as gvp + + +class MACEModel(torch.nn.Module): + def __init__( + self, + r_max=10.0, + num_bessel=8, + num_polynomial_cutoff=5, + max_ell=2, + correlation=3, + num_layers=5, + emb_dim=64, + in_dim=1, + out_dim=1, + aggr="sum", + pool="sum", + residual=True, + scalar_pred=True + ): + super().__init__() + self.r_max = r_max + self.emb_dim = emb_dim + self.num_layers = num_layers + self.residual = residual + self.scalar_pred = scalar_pred + # Embedding + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + + # Embedding lookup for initial node features + self.emb_in = torch.nn.Embedding(in_dim, emb_dim) + + self.convs = torch.nn.ModuleList() + self.prods = torch.nn.ModuleList() + self.reshapes = torch.nn.ModuleList() + hidden_irreps = (sh_irreps * emb_dim).sort()[0].simplify() + irrep_seq = [ + o3.Irreps(f'{emb_dim}x0e'), + # o3.Irreps(f'{emb_dim}x0e + {emb_dim}x1o + {emb_dim}x2e'), + # o3.Irreps(f'{emb_dim//2}x0e + {emb_dim//2}x0o + {emb_dim//2}x1e + {emb_dim//2}x1o + {emb_dim//2}x2e + {emb_dim//2}x2o'), + hidden_irreps + ] + for i in range(num_layers): + in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)] + out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)] + conv = TensorProductConvLayer( + in_irreps=in_irreps, + out_irreps=out_irreps, + sh_irreps=sh_irreps, + edge_feats_dim=self.radial_embedding.out_dim, + hidden_dim=emb_dim, + gate=False, + aggr=aggr, + ) + self.convs.append(conv) + self.reshapes.append(reshape_irreps(out_irreps)) + prod = EquivariantProductBasisBlock( + node_feats_irreps=out_irreps, + target_irreps=out_irreps, + correlation=correlation, + element_dependent=False, + num_elements=in_dim, + use_sc=residual + ) + self.prods.append(prod) + + # Global pooling/readout function + self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] + + if self.scalar_pred: + # Predictor MLP + self.pred = torch.nn.Sequential( + torch.nn.Linear(emb_dim, emb_dim), + torch.nn.ReLU(), + torch.nn.Linear(emb_dim, out_dim) + ) + else: + self.pred = torch.nn.Linear(hidden_irreps.dim, out_dim) + + def forward(self, batch): + h = self.emb_in(batch.atoms) # (n,) -> (n, d) + + # Edge features + vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding(lengths) + + for conv, reshape, prod in zip(self.convs, self.reshapes, self.prods): + # Message passing layer + h_update = conv(h, batch.edge_index, edge_attrs, edge_feats) + # Update node features + sc = F.pad(h, (0, h_update.shape[-1] - h.shape[-1])) + h = prod(reshape(h_update), sc, None) + + if self.scalar_pred: + # Select only scalars for prediction + h = h[:,:self.emb_dim] + out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) + return self.pred(out) # (batch_size, out_dim) + + +class TFNModel(torch.nn.Module): + def __init__( + self, + r_max=10.0, + num_bessel=8, + num_polynomial_cutoff=5, + max_ell=2, + num_layers=5, + emb_dim=64, + in_dim=1, + out_dim=1, + aggr="sum", + pool="sum", + residual=True, + scalar_pred=True + ): + super().__init__() + self.r_max = r_max + self.emb_dim = emb_dim + self.num_layers = num_layers + self.residual = residual + self.scalar_pred = scalar_pred + # Embedding + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + + # Embedding lookup for initial node features + self.emb_in = torch.nn.Embedding(in_dim, emb_dim) + + self.convs = torch.nn.ModuleList() + hidden_irreps = (sh_irreps * emb_dim).sort()[0].simplify() + irrep_seq = [ + o3.Irreps(f'{emb_dim}x0e'), + # o3.Irreps(f'{emb_dim}x0e + {emb_dim}x1o + {emb_dim}x2e'), + # o3.Irreps(f'{emb_dim//2}x0e + {emb_dim//2}x0o + {emb_dim//2}x1e + {emb_dim//2}x1o + {emb_dim//2}x2e + {emb_dim//2}x2o'), + hidden_irreps + ] + for i in range(num_layers): + in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)] + out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)] + conv = TensorProductConvLayer( + in_irreps=in_irreps, + out_irreps=out_irreps, + sh_irreps=sh_irreps, + edge_feats_dim=self.radial_embedding.out_dim, + hidden_dim=emb_dim, + gate=True, + aggr=aggr, + ) + self.convs.append(conv) + + # Global pooling/readout function + self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] + + if self.scalar_pred: + # Predictor MLP + self.pred = torch.nn.Sequential( + torch.nn.Linear(emb_dim, emb_dim), + torch.nn.ReLU(), + torch.nn.Linear(emb_dim, out_dim) + ) + else: + self.pred = torch.nn.Linear(hidden_irreps.dim, out_dim) + + def forward(self, batch): + h = self.emb_in(batch.atoms) # (n,) -> (n, d) + + # Edge features + vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding(lengths) + + for conv in self.convs: + # Message passing layer + h_update = conv(h, batch.edge_index, edge_attrs, edge_feats) + + # Update node features + h = h_update + F.pad(h, (0, h_update.shape[-1] - h.shape[-1])) if self.residual else h_update + + if self.scalar_pred: + # Select only scalars for prediction + h = h[:,:self.emb_dim] + out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) + return self.pred(out) # (batch_size, out_dim) + + +class GVPGNNModel(torch.nn.Module): + def __init__( + self, + r_max=10.0, + num_bessel=8, + num_polynomial_cutoff=5, + num_layers=5, + emb_dim=64, + in_dim=1, + out_dim=1, + aggr="sum", + pool="sum", + residual=True + ): + super().__init__() + _DEFAULT_V_DIM = (emb_dim, emb_dim) + _DEFAULT_E_DIM = (emb_dim, 1) + activations = (F.relu, None) + + self.r_max = r_max + self.emb_dim = emb_dim + self.num_layers = num_layers + # Embedding + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + self.emb_in = torch.nn.Embedding(in_dim, emb_dim) + self.W_e = torch.nn.Sequential( + gvp.LayerNorm((self.radial_embedding.out_dim, 1)), + gvp.GVP((self.radial_embedding.out_dim, 1), _DEFAULT_E_DIM, + activations=(None, None), vector_gate=True) + ) + self.W_v = torch.nn.Sequential( + gvp.LayerNorm((emb_dim, 0)), + gvp.GVP((emb_dim, 0), _DEFAULT_V_DIM, + activations=(None, None), vector_gate=True) + ) + + # Stack of GNN layers + self.layers = torch.nn.ModuleList( + gvp.GVPConvLayer(_DEFAULT_V_DIM, _DEFAULT_E_DIM, + activations=activations, vector_gate=True, + residual=residual) + for _ in range(num_layers)) + + self.W_out = torch.nn.Sequential( + gvp.LayerNorm(_DEFAULT_V_DIM), + gvp.GVP(_DEFAULT_V_DIM, (emb_dim, 0), + activations=activations, vector_gate=True) + ) + + # Global pooling/readout function + self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] + + # Predictor MLP + self.pred = torch.nn.Sequential( + torch.nn.Linear(emb_dim, emb_dim), + torch.nn.ReLU(), + torch.nn.Linear(emb_dim, out_dim) + ) + + def forward(self, batch): + + # Edge features + vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + + h_V = self.emb_in(batch.atoms) # (n,) -> (n, d) + h_E = (self.radial_embedding(lengths), torch.nan_to_num(torch.div(vectors, lengths)).unsqueeze_(-2)) + + h_V = self.W_v(h_V) + h_E = self.W_e(h_E) + + for layer in self.layers: + h_V = layer(h_V, batch.edge_index, h_E) + + out = self.W_out(h_V) + + out = self.pool(out, batch.batch) # (n, d) -> (batch_size, d) + return self.pred(out) # (batch_size, out_dim) + + +class EGNNModel(torch.nn.Module): + def __init__( + self, + num_layers=5, + emb_dim=128, + in_dim=1, + out_dim=1, + activation="relu", + norm="layer", + aggr="sum", + pool="sum", + residual=True + ): + """E(n) Equivariant GNN model + + Args: + num_layers: (int) - number of message passing layers + emb_dim: (int) - hidden dimension + in_dim: (int) - initial node feature dimension + out_dim: (int) - output number of classes + activation: (str) - non-linearity within MLPs (swish/relu) + norm: (str) - normalisation layer (layer/batch) + aggr: (str) - aggregation function `\oplus` (sum/mean/max) + pool: (str) - global pooling function (sum/mean) + residual: (bool) - whether to use residual connections + """ + super().__init__() + + # Embedding lookup for initial node features + self.emb_in = torch.nn.Embedding(in_dim, emb_dim) + + # Stack of GNN layers + self.convs = torch.nn.ModuleList() + for layer in range(num_layers): + self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr)) + + # Global pooling/readout function + self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] + + # Predictor MLP + self.pred = torch.nn.Sequential( + torch.nn.Linear(emb_dim, emb_dim), + torch.nn.ReLU(), + torch.nn.Linear(emb_dim, out_dim) + ) + self.residual = residual + + def forward(self, batch): + + h = self.emb_in(batch.atoms) # (n,) -> (n, d) + pos = batch.pos # (n, 3) + + for conv in self.convs: + # Message passing layer + h_update, pos_update = conv(h, pos, batch.edge_index) + + # Update node features (n, d) -> (n, d) + h = h + h_update if self.residual else h_update + + # Update node coordinates (no residual) (n, 3) -> (n, 3) + pos = pos_update + + out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) + return self.pred(out) # (batch_size, out_dim) + + +class MPNNModel(torch.nn.Module): + def __init__( + self, + num_layers=5, + emb_dim=128, + in_dim=1, + out_dim=1, + activation="relu", + norm="layer", + aggr="sum", + pool="sum", + residual=True + ): + """Vanilla Message Passing GNN model + + Args: + num_layers: (int) - number of message passing layers + emb_dim: (int) - hidden dimension + in_dim: (int) - initial node feature dimension + out_dim: (int) - output number of classes + activation: (str) - non-linearity within MLPs (swish/relu) + norm: (str) - normalisation layer (layer/batch) + aggr: (str) - aggregation function `\oplus` (sum/mean/max) + pool: (str) - global pooling function (sum/mean) + residual: (bool) - whether to use residual connections + """ + super().__init__() + + # Embedding lookup for initial node features + self.emb_in = torch.nn.Embedding(in_dim, emb_dim) + + # Stack of GNN layers + self.convs = torch.nn.ModuleList() + for layer in range(num_layers): + self.convs.append(MPNNLayer(emb_dim, activation, norm, aggr)) + + # Global pooling/readout function + self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] + + # Predictor MLP + self.pred = torch.nn.Sequential( + torch.nn.Linear(emb_dim, emb_dim), + torch.nn.ReLU(), + torch.nn.Linear(emb_dim, out_dim) + ) + self.residual = residual + + def forward(self, batch): + + h = self.emb_in(batch.atoms) # (n,) -> (n, d) + + for conv in self.convs: + # Message passing layer and residual connection + h = h + conv(h, batch.edge_index) if self.residual else conv(h, batch.edge_index) + + out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) + return self.pred(out) # (batch_size, out_dim) + + +class SchNetModel(SchNet): + def __init__( + self, + hidden_channels: int = 128, + in_dim: int = 1, + out_dim: int = 1, + num_filters: int = 128, + num_layers: int = 6, + num_gaussians: int = 50, + cutoff: float = 10, + max_num_neighbors: int = 32, + readout: str = 'add', + dipole: bool = False, + mean: Optional[float] = None, + std: Optional[float] = None, + atomref: Optional[torch.Tensor] = None, + ): + super().__init__(hidden_channels, num_filters, num_layers, num_gaussians, cutoff, max_num_neighbors, readout, dipole, mean, std, atomref) + + # Overwrite atom embedding and final predictor + self.lin2 = torch.nn.Linear(hidden_channels // 2, out_dim) + + def forward(self, batch): + h = self.embedding(batch.atoms) + + row, col = batch.edge_index + edge_weight = (batch.pos[row] - batch.pos[col]).norm(dim=-1) + edge_attr = self.distance_expansion(edge_weight) + + for interaction in self.interactions: + h = h + interaction(h, batch.edge_index, edge_weight, edge_attr) + + h = self.lin1(h) + h = self.act(h) + h = self.lin2(h) + + out = scatter(h, batch.batch, dim=0, reduce=self.readout) + return out + + +class DimeNetPPModel(DimeNetPlusPlus): + def __init__( + self, + hidden_channels: int = 128, + in_dim: int = 1, + out_dim: int = 1, + num_layers: int = 4, + int_emb_size: int = 64, + basis_emb_size: int = 8, + out_emb_channels: int = 256, + num_spherical: int = 7, + num_radial: int = 6, + cutoff: float = 10, + max_num_neighbors: int = 32, + envelope_exponent: int = 5, + num_before_skip: int = 1, + num_after_skip: int = 2, + num_output_layers: int = 3, + act: Union[str, Callable] = 'swish' + ): + super().__init__(hidden_channels, out_dim, num_layers, int_emb_size, basis_emb_size, out_emb_channels, num_spherical, num_radial, cutoff, max_num_neighbors, envelope_exponent, num_before_skip, num_after_skip, num_output_layers, act) + + def forward(self, batch): + + i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( + batch.edge_index, num_nodes=batch.atoms.size(0)) + + # Calculate distances. + dist = (batch.pos[i] - batch.pos[j]).pow(2).sum(dim=-1).sqrt() + + # Calculate angles. + pos_i = batch.pos[idx_i] + pos_ji, pos_ki = batch.pos[idx_j] - pos_i, batch.pos[idx_k] - pos_i + a = (pos_ji * pos_ki).sum(dim=-1) + b = torch.cross(pos_ji, pos_ki).norm(dim=-1) + angle = torch.atan2(b, a) + + rbf = self.rbf(dist) + sbf = self.sbf(dist, angle, idx_kj) + + # Embedding block. + x = self.emb(batch.atoms, rbf, i, j) + P = self.output_blocks[0](x, rbf, i, num_nodes=batch.pos.size(0)) + + # Interaction blocks. + for interaction_block, output_block in zip(self.interaction_blocks, + self.output_blocks[1:]): + x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) + P += output_block(x, rbf, i) + + return P.sum(dim=0) if batch is None else scatter(P, batch.batch, dim=0) diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000..3769f38 --- /dev/null +++ b/src/modules/__init__.py @@ -0,0 +1,69 @@ +########################################################################################### +# This directory contains an implementation of MACE, with minor adaptations +# +# Paper: MACE: Higher Order Equivariant Message Passing Neural Networks +# for Fast and Accurate Force Fields, Batatia et al. +# +# Orginal repository: https://github.com/ACEsuit/mace +########################################################################################### + +from typing import Callable, Dict, Optional, Type + +import torch + +from .blocks import ( + AgnosticNonlinearInteractionBlock, + AgnosticResidualNonlinearInteractionBlock, + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + RealAgnosticInteractionBlock, + RealAgnosticResidualInteractionBlock, + ResidualElementDependentInteractionBlock, + ScaleShiftBlock, +) +from .radial import BesselBasis, PolynomialCutoff +from .symmetric_contraction import SymmetricContraction + +interaction_classes: Dict[str, Type[InteractionBlock]] = { + "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock, + "ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock, + "AgnosticResidualNonlinearInteractionBlock": AgnosticResidualNonlinearInteractionBlock, + "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, + "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, +} + +gate_dict: Dict[str, Optional[Callable]] = { + "abs": torch.abs, + "tanh": torch.tanh, + "silu": torch.nn.functional.silu, + "None": None, +} + +__all__ = [ + "AtomicEnergiesBlock", + "RadialEmbeddingBlock", + "LinearNodeEmbeddingBlock", + "LinearReadoutBlock", + "EquivariantProductBasisBlock", + "ScaleShiftBlock", + "InteractionBlock", + "NonLinearReadoutBlock", + "PolynomialCutoff", + "BesselBasis", + "MACE", + "ScaleShiftMACE", + "BOTNet", + "ScaleShiftBOTNet", + "EnergyForcesLoss", + "WeightedEnergyForcesLoss", + "WeightedForcesLoss", + "SymmetricContraction", + "interaction_classes", + "compute_mean_std_atomic_inter_energy", + "compute_avg_num_neighbors", +] diff --git a/src/modules/blocks.py b/src/modules/blocks.py new file mode 100644 index 0000000..54cbb83 --- /dev/null +++ b/src/modules/blocks.py @@ -0,0 +1,549 @@ +########################################################################################### +# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the ASL License (see ASL.md) +########################################################################################### + +from abc import ABC, abstractmethod +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch.nn.functional +from e3nn import nn, o3 + +# from mace.tools.scatter import scatter_sum +from torch_scatter import scatter_sum + +from .irreps_tools import ( + linear_out_irreps, + reshape_irreps, + tp_out_irreps_with_instructions, +) +from .radial import BesselBasis, PolynomialCutoff +from .symmetric_contraction import SymmetricContraction + + +class LinearNodeEmbeddingBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): + super().__init__() + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) + + def forward( + self, node_attrs: torch.Tensor, # [n_nodes, irreps] + ): + return self.linear(node_attrs) + + +class LinearReadoutBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps = o3.Irreps("0e")): + super().__init__() + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, irreps_out] + + +class NonLinearReadoutBlock(torch.nn.Module): + def __init__( + self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps, + gate: Optional[Callable], irreps_out: o3.Irreps = o3.Irreps("0e") + ): + super().__init__() + self.hidden_irreps = MLP_irreps + self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) + self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) + self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irreps_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.non_linearity(self.linear_1(x)) + return self.linear_2(x) # [n_nodes, irreps_out] + + +class AtomicEnergiesBlock(torch.nn.Module): + atomic_energies: torch.Tensor + + def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): + super().__init__() + assert len(atomic_energies.shape) == 1 + + self.register_buffer( + "atomic_energies", + torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), + ) # [n_elements, ] + + def forward( + self, x: torch.Tensor # one-hot of elements [..., n_elements] + ) -> torch.Tensor: # [..., ] + return torch.matmul(x, self.atomic_energies) + + def __repr__(self): + formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) + return f"{self.__class__.__name__}(energies=[{formatted_energies}])" + + +class RadialEmbeddingBlock(torch.nn.Module): + def __init__(self, r_max: float, num_bessel: int, num_polynomial_cutoff: int): + super().__init__() + self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) + self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) + self.out_dim = num_bessel + + def forward( + self, edge_lengths: torch.Tensor, # [n_edges, 1] + ): + bessel = self.bessel_fn(edge_lengths) # [n_edges, n_basis] + cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] + return bessel * cutoff # [n_edges, n_basis] + + +class EquivariantProductBasisBlock(torch.nn.Module): + def __init__( + self, + node_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + correlation: Union[int, Dict[str, int]], + element_dependent: bool = True, + use_sc: bool = True, + batch_norm: bool = False, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_sc = use_sc + self.symmetric_contractions = SymmetricContraction( + irreps_in=node_feats_irreps, + irreps_out=target_irreps, + correlation=correlation, + element_dependent=element_dependent, + num_elements=num_elements, + ) + # Update linear + self.linear = o3.Linear( + target_irreps, target_irreps, internal_weights=True, shared_weights=True, + ) + self.batch_norm = nn.BatchNorm(target_irreps) if batch_norm else None + + def forward( + self, node_feats: torch.Tensor, sc: Optional[torch.Tensor], node_attrs: Optional[torch.Tensor] + ) -> torch.Tensor: + node_feats = self.symmetric_contractions(node_feats, node_attrs) + out = self.linear(node_feats) + if self.batch_norm: + out = self.batch_norm(out) + if self.use_sc: + out = out + sc + return out + + +class InteractionBlock(ABC, torch.nn.Module): + def __init__( + self, + node_attrs_irreps: o3.Irreps, + node_feats_irreps: o3.Irreps, + edge_attrs_irreps: o3.Irreps, + edge_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + hidden_irreps: o3.Irreps, + avg_num_neighbors: float, + ) -> None: + super().__init__() + self.node_attrs_irreps = node_attrs_irreps + self.node_feats_irreps = node_feats_irreps + self.edge_attrs_irreps = edge_attrs_irreps + self.edge_feats_irreps = edge_feats_irreps + self.target_irreps = target_irreps + self.hidden_irreps = hidden_irreps + self.avg_num_neighbors = avg_num_neighbors + + self._setup() + + @abstractmethod + def _setup(self) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +nonlinearities = {1: torch.nn.SiLU(), -1: torch.nn.Tanh()} + + +class TensorProductWeightsBlock(torch.nn.Module): + def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int): + super().__init__() + + weights = torch.empty( + (num_elements, num_edge_feats, num_feats_out), + dtype=torch.get_default_dtype(), + ) + torch.nn.init.xavier_uniform_(weights) + self.weights = torch.nn.Parameter(weights) + + def forward( + self, + sender_or_receiver_node_attrs: torch.Tensor, # assumes that the node attributes are one-hot encoded + edge_feats: torch.Tensor, + ): + return torch.einsum( + "be, ba, aek -> bk", edge_feats, sender_or_receiver_node_attrs, self.weights + ) + + def __repr__(self): + return ( + f'{self.__class__.__name__}(shape=({", ".join(str(s) for s in self.weights.shape)}), ' + f"weights={np.prod(self.weights.shape)})" + ) + + +class ResidualElementDependentInteractionBlock(InteractionBlock): + def _setup(self) -> None: + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + self.conv_tp_weights = TensorProductWeightsBlock( + num_elements=self.node_attrs_irreps.num_irreps, + num_edge_feats=self.edge_feats_irreps.num_irreps, + num_feats_out=self.conv_tp.weight_numel, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) + self.irreps_out = self.irreps_out.simplify() + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + sender, receiver = edge_index + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(node_attrs[sender], edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return message + sc # [n_nodes, irreps] + + +class AgnosticNonlinearInteractionBlock(InteractionBlock): + def _setup(self) -> None: + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(), + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) + self.irreps_out = self.irreps_out.simplify() + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + sender, receiver = edge_index + num_nodes = node_feats.shape[0] + tp_weights = self.conv_tp_weights(edge_feats) + node_feats = self.linear_up(node_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + message = self.skip_tp(message, node_attrs) + return message # [n_nodes, irreps] + + +class AgnosticResidualNonlinearInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(), + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) + self.irreps_out = self.irreps_out.simplify() + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + sender, receiver = edge_index + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + message = message + sc + return message # [n_nodes, irreps] + + +class RealAgnosticInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(), + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender, receiver = edge_index + num_nodes = node_feats.shape[0] + + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +class RealAgnosticResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(), + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender, receiver = edge_index + num_nodes = node_feats.shape[0] + + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +class ScaleShiftBlock(torch.nn.Module): + def __init__(self, scale: float, shift: float): + super().__init__() + self.register_buffer( + "scale", torch.tensor(scale, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "shift", torch.tensor(shift, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.scale * x + self.shift + + def __repr__(self): + return ( + f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})" + ) diff --git a/src/modules/cg.py b/src/modules/cg.py new file mode 100644 index 0000000..b1fc237 --- /dev/null +++ b/src/modules/cg.py @@ -0,0 +1,133 @@ +########################################################################################### +# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger) +# Authors: Ilyes Batatia +# This program is distributed under the ASL License (see ASL.md) +########################################################################################### + +import collections +from typing import List, Union + +import torch +from e3nn import o3 + +# Based on e3nn + +_TP = collections.namedtuple("_TP", "op, args") +_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") + + +def _wigner_nj( + irrepss: List[o3.Irreps], + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irrepss = [o3.Irreps(irreps) for irreps in irrepss] + if filter_ir_mid is not None: + filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] + + if len(irrepss) == 1: + (irreps,) = irrepss + ret = [] + e = torch.eye(irreps.dim, dtype=dtype) + i = 0 + for mul, ir in irreps: + for _ in range(mul): + sl = slice(i, i + ir.dim) + ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] + i += ir.dim + return ret + + *irrepss_left, irreps_right = irrepss + ret = [] + for ir_left, path_left, C_left in _wigner_nj( + irrepss_left, + normalization=normalization, + filter_ir_mid=filter_ir_mid, + dtype=dtype, + ): + i = 0 + for mul, ir in irreps_right: + for ir_out in ir_left * ir: + if filter_ir_mid is not None and ir_out not in filter_ir_mid: + continue + + C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) + if normalization == "component": + C *= ir_out.dim ** 0.5 + if normalization == "norm": + C *= ir_left.dim ** 0.5 * ir.dim ** 0.5 + + C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) + C = C.reshape( + ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim + ) + for u in range(mul): + E = torch.zeros( + ir_out.dim, + *(irreps.dim for irreps in irrepss_left), + irreps_right.dim, + dtype=dtype, + ) + sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) + E[..., sl] = C + ret += [ + ( + ir_out, + _TP( + op=(ir_left, ir, ir_out), + args=( + path_left, + _INPUT(len(irrepss_left), sl.start, sl.stop), + ), + ), + E, + ) + ] + i += mul * ir.dim + return sorted(ret, key=lambda x: x[0]) + + +def U_matrix_real( + irreps_in: Union[str, o3.Irreps], + irreps_out: Union[str, o3.Irreps], + correlation: int, + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irreps_out = o3.Irreps(irreps_out) + irrepss = [o3.Irreps(irreps_in)] * correlation + if correlation == 4: + filter_ir_mid = [ + (0, 1), + (1, -1), + (2, 1), + (3, -1), + (4, 1), + (5, -1), + (6, 1), + (7, -1), + (8, 1), + (9, -1), + (10, 1), + (11, -1), + ] + wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) + current_ir = wigners[0][0] + out = [] + stack = torch.tensor([]) + + for ir, _, base_o3 in wigners: + if ir in irreps_out and ir == current_ir: + stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) + last_ir = current_ir + elif ir in irreps_out and ir != current_ir: + if len(stack) != 0: + out += [last_ir, stack] + stack = base_o3.squeeze().unsqueeze(-1) + current_ir, last_ir = ir, ir + else: + current_ir = ir + out += [last_ir, stack] + return out diff --git a/src/modules/irreps_tools.py b/src/modules/irreps_tools.py new file mode 100644 index 0000000..09a1758 --- /dev/null +++ b/src/modules/irreps_tools.py @@ -0,0 +1,97 @@ +########################################################################################### +# Elementary tools for handling irreducible representations +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the ASL License (see ASL.md) +########################################################################################### + +from typing import List, Tuple + +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + + +# Based on mir-group/nequip +def tp_out_irreps_with_instructions( + irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps +) -> Tuple[o3.Irreps, List]: + trainable = True + + # Collect possible irreps and their instructions + irreps_out_list: List[Tuple[int, o3.Irreps]] = [] + instructions = [] + for i, (mul, ir_in) in enumerate(irreps1): + for j, (_, ir_edge) in enumerate(irreps2): + for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 + if ir_out in target_irreps: + k = len(irreps_out_list) # instruction index + irreps_out_list.append((mul, ir_out)) + instructions.append((i, j, k, "uvu", trainable)) + + # We sort the output irreps of the tensor product so that we can simplify them + # when they are provided to the second o3.Linear + irreps_out = o3.Irreps(irreps_out_list) + irreps_out, permut, _ = irreps_out.sort() + + # Permute the output indexes of the instructions to match the sorted irreps: + instructions = [ + (i_in1, i_in2, permut[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + return irreps_out, instructions + + +def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: + # Assuming simplified irreps + irreps_mid = [] + for _, ir_in in irreps: + found = False + + for mul, ir_out in target_irreps: + if ir_in == ir_out: + irreps_mid.append((mul, ir_out)) + found = True + break + + if not found: + raise RuntimeError(f"{ir_in} not in {target_irreps}") + + return o3.Irreps(irreps_mid) + + +@compile_mode("script") +class reshape_irreps(torch.nn.Module): + def __init__(self, irreps: o3.Irreps) -> None: + super().__init__() + self.irreps = irreps + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + ix = 0 + out = [] + batch, _ = tensor.shape + for mul, ir in self.irreps: + d = ir.dim + field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] + ix += mul * d + field = field.reshape(batch, mul, d) + out.append(field) + return torch.cat(out, dim=-1) + + +def irreps2gate(irreps): + irreps_scalars = [] + irreps_gated = [] + for mul, ir in irreps: + if ir.l == 0 and ir.p == 1: + irreps_scalars.append((mul, ir)) + else: + irreps_gated.append((mul, ir)) + irreps_scalars = o3.Irreps(irreps_scalars).simplify() + irreps_gated = o3.Irreps(irreps_gated).simplify() + if irreps_gated.dim > 0: + ir = '0e' + else: + ir = None + irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify() + return irreps_scalars, irreps_gates, irreps_gated diff --git a/src/modules/model.py b/src/modules/model.py new file mode 100644 index 0000000..4595b52 --- /dev/null +++ b/src/modules/model.py @@ -0,0 +1,171 @@ +from typing import Callable, Optional, Type +import torch +from torch_scatter import scatter +from e3nn import o3 + +from src.modules.blocks import ( + EquivariantProductBasisBlock, + InteractionBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, +) +from src.modules import ( + interaction_classes, + gate_dict +) + + +class OriginalMACEModel(torch.nn.Module): + def __init__( + self, + r_max: float = 10.0, + num_bessel: int = 8, + num_polynomial_cutoff: int = 5, + max_ell: int = 2, + interaction_cls: Type[InteractionBlock] = interaction_classes["RealAgnosticResidualInteractionBlock"], + interaction_cls_first: Type[InteractionBlock] = interaction_classes["RealAgnosticInteractionBlock"], + num_interactions: int = 2, + num_elements: int = 1, + hidden_irreps: o3.Irreps = o3.Irreps("64x0e + 64x1o + 64x2e"), + MLP_irreps: o3.Irreps = o3.Irreps("64x0e"), + irreps_out: o3.Irreps = o3.Irreps("1x0e"), + avg_num_neighbors: int = 1, + correlation: int = 3, + gate: Optional[Callable] = gate_dict["silu"], + num_layers=2, + in_dim=1, + out_dim=1, + ): + super().__init__() + self.r_max = r_max + self.num_elements = num_elements + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + + # Interactions and readout + self.atomic_energies_fn = LinearReadoutBlock(node_feats_irreps, irreps_out) + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + element_dependent=True, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearReadoutBlock(hidden_irreps, irreps_out)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + hidden_irreps_out = str( + hidden_irreps[0] + ) # Select only scalars for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + element_dependent=True, + num_elements=num_elements, + use_sc=True + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate, irreps_out) + ) + else: + self.readouts.append(LinearReadoutBlock(hidden_irreps, irreps_out)) + + def forward(self, batch): + # MACE expects one-hot-ified input + batch.atoms.unsqueeze_(-1) + shape = batch.atoms.shape[:-1] + (self.num_elements,) + node_attrs = torch.zeros(shape, device=batch.atoms.device).view(shape) + node_attrs.scatter_(dim=-1, index=batch.atoms, value=1) + + # Node embeddings + node_feats = self.node_embedding(node_attrs) + node_e0 = self.atomic_energies_fn(node_feats) + e0 = scatter(node_e0, batch.batch, dim=0, reduce="sum") # [n_graphs, irreps_out] + + # Edge features + vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding(lengths) + + # Interactions + energies = [e0] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=node_attrs, + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=batch.edge_index, + ) + node_feats = product( + node_feats=node_feats, sc=sc, node_attrs=node_attrs + ) + node_energies = readout(node_feats).squeeze(-1) # [n_nodes, irreps_out] + energy = scatter(node_energies, batch.batch, dim=0, reduce="sum") # [n_graphs, irreps_out] + energies.append(energy) + + # Sum over energy contributions + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, irreps_out] + + return total_energy diff --git a/src/modules/radial.py b/src/modules/radial.py new file mode 100644 index 0000000..c8686a7 --- /dev/null +++ b/src/modules/radial.py @@ -0,0 +1,84 @@ +########################################################################################### +# Radial basis and cutoff +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the ASL License (see ASL.md) +########################################################################################### + +import numpy as np +import torch + + +class BesselBasis(torch.nn.Module): + """ + Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020. + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8, trainable=False): + super().__init__() + + bessel_weights = ( + np.pi + / r_max + * torch.linspace( + start=1.0, + end=num_basis, + steps=num_basis, + dtype=torch.get_default_dtype(), + ) + ) + if trainable: + self.bessel_weights = torch.nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) + + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "prefactor", + torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor,) -> torch.Tensor: # [..., 1] + numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] + return self.prefactor * (numerator / x) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " + f"trainable={self.bessel_weights.requires_grad})" + ) + + +class PolynomialCutoff(torch.nn.Module): + """ + Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020. + Equation (8) + """ + + p: torch.Tensor + r_max: torch.Tensor + + def __init__(self, r_max: float, p=6): + super().__init__() + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # yapf: disable + envelope = ( + 1.0 + - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p) + + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1) + - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2) + ) + # yapf: enable + + # noinspection PyUnresolvedReferences + return envelope * (x < self.r_max).type(torch.get_default_dtype()) + + def __repr__(self): + return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" diff --git a/src/modules/symmetric_contraction.py b/src/modules/symmetric_contraction.py new file mode 100644 index 0000000..85da8a0 --- /dev/null +++ b/src/modules/symmetric_contraction.py @@ -0,0 +1,188 @@ +########################################################################################### +# Implementation of the symmetric contraction algorithm presented in the MACE paper +# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11) +# Authors: Ilyes Batatia +# This program is distributed under the ASL License (see ASL.md) +########################################################################################### + +from typing import Dict, Optional, Union + +import torch +import torch.fx +from e3nn import o3 +from e3nn.util.codegen import CodeGenMixin +from e3nn.util.jit import compile_mode +from opt_einsum import contract + +from .cg import U_matrix_real + + +@compile_mode("script") +class SymmetricContraction(CodeGenMixin, torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: Union[int, Dict[str, int]], + irrep_normalization: str = "component", + path_normalization: str = "element", + internal_weights: Optional[bool] = None, + shared_weights: Optional[torch.Tensor] = None, + element_dependent: Optional[bool] = None, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + if irrep_normalization is None: + irrep_normalization = "component" + + if path_normalization is None: + path_normalization = "element" + + assert irrep_normalization in ["component", "norm", "none"] + assert path_normalization in ["element", "path", "none"] + + self.irreps_in = o3.Irreps(irreps_in) + self.irreps_out = o3.Irreps(irreps_out) + + del irreps_in, irreps_out + + if not isinstance(correlation, tuple): + corr = correlation + correlation = {} + for irrep_out in self.irreps_out: + correlation[irrep_out] = corr + + assert shared_weights or not internal_weights + + if internal_weights is None: + internal_weights = True + + if element_dependent is None: + element_dependent = True + + self.internal_weights = internal_weights + self.shared_weights = shared_weights + + del internal_weights, shared_weights + + self.contractions = torch.nn.ModuleDict() + for irrep_out in self.irreps_out: + self.contractions[str(irrep_out)] = Contraction( + irreps_in=self.irreps_in, + irrep_out=o3.Irreps(str(irrep_out.ir)), + correlation=correlation[irrep_out], + internal_weights=self.internal_weights, + element_dependent=element_dependent, + num_elements=num_elements, + weights=self.shared_weights, + ) + + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + outs = [] + for irrep in self.irreps_out: + outs.append(self.contractions[str(irrep)](x, y)) + return torch.cat(outs, dim=-1) + + +class Contraction(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps, + correlation: int, + internal_weights: bool = True, + element_dependent: bool = True, + num_elements: Optional[int] = None, + weights: Optional[torch.Tensor] = None, + ) -> None: + super().__init__() + + self.element_dependent = element_dependent + self.num_features = irreps_in.count((0, 1)) + self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) + self.correlation = correlation + dtype = torch.get_default_dtype() + for nu in range(1, correlation + 1): + U_matrix = U_matrix_real( + irreps_in=self.coupling_irreps, + irreps_out=irrep_out, + correlation=nu, + dtype=dtype, + )[-1] + self.register_buffer(f"U_matrix_{nu}", U_matrix) + + if element_dependent: + # Tensor contraction equations + self.equation_main = "...ik,ekc,bci,be -> bc..." + self.equation_weighting = "...k,ekc,be->bc..." + self.equation_contract = "bc...i,bci->bc..." + if internal_weights: + # Create weight for product basis + self.weights = torch.nn.ParameterDict({}) + for i in range(1, correlation + 1): + num_params = self.U_tensors(i).size()[-1] + w = torch.nn.Parameter( + torch.randn(num_elements, num_params, self.num_features) + / num_params + ) + self.weights[str(i)] = w + else: + self.register_buffer("weights", weights) + + else: + # Tensor contraction equations + self.equation_main = "...ik,kc,bci -> bc..." + self.equation_weighting = "...k,kc->c..." + self.equation_contract = "bc...i,bci->bc..." + if internal_weights: + # Create weight for product basis + self.weights = torch.nn.ParameterDict({}) + for i in range(1, correlation + 1): + num_params = self.U_tensors(i).size()[-1] + w = torch.nn.Parameter( + torch.randn(num_params, self.num_features) / num_params + ) + self.weights[str(i)] = w + else: + self.register_buffer("weights", weights) + + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + if self.element_dependent: + out = contract( + self.equation_main, + self.U_tensors(self.correlation), + self.weights[str(self.correlation)], + x, + y, + ) # TODO: use optimize library and cuTENSOR # pylint: disable=fixme + for corr in range(self.correlation - 1, 0, -1): + c_tensor = contract( + self.equation_weighting, + self.U_tensors(corr), + self.weights[str(corr)], + y, + ) + c_tensor = c_tensor + out + out = contract(self.equation_contract, c_tensor, x) + + else: + out = contract( + self.equation_main, + self.U_tensors(self.correlation), + self.weights[str(self.correlation)], + x, + ) # TODO: use optimize library and cuTENSOR # pylint: disable=fixme + for corr in range(self.correlation - 1, 0, -1): + c_tensor = contract( + self.equation_weighting, + self.U_tensors(corr), + self.weights[str(corr)], + ) + c_tensor = c_tensor + out + out = contract(self.equation_contract, c_tensor, x) + resize_shape = torch.prod(torch.tensor(out.shape[1:])) + return out.view(out.shape[0], resize_shape) + + def U_tensors(self, nu): + return self._buffers[f"U_matrix_{nu}"] diff --git a/src/tfn_layers.py b/src/tfn_layers.py new file mode 100644 index 0000000..eb2173d --- /dev/null +++ b/src/tfn_layers.py @@ -0,0 +1,89 @@ +import torch +from torch_scatter import scatter + +import e3nn +from e3nn import o3 +from e3nn import nn + +from src.modules.irreps_tools import irreps2gate + + +class TensorProductConvLayer(torch.nn.Module): + def __init__( + self, + in_irreps, + out_irreps, + sh_irreps, + edge_feats_dim, + hidden_dim, + aggr="add", + batch_norm=False, + gate=True + ): + """Tensor Field Network GNN Layer + + Implements a Tensor Field Network equivariant GNN layer for higher-order tensors, using e3nn. + Implementation adapted from: https://github.com/gcorso/DiffDock/ + + Paper: Tensor Field Networks, Thomas, Smidt et al. + + Args: + in_irreps: (e3nn.o3.Irreps) Input irreps dimensions + out_irreps: (e3nn.o3.Irreps) Output irreps dimensions + sh_irreps: (e3nn.o3.Irreps) Spherical harmonic irreps dimensions + edge_feats_dim: (int) Edge feature dimensions + hidden_dim: (int) Hidden dimension of MLP for computing tensor product weights + aggr: (str) Message passing aggregator + batch_norm: (bool) Whether to apply equivariant batch norm + gate: (bool) Whether to apply gated non-linearity + """ + super().__init__() + self.in_irreps = in_irreps + self.out_irreps = out_irreps + self.sh_irreps = sh_irreps + self.edge_feats_dim = edge_feats_dim + self.aggr = aggr + + if gate: + # Optionally apply gated non-linearity + irreps_scalars, irreps_gates, irreps_gated = irreps2gate(o3.Irreps(out_irreps)) + act_scalars = [torch.nn.functional.silu for _, ir in irreps_scalars] + act_gates = [torch.sigmoid for _, ir in irreps_gates] + if irreps_gated.num_irreps == 0: + self.gate = nn.Activation(out_irreps, acts=[torch.nn.functional.silu]) + else: + self.gate = nn.Gate( + irreps_scalars, act_scalars, # scalar + irreps_gates, act_gates, # gates (scalars) + irreps_gated # gated tensors + ) + # Output irreps for the tensor product must be updated + self.out_irreps = out_irreps = self.gate.irreps_in + else: + self.gate = None + + # Tensor product over edges to construct messages + self.tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False) + + # MLP used to compute weights of tensor product + self.fc = torch.nn.Sequential( + torch.nn.Linear(edge_feats_dim, hidden_dim), + torch.nn.ReLU(), + torch.nn.Linear(hidden_dim, self.tp.weight_numel) + ) + + # Optional equivariant batch norm + self.batch_norm = nn.BatchNorm(out_irreps) if batch_norm else None + + def forward(self, node_attr, edge_index, edge_attr, edge_feat): + src, dst = edge_index + # Compute messages + tp = self.tp(node_attr[dst], edge_attr, self.fc(edge_feat)) + # Aggregate messages + out = scatter(tp, src, dim=0, reduce=self.aggr) + # Optionally apply gated non-linearity and/or batch norm + if self.gate: + out = self.gate(out) + if self.batch_norm: + out = self.batch_norm(out) + return out diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/plot_utils.py b/src/utils/plot_utils.py new file mode 100644 index 0000000..27abd2c --- /dev/null +++ b/src/utils/plot_utils.py @@ -0,0 +1,80 @@ +import numpy as np +from torch_geometric.utils import to_networkx +import matplotlib.pyplot as plt + + +def plot_2d(data, lim=10): + # The graph to visualize + G = to_networkx(data) + pos = data.pos.numpy() + + # Extract node and edge positions from the layout + node_xyz = np.array([pos[v, :2] for v in sorted(G)]) + edge_xyz = np.array([(pos[u, :2], pos[v, :2]) for u, v in G.edges()]) + + # Create the 2D figure + fig = plt.figure() + ax = fig.add_subplot(111) + + # Plot the nodes - alpha is scaled by "depth" automatically + ax.scatter(*node_xyz.T, s=100, c=data.atoms.numpy(), cmap="rainbow") + + # Plot the edges + for vizedge in edge_xyz: + ax.plot(*vizedge.T, color="tab:gray") + + # Turn gridlines off + # ax.grid(False) + + # Suppress tick labels + # for dim in (ax.xaxis, ax.yaxis, ax.zaxis): + # dim.set_ticks([]) + + # Set axes labels and limits + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_xlim([-lim, lim]) + ax.set_ylim([-lim, lim]) + ax.set_aspect('equal', 'box') + + # fig.tight_layout() + plt.show() + + +def plot_3d(data, lim=10): + # The graph to visualize + G = to_networkx(data) + pos = data.pos.numpy() + + # Extract node and edge positions from the layout + node_xyz = np.array([pos[v] for v in sorted(G)]) + edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()]) + + # Create the 3D figure + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + # Plot the nodes - alpha is scaled by "depth" automatically + ax.scatter(*node_xyz.T, s=100, c=data.atoms.numpy(), cmap="rainbow") + + # Plot the edges + for vizedge in edge_xyz: + ax.plot(*vizedge.T, color="tab:gray") + + # Turn gridlines off + # ax.grid(False) + + # Suppress tick labels + # for dim in (ax.xaxis, ax.yaxis, ax.zaxis): + # dim.set_ticks([]) + + # Set axes labels and limits + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.set_xlim([-lim, lim]) + ax.set_ylim([-lim, lim]) + ax.set_zlim([-lim, lim]) + + # fig.tight_layout() + plt.show() diff --git a/src/utils/train_utils.py b/src/utils/train_utils.py new file mode 100644 index 0000000..0630b6d --- /dev/null +++ b/src/utils/train_utils.py @@ -0,0 +1,116 @@ +import time +import random +from tqdm import tqdm +import numpy as np +from sklearn.metrics import accuracy_score + +import torch +import torch.nn.functional as F + + +def seed(seed=0): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def train(model, train_loader, optimizer, device): + model.train() + loss_all = 0 + for batch in train_loader: + batch = batch.to(device) + optimizer.zero_grad() + y_pred = model(batch) + loss = F.cross_entropy(y_pred, batch.y) + loss.backward() + loss_all += loss.item() * batch.num_graphs + optimizer.step() + return loss_all / len(train_loader.dataset) + + +def eval(model, loader, device): + model.eval() + y_pred = [] + y_true = [] + for batch in loader: + batch = batch.to(device) + with torch.no_grad(): + y_pred.append(model(batch).detach().cpu()) + y_true.append(batch.y.detach().cpu()) + return accuracy_score( + torch.concat(y_true, dim=0), + np.argmax(torch.concat(y_pred, dim=0), axis=1) + ) * 100 # return percentage + + +def _run_experiment(model, train_loader, val_loader, test_loader, n_epochs=100, verbose=True, device='cpu'): + total_param = 0 + for param in model.parameters(): + total_param += np.prod(list(param.data.size())) + model = model.to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode='max', factor=0.9, patience=25, min_lr=0.00001) + + if verbose: + print(f"Running experiment for {type(model).__name__}.") + # print("\nModel architecture:") + # print(model) + print(f'Total parameters: {total_param}') + print("\nStart training:") + + best_val_acc = None + perf_per_epoch = [] # Track Test/Val performace vs. epoch (for plotting) + t = time.time() + for epoch in range(1, n_epochs+1): + # Train model for one epoch, return avg. training loss + loss = train(model, train_loader, optimizer, device) + + # Evaluate model on validation set + val_acc = eval(model, val_loader, device) + + if best_val_acc is None or val_acc >= best_val_acc: + # Evaluate model on test set if validation metric improves + test_acc = eval(model, test_loader, device) + best_val_acc = val_acc + + if epoch % 10 == 0 and verbose: + print(f'Epoch: {epoch:03d}, LR: {lr:.5f}, Loss: {loss:.5f}, ' + f'Val Acc: {val_acc:.3f}, Test Acc: {test_acc:.3f}') + + perf_per_epoch.append((test_acc, val_acc, epoch, type(model).__name__)) + scheduler.step(val_acc) + lr = optimizer.param_groups[0]['lr'] + + t = time.time() - t + train_time = t + if verbose: + print(f"\nDone! Training took {train_time:.2f}s. Best validation accuracy: {best_val_acc:.3f}, corresponding test accuracy: {test_acc:.3f}.") + + return best_val_acc, test_acc, train_time, perf_per_epoch + + +def run_experiment(model, train_loader, val_loader, test_loader, n_epochs=100, n_times=100, verbose=False, device='cpu'): + print(f"Running experiment for {type(model).__name__} ({device}).") + + best_val_acc_list = [] + test_acc_list = [] + train_time_list = [] + for idx in tqdm(range(n_times)): + seed(idx) # set random seed + best_val_acc, test_acc, train_time, _ = _run_experiment(model, train_loader, val_loader, test_loader, n_epochs, verbose, device) + best_val_acc_list.append(best_val_acc) + test_acc_list.append(test_acc) + train_time_list.append(train_time) + + print(f'\nDone! Averaged over {n_times} runs: \n ' + f'- Training time: {np.mean(train_time_list):.2f}s ± {np.std(train_time_list):.2f}. \n ' + f'- Best validation accuracy: {np.mean(best_val_acc_list):.3f} ± {np.std(best_val_acc_list):.3f}. \n' + f'- Test accuracy: {np.mean(test_acc_list):.1f} ± {np.std(test_acc_list):.1f}. \n') + + return best_val_acc_list, test_acc_list, train_time_list