@@ -24,56 +24,12 @@ def dna_examples():
24
24
class TestUme :
25
25
"""Tests for the Universal Molecular Encoder (Ume) class"""
26
26
27
- @patch ("lobster.model._ume.FlexBERT.load_from_checkpoint" )
28
- def test_frozen_parameters (self , mock_load_checkpoint ):
29
- """Test that parameters are frozen when freeze=True"""
30
- # Create mock model
31
- mock_model = MagicMock ()
32
- mock_params = [torch .nn .Parameter (torch .randn (10 , 10 ))]
33
- mock_model .model .parameters .return_value = mock_params
34
- mock_load_checkpoint .return_value = mock_model
35
-
36
- # Create Ume with frozen parameters
37
- ume = Ume ("dummy_checkpoint.ckpt" , freeze = True )
38
-
39
- # Verify that load_from_checkpoint was called
40
- mock_load_checkpoint .assert_called_once_with ("dummy_checkpoint.ckpt" )
41
-
42
- # Verify that parameters were accessed
43
- mock_model .model .parameters .assert_called ()
44
-
45
- # Verify freeze attribute is True
46
- assert ume .freeze is True
47
-
48
- @patch ("lobster.model._ume.FlexBERT.load_from_checkpoint" )
49
- def test_unfrozen_parameters (self , mock_load_checkpoint ):
50
- """Test that parameters are not frozen when freeze=False"""
51
- # Create mock model
52
- mock_model = MagicMock ()
53
- mock_params = [torch .nn .Parameter (torch .randn (10 , 10 ))]
54
- mock_model .model .parameters .return_value = mock_params
55
- mock_load_checkpoint .return_value = mock_model
56
-
57
- # Create Ume without freezing parameters
58
- ume = Ume ("dummy_checkpoint.ckpt" , freeze = False )
59
-
60
- # Verify freeze attribute is False
61
- assert ume .freeze is False
62
-
63
- # Verify that parameters were not frozen
64
- mock_model .model .parameters .assert_not_called ()
65
-
66
- @patch ("lobster.model._ume.FlexBERT.load_from_checkpoint" )
67
27
@patch ("lobster.model._ume.UmeSmilesTokenizerFast" )
68
28
@patch ("lobster.model._ume.UmeAminoAcidTokenizerFast" )
69
29
@patch ("lobster.model._ume.UmeNucleotideTokenizerFast" )
70
30
@patch ("lobster.model._ume.UmeLatentGenerator3DCoordTokenizerFast" )
71
- def test_tokenizer_initialization (self , mock_coord , mock_nucleotide , mock_amino , mock_smiles , mock_load_checkpoint ):
31
+ def test_tokenizer_initialization (self , mock_coord , mock_nucleotide , mock_amino , mock_smiles ):
72
32
"""Test that tokenizers are initialized during __init__"""
73
- # Set up model mock
74
- mock_model = MagicMock ()
75
- mock_load_checkpoint .return_value = mock_model
76
-
77
33
# Setup tokenizer mocks
78
34
mock_smiles_instance = MagicMock ()
79
35
mock_amino_instance = MagicMock ()
@@ -86,74 +42,44 @@ def test_tokenizer_initialization(self, mock_coord, mock_nucleotide, mock_amino,
86
42
mock_coord .return_value = mock_coord_instance
87
43
88
44
# Create Ume instance
89
- ume = Ume ("dummy_checkpoint.ckpt" )
45
+ ume = Ume ()
90
46
91
- # Verify each tokenizer was instantiated exactly once
92
47
mock_smiles .assert_called_once ()
93
48
mock_amino .assert_called_once ()
94
49
mock_nucleotide .assert_called_once ()
95
50
mock_coord .assert_called_once ()
96
51
97
- # Verify tokenizers were stored in the dictionary
98
52
assert ume .tokenizers [Modality .SMILES ] == mock_smiles_instance
99
53
assert ume .tokenizers [Modality .AMINO_ACID ] == mock_amino_instance
100
54
assert ume .tokenizers [Modality .NUCLEOTIDE ] == mock_nucleotide_instance
101
55
assert ume .tokenizers [Modality .COORDINATES_3D ] == mock_coord_instance
102
56
103
57
@patch ("lobster.model._ume.FlexBERT.load_from_checkpoint" )
104
- @patch ("lobster.model._ume.UmeSmilesTokenizerFast" )
105
- @patch ("lobster.model._ume.UmeAminoAcidTokenizerFast" )
106
- @patch ("lobster.model._ume.UmeNucleotideTokenizerFast" )
107
- @patch ("lobster.model._ume.UmeLatentGenerator3DCoordTokenizerFast" )
108
- def test_get_tokenizer (self , mock_coord , mock_nucleotide , mock_amino , mock_smiles , mock_load_checkpoint ):
58
+ def test_get_tokenizer (self , mock_load_checkpoint ):
109
59
"""Test getting tokenizers for different modalities"""
110
- # Set up model mock
111
- mock_model = MagicMock ()
112
- mock_load_checkpoint .return_value = mock_model
60
+ ume = Ume ()
113
61
114
- # Setup tokenizer mocks
115
- mock_smiles_instance = MagicMock ()
116
- mock_amino_instance = MagicMock ()
117
- mock_nucleotide_instance = MagicMock ()
118
- mock_coord_instance = MagicMock ()
62
+ mock_tokenizers = {}
63
+ for modality in Modality :
64
+ mock_tokenizers [modality ] = MagicMock ()
119
65
120
- mock_smiles .return_value = mock_smiles_instance
121
- mock_amino .return_value = mock_amino_instance
122
- mock_nucleotide .return_value = mock_nucleotide_instance
123
- mock_coord .return_value = mock_coord_instance
66
+ ume .tokenizers = mock_tokenizers
124
67
125
- # Create Ume instance
126
- ume = Ume ("dummy_checkpoint.ckpt" )
127
-
128
- # Test each modality
129
68
modality_map = {
130
- "SMILES" : mock_smiles_instance ,
131
- "amino_acid" : mock_amino_instance ,
132
- "nucleotide" : mock_nucleotide_instance ,
133
- "3d_coordinates" : mock_coord_instance ,
69
+ "SMILES" : Modality . SMILES ,
70
+ "amino_acid" : Modality . AMINO_ACID ,
71
+ "nucleotide" : Modality . NUCLEOTIDE ,
72
+ "3d_coordinates" : Modality . COORDINATES_3D ,
134
73
}
135
74
136
- for modality , mock_instance in modality_map .items ():
137
- # Get tokenizer - this should now return the pre-instantiated tokenizer
138
- tokenizer = ume .get_tokenizer (["test" ], modality )
75
+ for modality_str , modality_enum in modality_map .items ():
76
+ tokenizer = ume .get_tokenizer (modality_str )
139
77
140
- # Verify the returned tokenizer is our mock instance
141
- assert tokenizer == mock_instance
142
-
143
- # Verify that no new tokenizer is instantiated (count should remain at 1)
144
- if modality == "SMILES" :
145
- assert mock_smiles .call_count == 1
146
- elif modality == "amino_acid" :
147
- assert mock_amino .call_count == 1
148
- elif modality == "nucleotide" :
149
- assert mock_nucleotide .call_count == 1
150
- elif modality == "3d_coordinates" :
151
- assert mock_coord .call_count == 1
78
+ assert tokenizer == mock_tokenizers [modality_enum ]
152
79
153
80
@patch ("lobster.model._ume.FlexBERT.load_from_checkpoint" )
154
81
def test_get_embeddings_basic (self , mock_load_checkpoint , smiles_examples , protein_examples , dna_examples ):
155
82
"""Test basic embedding functionality for all modalities"""
156
- # Mock model with controlled output
157
83
mock_model = MagicMock ()
158
84
mock_model .max_length = 512
159
85
mock_model .device = torch .device ("cpu" )
@@ -169,7 +95,7 @@ def mock_tokens_to_latents(**kwargs):
169
95
mock_load_checkpoint .return_value = mock_model
170
96
171
97
# Create Ume instance
172
- ume = Ume ("dummy_checkpoint.ckpt" )
98
+ ume = Ume . load_from_checkpoint ("dummy_checkpoint.ckpt" )
173
99
174
100
# Test for each modality
175
101
modalities = ["SMILES" , "amino_acid" , "nucleotide" ]
@@ -193,6 +119,15 @@ def mock_tokens_to_latents(**kwargs):
193
119
embeddings = ume .get_embeddings (test_inputs [modality ], modality )
194
120
assert embeddings .shape == (batch_size , 768 )
195
121
122
+ # Verify tokenizer was called with the correct inputs
123
+ mock_tokenizer .assert_called_with (
124
+ test_inputs [modality ],
125
+ return_tensors = "pt" ,
126
+ padding = "max_length" ,
127
+ truncation = True ,
128
+ max_length = mock_model .max_length ,
129
+ )
130
+
196
131
# Test token-level embeddings
197
132
token_embeddings = ume .get_embeddings (test_inputs [modality ], modality , aggregate = False )
198
133
assert token_embeddings .shape == (batch_size , seq_len , 768 )
0 commit comments