1
+ from os import listdir
2
+ from os .path import join , isdir , isfile
3
+ import numpy as np
4
+ import argparse
5
+ import os
6
+ import torch
7
+ import torch .nn as nn
8
+ from transformers import BertModel , BertTokenizer , BertForMaskedLM
9
+ import glob
10
+ import time
11
+ import re
12
+
13
+ def parsePDB (filename , atom = "CA" ):
14
+ file = open (filename , "r" )
15
+ lines = file .readlines ()
16
+ coords = []
17
+ aas = []
18
+
19
+ cur_resdex = - 1
20
+ aa = ""
21
+ for line in lines :
22
+ if "ATOM" in line :
23
+ if cur_resdex != int (line [22 :26 ]):
24
+ cur_resdex = int (line [22 :26 ])
25
+ new_res = True
26
+ aa = line [17 :20 ]
27
+ aas .append (aa )
28
+ if atom == "CA" and " CA " == line [12 :16 ]:
29
+ xyz = [float (line [30 :38 ]), float (line [38 :46 ]), float (line [46 :54 ])]
30
+ coords .append (xyz )
31
+ elif atom == "CB" :
32
+ if aa == "GLY" and " CA " == line [12 :16 ]:
33
+ xyz = [float (line [30 :38 ]), float (line [38 :46 ]), float (line [46 :54 ])]
34
+ coords .append (xyz )
35
+ elif " CB " == line [12 :16 ]:
36
+ xyz = [float (line [30 :38 ]), float (line [38 :46 ]), float (line [46 :54 ])]
37
+ coords .append (xyz )
38
+ return np .array (coords ), aas
39
+
40
+ ####################
41
+ # INDEXERS/MAPPERS
42
+ ####################
43
+ # Assigning numbers to 3 letter amino acids.
44
+ residues = ['ALA' , 'ARG' , 'ASN' , 'ASP' , 'CYS' , 'GLN' , 'GLU' ,\
45
+ 'GLY' , 'HIS' , 'ILE' , 'LEU' , 'LYS' , 'MET' , 'PHE' ,\
46
+ 'PRO' , 'SER' , 'THR' , 'TRP' , 'TYR' , 'VAL' ]
47
+ residuemap = dict ([(residues [i ], i ) for i in range (len (residues ))])
48
+
49
+ # Mapping 3 letter AA to 1 letter AA (e.g. ALA to A)
50
+ oneletter = ["A" , "R" , "N" , "D" , "C" , \
51
+ "Q" , "E" , "G" , "H" , "I" , \
52
+ "L" , "K" , "M" , "F" , "P" , \
53
+ "S" , "T" , "W" , "Y" , "V" ]
54
+ aanamemap = dict ([(residues [i ], oneletter [i ]) for i in range (len (residues ))])
55
+
56
+ def parse_fasta (filename ,limit = - 1 ):
57
+ '''function to parse fasta'''
58
+ header = []
59
+ sequence = []
60
+ lines = open (filename , "r" )
61
+ for line in lines :
62
+ line = line .rstrip ()
63
+ if line [0 ] == ">" :
64
+ if len (header ) == limit :
65
+ break
66
+ header .append (line [1 :])
67
+ sequence .append ([])
68
+ else :
69
+ sequence [- 1 ].append (line )
70
+ lines .close ()
71
+ sequence = ['' .join (seq ) for seq in sequence ]
72
+ return np .array (header ), np .array (sequence )
73
+
74
+ def main ():
75
+ #####################
76
+ # Parsing arguments
77
+ #####################
78
+ parser = argparse .ArgumentParser (description = "ProtBert embedding generator" ,
79
+ epilog = "v0.0.1" )
80
+ parser .add_argument ("input" ,
81
+ action = "store" ,
82
+ help = "path to input folder" )
83
+
84
+ parser .add_argument ("output" ,
85
+ action = "store" ,
86
+ help = "path to output folder" )
87
+
88
+ parser .add_argument ("--modelpath" ,
89
+ "-modelpath" ,
90
+ action = "store" ,
91
+ default = '/home/justas/Desktop/my_projects/python_runs/models/ProtBert-BFD/' ,
92
+ help = "modelpath (default: /home/justas/Desktop/my_projects/python_runs/models/ProtBert-BFD/" )
93
+
94
+ args = parser .parse_args ()
95
+
96
+ if not isdir (args .output ):
97
+ os .mkdir (args .output )
98
+
99
+ pdbfiles = [i for i in listdir (args .input ) if i .endswith (".pdb" )]
100
+
101
+ for pdbfile in pdbfiles :
102
+ try :
103
+ coords , aas = parsePDB (join (args .input , pdbfile ))
104
+ output = ">" + pdbfile [:- 4 ]+ "\n "
105
+ output += "" .join ([aanamemap [i ] for i in aas ])+ "\n "
106
+ f = open (join (args .output , pdbfile [:- 4 ]+ ".fa" ), "w" )
107
+ f .write (output )
108
+ f .close ()
109
+ except :
110
+ print (pdbfile )
111
+
112
+
113
+ downloadFolderPath = args .modelpath
114
+ modelFolderPath = downloadFolderPath
115
+ modelFilePath = os .path .join (modelFolderPath , 'pytorch_model.bin' )
116
+ configFilePath = os .path .join (modelFolderPath , 'config.json' )
117
+ vocabFilePath = os .path .join (modelFolderPath , 'vocab.txt' )
118
+
119
+ tokenizer = BertTokenizer (vocabFilePath , do_lower_case = False )
120
+
121
+ model = BertForMaskedLM .from_pretrained (modelFolderPath , output_attentions = True )
122
+ device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
123
+ model = model .to (device )
124
+ model = model .eval ()
125
+
126
+ INPUT_PATH = args .input
127
+ OUTPUT_PATH = args .output
128
+
129
+ file_list = glob .glob (join (OUTPUT_PATH , "*.fa" ))
130
+ protein_names = []
131
+ for i in file_list :
132
+ name_1 = i .split ("/" )[- 1 ]
133
+ protein_names .append (name_1 [:- 3 ])
134
+
135
+ start = time .time ()
136
+ for i in range (len (protein_names )):
137
+ if i % 100 == 0 :
138
+ print (100 * (i + 1 )/ len (protein_names ))
139
+ a , b = parse_fasta (join (OUTPUT_PATH , f"{ protein_names [i ]} .fa" ))
140
+ sequences_Example = [b [0 ].replace ("" , " " )[1 : - 1 ]]
141
+ sequences_Example = [re .sub (r"[UZOB]" , "X" , sequence ) for sequence in sequences_Example ]
142
+ ids = tokenizer .batch_encode_plus (sequences_Example , add_special_tokens = True , pad_to_max_length = True )
143
+ input_ids = torch .tensor (ids ['input_ids' ]).to (device )
144
+ attention_mask = torch .tensor (ids ['attention_mask' ]).to (device )
145
+
146
+ with torch .no_grad ():
147
+ Z_out = model (input_ids = input_ids , attention_mask = attention_mask )
148
+
149
+ last_layer_attn = np .array ((Z_out [1 ][- 1 ].cpu ().detach ().numpy ())[0 ,:,1 :- 1 ,1 :- 1 ], np .float32 )
150
+
151
+ np .save (join (OUTPUT_PATH , f'bert_{ protein_names [i ]} .npy' ), last_layer_attn )
152
+ print (f'total runtime: { time .time ()- start } seconds' )
153
+
154
+
155
+ if __name__ == "__main__" :
156
+ main ()
0 commit comments