diff --git a/MaxText/tests/maxengine_test.py b/MaxText/tests/maxengine_test.py index d2de3e031..2ac5ba430 100644 --- a/MaxText/tests/maxengine_test.py +++ b/MaxText/tests/maxengine_test.py @@ -122,19 +122,31 @@ def test_basic_prefill(self): self.assertNotEqual(prefill_result["tokens"], jnp.array([0])) self.assertTrue(jnp.array_equal(first_token.data.size, 3)) - def test_sampling(self): + def test_greedy_sampling(self): selected_logits = jnp.array([[[-3.46875, -4.90625, 4.125, 4.0625, -2.6875, 1.1953125, 8.2, 0.345, 0.0]]]) rng = jax.random.PRNGKey(0) # all results should be the same in this configuration, this is only for functionality test greedy_logit = inference_utils.sampling(selected_logits, rng, "greedy") self.assertEqual(selected_logits[0, 0, greedy_logit[0, 0]], 8.2) + def test_weighted_sampling(self): + selected_logits = jnp.array([[[-3.46875, -4.90625, 4.125, 4.0625, -2.6875, 1.1953125, 8.2, 0.345, 0.0]]]) + rng = jax.random.PRNGKey(0) + weighted_logits = inference_utils.sampling(selected_logits, rng, "weighted", temperature=2) self.assertEqual(selected_logits[0, 0, weighted_logits[0, 0]], 8.2) + def test_nucleus_sampling(self): + selected_logits = jnp.array([[[-3.46875, -4.90625, 4.125, 4.0625, -2.6875, 1.1953125, 8.2, 0.345, 0.0]]]) + rng = jax.random.PRNGKey(0) + nucleus_logits = inference_utils.sampling(selected_logits, rng, "nucleus", nucleus_topp=3) self.assertEqual(selected_logits[0, 0, nucleus_logits[0, 0]], 8.2) + def test_topk_sampling(self): + selected_logits = jnp.array([[[-3.46875, -4.90625, 4.125, 4.0625, -2.6875, 1.1953125, 8.2, 0.345, 0.0]]]) + rng = jax.random.PRNGKey(0) + topk_logits = inference_utils.sampling(selected_logits, rng, "topk", topk=3) self.assertEqual(selected_logits[0, 0, topk_logits[0, 0]], 8.2)