Skip to content

Commit

Permalink
upload source code
Browse files Browse the repository at this point in the history
  • Loading branch information
bojunliu0818 committed Jan 20, 2024
1 parent 859334c commit 0979a0d
Show file tree
Hide file tree
Showing 17 changed files with 1,629 additions and 1 deletion.
38 changes: 38 additions & 0 deletions .github/workflows/build-code.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Build TS-DART source code

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
build-code:
runs-on: ${{ matrix.os }}
strategy:
max-parallel: 5
matrix:
python-version: ["3.9","3.10"]
os:
- macOS-latest
- ubuntu-latest
- windows-latest
defaults:
run:
shell: bash -el {0}

steps:
- uses: actions/checkout@v3
- uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
python-version: ${{ matrix.python-version }}
- name: Install required packages
run: |
python -m pip install --upgrade pip
- name: Pip install
run: |
python -m pip install .
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Byte-compiled / optimized / DLL files
.idea/
__pycache__/

# Distribution / packaging
build/
dist/
*.egg-info/
*.egg/
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Bojun Liu

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include LICENSE.txt
69 changes: 68 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,68 @@
# ts-dart
# TS-DART: Transition State identification via Dispersion and vAriational principle Regularized neural neTworks

### Abstract

Identifying transitional states is crucial for understanding protein conformational changes that underlie numerous fundamental biological processes. Markov state models (MSMs) constructed from Molecular Dynamics (MD) simulations have demonstrated considerable success in studying protein conformational changes, which are often associated with rare events transiting over free energy barriers. However, it remains challenging for MSMs to identify the transition states, as they group MD conformations into discrete metastable states and do not provide information on transition states lying at the top of free energy barriers between metastable states. Inspired by recent advances in trustworthy artificial intelligence (AI) for detecting out-of-distribution (OOD) data, we present Transition State identification via Dispersion and vAriational principle Regularized neural neTworks (TS-DART). This deep learning approach effectively detects the transition states from MD simulations using hyperspherical embeddings in the latent space. The key insight of TS-DART is to treat the transition state structures as OOD data, recognizing that the transition states are less populated and exhibit a distributional shift from metastable states. Our TS-DART method offers an end-to-end pipeline for identifying transition states from MD simulations. By introducing a dispersion loss function to regularize the hyperspherical latent space, TS-DART can discern transition state conformations that separate multiple metastable states in an MSM. Furthermore, TS-DART provides hyperspherical latent representations that preserve all relevant kinetic geometries of the original dynamics. We demonstrate the power of TS-DART by applying it to a 2D-potential, alanine dipeptide and the translocation of a DNA motor protein on DNA. In all these systems, TS-DART outperforms previous methods in identifying transition states. As TS-DART integrates the dimensionality reduction, state decomposition, and transition state identification in a unified framework, we anticipate that it will be applicable for studying transition states of protein conformational changes.

### Illustration

![figure](./docs/figs/fig2.png)

## Installation from sources

The source code can be installed with a local clone:

```bash
git clone https://github.com/bojunliu0818/ts-dart.git
```

```bash
python -m pip install ./ts-dart
```
## Quick start

### Start with jupyter notebook

Check this file:

```
./ts-dart/example/quadruple-well-example.ipynb
```

### Start with python script (Linux)

```sh
python ./ts-dart/scripts/train_tsdart.py \
--seed 1 \
--device 'cpu' \
--lag_time 10 \
--encoder_sizes 2 20 20 20 10 2 \
--feat_dim 2 \
--n_states 2 \
--beta 0.01 \
--gamma 1 \
--proto_update_factor 0.5 \
--scaling_temperature 0.1 \
--learning_rate 0.001 \
--pretrain 10 \
--n_epochs 20 \
--train_split 0.9 \
--train_batch_size 1000 \
--data_directory ./ts-dart/data \
--saving_directory .
```

Or
```
sh ./ts-dart/scripts/train_tsdart.sh
```

## Reference

