-
Notifications
You must be signed in to change notification settings - Fork 0
/
contacts_network.py
491 lines (446 loc) · 19.6 KB
/
contacts_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Network for predicting C-beta contacts."""
from absl import logging
import sonnet
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import asa_output
from alphafold_casp13 import secstruct
from alphafold_casp13 import two_dim_convnet
from alphafold_casp13 import two_dim_resnet
def call_on_tuple(f):
"""Unpacks a tuple input parameter into arguments for a function f.
Mimics tuple unpacking in lambdas, which existed in Python 2 but has been
removed in Python 3.
Args:
f: A function taking multiple arguments.
Returns:
A function equivalent to f accepting a tuple, which is then unpacked.
"""
return lambda args: f(*args)
class ContactsNet(sonnet.AbstractModule):
"""A network to go from sequence to distance histograms."""
def __init__(self,
binary_code_bits,
data_format,
distance_multiplier,
features,
features_forward,
max_range,
min_range,
num_bins,
reshape_layer,
resolution_noise_scale,
scalars,
targets,
network_2d_deep,
torsion_bins=None,
skip_connect=0,
position_specific_bias_size=0,
filters_1d=(),
collapsed_batch_norm=False,
is_ca_feature=False,
asa_multiplier=0.0,
secstruct_multiplier=0.0,
torsion_multiplier=0.0,
name='contacts_net'):
"""Construct position prediction network."""
super(ContactsNet, self).__init__(name=name)
self._filters_1d = filters_1d
self._collapsed_batch_norm = collapsed_batch_norm
self._is_ca_feature = is_ca_feature
self._binary_code_bits = binary_code_bits
self._data_format = data_format
self._distance_multiplier = distance_multiplier
self._features = features
self._features_forward = features_forward
self._max_range = max_range
self._min_range = min_range
self._num_bins = num_bins
self._position_specific_bias_size = position_specific_bias_size
self._reshape_layer = reshape_layer
self._resolution_noise_scale = resolution_noise_scale
self._scalars = scalars
self._torsion_bins = torsion_bins
self._skip_connect = skip_connect
self._targets = targets
self._network_2d_deep = network_2d_deep
self.asa_multiplier = asa_multiplier
self.secstruct_multiplier = secstruct_multiplier
self.torsion_multiplier = torsion_multiplier
with self._enter_variable_scope():
if self.secstruct_multiplier > 0:
self._secstruct = secstruct.Secstruct()
if self.asa_multiplier > 0:
self._asa = asa_output.ASAOutputLayer()
if self._position_specific_bias_size:
self._position_specific_bias = tf.get_variable(
'position_specific_bias',
[self._position_specific_bias_size, self._num_bins or 1],
initializer=tf.zeros_initializer())
def quant_threshold(self, threshold=8.0):
"""Find the bin that is 8A+: we sum mass below this bin gives contact prob.
Args:
threshold: The distance threshold.
Returns:
Index of bin.
"""
# Note that this misuses the max_range as the range.
return int(
(threshold - self._min_range) * self._num_bins / float(self._max_range))
def _build(self, crop_size_x=0, crop_size_y=0, placeholders=None):
"""Puts the network into the graph.
Args:
crop_size_x: Crop a chunk out in one dimension. 0 means no cropping.
crop_size_y: Crop a chunk out in one dimension. 0 means no cropping.
placeholders: A dict containing the placeholders needed.
Returns:
A Tensor with logits of size [batch_size, num_residues, 3].
"""
crop_placeholder = placeholders['crop_placeholder']
inputs_1d = placeholders['inputs_1d_placeholder']
if self._is_ca_feature and 'aatype' in self._features:
logging.info('Collapsing aatype to is_ca_feature %s',
inputs_1d.shape.as_list()[-1])
assert inputs_1d.shape.as_list()[-1] <= 21 + (
1 if 'seq_length' in self._features else 0)
inputs_1d = inputs_1d[:, :, 7:8]
logits = self.compute_outputs(
inputs_1d=inputs_1d,
residue_index=placeholders['residue_index_placeholder'],
inputs_2d=placeholders['inputs_2d_placeholder'],
crop_x=crop_placeholder[:, 0:2],
crop_y=crop_placeholder[:, 2:4],
use_on_the_fly_stats=True,
crop_size_x=crop_size_x,
crop_size_y=crop_size_y,
data_format='NHWC', # Force NHWC for evals.
)
return logits
def compute_outputs(self, inputs_1d, residue_index, inputs_2d, crop_x, crop_y,
use_on_the_fly_stats, crop_size_x, crop_size_y,
data_format='NHWC'):
"""Given the inputs for a block, compute the network outputs."""
hidden_1d = inputs_1d
hidden_1d_list = [hidden_1d]
if len(hidden_1d_list) != 1:
hidden_1d = tf.concat(hidden_1d_list, 2)
output_dimension = self._num_bins or 1
if self._distance_multiplier > 0:
output_dimension += 1
logits, activations = self._build_2d_embedding(
hidden_1d=hidden_1d,
residue_index=residue_index,
inputs_2d=inputs_2d,
output_dimension=output_dimension,
use_on_the_fly_stats=use_on_the_fly_stats,
crop_x=crop_x,
crop_y=crop_y,
crop_size_x=crop_size_x, crop_size_y=crop_size_y,
data_format=data_format)
logits = tf.debugging.check_numerics(
logits, 'NaN in resnet activations', name='resnet_activations')
if (self.secstruct_multiplier > 0 or
self.asa_multiplier > 0 or
self.torsion_multiplier > 0):
# Make a 1d embedding by reducing the 2D activations.
# We do this in the x direction and the y direction separately.
collapse_dim = 1
join_dim = -1
embedding_1d = tf.concat(
# First targets are crop_x (axis 2) which we must reduce on axis 1
[tf.concat([tf.reduce_max(activations, axis=collapse_dim),
tf.reduce_mean(activations, axis=collapse_dim)],
axis=join_dim),
# Next targets are crop_y (axis 1) which we must reduce on axis 2
tf.concat([tf.reduce_max(activations, axis=collapse_dim+1),
tf.reduce_mean(activations, axis=collapse_dim+1)],
axis=join_dim)],
axis=collapse_dim) # Join the two crops together.
if self._collapsed_batch_norm:
embedding_1d = tf.contrib.layers.batch_norm(
embedding_1d, is_training=use_on_the_fly_stats,
fused=True, decay=0.999, scope='collapsed_batch_norm',
data_format='NHWC')
for i, nfil in enumerate(self._filters_1d):
embedding_1d = tf.contrib.layers.fully_connected(
embedding_1d,
num_outputs=nfil,
normalizer_fn=(
tf.contrib.layers.batch_norm if self._collapsed_batch_norm
else None),
normalizer_params={'is_training': use_on_the_fly_stats,
'updates_collections': None},
scope='collapsed_embed_%d' % i)
if self.torsion_multiplier > 0:
self.torsion_logits = tf.contrib.layers.fully_connected(
embedding_1d,
num_outputs=self._torsion_bins * self._torsion_bins,
activation_fn=None,
scope='torsion_logits')
self.torsion_output = tf.nn.softmax(self.torsion_logits)
if self.secstruct_multiplier > 0:
self._secstruct.make_layer_new(embedding_1d)
if self.asa_multiplier > 0:
self.asa_logits = self._asa.compute_asa_output(embedding_1d)
return logits
@staticmethod
def _concatenate_2d(hidden_1d, residue_index, hidden_2d, crop_x, crop_y,
binary_code_bits, crop_size_x, crop_size_y):
# Form the pairwise expansion of the 1D embedding
# And the residue offsets and (one) absolute position.
with tf.name_scope('Features2D'):
range_scale = 100.0 # Crude normalization factor.
n = tf.shape(hidden_1d)[1]
# pylint: disable=g-long-lambda
hidden_1d_cropped_y = tf.map_fn(
call_on_tuple(lambda c, h: tf.pad(
h[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_y -(n - c[0]))], [0, 0]])),
elems=(crop_y, hidden_1d), dtype=tf.float32,
back_prop=True)
range_n_y = tf.map_fn(
call_on_tuple(lambda ri, c: tf.pad(
ri[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_y -(n - c[0]))]])),
elems=(residue_index, crop_y), dtype=tf.int32,
back_prop=False)
hidden_1d_cropped_x = tf.map_fn(
call_on_tuple(lambda c, h: tf.pad(
h[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_x -(n - c[0]))], [0, 0]])),
elems=(crop_x, hidden_1d), dtype=tf.float32,
back_prop=True)
range_n_x = tf.map_fn(
call_on_tuple(lambda ri, c: tf.pad(
ri[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_x -(n - c[0]))]])),
elems=(residue_index, crop_x), dtype=tf.int32,
back_prop=False)
# pylint: enable=g-long-lambda
n_x = crop_size_x
n_y = crop_size_y
offset = (tf.expand_dims(tf.cast(range_n_x, tf.float32), 1) -
tf.expand_dims(tf.cast(range_n_y, tf.float32), 2)) / range_scale
position_features = [
tf.tile(
tf.reshape(
(tf.cast(range_n_y, tf.float32) - range_scale) / range_scale,
[-1, n_y, 1, 1]), [1, 1, n_x, 1],
name='TileRange'),
tf.tile(
tf.reshape(offset, [-1, n_y, n_x, 1]), [1, 1, 1, 1],
name='TileOffset')
]
channels = 2
if binary_code_bits:
# Binary coding of position.
exp_range_n_y = tf.expand_dims(range_n_y, 2)
bin_y = tf.stop_gradient(
tf.concat([tf.math.floormod(exp_range_n_y // (1 << i), 2)
for i in range(binary_code_bits)], 2))
exp_range_n_x = tf.expand_dims(range_n_x, 2)
bin_x = tf.stop_gradient(
tf.concat([tf.math.floormod(exp_range_n_x // (1 << i), 2)
for i in range(binary_code_bits)], 2))
position_features += [
tf.tile(
tf.expand_dims(tf.cast(bin_y, tf.float32), 2), [1, 1, n_x, 1],
name='TileBinRangey'),
tf.tile(
tf.expand_dims(tf.cast(bin_x, tf.float32), 1), [1, n_y, 1, 1],
name='TileBinRangex')
]
channels += 2 * binary_code_bits
augmentation_features = position_features + [
tf.tile(tf.expand_dims(hidden_1d_cropped_x, 1),
[1, n_y, 1, 1], name='Tile1Dx'),
tf.tile(tf.expand_dims(hidden_1d_cropped_y, 2),
[1, 1, n_x, 1], name='Tile1Dy')]
channels += 2 * hidden_1d.shape.as_list()[-1]
channels += hidden_2d.shape.as_list()[-1]
hidden_2d = tf.concat(
[hidden_2d] + augmentation_features, 3, name='Stack2Dfeatures')
logging.info('2d stacked features are depth %d %s', channels, hidden_2d)
hidden_2d.set_shape([None, None, None, channels])
return hidden_2d
def _build_2d_embedding(self, hidden_1d, residue_index, inputs_2d,
output_dimension, use_on_the_fly_stats, crop_x,
crop_y, crop_size_x, crop_size_y, data_format):
"""Returns NHWC logits and NHWC preactivations."""
logging.info('2d %s %s', inputs_2d, data_format)
# Stack with diagonal has already happened.
inputs_2d_cropped = inputs_2d
features_forward = None
hidden_2d = inputs_2d_cropped
hidden_2d = self._concatenate_2d(
hidden_1d, residue_index, hidden_2d, crop_x, crop_y,
self._binary_code_bits, crop_size_x, crop_size_y)
config_2d_deep = self._network_2d_deep
num_features = hidden_2d.shape.as_list()[3]
if data_format == 'NCHW':
logging.info('NCHW shape deep pre %s', hidden_2d)
hidden_2d = tf.transpose(hidden_2d, perm=[0, 3, 1, 2])
hidden_2d.set_shape([None, num_features, None, None])
logging.info('NCHW shape deep post %s', hidden_2d)
layers_forward = None
if config_2d_deep.extra_blocks:
# Optionally put some extra double-size blocks at the beginning.
with tf.variable_scope('Deep2DExtra'):
hidden_2d = two_dim_resnet.make_two_dim_resnet(
input_node=hidden_2d,
num_residues=None, # Unused
num_features=num_features,
num_predictions=2 * config_2d_deep.num_filters,
num_channels=2 * config_2d_deep.num_filters,
num_layers=config_2d_deep.extra_blocks *
config_2d_deep.num_layers_per_block,
filter_size=3,
batch_norm=config_2d_deep.use_batch_norm,
is_training=use_on_the_fly_stats,
fancy=True,
final_non_linearity=True,
atrou_rates=[1, 2, 4, 8],
data_format=data_format,
dropout_keep_prob=1.0
)
num_features = 2 * config_2d_deep.num_filters
if self._skip_connect:
layers_forward = hidden_2d
if features_forward is not None:
hidden_2d = tf.concat([hidden_2d, features_forward], 1
if data_format == 'NCHW' else 3)
with tf.variable_scope('Deep2D'):
logging.info('2d hidden shape is %s', str(hidden_2d.shape.as_list()))
contact_pre_logits = two_dim_resnet.make_two_dim_resnet(
input_node=hidden_2d,
num_residues=None, # Unused
num_features=num_features,
num_predictions=(config_2d_deep.num_filters
if self._reshape_layer else output_dimension),
num_channels=config_2d_deep.num_filters,
num_layers=config_2d_deep.num_blocks *
config_2d_deep.num_layers_per_block,
filter_size=3,
batch_norm=config_2d_deep.use_batch_norm,
is_training=use_on_the_fly_stats,
fancy=True,
final_non_linearity=self._reshape_layer,
atrou_rates=[1, 2, 4, 8],
data_format=data_format,
dropout_keep_prob=1.0
)
contact_logits = self._output_from_pre_logits(
contact_pre_logits, features_forward, layers_forward,
output_dimension, data_format, crop_x, crop_y, use_on_the_fly_stats)
if data_format == 'NCHW':
contact_pre_logits = tf.transpose(contact_pre_logits, perm=[0, 2, 3, 1])
# Both of these will be NHWC
return contact_logits, contact_pre_logits
def _output_from_pre_logits(self, contact_pre_logits, features_forward,
layers_forward, output_dimension, data_format,
crop_x, crop_y, use_on_the_fly_stats):
"""Given pre-logits, compute the final distogram/contact activations."""
config_2d_deep = self._network_2d_deep
if self._reshape_layer:
in_channels = config_2d_deep.num_filters
concat_features = [contact_pre_logits]
if features_forward is not None:
concat_features.append(features_forward)
in_channels += self._features_forward
if layers_forward is not None:
concat_features.append(layers_forward)
in_channels += 2 * config_2d_deep.num_filters
if len(concat_features) > 1:
contact_pre_logits = tf.concat(concat_features,
1 if data_format == 'NCHW' else 3)
contact_logits = two_dim_convnet.make_conv_layer(
contact_pre_logits,
in_channels=in_channels,
out_channels=output_dimension,
layer_name='output_reshape_1x1h',
filter_size=1,
filter_size_2=1,
non_linearity=False,
batch_norm=config_2d_deep.use_batch_norm,
is_training=use_on_the_fly_stats,
data_format=data_format)
else:
contact_logits = contact_pre_logits
if data_format == 'NCHW':
contact_logits = tf.transpose(contact_logits, perm=[0, 2, 3, 1])
if self._position_specific_bias_size:
# Make 2D pos-specific biases: NHWC.
biases = build_crops_biases(
self._position_specific_bias_size,
self._position_specific_bias, crop_x, crop_y, back_prop=True)
contact_logits += biases
# Will be NHWC.
return contact_logits
def update_crop_fetches(self, fetches):
"""Add auxiliary outputs for a crop to the fetches."""
if self.secstruct_multiplier > 0:
fetches['secstruct_probs'] = self._secstruct.get_q8_probs()
if self.asa_multiplier > 0:
fetches['asa_output'] = self._asa.asa_output
if self.torsion_multiplier > 0:
fetches['torsion_probs'] = self.torsion_output
def build_crops_biases(bias_size, raw_biases, crop_x, crop_y, back_prop):
"""Take the offset-specific biases and reshape them to match current crops.
Args:
bias_size: how many bias variables we're storing.
raw_biases: the bias variable
crop_x: B x 2 array of start/end for the batch
crop_y: B x 2 array of start/end for the batch
back_prop: whether to backprop through the map_fn.
Returns:
Reshaped biases.
"""
# First pad the biases with a copy of the final value to the maximum length.
max_off_diag = tf.reduce_max(
tf.maximum(tf.abs(crop_x[:, 1] - crop_y[:, 0]),
tf.abs(crop_y[:, 1] - crop_x[:, 0])))
padded_bias_size = tf.maximum(bias_size, max_off_diag)
biases = tf.concat(
[raw_biases,
tf.tile(raw_biases[-1:, :],
[padded_bias_size - bias_size, 1])], axis=0)
# Now prepend a mirror image (excluding 0th elt) for below-diagonal.
biases = tf.concat([tf.reverse(biases[1:, :], axis=[0]), biases], axis=0)
# Which diagonal of the full matrix each crop starts on (top left):
start_diag = crop_x[:, 0:1] - crop_y[:, 0:1] # B x 1
crop_size_x = tf.reduce_max(crop_x[:, 1] - crop_x[:, 0])
crop_size_y = tf.reduce_max(crop_y[:, 1] - crop_y[:, 0])
# Relative offset of each row within a crop:
# (off-diagonal decreases as y increases)
increment = tf.expand_dims(-tf.range(0, crop_size_y), 0) # 1 x crop_size_y
# Index of diagonal of first element of each row, flattened.
row_offsets = tf.reshape(start_diag + increment, [-1]) # B*crop_size_y
logging.info('row_offsets %s', row_offsets)
# Make it relative to the start of the biases array. (0-th diagonal is in
# the middle at position padded_bias_size - 1)
row_offsets += padded_bias_size - 1
# Map_fn to build the individual rows.
# B*cropsizey x cropsizex x num_bins
cropped_biases = tf.map_fn(lambda i: biases[i:i+crop_size_x, :],
elems=row_offsets, dtype=tf.float32,
back_prop=back_prop)
logging.info('cropped_biases %s', cropped_biases)
return tf.reshape(
cropped_biases, [-1, crop_size_y, crop_size_x, tf.shape(raw_biases)[-1]])