Skip to content

Commit 1b1e237

Browse files
committed
tests
1 parent b471dcd commit 1b1e237

File tree

3 files changed

+104
-124
lines changed

3 files changed

+104
-124
lines changed

notebooks/04-ume-multimodal-embeddings.ipynb

+55-5
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,51 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": null,
14+
"execution_count": 1,
1515
"metadata": {},
16-
"outputs": [],
16+
"outputs": [
17+
{
18+
"name": "stderr",
19+
"output_type": "stream",
20+
"text": [
21+
"/Users/zadorozk/Desktop/code/lobster/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
22+
" from .autonotebook import tqdm as notebook_tqdm\n"
23+
]
24+
},
25+
{
26+
"name": "stdout",
27+
"output_type": "stream",
28+
"text": [
29+
"Supported modalities: ['SMILES', 'amino_acid', 'nucleotide', '3d_coordinates']\n",
30+
"Vocab size: 1536\n"
31+
]
32+
}
33+
],
1734
"source": [
1835
"from lobster.model import Ume\n",
1936
"\n",
20-
"checkpoint = \"<your checkpoint>\"\n",
37+
"ume = Ume()\n",
2138
"\n",
22-
"ume = Ume(checkpoint, freeze=True)"
39+
"print(f\"Supported modalities: {ume.modalities}\")\n",
40+
"print(f\"Vocab size: {len(ume.get_vocab())}\")"
41+
]
42+
},
43+
{
44+
"cell_type": "markdown",
45+
"metadata": {},
46+
"source": [
47+
"### Load from checkpoint"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"checkpoint = \"ume-checkpoints/last.ckpt\" # Replace with the correct checkpoint path\n",
57+
"\n",
58+
"ume = Ume.load_from_checkpoint(checkpoint)"
2359
]
2460
},
2561
{
@@ -208,8 +244,22 @@
208244
}
209245
],
210246
"metadata": {
247+
"kernelspec": {
248+
"display_name": ".venv",
249+
"language": "python",
250+
"name": "python3"
251+
},
211252
"language_info": {
212-
"name": "python"
253+
"codemirror_mode": {
254+
"name": "ipython",
255+
"version": 3
256+
},
257+
"file_extension": ".py",
258+
"mimetype": "text/x-python",
259+
"name": "python",
260+
"nbconvert_exporter": "python",
261+
"pygments_lexer": "ipython3",
262+
"version": "3.12.9"
213263
}
214264
},
215265
"nbformat": 4,

tests/lobster/model/test__ume.py

+25-90
Original file line numberDiff line numberDiff line change
@@ -24,56 +24,12 @@ def dna_examples():
2424
class TestUme:
2525
"""Tests for the Universal Molecular Encoder (Ume) class"""
2626

