Skip to content

Commit

Permalink
Split tests
Browse files Browse the repository at this point in the history
  • Loading branch information
liurupeng committed Feb 4, 2025
1 parent b71d41b commit a70946d
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion MaxText/tests/maxengine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,31 @@ def test_basic_prefill(self):
self.assertNotEqual(prefill_result["tokens"], jnp.array([0]))
self.assertTrue(jnp.array_equal(first_token.data.size, 3))

def test_sampling(self):
def test_greedy_sampling(self):
selected_logits = jnp.array([[[-3.46875, -4.90625, 4.125, 4.0625, -2.6875, 1.1953125, 8.2, 0.345, 0.0]]])
rng = jax.random.PRNGKey(0)
# all results should be the same in this configuration, this is only for functionality test
greedy_logit = inference_utils.sampling(selected_logits, rng, "greedy")
self.assertEqual(selected_logits[0, 0, greedy_logit[0, 0]], 8.2)

def test_weighted_sampling(self):
selected_logits = jnp.array([[[-3.46875, -4.90625, 4.125, 4.0625, -2.6875, 1.1953125, 8.2, 0.345, 0.0]]])
rng = jax.random.PRNGKey(0)

weighted_logits = inference_utils.sampling(selected_logits, rng, "weighted", temperature=2)
self.assertEqual(selected_logits[0, 0, weighted_logits[0, 0]], 8.2)

def test_nucleus_sampling(self):
selected_logits = jnp.array([[[-3.46875, -4.90625, 4.125, 4.0625, -2.6875, 1.1953125, 8.2, 0.345, 0.0]]])
rng = jax.random.PRNGKey(0)

nucleus_logits = inference_utils.sampling(selected_logits, rng, "nucleus", nucleus_topp=3)
self.assertEqual(selected_logits[0, 0, nucleus_logits[0, 0]], 8.2)

def test_topk_sampling(self):
selected_logits = jnp.array([[[-3.46875, -4.90625, 4.125, 4.0625, -2.6875, 1.1953125, 8.2, 0.345, 0.0]]])
rng = jax.random.PRNGKey(0)

topk_logits = inference_utils.sampling(selected_logits, rng, "topk", topk=3)
self.assertEqual(selected_logits[0, 0, topk_logits[0, 0]], 8.2)

Expand Down

1 comment on commit a70946d

@xy12181
Copy link
Collaborator

@xy12181 xy12181 commented on a70946d Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Please sign in to comment.