-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
66 lines (53 loc) · 1.98 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Adapted from Nakata, S., Mori, Y. & Tanaka, S.
End-to-end protein–ligand complex structure generation with diffusion-based generative models.
BMC Bioinformatics 24, 233 (2023).
https://doi.org/10.1186/s12859-023-05354-5
Repository: https://github.com/shuyana/DiffusionProteinLigand
"""
import os
import warnings
from argparse import ArgumentParser
from pathlib import Path
from shutil import rmtree
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from ProteinReDiff.data import PDBDataModule
from ProteinReDiff.model import ProteinReDiffModel
def main(args):
pl.seed_everything(args.seed, workers=True)
if os.path.exists(args.save_dir):
rmtree(args.save_dir)
args.save_dir.mkdir(parents=True)
datamodule = PDBDataModule.from_argparse_args(args)
model = ProteinReDiffModel(args)
trainer = pl.Trainer.from_argparse_args(
args,
accelerator="auto",
precision=16,
strategy="ddp_find_unused_parameters_false",
callbacks=[
ModelCheckpoint(
filename="{epoch:03d}-{val_loss:.2f}",
monitor="val_loss",
save_top_k=3,
save_last=True,
)
],
default_root_dir=args.save_dir,
max_epochs=-1,
)
trainer.fit(model, datamodule=datamodule)
if __name__ == "__main__":
parser = ArgumentParser()
parser = PDBDataModule.add_argparse_args(parser)
parser = ProteinReDiffModel.add_argparse_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--num_gpus", type = int, default = 1)
parser.add_argument("--save_dir", type=Path, required=True)
args = parser.parse_args()
# https://github.com/Lightning-AI/lightning/issues/5558#issuecomment-1199306489
warnings.filterwarnings("ignore", "Detected call of", UserWarning)
main(args)