Skip to content

Commit

Permalink
Enable Global Compilation -- reduce latency from 5.81 second down to …
Browse files Browse the repository at this point in the history
…9.71 ms, achieving 374x end-to-end speedup.

PiperOrigin-RevId: 728636628
  • Loading branch information
JianmingTONG authored and copybara-github committed Feb 19, 2025
1 parent 43e7ae0 commit 94dfcfe
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
43 changes: 43 additions & 0 deletions jaxite_ec/finite_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,3 +1054,46 @@ def add_sub_rns_var(*values, moduli_t=utils.RNS_MODULI):
# u2 = 0 or 1, but if u2 = 1 then l < 2**16 - t, so 2**16 - t + t < 2**16
u2, l2 = split_view_32_to_16_8(i1)
return jnp.add(jnp.multiply(u2, moduli_t).astype(jnp.uint16), l2)


def construct_rns_conversion_matrix(
p_bytes=utils.U8_CHUNK_NUM
):
"""Construct the reduction matrix.
Args:
p: The modulus.
Returns:
rns_conv_mat: The rns conversion matrix.
Note that: this function runs on CPU of the TPU-VM, which cannot be jitted.
"""
conv_mat = np.zeros((p_bytes, utils.NUM_MODULI), dtype=jnp.uint16)
for i in range(p_bytes):
placevalue = 256**i
conv_mat[i, :] = utils.to_rns(placevalue, utils.MODULI)
l, h = get_parts(conv_mat)
rns_conv_mat = np.hstack((l, h))
return rns_conv_mat
@jax.named_call
@functools.partial(
jax.jit,
static_argnames=("s_idx"),
)
def convert_to_rns(
values: jax.Array,
rns_conv_mat: jax.Array,
s_idx=utils.NUM_MODULI,
):
"""Apply matrix operation to convert to RNS.
Args:
values: Array of bigints.
rns_conv_mat: precomputed conversion matrix
Returns:
rns_values: values in RNS form
"""
v = jnp.matmul(
values.view(jnp.uint8), rns_conv_mat, preferred_element_type=jnp.uint32
)
c = v[:, :s_idx] + (v[:, s_idx:] << 8)
return moduli_rns_red_internal_2u16(c)
17 changes: 16 additions & 1 deletion jaxite_ec/finite_field_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,22 @@ def test_jax_mod_mul_lazy_reduction(self):
c_list[i] % utils.MODULUS_377_INT,
(a_list[i] * b_list[i]) % utils.MODULUS_377_INT,
)
print("testing pass")

def test_jax_convert(self):
batch_size = 16
a_list = [randint(0, utils.MODULUS_377_INT) for _ in range(batch_size)]
rns_conv_matrix = ff.construct_rns_conversion_matrix(utils.U8_CHUNK_NUM)
a_batch = utils.int_list_to_jax_array(
a_list, base=utils.BASE, array_size=utils.U16_CHUNK_NUM
)
print("Test test_jax_convert", end=" ")
# copybara: session = xprof_session.XprofSession()
rns_batch = ff.convert_to_rns(a_batch, rns_conv_matrix)
# copybara: session_id = session.end_session_and_get_session_id()
# copybara: print(f'session_id: http://xprof/?session_id={session_id}')
# No fast conversion back unfortunately
a_int = utils.jax_rns_array_to_int_list(rns_batch)
assert(a_int == a_list)


if __name__ == "__main__":
Expand Down

0 comments on commit 94dfcfe

Please sign in to comment.