Skip to content

Commit

Permalink
Finish BAT + offline compilation of BSK.
Browse files Browse the repository at this point in the history
This reduces jax_external_product latency of 19.746 in original BMMP based bootstrapping down to 17.535 us, because BAT has better efficiency.

However, the jax_bootstrapping ends up increasing from 9.428 ms (cl/724481568) into 24.253 ms, because no global scheduling optimization has been performed. Further optimization could at least reduce latency from 9.428 ms down to 17.536/19.746*9.428 ms.

PiperOrigin-RevId: 727551971
  • Loading branch information
JianmingTONG authored and copybara-github committed Feb 16, 2025
1 parent a848652 commit 43e7ae0
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 30 deletions.
2 changes: 1 addition & 1 deletion jaxite/jaxite_lib/blind_rotate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_cmux_noise_free(self, control):
rlwe_ct_2 = rlwe.encrypt(rlwe_pt_2, self.rlwe_key, prg=self.noise_free_rng)

cmux_output = bootstrap.cmux(
control_ct, rlwe_ct_1, rlwe_ct_2, self.decomposition_params
control_ct, rlwe_ct_1, rlwe_ct_2, self.decomposition_params,
)
decrypted = rlwe.decrypt(
cmux_output, self.rlwe_key, encoding_params=self.encoding
Expand Down
83 changes: 68 additions & 15 deletions jaxite/jaxite_lib/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from jaxite.jaxite_lib import rlwe
from jaxite.jaxite_lib import types

GEN_BSK_NUM_BATCHES = 15
GEN_BSK_NUM_BATCHES = 20
NON_DIVISIBLE_BATCH_SIZE_WARNING = (
"Expected lwe_sk_dim to be a multiple of %s, but was %s. "
"This is OK for tests with small LWE secret key sizes, but may cause "
Expand All @@ -32,6 +32,7 @@ class BootstrappingKey:

encrypted_lwe_sk_bits: jnp.ndarray
use_bmmp: bool = True
use_bat: bool = False


def gen_bootstrapping_key(
Expand All @@ -40,6 +41,7 @@ def gen_bootstrapping_key(
decomposition_params: decomposition.DecompositionParameters,
prg: random_source.RandomSource,
use_bmmp: bool = True,
use_bat: bool = False,
) -> BootstrappingKey:
"""Generate a bootstrapping key for the given LWE secret key.
Expand Down Expand Up @@ -183,13 +185,43 @@ def process_one_batch(i):
rlwe_sk.modulus_degree,
))

def bat_offline_compile_cggi(mat_a):
"""Convert the input matrix with 32 bit elements into u8(*matrix.shape,4,4).
Replace each element in the original matrix by a p*q matrix (p=q=4).
Args:
mat_a: The input matrix.
Returns:
The converted matrix.
"""
mat_a_u8 = jax.lax.bitcast_convert_type(mat_a, new_dtype=jnp.uint8).reshape(
*mat_a.shape, 4, 1
)
mat_a_u8_r1 = jnp.roll(mat_a_u8, 1, axis=-2)
mat_a_u8_r2 = jnp.roll(mat_a_u8, 2, axis=-2)
mat_a_u8_r3 = jnp.roll(mat_a_u8, 3, axis=-2)
mat_a_u8_array = jnp.concatenate(
[mat_a_u8, mat_a_u8_r1, mat_a_u8_r2, mat_a_u8_r3], axis=-1
)
return jnp.tril(mat_a_u8_array)

if use_bmmp:
use_bat = False
else:
if use_bat:
encrypted_lwe_sk_bits = bat_offline_compile_cggi(encrypted_lwe_sk_bits)

return BootstrappingKey(
encrypted_lwe_sk_bits=encrypted_lwe_sk_bits, use_bmmp=use_bmmp
encrypted_lwe_sk_bits=encrypted_lwe_sk_bits,
use_bmmp=use_bmmp,
use_bat=use_bat,
)


@jax.named_call
@functools.partial(jax.jit, static_argnums=(4, 5, 6, 7))
@functools.partial(jax.jit, static_argnums=(4, 5, 6, 7, 8))
def jit_bootstrap(
ciphertext: types.LweCiphertext,
test_poly_ciphertext_message: jnp.ndarray,
Expand All @@ -199,6 +231,7 @@ def jit_bootstrap(
bs_decomposition_params: decomposition.DecompositionParameters,
scheme_params: parameters.SchemeParameters,
bsk_use_bmmp: bool = True,
bsk_use_bat: bool = False,
) -> types.LweCiphertext:
"""Apply functional bootstrap to reduce noise in the input ciphertext.
Expand Down Expand Up @@ -235,9 +268,10 @@ def jit_bootstrap(
test_poly_ciphertext_message,
approx_ciphertext,
bsk_encrypted_lwe_sk_bits,
bsk_use_bmmp,
test_poly_log_coefficient_modulus,
bs_decomposition_params,
bsk_use_bmmp,
bsk_use_bat,
)

extracted = jit_sample_extract(rotated, mod_degree)
Expand Down Expand Up @@ -306,9 +340,10 @@ def bootstrap(
test_poly_ciphertext.message,
approx_ciphertext,
bsk.encrypted_lwe_sk_bits,
bsk.use_bmmp,
test_poly_log_coefficient_modulus,
bs_decomposition_params,
bsk.use_bmmp,
bsk.use_bat,
)
if callback:
callback("rotated", rotated, **kwargs)
Expand Down Expand Up @@ -345,26 +380,33 @@ def external_product(
)


@functools.partial(jax.jit, static_argnames="decomposition_params")
@functools.partial(jax.jit, static_argnums=(2, 3))
def jit_external_product(
rgsw_ct: jnp.ndarray,
rlwe_ct: jnp.ndarray,
decomposition_params: decomposition.DecompositionParameters,
use_bat: bool = False,
) -> rlwe.RlweCiphertext:
"""Compute the external product of an RSGW and RLWE ciphertext."""
decomposed_rlwe = decomposition.decompose_rlwe_ciphertext(
rlwe_ct, decomposition_params
)
return polymul_kernel.negacyclic_vector_matrix_polymul(
decomposed_rlwe, rgsw_ct
)
if use_bat:
return polymul_kernel.negacyclic_vector_matrix_polymul_bat(
decomposed_rlwe, rgsw_ct
)
else:
return polymul_kernel.negacyclic_vector_matrix_polymul(
decomposed_rlwe, rgsw_ct
)


def cmux(
control: rgsw.RgswCiphertext,
eq_zero: rlwe.RlweCiphertext,
neq_zero: rlwe.RlweCiphertext,
decomposition_params: decomposition.DecompositionParameters,
use_bat: bool = False,
) -> rlwe.RlweCiphertext:
"""Compute CMUX: controlled multiplexer.
Expand All @@ -373,6 +415,7 @@ def cmux(
eq_zero: RLWE ciphertext selected if control=0
neq_zero: RLWE ciphertext selected if control=1
decomposition_params: decomposition parameters for the external product
use_bat: whether to use the batched implementation of the external product
Returns:
RlwePlaintext: selected RLWE ciphertext
Expand All @@ -397,7 +440,11 @@ def cmux(
)
modulus_degree = eq_zero.modulus_degree
output = jit_cmux(
control.message, eq_zero.message, neq_zero.message, decomposition_params
control.message,
eq_zero.message,
neq_zero.message,
decomposition_params,
use_bat,
)
return rlwe.RlweCiphertext(
log_coefficient_modulus=eq_zero.log_coefficient_modulus,
Expand All @@ -407,12 +454,13 @@ def cmux(


@jax.named_call
@functools.partial(jax.jit, static_argnames="decomposition_params")
@functools.partial(jax.jit, static_argnums=(3, 4))
def jit_cmux(
control: jnp.ndarray,
eq_zero: jnp.ndarray,
neq_zero: jnp.ndarray,
decomposition_params: decomposition.DecompositionParameters,
use_bat: bool,
) -> rlwe.RlweCiphertext:
"""A jitted cmux."""
return (
Expand All @@ -421,8 +469,9 @@ def jit_cmux(
rgsw_ct=control,
rlwe_ct=neq_zero - eq_zero,
decomposition_params=decomposition_params,
use_bat=use_bat,
)
)
).astype(jnp.uint32)


def blind_rotate(
Expand Down Expand Up @@ -453,9 +502,10 @@ def blind_rotate(
rot_polynomial.message,
coefficient_index,
bsk.encrypted_lwe_sk_bits,
bsk.use_bmmp,
rot_polynomial.log_coefficient_modulus,
decomposition_params,
bsk.use_bmmp,
bsk.use_bat,
)
return rlwe.RlweCiphertext(
log_coefficient_modulus=rot_polynomial.log_coefficient_modulus,
Expand All @@ -465,14 +515,15 @@ def blind_rotate(


@jax.named_call
@functools.partial(jax.jit, static_argnums=(3, 4, 5))
@functools.partial(jax.jit, static_argnums=(3, 4, 5, 6))
def jit_blind_rotate(
rot_polynomial: jnp.ndarray,
coefficient_index: types.LweCiphertext,
bsk: jnp.ndarray,
use_bmmp: bool,
log_coefficient_modulus: int,
decomposition_params: decomposition.DecompositionParameters,
use_bmmp: bool,
use_bat: bool,
) -> rlwe.RlweCiphertext:
"""Rotate an encrypted polynomial `coefficient_index` times."""
# Calculate c' = X^{-b^tilde} * RLWE_s'(v) (for each entry in RLWE_s'(v))
Expand Down Expand Up @@ -513,6 +564,7 @@ def one_external_product(j, c_prime_accum):
rgsw_ct=bmmp_factor,
rlwe_ct=c_prime_accum,
decomposition_params=decomposition_params,
use_bat=False,
)
# c'_mul = c' * X^{a_j^tilde} (for each entry in c')
else:
Expand All @@ -530,6 +582,7 @@ def one_external_product(j, c_prime_accum):
eq_zero=c_prime_accum,
neq_zero=c_prime_mul,
decomposition_params=decomposition_params,
use_bat=use_bat,
)

return jax.lax.fori_loop(0, num_loop_terms, one_external_product, c_prime)
Expand Down
43 changes: 37 additions & 6 deletions jaxite/jaxite_lib/bootstrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@parameterized.product(
use_bmmp=[True, False]
use_bmmp=[True, False], use_bat=[True, False],
)
class BootstrapBaseTest(parameterized.TestCase):
"""A base class for running bootstrap tests."""
Expand All @@ -44,7 +44,13 @@ def run_bootstrap_test(
rlwe_rng: random_source.RandomSource,
skip_assert: bool = False,
use_bmmp: bool = True,
use_bat: bool = False,
):
if use_bmmp and use_bat:
self.skipTest(
"bmmp cannot be used with BAT, because BAT requires data to be known"
" at compile time"
)
cleartext = 2**message_bits - 1
test_utils.assert_safe_modulus_switch(
mod_degree, message_bits, lwe_dimension
Expand Down Expand Up @@ -82,6 +88,7 @@ def run_bootstrap_test(
decomposition_params=test_utils.BSK_DECOMP_PARAMS_128_BIT_SECURITY,
prg=rlwe_rng,
use_bmmp=use_bmmp,
use_bat=use_bat,
)
ksk = key_switch.gen_key(
in_key=rlwe.flatten_key(rlwe_key),
Expand Down Expand Up @@ -126,6 +133,7 @@ def run_bootstrap_test(
test_utils.BSK_DECOMP_PARAMS_128_BIT_SECURITY,
scheme_parameters,
bsk.use_bmmp,
bsk.use_bat,
)

self.assertEqual(len(jit_bootstrapped), len(bootstrapped))
Expand Down Expand Up @@ -153,11 +161,20 @@ def run_bootstrap_test(


@parameterized.product(
log_ai_bound=_LOG_AI_BOUNDS, seed=_SEEDS, use_bmmp=[True, False]
log_ai_bound=_LOG_AI_BOUNDS,
seed=_SEEDS,
use_bmmp=[True, False],
use_bat=[True, False],
)
class BootstrapTest(BootstrapBaseTest):

def test_3_bit_bootstrap(self, log_ai_bound, seed, use_bmmp):
def test_3_bit_bootstrap(self, log_ai_bound, seed, use_bmmp, use_bat):
if use_bmmp and use_bat:
self.skipTest(
"bmmp cannot be used with BAT, because BAT requires data to be known"
" at compile time"
)

message_bits = 3
padding_bits = 1
lwe_dimension = 4
Expand All @@ -179,11 +196,19 @@ def test_3_bit_bootstrap(self, log_ai_bound, seed, use_bmmp):
padding_bits=padding_bits,
rlwe_rng=rng,
use_bmmp=use_bmmp,
use_bat=use_bat,
)

def test_3_bit_bootstrap_larger_lwe_dimension(
self, log_ai_bound: int, seed: int, use_bmmp: bool
self, log_ai_bound: int, seed: int, use_bmmp: bool, use_bat: bool
):
if use_bmmp and use_bat:
self.skipTest(
"bmmp cannot be used with BAT, because BAT requires data to be known"
" at compile time"
)
if not use_bmmp:
self.skipTest("BAT does not support larger LWE dimensions")
message_bits = 3
padding_bits = 1
lwe_dimension = 100
Expand All @@ -207,12 +232,17 @@ def test_3_bit_bootstrap_larger_lwe_dimension(
padding_bits=padding_bits,
rlwe_rng=rng,
use_bmmp=use_bmmp,
use_bat=use_bat,
)


def test_3_bit_bootstrap_prod_decomp_params(
self, log_ai_bound: int, seed: int, use_bmmp: bool
self, log_ai_bound: int, seed: int, use_bmmp: bool, use_bat: bool
):
if use_bmmp and use_bat:
self.skipTest(
"bmmp cannot be used with BAT, because BAT requires data to be known"
" at compile time"
)
message_bits = 3
padding_bits = 1
lwe_dimension = 30
Expand All @@ -234,6 +264,7 @@ def test_3_bit_bootstrap_prod_decomp_params(
padding_bits=padding_bits,
rlwe_rng=rng,
use_bmmp=use_bmmp,
use_bat=use_bat,
)


Expand Down
2 changes: 1 addition & 1 deletion jaxite/jaxite_lib/matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def poly_mul(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
@jax.named_call
@functools.partial(jax.jit, static_argnames="log_modulus")
def monomial_mul(
poly: jnp.ndarray, degree: int, log_modulus: int
poly: jnp.ndarray, degree: jnp.uint32, log_modulus: jnp.uint32
) -> jnp.ndarray:
"""Computes `poly * X^degree mod (X^N + 1)` where N = len(poly).
Expand Down
Loading

0 comments on commit 43e7ae0

Please sign in to comment.