diff --git a/jaxite/jaxite_lib/bootstrap.py b/jaxite/jaxite_lib/bootstrap.py index 5db32a5..d49a399 100644 --- a/jaxite/jaxite_lib/bootstrap.py +++ b/jaxite/jaxite_lib/bootstrap.py @@ -31,6 +31,7 @@ class BootstrappingKey: """An array with row j an RGSW encryption of bit j of an LWE secret key.""" encrypted_lwe_sk_bits: jnp.ndarray + use_bmmp: bool = True def gen_bootstrapping_key( @@ -38,6 +39,7 @@ def gen_bootstrapping_key( rgsw_sk: rgsw.RgswSecretKey, decomposition_params: decomposition.DecompositionParameters, prg: random_source.RandomSource, + use_bmmp: bool = True, ) -> BootstrappingKey: """Generate a bootstrapping key for the given LWE secret key. @@ -64,7 +66,11 @@ def gen_bootstrapping_key( "requires an even lwe_sk dimension, but got " f"{lwe_sk_dim}" ) - num_bsk_encryptions = lwe_sk_dim + lwe_sk_dim // 2 + + if use_bmmp: + num_bsk_encryptions = lwe_sk_dim + lwe_sk_dim // 2 + else: + num_bsk_encryptions = lwe_sk_dim ai_samples = prg.uniform( shape=( @@ -81,30 +87,36 @@ def gen_bootstrapping_key( dtype=jnp.uint32, ) - # Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier - # (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger - # bootstrapping key to reduce the number of external products required by 1/2. - # Rather than encrypt the secret key bits of the LWE key separately, we - # encrypt: - # - # BSK_{3i} = s_{2i} * s_{2i+1}, - # BSK_{3i+1} = s_{2i} * (1 − s_{2i+1}), - # BSK_{3i+2} = (1 − s_{2i}) * s_{2i+1} - # - # which enables a bootstrap operation that involves 1/2 as many external - # products, though this causes the bootstrapping key to be 50% larger. lwe_sk_data = lwe_sk.key_data.astype(jnp.uint32) - bsk_input_pairs = jnp.zeros(num_bsk_encryptions, dtype=jnp.uint32) - bsk_input_pairs = bsk_input_pairs.at[::3].set( - jnp.multiply(lwe_sk_data[::2], lwe_sk_data[1::2]) - ) - bsk_input_pairs = bsk_input_pairs.at[1::3].set( - jnp.multiply(lwe_sk_data[::2], 1 - lwe_sk_data[1::2]) - ) - bsk_input_pairs = bsk_input_pairs.at[2::3].set( - jnp.multiply(1 - lwe_sk_data[::2], lwe_sk_data[1::2]) - ) + # boostrapping key encrypted each of the security key + # BMMP groups multiple secrutiy keys together into a single BSK. (1.5N v.s N) + if use_bmmp: + # Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier + # (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger + # bootstrapping key to reduce the number of external products required by + # 1/2. Rather than encrypt the secret key bits of the LWE key separately, + # we encrypt: + # + # BSK_{3i} = s_{2i} * s_{2i+1}, + # BSK_{3i+1} = s_{2i} * (1 − s_{2i+1}), + # BSK_{3i+2} = (1 − s_{2i}) * s_{2i+1} + # + # which enables a bootstrap operation that involves 1/2 as many external + # products, though this causes the bootstrapping key to be 50% larger. + bsk_input_pairs = jnp.zeros(num_bsk_encryptions, dtype=jnp.uint32) + bsk_input_pairs = bsk_input_pairs.at[::3].set( + jnp.multiply(lwe_sk_data[::2], lwe_sk_data[1::2]) + ) + bsk_input_pairs = bsk_input_pairs.at[1::3].set( + jnp.multiply(lwe_sk_data[::2], 1 - lwe_sk_data[1::2]) + ) + bsk_input_pairs = bsk_input_pairs.at[2::3].set( + jnp.multiply(1 - lwe_sk_data[::2], lwe_sk_data[1::2]) + ) + else: + # BSK_{i} = s_{i}, + bsk_input_pairs = lwe_sk_data # Applying vmap to the entire jit_encrypt over all sk bits will exhaust the # tensor core's memory with prod security parameters. It will eagerly allocate @@ -171,11 +183,13 @@ def process_one_batch(i): rlwe_sk.modulus_degree, )) - return BootstrappingKey(encrypted_lwe_sk_bits=encrypted_lwe_sk_bits) + return BootstrappingKey( + encrypted_lwe_sk_bits=encrypted_lwe_sk_bits, use_bmmp=use_bmmp + ) @jax.named_call -@functools.partial(jax.jit, static_argnums=(4, 5, 6)) +@functools.partial(jax.jit, static_argnums=(4, 5, 6, 7)) def jit_bootstrap( ciphertext: types.LweCiphertext, test_poly_ciphertext_message: jnp.ndarray, @@ -184,6 +198,7 @@ def jit_bootstrap( ks_decomposition_params: decomposition.DecompositionParameters, bs_decomposition_params: decomposition.DecompositionParameters, scheme_params: parameters.SchemeParameters, + bsk_use_bmmp: bool = True, ) -> types.LweCiphertext: """Apply functional bootstrap to reduce noise in the input ciphertext. @@ -220,6 +235,7 @@ def jit_bootstrap( test_poly_ciphertext_message, approx_ciphertext, bsk_encrypted_lwe_sk_bits, + bsk_use_bmmp, test_poly_log_coefficient_modulus, bs_decomposition_params, ) @@ -290,6 +306,7 @@ def bootstrap( test_poly_ciphertext.message, approx_ciphertext, bsk.encrypted_lwe_sk_bits, + bsk.use_bmmp, test_poly_log_coefficient_modulus, bs_decomposition_params, ) @@ -436,6 +453,7 @@ def blind_rotate( rot_polynomial.message, coefficient_index, bsk.encrypted_lwe_sk_bits, + bsk.use_bmmp, rot_polynomial.log_coefficient_modulus, decomposition_params, ) @@ -447,11 +465,12 @@ def blind_rotate( @jax.named_call -@functools.partial(jax.jit, static_argnums=(3, 4)) +@functools.partial(jax.jit, static_argnums=(3, 4, 5)) def jit_blind_rotate( rot_polynomial: jnp.ndarray, coefficient_index: types.LweCiphertext, bsk: jnp.ndarray, + use_bmmp: bool, log_coefficient_modulus: int, decomposition_params: decomposition.DecompositionParameters, ) -> rlwe.RlweCiphertext: @@ -467,30 +486,51 @@ def jit_blind_rotate( # Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier # (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger # bootstrapping key to reduce the number of external products required by 1/2. - num_loop_terms = (coefficient_index.shape[0] - 1) // 2 + if use_bmmp: + num_loop_terms = (coefficient_index.shape[0] - 1) // 2 + else: + num_loop_terms = coefficient_index.shape[0] - 1 def one_external_product(j, c_prime_accum): - # Doing this computation inside the external product loop improves cache - # locality, resulting in reduced data copying. - power1 = coefficient_index[2 * j] + coefficient_index[2 * j + 1] - power2 = coefficient_index[2 * j] - power3 = coefficient_index[2 * j + 1] - bmmp_factor = ( - matrix_utils.scale_by_x_power_n_minus_1( - power1, bsk[3 * j], log_modulus=log_coefficient_modulus - ) - + matrix_utils.scale_by_x_power_n_minus_1( - power2, bsk[3 * j + 1], log_modulus=log_coefficient_modulus - ) - + matrix_utils.scale_by_x_power_n_minus_1( - power3, bsk[3 * j + 2], log_modulus=log_coefficient_modulus - ) - ).astype(jnp.uint32) - return c_prime_accum + jit_external_product( - rgsw_ct=bmmp_factor, - rlwe_ct=c_prime_accum, - decomposition_params=decomposition_params, - ) + if use_bmmp: + # Doing this computation inside the external product loop improves cache + # locality, resulting in reduced data copying. + power1 = coefficient_index[2 * j] + coefficient_index[2 * j + 1] + power2 = coefficient_index[2 * j] + power3 = coefficient_index[2 * j + 1] + bmmp_factor = ( + matrix_utils.scale_by_x_power_n_minus_1( # Rotation. + power1, bsk[3 * j], log_modulus=log_coefficient_modulus + ) + + matrix_utils.scale_by_x_power_n_minus_1( + power2, bsk[3 * j + 1], log_modulus=log_coefficient_modulus + ) + + matrix_utils.scale_by_x_power_n_minus_1( + power3, bsk[3 * j + 2], log_modulus=log_coefficient_modulus + ) + ).astype(jnp.uint32) + return c_prime_accum + jit_external_product( + rgsw_ct=bmmp_factor, + rlwe_ct=c_prime_accum, + decomposition_params=decomposition_params, + ) + # c'_mul = c' * X^{a_j^tilde} (for each entry in c') + else: + # where a_j^tilde = coefficient_index[j] #Disabled BMMP + c_prime_mul = matrix_utils.monomial_mul_list( + c_prime_accum, + coefficient_index[j], + log_coefficient_modulus, + ).astype(jnp.uint32) + + # Update c_prime with the output of the CMUX operation, where either + # `c_prime` or `c_prime * X^{a_j^tilde}` is chosen by `bsk` at index j. + return jit_cmux( + control=bsk[j], + eq_zero=c_prime_accum, + neq_zero=c_prime_mul, + decomposition_params=decomposition_params, + ) return jax.lax.fori_loop(0, num_loop_terms, one_external_product, c_prime) diff --git a/jaxite/jaxite_lib/bootstrap_test.py b/jaxite/jaxite_lib/bootstrap_test.py index e8b26da..53fdb7e 100644 --- a/jaxite/jaxite_lib/bootstrap_test.py +++ b/jaxite/jaxite_lib/bootstrap_test.py @@ -40,6 +40,7 @@ def run_bootstrap_test( padding_bits: int, rlwe_rng: random_source.RandomSource, skip_assert: bool = False, + use_bmmp: bool = False, ): cleartext = 2**message_bits - 1 test_utils.assert_safe_modulus_switch( @@ -77,6 +78,7 @@ def run_bootstrap_test( rgsw_sk=rgsw_key, decomposition_params=test_utils.BSK_DECOMP_PARAMS_128_BIT_SECURITY, prg=rlwe_rng, + use_bmmp=use_bmmp, ) ksk = key_switch.gen_key( in_key=rlwe.flatten_key(rlwe_key), @@ -120,6 +122,7 @@ def run_bootstrap_test( test_utils.KSK_DECOMP_PARAMS_128_BIT_SECURITY, test_utils.BSK_DECOMP_PARAMS_128_BIT_SECURITY, scheme_parameters, + bsk.use_bmmp, ) self.assertEqual(len(jit_bootstrapped), len(bootstrapped)) @@ -147,12 +150,11 @@ def run_bootstrap_test( @parameterized.product( - log_ai_bound=_LOG_AI_BOUNDS, - seed=_SEEDS, + log_ai_bound=_LOG_AI_BOUNDS, seed=_SEEDS, use_bmmp=[True, False] ) class BootstrapTest(BootstrapBaseTest): - def test_3_bit_bootstrap(self, log_ai_bound, seed): + def test_3_bit_bootstrap(self, log_ai_bound, seed, use_bmmp): message_bits = 3 padding_bits = 1 lwe_dimension = 4 @@ -173,10 +175,11 @@ def test_3_bit_bootstrap(self, log_ai_bound, seed): mod_degree=mod_degree, padding_bits=padding_bits, rlwe_rng=rng, + use_bmmp=use_bmmp, ) def test_3_bit_bootstrap_larger_lwe_dimension( - self, log_ai_bound: int, seed: int + self, log_ai_bound: int, seed: int, use_bmmp: bool ): message_bits = 3 padding_bits = 1 @@ -200,10 +203,11 @@ def test_3_bit_bootstrap_larger_lwe_dimension( mod_degree=mod_degree, padding_bits=padding_bits, rlwe_rng=rng, + use_bmmp=use_bmmp, ) @absltest.skip("b/325287870") - def test_6_bit_bootstrap(self, log_ai_bound: int, seed: int): + def test_6_bit_bootstrap(self, log_ai_bound: int, seed: int, use_bmmp: bool): message_bits = 6 padding_bits = 1 lwe_dimension = 30 @@ -224,10 +228,11 @@ def test_6_bit_bootstrap(self, log_ai_bound: int, seed: int): mod_degree=mod_degree, padding_bits=padding_bits, rlwe_rng=rng, + use_bmmp=use_bmmp, ) def test_3_bit_bootstrap_prod_decomp_params( - self, log_ai_bound: int, seed: int + self, log_ai_bound: int, seed: int, use_bmmp: bool ): message_bits = 3 padding_bits = 1 @@ -249,6 +254,7 @@ def test_3_bit_bootstrap_prod_decomp_params( mod_degree=mod_degree, padding_bits=padding_bits, rlwe_rng=rng, + use_bmmp=use_bmmp, )