27-
@patch("lobster.model._ume.FlexBERT.load_from_checkpoint")
28-
def test_frozen_parameters(self, mock_load_checkpoint):
29-
"""Test that parameters are frozen when freeze=True"""
30-
# Create mock model
31-
mock_model = MagicMock()
32-
mock_params = [torch.nn.Parameter(torch.randn(10, 10))]
33-
mock_model.model.parameters.return_value = mock_params
34-
mock_load_checkpoint.return_value = mock_model
35-
36-
# Create Ume with frozen parameters
37-
ume = Ume("dummy_checkpoint.ckpt", freeze=True)
38-
39-
# Verify that load_from_checkpoint was called
40-
mock_load_checkpoint.assert_called_once_with("dummy_checkpoint.ckpt")
41-
42-
# Verify that parameters were accessed
43-
mock_model.model.parameters.assert_called()
44-
45-
# Verify freeze attribute is True
46-
assert ume.freeze is True
47-
48-
@patch("lobster.model._ume.FlexBERT.load_from_checkpoint")
49-
def test_unfrozen_parameters(self, mock_load_checkpoint):
50-
"""Test that parameters are not frozen when freeze=False"""
51-
# Create mock model
52-
mock_model = MagicMock()
53-
mock_params = [torch.nn.Parameter(torch.randn(10, 10))]
54-
mock_model.model.parameters.return_value = mock_params
55-
mock_load_checkpoint.return_value = mock_model
56-
57-
# Create Ume without freezing parameters
58-
ume = Ume("dummy_checkpoint.ckpt", freeze=False)
59-
60-
# Verify freeze attribute is False
61-
assert ume.freeze is False
62-
63-
# Verify that parameters were not frozen
64-
mock_model.model.parameters.assert_not_called()
65-
66-
@patch("lobster.model._ume.FlexBERT.load_from_checkpoint")
6727
@patch("lobster.model._ume.UmeSmilesTokenizerFast")
6828
@patch("lobster.model._ume.UmeAminoAcidTokenizerFast")
6929
@patch("lobster.model._ume.UmeNucleotideTokenizerFast")
7030
@patch("lobster.model._ume.UmeLatentGenerator3DCoordTokenizerFast")
71-
def test_tokenizer_initialization(self, mock_coord, mock_nucleotide, mock_amino, mock_smiles, mock_load_checkpoint):
31+
def test_tokenizer_initialization(self, mock_coord, mock_nucleotide, mock_amino, mock_smiles):
7232
"""Test that tokenizers are initialized during __init__"""
73-
# Set up model mock
74-
mock_model = MagicMock()
75-
mock_load_checkpoint.return_value = mock_model
76-
7733
# Setup tokenizer mocks
7834
mock_smiles_instance = MagicMock()
7935
mock_amino_instance = MagicMock()
@@ -86,74 +42,44 @@ def test_tokenizer_initialization(self, mock_coord, mock_nucleotide, mock_amino,
8642
mock_coord.return_value = mock_coord_instance
8743

8844
# Create Ume instance
89-
ume = Ume("dummy_checkpoint.ckpt")
45+
ume = Ume()
9046

91-
# Verify each tokenizer was instantiated exactly once
9247
mock_smiles.assert_called_once()
9348
mock_amino.assert_called_once()
9449
mock_nucleotide.assert_called_once()
9550
mock_coord.assert_called_once()
9651

97-
# Verify tokenizers were stored in the dictionary
9852
assert ume.tokenizers[Modality.SMILES] == mock_smiles_instance
9953
assert ume.tokenizers[Modality.AMINO_ACID] == mock_amino_instance
10054
assert ume.tokenizers[Modality.NUCLEOTIDE] == mock_nucleotide_instance
10155
assert ume.tokenizers[Modality.COORDINATES_3D] == mock_coord_instance
10256

10357
@patch("lobster.model._ume.FlexBERT.load_from_checkpoint")
104-
@patch("lobster.model._ume.UmeSmilesTokenizerFast")
105-
@patch("lobster.model._ume.UmeAminoAcidTokenizerFast")
106-
@patch("lobster.model._ume.UmeNucleotideTokenizerFast")
107-
@patch("lobster.model._ume.UmeLatentGenerator3DCoordTokenizerFast")
108-
def test_get_tokenizer(self, mock_coord, mock_nucleotide, mock_amino, mock_smiles, mock_load_checkpoint):
58+
def test_get_tokenizer(self, mock_load_checkpoint):
10959
"""Test getting tokenizers for different modalities"""
110-
# Set up model mock
111-
mock_model = MagicMock()
112-
mock_load_checkpoint.return_value = mock_model
60+
ume = Ume()
11361

114-
# Setup tokenizer mocks
115-
mock_smiles_instance = MagicMock()
116-
mock_amino_instance = MagicMock()
117-
mock_nucleotide_instance = MagicMock()
118-
mock_coord_instance = MagicMock()
62+
mock_tokenizers = {}
63+
for modality in Modality:
64+
mock_tokenizers[modality] = MagicMock()
11965

120-
mock_smiles.return_value = mock_smiles_instance
121-
mock_amino.return_value = mock_amino_instance
122-
mock_nucleotide.return_value = mock_nucleotide_instance
123-
mock_coord.return_value = mock_coord_instance
66+
ume.tokenizers = mock_tokenizers
12467

125-
# Create Ume instance
126-
ume = Ume("dummy_checkpoint.ckpt")
127-
128-
# Test each modality
12968
modality_map = {
130-
"SMILES": mock_smiles_instance,
131-
"amino_acid": mock_amino_instance,
132-
"nucleotide": mock_nucleotide_instance,
133-
"3d_coordinates": mock_coord_instance,
69+
"SMILES": Modality.SMILES,
70+
"amino_acid": Modality.AMINO_ACID,
71+
"nucleotide": Modality.NUCLEOTIDE,
72+
"3d_coordinates": Modality.COORDINATES_3D,
13473
}
13574

136-
for modality, mock_instance in modality_map.items():
137-
# Get tokenizer - this should now return the pre-instantiated tokenizer
138-
tokenizer = ume.get_tokenizer(["test"], modality)
75+
for modality_str, modality_enum in modality_map.items():
76+
tokenizer = ume.get_tokenizer(modality_str)
13977

140-
# Verify the returned tokenizer is our mock instance
141-
assert tokenizer == mock_instance
142-
143-
# Verify that no new tokenizer is instantiated (count should remain at 1)
144-
if modality == "SMILES":
145-
assert mock_smiles.call_count == 1
146-
elif modality == "amino_acid":
147-
assert mock_amino.call_count == 1
148-
elif modality == "nucleotide":
149-
assert mock_nucleotide.call_count == 1
150-
elif modality == "3d_coordinates":
151-
assert mock_coord.call_count == 1
78+
assert tokenizer == mock_tokenizers[modality_enum]
15279

15380
@patch("lobster.model._ume.FlexBERT.load_from_checkpoint")
15481
def test_get_embeddings_basic(self, mock_load_checkpoint, smiles_examples, protein_examples, dna_examples):
15582
"""Test basic embedding functionality for all modalities"""
156-
# Mock model with controlled output
15783
mock_model = MagicMock()
15884
mock_model.max_length = 512
15985
mock_model.device = torch.device("cpu")
@@ -169,7 +95,7 @@ def mock_tokens_to_latents(**kwargs):
16995
mock_load_checkpoint.return_value = mock_model
17096

17197
# Create Ume instance
172-
ume = Ume("dummy_checkpoint.ckpt")
98+
ume = Ume.load_from_checkpoint("dummy_checkpoint.ckpt")
17399

174100
# Test for each modality
175101
modalities = ["SMILES", "amino_acid", "nucleotide"]
@@ -193,6 +119,15 @@ def mock_tokens_to_latents(**kwargs):
193119
embeddings = ume.get_embeddings(test_inputs[modality], modality)
194120
assert embeddings.shape == (batch_size, 768)
195121

122+
# Verify tokenizer was called with the correct inputs
123+
mock_tokenizer.assert_called_with(
124+
test_inputs[modality],
125+
return_tensors="pt",
126+
padding="max_length",
127+
truncation=True,
128+
max_length=mock_model.max_length,
129+
)
130+
196131
# Test token-level embeddings
197132
token_embeddings = ume.get_embeddings(test_inputs[modality], modality, aggregate=False)
198133
assert token_embeddings.shape == (batch_size, seq_len, 768)

tests/lobster/tokenization/test__ume_tokenizers.py

+24-29
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ def test_add_reserved_tokens():
1414
"<eos>",
1515
"<unk>",
1616
"<pad>",
17-
"<reserved_special_token_0>",
18-
"<reserved_special_token_1>",
19-
"<reserved_special_token_2>",
17+
"<extra_special_token_0>",
18+
"<extra_special_token_1>",
2019
],
2120
"amino_acid_tokenizer": ["A"], # 1 amino acid tokens
2221
"smiles_tokenizer": ["C", "O"], # 2 SMILES tokens
@@ -31,20 +30,18 @@ def test_add_reserved_tokens():
3130
"<eos>",
3231
"<unk>",
3332
"<pad>",
34-
"<reserved_special_token_0>", # reserved from special tokens
35-
"<reserved_special_token_1>", # reserved from special tokens
36-
"<reserved_special_token_2>", # reserved from special tokens
33+
"<extra_special_token_0>", # reserved from special tokens
34+
"<extra_special_token_1>", # reserved from special tokens
3735
"A",
3836
]
3937
assert result["smiles_tokenizer"] == [
4038
"<cls>",
4139
"<eos>",
4240
"<unk>",
4341
"<pad>",
44-
"<reserved_special_token_0>", # reserved from special tokens
45-
"<reserved_special_token_1>", # reserved from special tokens
46-
"<reserved_special_token_2>", # reserved from special tokens
47-
"<reserved_special_token_3>", # reserved for amino acids
42+
"<extra_special_token_0>", # reserved from special tokens
43+
"<extra_special_token_1>", # reserved from special tokens
44+
"<reserved_for_amino_acids_special_token_2>", # reserved for amino acids
4845
"C",
4946
"O",
5047
]
@@ -53,12 +50,11 @@ def test_add_reserved_tokens():
5350
"<eos>",
5451
"<unk>",
5552
"<pad>",
56-
"<reserved_special_token_0>", # reserved from special tokens
57-
"<reserved_special_token_1>", # reserved from special tokens
58-
"<reserved_special_token_2>", # reserved from special tokens
59-
"<reserved_special_token_3>", # reserved for amino acids
60-
"<reserved_special_token_4>", # reserved for SMILES
61-
"<reserved_special_token_5>", # reserved for SMILES
53+
"<extra_special_token_0>", # reserved from special tokens
54+
"<extra_special_token_1>", # reserved from special tokens
55+
"<reserved_for_amino_acids_special_token_2>", # reserved from special tokens
56+
"<reserved_for_smiles_special_token_3>", # reserved for SMILES
57+
"<reserved_for_smiles_special_token_4>", # reserved for SMILES
6258
"A",
6359
"C",
6460
"G",
@@ -68,15 +64,14 @@ def test_add_reserved_tokens():
6864
"<eos>",
6965
"<unk>",
7066
"<pad>",
71-
"<reserved_special_token_0>", # reserved from special tokens
72-
"<reserved_special_token_1>", # reserved from special tokens
73-
"<reserved_special_token_2>", # reserved from special tokens
74-
"<reserved_special_token_3>", # reserved for amino acids
75-
"<reserved_special_token_4>", # reserved for SMILES
76-
"<reserved_special_token_5>", # reserved for SMILES
77-
"<reserved_special_token_6>", # reserved for nucleotides
78-
"<reserved_special_token_7>", # reserved for nucleotides
79-
"<reserved_special_token_8>", # reserved for nucleotides
67+
"<extra_special_token_0>", # reserved from special tokens
68+
"<extra_special_token_1>", # reserved from special tokens
69+
"<reserved_for_amino_acids_special_token_2>", # reserved from special tokens
70+
"<reserved_for_smiles_special_token_3>", # reserved for SMILES
71+
"<reserved_for_smiles_special_token_4>", # reserved for SMILES
72+
"<reserved_for_nucleotides_special_token_5>", # reserved for nucleotides
73+
"<reserved_for_nucleotides_special_token_6>", # reserved for nucleotides
74+
"<reserved_for_nucleotides_special_token_7>", # reserved for nucleotides
8075
"X1",
8176
"Y1",
8277
"Z1",
@@ -87,22 +82,22 @@ def test_add_reserved_tokens():
8782
def test_ume_aminio_acid_tokenizer():
8883
tokenizer = UmeAminoAcidTokenizerFast()
8984
assert tokenizer.tokenize("VYF") == ["V", "Y", "F"]
90-
assert tokenizer.encode("VYF", padding="do_not_pad", add_special_tokens=True) == [0, 23, 35, 34, 2]
85+
assert tokenizer.encode("VYF", padding="do_not_pad", add_special_tokens=True) == [0, 28, 40, 39, 2]
9186

