Skip to content

Commit

Permalink
Create LWEtoJaxiteWord Conversion.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 727560107
  • Loading branch information
JianmingTONG authored and copybara-github committed Feb 16, 2025
1 parent 43e7ae0 commit aa5aaff
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
16 changes: 16 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,22 @@ tpu_test(
],
)

tpu_test(
name = "add_test",
size = "large",
timeout = "eternal",
srcs = ["jaxite_word/add_test.py"],
shard_count = 3,
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "decomposition_test",
size = "small",
Expand Down
58 changes: 58 additions & 0 deletions jaxite_word/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""TPU kernels for Evaluation of the CKKS algorithm."""

import jax
import jax.numpy as jnp


def jax_add(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array):
"""This function processes all degree of the two input polynomials in parallel using multi-trheading.
Assuming the input data type is jax array.
Args:
value_a: the first operand of the addition.
value_b: the second operand of the addition.
modulus_list: the list of moduli for each degree.
Returns:
The result of the addition.
"""
num_elements, _, degree = value_a.shape
modulus_broadcast = jnp.tile(
modulus_list[None, :, None], (num_elements, 1, degree)
)
result = value_a + value_b
return jnp.where(
result > modulus_broadcast, result - modulus_broadcast, result
) # jnp.mod(value_a + value_b, modulus_broadcast)


def vmap_add(
value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array
):
"""This function processes all degree of the two input polynomials in SIMD using jax.vmap.
Assuming the input data type is jax array.
Args:
value_a: the first operand of the addition.
value_b: the second operand of the addition.
modulus_list: the list of moduli for each degree.
Returns:
The result of the addition.
"""
num_elements, num_towers, degree = value_a.shape
#ToDo: expand api into four dimensions array with num_ciphertexts, num_towers, degree, num_elements
modulus_broadcast = jnp.tile(
modulus_list[None, :, None], (num_elements, 1, degree)
)

def chunk_wise_add(value_a, value_b):
return value_a + value_b

def chunk_wise_subtract(value_a, value_b):
return jnp.where(value_a > value_b, value_a - value_b, value_a)

result = jax.vmap(chunk_wise_add)(value_a, value_b)
return jax.vmap(chunk_wise_subtract)(result, modulus_broadcast)
108 changes: 108 additions & 0 deletions jaxite_word/add_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""A module for operations on test CKKS evaluation kernels including.
- ModAdd
- HEAdd
- HESub
- HEMul
- HERotate
"""

from concurrent import futures
from typing import Any, Callable

import jax
import jax.numpy as jnp
from jaxite.jaxite_word import add

from absl.testing import absltest
from absl.testing import parameterized


ProcessPoolExecutor = futures.ProcessPoolExecutor

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_traceback_filtering", "off")


class CKKSEvalKernelsTest(parameterized.TestCase):
"""A base class for running bootstrap tests."""

def __init__(self, *args, **kwargs):
super(CKKSEvalKernelsTest, self).__init__(*args, **kwargs)
self.debug = False # dsiable it from printing the test input values
self.modulus_element_0_tower_0 = 1152921504606748673
self.modulus_element_0_tower_1 = 268664833
self.modulus_element_0_tower_2 = 557057
self.random_key = jax.random.key(0)

def random(self, shape, modulus_list, dtype=jnp.int32):
assert len(modulus_list) == shape[1]

return jnp.concatenate(
[
jax.random.randint(
self.random_key,
shape=(shape[0], 1, shape[2]),
minval=0,
maxval=bound,
dtype=dtype,
)
for bound in modulus_list
],
axis=1,
)

@parameterized.named_parameters(
dict(
testcase_name="jax_add",
test_target=add.jax_add,
modulus_list=[1152921504606748673, 268664833, 557057],
shape=(2, 3, 16384), # number of elements, number of towers, degree
),
dict(
testcase_name="vmap_add",
test_target=add.vmap_add,
modulus_list=[1152921504606748673, 268664833, 557057],
shape=(2, 3, 16384), # number of elements, number of towers, degree
),
)
def test_add(
self,
test_target: Callable[[Any, Any, Any], Any],
modulus_list=jax.Array,
shape=tuple[int, int, int],
):
"""This function tests the add function using Python native integer data type with arbitrary precision.
This test finishes in 1.05 second.
Args:
test_target: The function to test.
modulus_list: A jax.Array of integers.
shape: A tuple of integers representing the shape of the input arrays.
"""
# Only test a single element to save comparison time,
# Correctness-wise, it's sufficient for add.
value_a = self.random(shape, modulus_list, dtype=jnp.uint64)
value_b = self.random(shape, modulus_list, dtype=jnp.uint64)
assert value_a.shape == shape
assert value_b.shape == shape
result_a_plus_b = []
for element_id in range(value_a.shape[0]):
result_a_plus_b_one_element = []
for tower_id in range(value_a.shape[1]):
add_res = int(value_b[element_id, tower_id, 0]) + int(
value_a[element_id, tower_id, 0]
)
if add_res > modulus_list[tower_id]:
add_res = add_res - modulus_list[tower_id]
result_a_plus_b_one_element.append(add_res)
result_a_plus_b.append(result_a_plus_b_one_element)
result_a_plus_b = jnp.array(result_a_plus_b, dtype=jnp.uint64)
modulus_list = jnp.array(modulus_list, dtype=jnp.uint64)
result = test_target(value_a, value_b, modulus_list)
self.assertEqual(result[:, :, 0].all(), result_a_plus_b.all())


if __name__ == "__main__":
absltest.main()

0 comments on commit aa5aaff

Please sign in to comment.