Skip to content

Commit

Permalink
v2.3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton authored Feb 11, 2023
2 parents 15eb3e0 + cf096a6 commit ba75383
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 29 deletions.
8 changes: 3 additions & 5 deletions alphafold/model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
# Internal import (7716).


def get_model_haiku_params(model_name: str, data_dir: str, fuse: bool = True) -> hk.Params:
def get_model_haiku_params(model_name: str,
data_dir: str, fuse: bool = True, to_jnp: bool = True) -> hk.Params:
"""Get the Haiku parameters from a model name."""

path = os.path.join(data_dir, 'params', f'params_{model_name}.npz')

params = np.load(path, allow_pickle=False)

return utils.flat_params_to_haiku(params, fuse=fuse)
return utils.flat_params_to_haiku(params, fuse=fuse, to_jnp=to_jnp)
11 changes: 7 additions & 4 deletions alphafold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ def predict(self,
while r < num_iters:
if self.multimer_mode:
sub_feat = feat
sub_feat["iter"] = np.array(r)
else:
s = r * num_ensemble
e = (r+1) * num_ensemble
sub_feat = jax.tree_map(lambda x:x[s:e], feat)

sub_feat["prev"] = result["prev"]
result = self.apply(self.params, key, sub_feat)
key, sub_key = jax.random.split(key)
result = self.apply(self.params, sub_key, sub_feat)
seq_mask = feat["seq_mask"] if self.multimer_mode else feat["seq_mask"][0]
confidences = get_confidence_metrics(result, mask=seq_mask, rank_by=self.config.model.rank_by)

Expand All @@ -235,13 +235,16 @@ def predict(self,
stop = True
prev_pos = result["prev"]["prev_pos"][:,ca_idx]

result["pae"] = result.pop("predicted_aligned_error")
result.update(confidences)
if prediction_callback is not None: prediction_callback(result, r)

if prediction_callback is not None:
prediction_callback(result, r)

if verbose:
print_line = f"recycle={r} plddt={confidences['mean_plddt']:.3g}"
for k in ["ptm","iptm","diff"]:
if k in confidences: print_line += f" {k}:{confidences[k]:.3g}"
if k in confidences: print_line += f" {k}={confidences[k]:.3g}"
print(print_line)
r += 1
if stop: break
Expand Down
23 changes: 18 additions & 5 deletions alphafold/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,14 @@ def get_prev(ret):
}
return new_prev

prev = batch.pop("prev")
batch = jax.tree_map(lambda x:x[0], batch)
prev = batch.pop("prev",None)
if batch["aatype"].ndim == 2:
batch = jax.tree_map(lambda x:x[0], batch)
if prev is None:
L = batch["aatype"].shape[0]
prev = {'prev_msa_first_row': jnp.zeros([L,256]),
'prev_pair': jnp.zeros([L,L,128]),
'prev_pos': jnp.zeros([L,37,3])}
ret = impl(batch={**batch, **prev}, is_training=is_training)
ret["prev"] = get_prev(ret)
if not return_representations:
Expand Down Expand Up @@ -413,8 +419,15 @@ def slice_recycle_idx(x):
compute_loss=compute_loss,
ensemble_representations=ensemble_representations)

emb_config = self.config.embeddings_and_evoformer
ret = do_call(prev=batch.pop("prev"), recycle_idx=0)
emb_config = self.config.embeddings_and_evoformer
prev = batch.pop("prev",None)
if prev is None:
L = num_residues
prev = {'prev_msa_first_row': jnp.zeros([L,256]),
'prev_pair': jnp.zeros([L,L,128]),
'prev_pos': jnp.zeros([L,37,3])}

ret = do_call(prev=prev, recycle_idx=0)
ret["prev"] = get_prev(ret)

if compute_loss:
Expand Down Expand Up @@ -2222,4 +2235,4 @@ def map_fn(batch):
# No gradients if no templates.
embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype)

return embedding
return embedding
12 changes: 1 addition & 11 deletions alphafold/model/modules_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,17 +442,7 @@ def apply_network(prev, safe_key):
batch=recycled_batch,
is_training=is_training,
safe_key=safe_key)

#########################################
num_iter = c.num_recycle
def key_body(i, k):
k_ = jax.random.split(k[0])
o = jax.lax.cond(i==num_iter, lambda _:k[0], lambda _:k_[1], None)
return [k_[0],o]
k = safe_key.get()
safe_key = prng.SafeKey(jax.lax.fori_loop(0,batch.pop("iter")+1, key_body, [k,k])[1])
##########################################


ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key)
ret["prev"] = get_prev(ret)

Expand Down
6 changes: 3 additions & 3 deletions alphafold/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
(jnp.sum(mask, axis=axis) * broadcast_factor + eps))


def flat_params_to_haiku(params, fuse=True):
def flat_params_to_haiku(params, fuse=True, to_jnp=True):
"""Convert a dictionary of NumPy arrays to Haiku parameters."""
P = {}
for path, array in params.items():
scope, name = path.split('//')
if scope not in P:
P[scope] = {}
P[scope][name] = jnp.array(array)
P[scope][name] = jnp.array(array) if to_jnp else array
for a in ["evoformer_iteration",
"extra_msa_stack",
"template_embedding/single_template_embedding/template_embedding_iteration",
Expand All @@ -113,7 +113,7 @@ def flat_params_to_haiku(params, fuse=True):
R = P.pop(f"{k}/right_{c}")
P[f"{k}/{c}"] = {}
for d in ["bias","weights"]:
P[f"{k}/{c}"][d] = jnp.concatenate([L[d],R[d]],-1)
P[f"{k}/{c}"][d] = jnp.concatenate([L[d],R[d]],-1) if to_jnp else np.concatenate([L[d],R[d]],-1)
P[f"{k}/center_norm"] = P.pop(f"{k}/center_layer_norm")
P[f"{k}/left_norm_input"] = P.pop(f"{k}/layer_norm_input")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

setup(
name='alphafold-colabfold',
version='2.3.1',
version='2.3.2',
long_description_content_type='text/markdown',
description='An implementation of the inference pipeline of AlphaFold v2.0.'
'This is a completely new model that was entered as AlphaFold2 in CASP14 '
Expand Down

0 comments on commit ba75383

Please sign in to comment.