9287

9388
def test_ume_smiles_tokenizer():
9489
tokenizer = UmeSmilesTokenizerFast()
9590
assert tokenizer.tokenize("CCO") == ["C", "C", "O"]
96-
assert tokenizer.encode("CCO", padding="do_not_pad", add_special_tokens=True) == [0, 46, 46, 49, 2]
91+
assert tokenizer.encode("CCO", padding="do_not_pad", add_special_tokens=True) == [0, 52, 52, 56, 2]
9792

9893

9994
def test_ume_nucleotide_tokenizer():
10095
tokenizer = UmeNucleotideTokenizerFast()
10196
assert tokenizer.tokenize("ACGT") == ["A", "C", "G", "T"]
102-
assert tokenizer.encode("ACGT", padding="do_not_pad", add_special_tokens=True) == [0, 623, 624, 625, 626, 2]
97+
assert tokenizer.encode("ACGT", padding="do_not_pad", add_special_tokens=True) == [0, 1272, 1273, 1274, 1275, 2]
10398

10499

105100
def test_ume_latent_generator_tokenizer():
106101
tokenizer = UmeLatentGenerator3DCoordTokenizerFast()
107102
assert tokenizer.tokenize("gd fh ds") == ["gd", "fh", "ds"]
108-
assert tokenizer.encode("gd fh ds", padding="do_not_pad", add_special_tokens=True) == [0, 816, 794, 753, 2]
103+
assert tokenizer.encode("gd fh ds", padding="do_not_pad", add_special_tokens=True) == [0, 1465, 1443, 1402, 2]

0 commit comments

Comments
 (0)