Skip to content

Commit

Permalink
restructure project so that OSS version can be published to PYPI
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550071706
  • Loading branch information
j2kun authored and copybara-github committed Jul 21, 2023
1 parent d891597 commit 6622e55
Show file tree
Hide file tree
Showing 44 changed files with 350 additions and 627 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ on:
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
build-and-test:
runs-on: ubuntu-latest
runs-on:
labels: ubuntu-20.04-16core
steps:
- name: Check out repository code
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3
Expand All @@ -20,8 +24,6 @@ jobs:
path: |
~/.cache/bazel
key: ${{ runner.os }}-bazel-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE', 'requirements.txt') }}
restore-keys: |
${{ runner.os }}-bazel-

- name: "Run `bazel build`"
run: |
Expand All @@ -33,6 +35,5 @@ jobs:
--test_output=errors \
--test_size_filters=small \
--test_timeout=1800 \
--jobs=2 \
--experimental_ui_max_stdouterr_bytes=10485760 \
//...
8 changes: 4 additions & 4 deletions .github/workflows/periodic_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ name: Build and Test
permissions: read-all
on:
schedule:
# Every day at 1 AM
- cron: '0 1 * * *'
# Every week on Sunday at 1 AM
- cron: '0 1 * * 0'
jobs:
build-and-test:
runs-on: ubuntu-latest
runs-on:
labels: ubuntu-20.04-16core
steps:
- name: Check out repository code
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3
Expand All @@ -21,6 +22,5 @@ jobs:
--test_output=errors \
--test_size_filters=medium,large \
--test_timeout=3600 \
--jobs=1 \
--experimental_ui_max_stdouterr_bytes=10485760 \
//...
22 changes: 22 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,25 @@ bazel-bin
bazel-jaxite
bazel-out
bazel-testlogs

# Compiled python modules.
*.pyc

# Byte-compiled
_pycache__/
.cache/

# Poetry, setuptools, PyPI distribution artifacts.
/*.egg-info
.eggs/
build/
dist/

# Type checking
.pytype/

# Other
*.DS_Store

# PyCharm
.idea
296 changes: 270 additions & 26 deletions BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# An FHE cryptosystem built in JAX

load("@rules_python//python:defs.bzl", "py_library")
load("@jaxite//bazel:test_oss.bzl", "cpu_gpu_tpu_test", "gpu_tpu_test", "tpu_test")
load("@rules_python//python:defs.bzl", "py_test")
load("@rules_license//rules:license.bzl", "license")

package(
Expand All @@ -20,40 +22,282 @@ license(

licenses(["notice"])

exports_files(["LICENSE"])

# a single-source build dependency that gives the whole (non-test) jaxite
# source tree
# source tree; note we chose the style of putting all test rules below, because
# glob does not recurse into subdirectories with BUILD files in them.
py_library(
name = "jaxite",
srcs = glob(
[
"jaxite_lib/*.py",
"jaxite_bool/*.py",
],
["**/*.py"],
exclude = [
"**/*_test.py",
"**/test_util.py",
],
),
visibility = ["//visibility:public"],
deps = [
"@jaxite//jaxite_bool",
"@jaxite//jaxite_bool:bool_encoding",
"@jaxite//jaxite_bool:bool_params",
"@jaxite//jaxite_bool:lut",
"@jaxite//jaxite_bool:type_converters",
"@jaxite//jaxite_lib:bootstrap",
"@jaxite//jaxite_lib:decomposition",
"@jaxite//jaxite_lib:encoding",
"@jaxite//jaxite_lib:key_switch",
"@jaxite//jaxite_lib:lwe",
"@jaxite//jaxite_lib:matrix_utils",
"@jaxite//jaxite_lib:parameters",
"@jaxite//jaxite_lib:random_source",
"@jaxite//jaxite_lib:rgsw",
"@jaxite//jaxite_lib:rlwe",
"@jaxite//jaxite_lib:test_polynomial",
"@jaxite//jaxite_lib:test_utils",
"@jaxite//jaxite_lib:types",
visibility = [":internal"],
deps = [
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
],
)

exports_files(["LICENSE"])
# Test rules are below, though the source files are in subdirectories.
py_library(
name = "test_utils",
srcs = ["jaxite/jaxite_lib/test_utils.py"],
srcs_version = "PY3ONLY",
deps = [
":jaxite",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
],
)

cpu_gpu_tpu_test(
name = "matrix_utils_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/matrix_utils_test.py"],
python_version = "PY3",
shard_count = 3,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "decomposition_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/decomposition_test.py"],
python_version = "PY3",
srcs_version = "PY3ONLY",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "encoding_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/encoding_test.py"],
python_version = "PY3",
srcs_version = "PY3ONLY",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
],
)

cpu_gpu_tpu_test(
name = "lwe_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/lwe_test.py"],
python_version = "PY3",
shard_count = 50,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
],
)

cpu_gpu_tpu_test(
name = "rlwe_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/rlwe_test.py"],
python_version = "PY3",
srcs_version = "PY3ONLY",
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "bootstrap_test",
size = "large",
srcs = ["jaxite/jaxite_lib/bootstrap_test.py"],
python_version = "PY3",
shard_count = 50,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "blind_rotate_test",
size = "large",
srcs = ["jaxite/jaxite_lib/blind_rotate_test.py"],
python_version = "PY3",
shard_count = 10,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "test_polynomial_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/test_polynomial_test.py"],
python_version = "PY3",
srcs_version = "PY3ONLY",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "key_switch_test",
size = "large",
srcs = ["jaxite/jaxite_lib/key_switch_test.py"],
python_version = "PY3",
shard_count = 50,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
],
)

cpu_gpu_tpu_test(
name = "random_source_test",
srcs = ["jaxite/jaxite_lib/random_source_test.py"],
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
],
)

cpu_gpu_tpu_test(
name = "rgsw_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/rgsw_test.py"],
python_version = "PY3",
shard_count = 10,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

py_test(
name = "lut_test",
srcs = ["jaxite/jaxite_bool/lut_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
],
)

gpu_tpu_test(
name = "jaxite_bool_test",
size = "large",
srcs = ["jaxite/jaxite_bool/jaxite_bool_test.py"],
python_version = "PY3",
shard_count = 50,
srcs_version = "PY3",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
],
)

gpu_tpu_test(
name = "jaxite_bool_multigate_test",
size = "large",
srcs = ["jaxite/jaxite_bool/jaxite_bool_multigate_test.py"],
python_version = "PY3",
shard_count = 20,
srcs_version = "PY3",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
],
)

tpu_test(
name = "pmap_test",
size = "large",
srcs = ["jaxite/jaxite_bool/pmap_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = ["manual"],
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
],
)
Loading

0 comments on commit 6622e55

Please sign in to comment.