Our codebase builds heavily on
- [https://github.com/deeplearning-wisc/cider](https://github.com/deeplearning-wisc/cider)
- [https://github.com/deeptime-ml/deeptime](https://github.com/deeptime-ml/deeptime)

Thanks for open-sourcing!

[Go to Top](#Abstract)
Binary file added data/quadruple-well.npy
Binary file not shown.
Binary file added docs/figs/fig1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
651 changes: 651 additions & 0 deletions example/quadruple-well-example.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[build-system]
requires = ["setuptools",
"wheel"
]
build-backend = "setuptools.build_meta"
203 changes: 203 additions & 0 deletions scripts/train_tsdart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import argparse
import pprint
import os
import glob
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn

from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split

from tsdart.utils import set_random_seed
from tsdart.model import TSDART, TSDARTLayer, TSDARTModel, TSDARTEstimator
from tsdart.dataprocessing import Preprocessing

parser = argparse.ArgumentParser(description='Training with TS-DART')

parser.add_argument('--seed', default=1, type=int, help='random seed')
parser.add_argument('--device', default='cpu', type=str, help='train the model with gpu or cpu')

parser.add_argument('--lag_time', type=int, help='the lag time used to create transition pairs', required=True)

parser.add_argument('--encoder_sizes', nargs='+', type=int, help='the size of each layer in TS-DART encoder, the size of the last layer represents feat_dim', required=True)
parser.add_argument('--feat_dim', type=int, help='the dimensionality of latent space ((d-1)-hypersphere)', required=True)
parser.add_argument('--n_states', type=int, help='the number of metastable states to consider', required=True)

parser.add_argument('--beta', default=0.01, type=float, help='the weight of dispersion loss')
parser.add_argument('--gamma', default=1, type=float, help='the radius of hypersphere')
parser.add_argument('--proto_update_factor', default=0.5, type=float, help='the update factor to compute state center vectors in EMA algorithm')
parser.add_argument('--scaling_temperature', default=0.1, type=float, help='the scaling factor in despersion loss')

parser.add_argument('--optimizer', default='Adam', type=str, help='the optimizer to train the model')
parser.add_argument('--learning_rate', default=1e-3, type=float, help='the learning rate to training the model')

parser.add_argument('--pretrain', default=10, type=int, help='the number of pretraining epochs with pure VAMP-2 loss optimization')
parser.add_argument('--n_epochs', default=20, type=int, help='the total number of training epochs with VAMP-2 and dispersion loss optimization')
parser.add_argument('--save_model_interval', default=None, type=int, help='save the model every save_epoch')

parser.add_argument('--train_split', default=0.9, type=float, help='the ratio of training dataset size to full dataset size')
parser.add_argument('--train_batch_size', default=1000, type=int, help='the batch size in training dataloader')
parser.add_argument('--val_batch_size', default=None, type=int, help='the batch size in validation dataloader')

parser.add_argument('--data_directory', type=str, help='the directory storing numpy files of trajectories', required=True)
parser.add_argument('--saving_directory', default='.', type=str, help='the saving directory of training results')

args = parser.parse_args()

state = {k: v for k, v in args._get_kwargs()}

date_time = datetime.now().strftime("%m_%d_%H_%M")

args.name = (f"{date_time}_tsdart_lr_{args.learning_rate}_bsz_{args.train_batch_size}_"
f"lag_time_{args.lag_time}_beta_{args.beta}_feat_dim_{args.feat_dim}_n_states_{args.n_states}_"
f"pretrain_{args.pretrain}_n_epochs_{args.n_epochs}")

args.log_directory = args.saving_directory+"/{name}/logs".format(name=args.name)
args.model_directory = args.saving_directory+"/{name}/checkpoints".format(name=args.name)

if not os.path.exists(args.model_directory):
os.makedirs(args.model_directory)
if not os.path.exists(args.log_directory):
os.makedirs(args.log_directory)

with open(os.path.join(args.log_directory, 'train_args.txt'), 'w') as f:
f.write(pprint.pformat(state))

def main():

device = torch.device(args.device)

data = []
np_name_list = []
for np_name in glob.glob(args.data_directory+'/*.npy'):
data.append(np.load(np_name))
np_name_list.append(np_name.rsplit('/')[-1])

set_random_seed(args.seed)

pre = Preprocessing(dtype=np.float32)
dataset = pre.create_dataset(lag_time=args.lag_time,data=data)

val = int(len(dataset)*(1-args.train_split))
train_data, val_data = random_split(dataset, [len(dataset)-val, val])

loader_train = DataLoader(train_data, batch_size=args.train_batch_size, shuffle=True)
if val == 0:
loader_val = DataLoader(train_data, batch_size=args.train_batch_size, shuffle=False)
else:
if args.val_batch_size is None or args.val_batch_size >= len(val_data):
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)
else:
loader_val = DataLoader(val_data, batch_size=args.val_batch_size, shuffle=False)

lobe = TSDARTLayer(args.encoder_sizes,n_states=args.n_states)
lobe = lobe.to(device=device)

tsdart = TSDART(lobe=lobe, learning_rate=args.learning_rate, device=device, beta=args.beta, feat_dim=args.feat_dim, n_states=args.n_states,
pretrain=args.pretrain, save_model_interval=args.save_model_interval)
tsdart_model = tsdart.fit(loader_train, n_epochs=args.n_epochs, validation_loader=loader_val).fetch_model()

validation_vamp = tsdart.validation_vamp
validation_dis = tsdart.validation_dis
validation_prototypes = tsdart.validation_prototypes

training_vamp = tsdart.training_vamp
training_dis = tsdart.training_dis

np.save((args.model_directory+'/validation_vamp.npy'),validation_vamp)
np.save((args.model_directory+'/validation_dis.npy'),validation_dis)
np.save((args.model_directory+'/validation_prototypes.npy'),validation_prototypes)

np.save((args.model_directory+'/training_vamp.npy'),training_vamp)
np.save((args.model_directory+'/training_dis.npy'),training_dis)

if args.save_model_interval is None:
torch.save(tsdart_model.lobe.state_dict(), args.model_directory+'/model_{}epochs.pytorch'.format(args.n_epochs))

hypersphere_embs = tsdart_model.transform(data=data,return_type='hypersphere_embs')
metastable_states = tsdart_model.transform(data=data,return_type='states')
softmax_probs = tsdart_model.transform(data=data,return_type='probs')

tsdart_estimator = TSDARTEstimator(tsdart_model)
ood_scores = tsdart_estimator.fit(data).ood_scores
state_centers = tsdart_estimator.fit(data).state_centers

dir1 = args.model_directory+'/model_{}epochs_hypersphere_embs'.format(args.n_epochs)
dir2 = args.model_directory+'/model_{}epochs_metastable_states'.format(args.n_epochs)
dir3 = args.model_directory+'/model_{}epochs_softmax_probs'.format(args.n_epochs)
dir4 = args.model_directory+'/model_{}epochs_ood_scores'.format(args.n_epochs)
dir5 = args.model_directory+'/model_{}epochs_state_centers'.format(args.n_epochs)

if not os.path.exists(dir1):
os.makedirs(dir1)
if not os.path.exists(dir2):
os.makedirs(dir2)
if not os.path.exists(dir3):
os.makedirs(dir3)
if not os.path.exists(dir4):
os.makedirs(dir4)
if not os.path.exists(dir5):
os.makedirs(dir5)

np.save((dir5+'/state_centers.npy'),state_centers)

if len(np_name_list) == 1: ### hypersphere_embs etc. is numpy array
np.save((dir1+'/hypersphere_embs_'+np_name_list[0]),hypersphere_embs)
np.save((dir2+'/metastable_states_'+np_name_list[0]),metastable_states)
np.save((dir3+'/softmax_probs_'+np_name_list[0]),softmax_probs)
np.save((dir4+'/ood_scores_'+np_name_list[0]),ood_scores)
else:
for k in range(len(np_name_list)): ### hypersphere_embs etc. is list of numpy arrays
np.save((dir1+'/hypersphere_embs_'+np_name_list[k]),hypersphere_embs[k])
np.save((dir2+'/metastable_states_'+np_name_list[k]),metastable_states[k])
np.save((dir3+'/softmax_probs_'+np_name_list[k]),softmax_probs[k])
np.save((dir4+'/ood_scores_'+np_name_list[k]),ood_scores[k])

else:
for i in range(len(tsdart._save_models)):
torch.save(tsdart._save_models[i].lobe.state_dict(), args.model_directory+'/model_{}epochs.pytorch'.format((i+1)*args.save_model_interval))

hypersphere_embs = tsdart._save_models[i].transform(data=data,return_type='hypersphere_embs')
metastable_states = tsdart._save_models[i].transform(data=data,return_type='states')
softmax_probs = tsdart._save_models[i].transform(data=data,return_type='probs')

tsdart_estimator = TSDARTEstimator(tsdart._save_models[i])
ood_scores = tsdart_estimator.fit(data).ood_scores
state_centers = tsdart_estimator.fit(data).state_centers

dir1 = args.model_directory+'/model_{}epochs_hypersphere_embs'.format((i+1)*args.save_model_interval)
dir2 = args.model_directory+'/model_{}epochs_metastable_states'.format((i+1)*args.save_model_interval)
dir3 = args.model_directory+'/model_{}epochs_softmax_probs'.format((i+1)*args.save_model_interval)
dir4 = args.model_directory+'/model_{}epochs_ood_scores'.format((i+1)*args.save_model_interval)
dir5 = args.model_directory+'/model_{}epochs_state_centers'.format((i+1)*args.save_model_interval)

if not os.path.exists(dir1):
os.makedirs(dir1)
if not os.path.exists(dir2):
os.makedirs(dir2)
if not os.path.exists(dir3):
os.makedirs(dir3)
if not os.path.exists(dir4):
os.makedirs(dir4)
if not os.path.exists(dir5):
os.makedirs(dir5)

np.save((dir5+'/state_centers.npy'),state_centers)

if len(np_name_list) == 1: ### hypersphere_embs etc. is numpy array
np.save((dir1+'/hypersphere_embs_'+np_name_list[0]),hypersphere_embs)
np.save((dir2+'/metastable_states_'+np_name_list[0]),metastable_states)
np.save((dir3+'/softmax_probs_'+np_name_list[0]),softmax_probs)
np.save((dir4+'/ood_scores_'+np_name_list[0]),ood_scores)
else:
for k in range(len(np_name_list)): ### hypersphere_embs etc. is list of numpy arrays
np.save((dir1+'/hypersphere_embs_'+np_name_list[k]),hypersphere_embs[k])
np.save((dir2+'/metastable_states_'+np_name_list[k]),metastable_states[k])
np.save((dir3+'/softmax_probs_'+np_name_list[k]),softmax_probs[k])
np.save((dir4+'/ood_scores_'+np_name_list[k]),ood_scores[k])

if __name__ == '__main__':
main()
18 changes: 18 additions & 0 deletions scripts/train_tsdart.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
python ./ts-dart/scripts/train_tsdart.py \
--seed 1 \
--device 'cpu' \
--lag_time 10 \
--encoder_sizes 2 20 20 20 10 2 \
--feat_dim 2 \
--n_states 2 \
--beta 0.01 \
--gamma 1 \
--proto_update_factor 0.5 \
--scaling_temperature 0.1 \
--learning_rate 0.001 \
--pretrain 10 \
--n_epochs 20 \
--train_split 0.9 \
--train_batch_size 1000 \
--data_directory ./ts-dart/data \
--saving_directory .
Loading

0 comments on commit 0979a0d

Please sign in to comment.