This repository provides a framework for training teacher and student models for molecular property prediction using Knowledge Distillation (KD). The implementation supports training on the QM9, ESOL, and FreeSolv datasets and utilizes graph neural networks (GNNs) like SchNet, DimeNet++, and TensorNet.
- Train teacher models on QM9 (first 5 targets).
- Train student models with or without knowledge distillation on 10 different QM9 and two experimental datasets within MoleculeNet (ESOL/FreeSolv).
- Supports SchNet, DimeNet++, and TensorNet models.
- Implements Uncertainity-weighted ensemble of L1 loss and Cosine Similarity loss for our regression-based KD approach.
- Hyperparameter tuning using Optuna.
- Early stopping & model checkpointing.
- Logging of metrics and best models.
Knowledge_distillation_for_molecules/
│
│── architectures/ # Model architectures
│ ├── schnet.py # SchNet architectures
│ ├── dimenetpp.py # DimeNet++ architectures
│ ├── tensornet.py # TensorNet architectures
│── training/ # Training scripts
│ ├── train_teacher.py # Train teacher models
│ ├── train_student.py # Train student models (with/without KD)
│ ├── total_loss.py # Loss function (L1 + Cosine similarity)
│ ├── optimize_hyperparams.py # Optuna-based hyperparameter tuning ✅
├── data/ # Data processing
│ ├── data_loader.py # Loads QM9, ESOL, FreeSolv datasets
├── utils/ # Helper functions
│ ├── train_utils.py # Metrics, reproducibility
│ ├── config.py # Centralized hyperparameters
│── results/ # Logs, saved models (will be created during model checkpoint saving)
│── README.md # Documentation
│── requirements.txt # Dependencies
To set up the environment and install dependencies, including torchmd-net
:
mamba create -n molecular-learning python=3.9
mamba activate molecular-learning
mamba install pytorch torchvision torchaudio -c pytorch
mamba install pyg -c pyg
mamba install torchmd-net -c conda-forge
pip install -r requirements.txt
Run the following command to train a teacher model:
python training/train_teacher.py
This trains a SchNet/DimeNet++/TensorNet model on QM9 (first 5 targets).
To train the student model, run:
python training/train_student.py
- With KD: The script will use the pre-trained teacher model.
- Without KD: Set
USE_KD = False
inconfig.py
.
To tune learning rate, batch size, and alpha using Optuna, run:
python training/optimize_hyperparams.py
- The best hyperparameters will be saved in
config_optimized.json
. - To apply them automatically, ensure
PERFORM_TUNING = False
inconfig.py
.
- Best models are saved in
results/
. - Metrics (train loss, val loss, R²) are logged in
config.LOG_PATH
.
- Pretrained teacher and student models for QM9, ESOL, and FreeSolv can be downloaded here - pretrained models
- If you find this repository useful, please cite our work:
@article{Sheshanarayana2025,
author = {R. Sheshanarayana and Fengqi You},
title = {Knowledge Distillation for Molecular Property Prediction: A Scalability Analysis},
journal = {Advanced Science},
year = {2025},
volume = {n/a},
pages = {2503271},
doi = {10.1002/advs.202503271}
}
Feel free to contribute or raise issues! 🚀