From f791747b9fa4ef27f9916134e98a7c6004f68a60 Mon Sep 17 00:00:00 2001 From: Shruthi Gorantala Date: Wed, 5 Mar 2025 18:04:02 -0800 Subject: [PATCH] Add new utils to type_converter to encode arbitrary sized ints. PiperOrigin-RevId: 733934165 --- .github/workflows/build_and_test.yml | 4 ++-- jaxite/jaxite_bool/type_converters.py | 27 +++++++++++++++++++-------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 0f6db53..e0fe5e0 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -13,13 +13,13 @@ concurrency: jobs: build-and-test: runs-on: - labels: ubuntu-20.04-16core + labels: ubuntu-22.04-16core steps: - name: Check out repository code uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3 - name: Cache bazel build artifacts - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # pin@v3.3.1 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # pin@v4.2.0 with: path: | ~/.cache/bazel diff --git a/jaxite/jaxite_bool/type_converters.py b/jaxite/jaxite_bool/type_converters.py index 0c2310c..425e75f 100644 --- a/jaxite/jaxite_bool/type_converters.py +++ b/jaxite/jaxite_bool/type_converters.py @@ -3,24 +3,35 @@ from typing import List, Optional -def bit_slice_to_u8(bit_slice: List[bool]) -> int: - """Given a bit slice of length 8, returns a base-10 int representation.""" - if len(bit_slice) != 8: +def bit_slice_to_uint(bit_slice: List[bool], num_bits: int) -> int: + """Given a bit slice of num_bits, returns a base-10 int representation.""" + if len(bit_slice) != num_bits: raise ValueError(f'Expected an 8-bit representation but got: {bit_slice}.') result = 0 - for i in range(8): + for i in range(num_bits): result |= (int(bit_slice[i])) << i return result +def uint_to_bit_slice(input_int: int, num_bits: int) -> List[bool]: + """Given an integer [0, 255], returns a bitwise representation.""" + result: List[bool] = [False] * num_bits + for i in range(num_bits): + result[i] = ((input_int >> i) & 1) != 0 + return result + + +def bit_slice_to_u8(bit_slice: List[bool]) -> int: + """Given a bit slice of length 8, returns a base-10 int representation.""" + if len(bit_slice) != 8: + raise ValueError(f'Expected an 8-bit representation but got: {bit_slice}.') + return bit_slice_to_uint(bit_slice, 8) + def u8_to_bit_slice(input_int: int) -> List[bool]: """Given an integer [0, 255], returns a bitwise representation.""" if input_int < 0 or input_int > 255: raise ValueError(f'Expected a u8, but got: {input_int}.') - result: List[bool] = [False] * 8 - for i in range(8): - result[i] = ((input_int >> i) & 1) != 0 - return result + return uint_to_bit_slice(input_int, 8) def u8_list_to_bit_slice(input_list: List[int]) -> List[bool]: