Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable BMMP as on/off switch "use_bmmp". #49

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 88 additions & 48 deletions jaxite/jaxite_lib/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ class BootstrappingKey:
"""An array with row j an RGSW encryption of bit j of an LWE secret key."""

encrypted_lwe_sk_bits: jnp.ndarray
use_bmmp: bool = True


def gen_bootstrapping_key(
lwe_sk: lwe.LweSecretKey,
rgsw_sk: rgsw.RgswSecretKey,
decomposition_params: decomposition.DecompositionParameters,
prg: random_source.RandomSource,
use_bmmp: bool = True,
) -> BootstrappingKey:
"""Generate a bootstrapping key for the given LWE secret key.

Expand All @@ -64,7 +66,11 @@ def gen_bootstrapping_key(
"requires an even lwe_sk dimension, but got "
f"{lwe_sk_dim}"
)
num_bsk_encryptions = lwe_sk_dim + lwe_sk_dim // 2

if use_bmmp:
num_bsk_encryptions = lwe_sk_dim + lwe_sk_dim // 2
else:
num_bsk_encryptions = lwe_sk_dim

ai_samples = prg.uniform(
shape=(
Expand All @@ -81,30 +87,36 @@ def gen_bootstrapping_key(
dtype=jnp.uint32,
)

# Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier
# (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger
# bootstrapping key to reduce the number of external products required by 1/2.
# Rather than encrypt the secret key bits of the LWE key separately, we
# encrypt:
#
# BSK_{3i} = s_{2i} * s_{2i+1},
# BSK_{3i+1} = s_{2i} * (1 − s_{2i+1}),
# BSK_{3i+2} = (1 − s_{2i}) * s_{2i+1}
#
# which enables a bootstrap operation that involves 1/2 as many external
# products, though this causes the bootstrapping key to be 50% larger.
lwe_sk_data = lwe_sk.key_data.astype(jnp.uint32)
bsk_input_pairs = jnp.zeros(num_bsk_encryptions, dtype=jnp.uint32)

bsk_input_pairs = bsk_input_pairs.at[::3].set(
jnp.multiply(lwe_sk_data[::2], lwe_sk_data[1::2])
)
bsk_input_pairs = bsk_input_pairs.at[1::3].set(
jnp.multiply(lwe_sk_data[::2], 1 - lwe_sk_data[1::2])
)
bsk_input_pairs = bsk_input_pairs.at[2::3].set(
jnp.multiply(1 - lwe_sk_data[::2], lwe_sk_data[1::2])
)
# boostrapping key encrypted each of the security key
# BMMP groups multiple secrutiy keys together into a single BSK. (1.5N v.s N)
if use_bmmp:
# Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier
# (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger
# bootstrapping key to reduce the number of external products required by
# 1/2. Rather than encrypt the secret key bits of the LWE key separately,
# we encrypt:
#
# BSK_{3i} = s_{2i} * s_{2i+1},
# BSK_{3i+1} = s_{2i} * (1 − s_{2i+1}),
# BSK_{3i+2} = (1 − s_{2i}) * s_{2i+1}
#
# which enables a bootstrap operation that involves 1/2 as many external
# products, though this causes the bootstrapping key to be 50% larger.
bsk_input_pairs = jnp.zeros(num_bsk_encryptions, dtype=jnp.uint32)
bsk_input_pairs = bsk_input_pairs.at[::3].set(
jnp.multiply(lwe_sk_data[::2], lwe_sk_data[1::2])
)
bsk_input_pairs = bsk_input_pairs.at[1::3].set(
jnp.multiply(lwe_sk_data[::2], 1 - lwe_sk_data[1::2])
)
bsk_input_pairs = bsk_input_pairs.at[2::3].set(
jnp.multiply(1 - lwe_sk_data[::2], lwe_sk_data[1::2])
)
else:
# BSK_{i} = s_{i},
bsk_input_pairs = lwe_sk_data

# Applying vmap to the entire jit_encrypt over all sk bits will exhaust the
# tensor core's memory with prod security parameters. It will eagerly allocate
Expand Down Expand Up @@ -171,11 +183,13 @@ def process_one_batch(i):
rlwe_sk.modulus_degree,
))

return BootstrappingKey(encrypted_lwe_sk_bits=encrypted_lwe_sk_bits)
return BootstrappingKey(
encrypted_lwe_sk_bits=encrypted_lwe_sk_bits, use_bmmp=use_bmmp
)


@jax.named_call
@functools.partial(jax.jit, static_argnums=(4, 5, 6))
@functools.partial(jax.jit, static_argnums=(4, 5, 6, 7))
def jit_bootstrap(
ciphertext: types.LweCiphertext,
test_poly_ciphertext_message: jnp.ndarray,
Expand All @@ -184,6 +198,7 @@ def jit_bootstrap(
ks_decomposition_params: decomposition.DecompositionParameters,
bs_decomposition_params: decomposition.DecompositionParameters,
scheme_params: parameters.SchemeParameters,
bsk_use_bmmp: bool = True,
) -> types.LweCiphertext:
"""Apply functional bootstrap to reduce noise in the input ciphertext.

Expand Down Expand Up @@ -220,6 +235,7 @@ def jit_bootstrap(
test_poly_ciphertext_message,
approx_ciphertext,
bsk_encrypted_lwe_sk_bits,
bsk_use_bmmp,
test_poly_log_coefficient_modulus,
bs_decomposition_params,
)
Expand Down Expand Up @@ -290,6 +306,7 @@ def bootstrap(
test_poly_ciphertext.message,
approx_ciphertext,
bsk.encrypted_lwe_sk_bits,
bsk.use_bmmp,
test_poly_log_coefficient_modulus,
bs_decomposition_params,
)
Expand Down Expand Up @@ -436,6 +453,7 @@ def blind_rotate(
rot_polynomial.message,
coefficient_index,
bsk.encrypted_lwe_sk_bits,
bsk.use_bmmp,
rot_polynomial.log_coefficient_modulus,
decomposition_params,
)
Expand All @@ -447,11 +465,12 @@ def blind_rotate(


@jax.named_call
@functools.partial(jax.jit, static_argnums=(3, 4))
@functools.partial(jax.jit, static_argnums=(3, 4, 5))
def jit_blind_rotate(
rot_polynomial: jnp.ndarray,
coefficient_index: types.LweCiphertext,
bsk: jnp.ndarray,
use_bmmp: bool,
log_coefficient_modulus: int,
decomposition_params: decomposition.DecompositionParameters,
) -> rlwe.RlweCiphertext:
Expand All @@ -467,30 +486,51 @@ def jit_blind_rotate(
# Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier
# (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger
# bootstrapping key to reduce the number of external products required by 1/2.
num_loop_terms = (coefficient_index.shape[0] - 1) // 2
if use_bmmp:
num_loop_terms = (coefficient_index.shape[0] - 1) // 2
else:
num_loop_terms = coefficient_index.shape[0] - 1

def one_external_product(j, c_prime_accum):
# Doing this computation inside the external product loop improves cache
# locality, resulting in reduced data copying.
power1 = coefficient_index[2 * j] + coefficient_index[2 * j + 1]
power2 = coefficient_index[2 * j]
power3 = coefficient_index[2 * j + 1]
bmmp_factor = (
matrix_utils.scale_by_x_power_n_minus_1(
power1, bsk[3 * j], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power2, bsk[3 * j + 1], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power3, bsk[3 * j + 2], log_modulus=log_coefficient_modulus
)
).astype(jnp.uint32)
return c_prime_accum + jit_external_product(
rgsw_ct=bmmp_factor,
rlwe_ct=c_prime_accum,
decomposition_params=decomposition_params,
)
if use_bmmp:
# Doing this computation inside the external product loop improves cache
# locality, resulting in reduced data copying.
power1 = coefficient_index[2 * j] + coefficient_index[2 * j + 1]
power2 = coefficient_index[2 * j]
power3 = coefficient_index[2 * j + 1]
bmmp_factor = (
matrix_utils.scale_by_x_power_n_minus_1( # Rotation.
power1, bsk[3 * j], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power2, bsk[3 * j + 1], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power3, bsk[3 * j + 2], log_modulus=log_coefficient_modulus
)
).astype(jnp.uint32)
return c_prime_accum + jit_external_product(
rgsw_ct=bmmp_factor,
rlwe_ct=c_prime_accum,
decomposition_params=decomposition_params,
)
# c'_mul = c' * X^{a_j^tilde} (for each entry in c')
else:
# where a_j^tilde = coefficient_index[j] #Disabled BMMP
c_prime_mul = matrix_utils.monomial_mul_list(
c_prime_accum,
coefficient_index[j],
log_coefficient_modulus,
).astype(jnp.uint32)

# Update c_prime with the output of the CMUX operation, where either
# `c_prime` or `c_prime * X^{a_j^tilde}` is chosen by `bsk` at index j.
return jit_cmux(
control=bsk[j],
eq_zero=c_prime_accum,
neq_zero=c_prime_mul,
decomposition_params=decomposition_params,
)

return jax.lax.fori_loop(0, num_loop_terms, one_external_product, c_prime)

Expand Down
18 changes: 12 additions & 6 deletions jaxite/jaxite_lib/bootstrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def run_bootstrap_test(
padding_bits: int,
rlwe_rng: random_source.RandomSource,
skip_assert: bool = False,
use_bmmp: bool = False,
):
cleartext = 2**message_bits - 1
test_utils.assert_safe_modulus_switch(
Expand Down Expand Up @@ -77,6 +78,7 @@ def run_bootstrap_test(
rgsw_sk=rgsw_key,
decomposition_params=test_utils.BSK_DECOMP_PARAMS_128_BIT_SECURITY,
prg=rlwe_rng,
use_bmmp=use_bmmp,
)
ksk = key_switch.gen_key(
in_key=rlwe.flatten_key(rlwe_key),
Expand Down Expand Up @@ -120,6 +122,7 @@ def run_bootstrap_test(
test_utils.KSK_DECOMP_PARAMS_128_BIT_SECURITY,
test_utils.BSK_DECOMP_PARAMS_128_BIT_SECURITY,
scheme_parameters,
bsk.use_bmmp,
)

self.assertEqual(len(jit_bootstrapped), len(bootstrapped))
Expand Down Expand Up @@ -147,12 +150,11 @@ def run_bootstrap_test(


@parameterized.product(
log_ai_bound=_LOG_AI_BOUNDS,
seed=_SEEDS,
log_ai_bound=_LOG_AI_BOUNDS, seed=_SEEDS, use_bmmp=[True, False]
)
class BootstrapTest(BootstrapBaseTest):

def test_3_bit_bootstrap(self, log_ai_bound, seed):
def test_3_bit_bootstrap(self, log_ai_bound, seed, use_bmmp):
message_bits = 3
padding_bits = 1
lwe_dimension = 4
Expand All @@ -173,10 +175,11 @@ def test_3_bit_bootstrap(self, log_ai_bound, seed):
mod_degree=mod_degree,
padding_bits=padding_bits,
rlwe_rng=rng,
use_bmmp=use_bmmp,
)

def test_3_bit_bootstrap_larger_lwe_dimension(
self, log_ai_bound: int, seed: int
self, log_ai_bound: int, seed: int, use_bmmp: bool
):
message_bits = 3
padding_bits = 1
Expand All @@ -200,10 +203,11 @@ def test_3_bit_bootstrap_larger_lwe_dimension(
mod_degree=mod_degree,
padding_bits=padding_bits,
rlwe_rng=rng,
use_bmmp=use_bmmp,
)

@absltest.skip("b/325287870")
def test_6_bit_bootstrap(self, log_ai_bound: int, seed: int):
def test_6_bit_bootstrap(self, log_ai_bound: int, seed: int, use_bmmp: bool):
message_bits = 6
padding_bits = 1
lwe_dimension = 30
Expand All @@ -224,10 +228,11 @@ def test_6_bit_bootstrap(self, log_ai_bound: int, seed: int):
mod_degree=mod_degree,
padding_bits=padding_bits,
rlwe_rng=rng,
use_bmmp=use_bmmp,
)

def test_3_bit_bootstrap_prod_decomp_params(
self, log_ai_bound: int, seed: int
self, log_ai_bound: int, seed: int, use_bmmp: bool
):
message_bits = 3
padding_bits = 1
Expand All @@ -249,6 +254,7 @@ def test_3_bit_bootstrap_prod_decomp_params(
mod_degree=mod_degree,
padding_bits=padding_bits,
rlwe_rng=rng,
use_bmmp=use_bmmp,
)


Expand Down
Loading