Skip to content

Commit

Permalink
Use dedicated helper for scaling a matrix by x^n - 1
Browse files Browse the repository at this point in the history
I plan to follow this up with a custom pallas kernel, since the results show
less savings than I would have expected.

Also adds a jax helper to determine the TPU version

PiperOrigin-RevId: 604783729
  • Loading branch information
j2kun authored and copybara-github committed Feb 6, 2024
1 parent e2ece1e commit 26247cf
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 11 deletions.
1 change: 1 addition & 0 deletions .bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6.4.0
2 changes: 1 addition & 1 deletion BUILD
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# An FHE cryptosystem built in JAX

load("@rules_python//python:defs.bzl", "py_library")
load("@jaxite//bazel:test_oss.bzl", "cpu_gpu_tpu_test", "gpu_tpu_test", "tpu_test")
load("@rules_python//python:defs.bzl", "py_test")
load("@jaxite//bazel:test_oss.bzl", "cpu_gpu_tpu_test", "gpu_tpu_test", "tpu_test")
load("@rules_license//rules:license.bzl", "license")

package(
Expand Down
19 changes: 9 additions & 10 deletions jaxite/jaxite_lib/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,6 @@ def jit_blind_rotate(
log_coefficient_modulus,
).astype(jnp.uint32)

# rot_polynomial is an rlwe ciphertext, so the second dimension determines the
# degree of the polynomial
degree = rot_polynomial.shape[1]

# Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier
# (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger
# bootstrapping key to reduce the number of external products required by 1/2.
Expand All @@ -481,13 +477,16 @@ def one_bmmp_factor(j):
power1 = coefficient_index[2 * j] + coefficient_index[2 * j + 1]
power2 = coefficient_index[2 * j]
power3 = coefficient_index[2 * j + 1]
poly_term1 = matrix_utils.x_power_n_minus_1(power1, poly_mod_deg=degree)
poly_term2 = matrix_utils.x_power_n_minus_1(power2, poly_mod_deg=degree)
poly_term3 = matrix_utils.x_power_n_minus_1(power3, poly_mod_deg=degree)
return (
matrix_utils.poly_mul_const_matrix(poly_term1, bsk[3 * j])
+ matrix_utils.poly_mul_const_matrix(poly_term2, bsk[3 * j + 1])
+ matrix_utils.poly_mul_const_matrix(poly_term3, bsk[3 * j + 2])
matrix_utils.scale_by_x_power_n_minus_1(
power1, bsk[3 * j], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power2, bsk[3 * j + 1], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power3, bsk[3 * j + 2], log_modulus=log_coefficient_modulus
)
).astype(jnp.uint32)

bmmp_factors = jax.vmap(one_bmmp_factor, in_axes=(0,), out_axes=0)(
Expand Down
11 changes: 11 additions & 0 deletions jaxite/jaxite_lib/jax_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,14 @@ def f2(*args):
return out

return g


def get_tpu_version() -> int:
"""Returns the numeric version of the TPU, or -1 if not on TPU."""
kind = jax.devices()[0].device_kind
if 'TPU' not in kind:
return -1
if kind.endswith(' lite'):
kind = kind[: -len(' lite')]
assert kind[:-1] == 'TPU v', kind
return int(kind[-1])
30 changes: 30 additions & 0 deletions jaxite/jaxite_lib/matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,39 @@ def monomial_mul(
poly_mul_const_list, in_axes=(None, 0), out_axes=0
)

# Scale the elements of a matrix by a monomial.
monomial_mul_matrix = jax.vmap(
monomial_mul_list, in_axes=(0, None, None), out_axes=0
)


def poly_dot_product(
poly_vec1: jnp.ndarray, poly_vec2: jnp.ndarray
) -> jnp.ndarray:
"""Compute a dot product of two vectors of polynomials."""
return jnp.sum(poly_mul_list(poly_vec1, poly_vec2), axis=0).astype(jnp.uint32)


@functools.partial(jax.jit, static_argnames="log_modulus")
def scale_by_x_power_n_minus_1(
power: jnp.int32, matrix: jnp.ndarray, log_modulus: int
) -> jnp.ndarray:
"""An optimized poly mul for scaling a matrix of polynomials by x^n - 1.
Args:
power: The exponent n of x^n - 1 to scale each matrix entry by
matrix: The matrix to be scaled.
log_modulus: the base-2 logarithm of the polynomial coefficient modulus.
Returns:
An `jnp.ndarray` of the same shape as `matrix`, containing the
entries of `matrix` each scaled by x^n - 1.
"""
x_power_n_part = monomial_mul_matrix(matrix, power, log_modulus)
minus_one_part = -matrix
output = x_power_n_part + minus_one_part

if 0 < log_modulus < 32:
output = jnp.mod(output, jnp.uint32(2) ** log_modulus)

return output
11 changes: 11 additions & 0 deletions jaxite/jaxite_lib/matrix_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,17 @@ def test_i32_as_u8_matmul(self, lhs, rhs):
)
np.testing.assert_array_equal(expected, actual)

@hypothesis.given(strategies.integers(min_value=0, max_value=10), vectors(16))
@hypothesis.settings(deadline=None)
def test_scale_by_x_power_n_minus_1(self, power, poly):
matrix = jnp.tile(jnp.array(list(poly)), reps=jnp.array([8, 8, 1]))
poly_term = matrix_utils.x_power_n_minus_1(power, poly_mod_deg=16)
expected = matrix_utils.poly_mul_const_matrix(poly_term, matrix)
actual = matrix_utils.scale_by_x_power_n_minus_1(
power, matrix, log_modulus=32
)
np.testing.assert_array_equal(expected, actual)


if __name__ == '__main__':
absltest.main()

0 comments on commit 26247cf

Please sign in to comment.