Skip to content

Commit ee3e8cd

Browse files
author
Naozumi Hiranuma
committed
extractBert code is added
1 parent 6f6ef5e commit ee3e8cd

File tree

1 file changed

+156
-0
lines changed

1 file changed

+156
-0
lines changed

extractBert.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from os import listdir
2+
from os.path import join, isdir, isfile
3+
import numpy as np
4+
import argparse
5+
import os
6+
import torch
7+
import torch.nn as nn
8+
from transformers import BertModel, BertTokenizer, BertForMaskedLM
9+
import glob
10+
import time
11+
import re
12+
13+
def parsePDB(filename, atom="CA"):
14+
file = open(filename, "r")
15+
lines = file.readlines()
16+
coords = []
17+
aas = []
18+
19+
cur_resdex = -1
20+
aa = ""
21+
for line in lines:
22+
if "ATOM" in line:
23+
if cur_resdex != int(line[22:26]):
24+
cur_resdex = int(line[22:26])
25+
new_res = True
26+
aa = line[17:20]
27+
aas.append(aa)
28+
if atom == "CA" and " CA " == line[12:16]:
29+
xyz = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
30+
coords.append(xyz)
31+
elif atom == "CB":
32+
if aa == "GLY" and " CA " == line[12:16]:
33+
xyz = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
34+
coords.append(xyz)
35+
elif " CB " == line[12:16]:
36+
xyz = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
37+
coords.append(xyz)
38+
return np.array(coords), aas
39+
40+
####################
41+
# INDEXERS/MAPPERS
42+
####################
43+
# Assigning numbers to 3 letter amino acids.
44+
residues= ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU',\
45+
'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE',\
46+
'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL']
47+
residuemap = dict([(residues[i], i) for i in range(len(residues))])
48+
49+
# Mapping 3 letter AA to 1 letter AA (e.g. ALA to A)
50+
oneletter = ["A", "R", "N", "D", "C", \
51+
"Q", "E", "G", "H", "I", \
52+
"L", "K", "M", "F", "P", \
53+
"S", "T", "W", "Y", "V"]
54+
aanamemap = dict([(residues[i], oneletter[i]) for i in range(len(residues))])
55+
56+
def parse_fasta(filename,limit=-1):
57+
'''function to parse fasta'''
58+
header = []
59+
sequence = []
60+
lines = open(filename, "r")
61+
for line in lines:
62+
line = line.rstrip()
63+
if line[0] == ">":
64+
if len(header) == limit:
65+
break
66+
header.append(line[1:])
67+
sequence.append([])
68+
else:
69+
sequence[-1].append(line)
70+
lines.close()
71+
sequence = [''.join(seq) for seq in sequence]
72+
return np.array(header), np.array(sequence)
73+
74+
def main():
75+
#####################
76+
# Parsing arguments
77+
#####################
78+
parser = argparse.ArgumentParser(description="ProtBert embedding generator",
79+
epilog="v0.0.1")
80+
parser.add_argument("input",
81+
action="store",
82+
help="path to input folder")
83+
84+
parser.add_argument("output",
85+
action="store",
86+
help="path to output folder")
87+
88+
parser.add_argument("--modelpath",
89+
"-modelpath",
90+
action="store",
91+
default='/home/justas/Desktop/my_projects/python_runs/models/ProtBert-BFD/',
92+
help="modelpath (default: /home/justas/Desktop/my_projects/python_runs/models/ProtBert-BFD/")
93+
94+
args = parser.parse_args()
95+
96+
if not isdir(args.output):
97+
os.mkdir(args.output)
98+
99+
pdbfiles = [i for i in listdir(args.input) if i.endswith(".pdb")]
100+
101+
for pdbfile in pdbfiles:
102+
try:
103+
coords, aas = parsePDB(join(args.input, pdbfile))
104+
output = ">"+pdbfile[:-4]+"\n"
105+
output += "".join([aanamemap[i] for i in aas])+"\n"
106+
f = open(join(args.output, pdbfile[:-4]+".fa"), "w")
107+
f.write(output)
108+
f.close()
109+
except:
110+
print(pdbfile)
111+
112+
113+
downloadFolderPath = args.modelpath
114+
modelFolderPath = downloadFolderPath
115+
modelFilePath = os.path.join(modelFolderPath, 'pytorch_model.bin')
116+
configFilePath = os.path.join(modelFolderPath, 'config.json')
117+
vocabFilePath = os.path.join(modelFolderPath, 'vocab.txt')
118+
119+
tokenizer = BertTokenizer(vocabFilePath, do_lower_case=False )
120+
121+
model = BertForMaskedLM.from_pretrained(modelFolderPath, output_attentions=True)
122+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
123+
model = model.to(device)
124+
model = model.eval()
125+
126+
INPUT_PATH = args.input
127+
OUTPUT_PATH = args.output
128+
129+
file_list = glob.glob(join(OUTPUT_PATH, "*.fa"))
130+
protein_names = []
131+
for i in file_list:
132+
name_1 = i.split("/")[-1]
133+
protein_names.append(name_1[:-3])
134+
135+
start = time.time()
136+
for i in range(len(protein_names)):
137+
if i%100==0:
138+
print(100*(i+1)/len(protein_names))
139+
a, b = parse_fasta(join(OUTPUT_PATH, f"{protein_names[i]}.fa"))
140+
sequences_Example = [b[0].replace("", " ")[1: -1]]
141+
sequences_Example = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_Example]
142+
ids = tokenizer.batch_encode_plus(sequences_Example, add_special_tokens=True, pad_to_max_length=True)
143+
input_ids = torch.tensor(ids['input_ids']).to(device)
144+
attention_mask = torch.tensor(ids['attention_mask']).to(device)
145+
146+
with torch.no_grad():
147+
Z_out= model(input_ids=input_ids, attention_mask=attention_mask)
148+
149+
last_layer_attn = np.array((Z_out[1][-1].cpu().detach().numpy())[0,:,1:-1,1:-1], np.float32)
150+
151+
np.save(join(OUTPUT_PATH, f'bert_{protein_names[i]}.npy'), last_layer_attn)
152+
print(f'total runtime: {time.time()-start} seconds')
153+
154+
155+
if __name__== "__main__":
156+
main()

0 commit comments

Comments
 (0)