Skip to content

Commit

Permalink
Fast compile, mpcheck interfaces, improve circuits. (#4)
Browse files Browse the repository at this point in the history
* Split files for faster compilation, add mpcheck interfaces, improve circuits.

* temp patch before new pip package
  • Loading branch information
feltroidprime authored Oct 14, 2024
1 parent b500600 commit 8e10d78
Show file tree
Hide file tree
Showing 63 changed files with 427,915 additions and 422,852 deletions.
142 changes: 84 additions & 58 deletions precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import re
from enum import Enum

from garaga.definitions import CurveID
from garaga.precompiled_circuits.compilable_circuits.base import (
BaseModuloCircuit,
cairo1_tests_header,
compilation_mode_to_file_header,
format_cairo_files_in_parallel,
to_snake_case,
Expand Down Expand Up @@ -321,74 +322,99 @@ def main(
CIRCUITS_TO_COMPILE: dict[CircuitID, dict],
compilation_mode: int = 0,
):
"""Compiles and writes all circuits to .cairo files"""

# Ensure the 'codes' dict keys match the filenames used for file creation.
# Using sets to remove potential duplicates
filenames_used = set([v["filename"] for v in CIRCUITS_TO_COMPILE.values()])
codes = {filename: set() for filename in filenames_used}
selector_functions = {filename: set() for filename in filenames_used}
cairo1_tests_functions = {filename: set() for filename in filenames_used}
cairo1_full_function_names = {filename: set() for filename in filenames_used}

files = {
f: open(f"{PRECOMPILED_CIRCUITS_DIR}{f}.cairo", "w") for f in filenames_used
}
"""Compiles and writes all circuits to separate .cairo files"""

# Write the header to each file
HEADER = compilation_mode_to_file_header(compilation_mode)

for file in files.values():
file.write(HEADER)

# Instantiate and compile circuits for each curve
# Dictionary to store compiled circuits and selectors for each filename
compiled_files = {}

for circuit_id, circuit_info in CIRCUITS_TO_COMPILE.items():
for curve_id in circuit_info.get(
"curve_ids", [CurveID.BN254, CurveID.BLS12_381]
):
filename_key = circuit_info["filename"]
compiled_circuits, selectors = compile_circuit(
curve_id,
circuit_info["class"],
circuit_id,
circuit_info["params"],
compilation_mode,
)
codes[filename_key].update(compiled_circuits)
selector_functions[filename_key].update(selectors)

# Write selector functions and compiled circuit codes to their respective files
print(f"Writing circuits and selectors to .cairo files...")
for filename in filenames_used:
if filename in files:
# Write the selector functions for this file
for selector_function in sorted(selector_functions[filename]):
files[filename].write(selector_function)
# Write the compiled circuit codes
for compiled_circuit in sorted(codes[filename]):
files[filename].write(compiled_circuit + "\n")

if compilation_mode == 1:
files[filename].write(cairo1_tests_header() + "\n")
fns_to_import = sorted(cairo1_full_function_names[filename])
if "" in fns_to_import:
fns_to_import.remove("")
files[filename].write(f"use super::{{{','.join(fns_to_import)}}};\n")
for cairo1_test in sorted(cairo1_tests_functions[filename]):
files[filename].write(cairo1_test + "\n")
files[filename].write("}\n")
circuit_class = circuit_info["class"]
params = circuit_info["params"]
print(f"id: {circuit_id}, params: {params}")
temp_instance = circuit_class(
curve_id=CurveID.BN254.value,
compilation_mode=compilation_mode,
**(params[0] if params else {}),
)

else:
print(f"Warning: No file associated with filename '{filename}'")
if temp_instance.circuit.generic_circuit:
# Handle generic circuits (keep selector function and all in one file)
filename = f"{circuit_info['filename']}.cairo"
if filename not in compiled_files:
compiled_files[filename] = {"circuits": set(), "selectors": set()}

# Close all files
for file in files.values():
file.close()
for curve_id in [CurveID.BN254, CurveID.BLS12_381]:
circuits, selectors = compile_circuit(
curve_id, circuit_class, circuit_id, params, compilation_mode
)
compiled_files[filename]["circuits"].update(circuits)
compiled_files[filename]["selectors"].update(selectors)

else:
# Handle non-generic circuits (separate files for each)
for curve_id in [CurveID.BN254, CurveID.BLS12_381]:
if params is None:
params = [None]
for param in params:
param_str = f"_{param[list(param.keys())[0]]}" if param else ""
filename = f"{circuit_id.name.lower()}_{curve_id.name.lower()}{param_str}.cairo"

if filename not in compiled_files:
compiled_files[filename] = {
"circuits": set(),
"selectors": set(),
}

circuits, _ = compile_circuit(
curve_id,
circuit_class,
circuit_id,
[param] if param else None,
compilation_mode,
)
compiled_files[filename]["circuits"].update(circuits)
# compiled_files[filename]["selectors"].update(selectors)

# Write compiled circuits and selectors to files
for filename, content in compiled_files.items():
full_path = f"{PRECOMPILED_CIRCUITS_DIR}{filename}"
print(f"Writing {full_path}...")
with open(full_path, "w") as file:
file.write(HEADER)

# Add comment with available function names for generic circuits
if any("get_" in selector for selector in content["selectors"]):
function_names = set()
for circuit in content["circuits"]:
# Extract function name from the circuit code
match = re.search(r"func\s+(\w+)", circuit)
if match:
function_names.add(match.group(1))

file.write("// Available functions:\n")
for name in sorted(function_names):
file.write(f"// - {name}\n")
file.write("\n")

for selector in content["selectors"]:
file.write(selector + "\n")

for circuit in content["circuits"]:
file.write(circuit + "\n")

# Format all created files
all_files = [
f.removesuffix(".cairo")
for f in os.listdir(PRECOMPILED_CIRCUITS_DIR)
if f.endswith(".cairo")
]
format_cairo_files_in_parallel(
filenames_used, compilation_mode, PRECOMPILED_CIRCUITS_DIR
all_files, compilation_mode, PRECOMPILED_CIRCUITS_DIR
)

return None


Expand Down
1 change: 1 addition & 0 deletions precompiled_circuits/compilable_circuits/fustat_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def _run_circuit_inner(self, input: list[PyFelt]):
m, _, _, _, _ = circuit.multi_pairing_check(n_pairs)

circuit.extend_output(m)

circuit.finalize_circuit()

return circuit
61 changes: 61 additions & 0 deletions src/bls12_381/final_exp.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from starkware.cairo.common.registers import get_fp_and_pc
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin
from definitions import E12D, E6D, is_zero_E6D, one_E6D, one_E12D, bls, TRUE

from precompiled_circuits.final_exp_part_1_bls12_381 import get_BLS12_381_FINAL_EXP_PART_1_circuit
from precompiled_circuits.final_exp_part_2_bls12_381 import get_BLS12_381_FINAL_EXP_PART_2_circuit

from modulo_circuit import (
run_extension_field_modulo_circuit,
run_extension_field_modulo_circuit_continuation,
)
func final_exponentiation{
range_check_ptr,
poseidon_ptr: PoseidonBuiltin*,
range_check96_ptr: felt*,
add_mod_ptr: ModBuiltin*,
mul_mod_ptr: ModBuiltin*,
}(input: E12D*) -> (res: E12D) {
alloc_locals;
let (__fp__, _) = get_fp_and_pc();

local num: E6D = E6D(
v0=input.w0, v1=input.w2, v2=input.w4, v3=input.w6, v4=input.w8, v5=input.w10
);
local den: E6D = E6D(
v0=input.w1, v1=input.w3, v2=input.w5, v3=input.w7, v4=input.w9, v5=input.w11
);
let (local circuit_input: felt*) = alloc();
memcpy(dst=circuit_input, src=cast(&num, felt*), len=24);

let (den_is_zero) = is_zero_E6D(den, bls.CURVE_ID);
if (den_is_zero == TRUE) {
let (local one_E6: E6D) = one_E6D();
memcpy(dst=circuit_input + 24, src=cast(&one_E6, felt*), len=24);
} else {
memcpy(dst=circuit_input + 24, src=cast(&den, felt*), len=24);
}

let (local circuit) = get_BLS12_381_FINAL_EXP_PART_1_circuit();
let (output: felt*, Z: felt) = run_extension_field_modulo_circuit(circuit, circuit_input);
// %{
// part1 = pack_bigint_ptr(memory, ids.output, 4, 2**96, ids.circuit.output_len//4)
// for x in part1:
// print(f"T0/T2/_SUM = {hex(x)}")
// %}
let _sum = [cast(output + 2 * E6D.SIZE, E6D*)];
let (_sum_is_zero) = is_zero_E6D(_sum, bls.CURVE_ID);

if (_sum_is_zero == TRUE) {
let (one_E12: E12D) = one_E12D();
return (res=one_E12);
} else {
let (circuit) = get_BLS12_381_FINAL_EXP_PART_2_circuit();
let (output: felt*, _: felt) = run_extension_field_modulo_circuit_continuation(
circuit, output, Z
);
return (res=[cast(output, E12D*)]);
}
}
47 changes: 47 additions & 0 deletions src/bls12_381/multi_pairing_1.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from starkware.cairo.common.registers import get_fp_and_pc
from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin, UInt384
from definitions import E12D, E6D, G1G2Pair, TRUE, bls

from precompiled_circuits.multi_miller_loop_bls12_381_1 import (
get_BLS12_381_MULTI_MILLER_LOOP_1_circuit,
)

from modulo_circuit import (
run_extension_field_modulo_circuit,
run_extension_field_modulo_circuit_continuation,
)
from bls12_381.final_exp import final_exponentiation
from ec_ops import all_g1_g2_pairs_are_on_curve

func multi_pairing_1P{
range_check_ptr,
poseidon_ptr: PoseidonBuiltin*,
range_check96_ptr: felt*,
add_mod_ptr: ModBuiltin*,
mul_mod_ptr: ModBuiltin*,
}(input: G1G2Pair*) -> (res: E12D) {
alloc_locals;
let n_pairs = 1;
let (all_on_curve) = all_g1_g2_pairs_are_on_curve(input, n_pairs, bls.CURVE_ID);
assert all_on_curve = TRUE;

let (m) = multi_miller_loop_1P(cast(input, felt*));

let (f) = final_exponentiation(m);

return (res=f);
}

func multi_miller_loop_1P{
range_check_ptr,
poseidon_ptr: PoseidonBuiltin*,
range_check96_ptr: felt*,
add_mod_ptr: ModBuiltin*,
mul_mod_ptr: ModBuiltin*,
}(input: felt*) -> (res: E12D*) {
alloc_locals;
let (__fp__, _) = get_fp_and_pc();
let (circuit) = get_BLS12_381_MULTI_MILLER_LOOP_1_circuit();
let (output: felt*, _) = run_extension_field_modulo_circuit(circuit, input);
return (res=cast(output, E12D*));
}
47 changes: 47 additions & 0 deletions src/bls12_381/multi_pairing_2.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from starkware.cairo.common.registers import get_fp_and_pc
from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin, UInt384
from definitions import E12D, E6D, G1G2Pair, TRUE, bls

from precompiled_circuits.multi_miller_loop_bls12_381_2 import (
get_BLS12_381_MULTI_MILLER_LOOP_2_circuit,
)

from modulo_circuit import (
run_extension_field_modulo_circuit,
run_extension_field_modulo_circuit_continuation,
)
from bls12_381.final_exp import final_exponentiation
from ec_ops import all_g1_g2_pairs_are_on_curve

func multi_pairing_2P{
range_check_ptr,
poseidon_ptr: PoseidonBuiltin*,
range_check96_ptr: felt*,
add_mod_ptr: ModBuiltin*,
mul_mod_ptr: ModBuiltin*,
}(input: G1G2Pair*) -> (res: E12D) {
alloc_locals;
let n_pairs = 2;
let (all_on_curve) = all_g1_g2_pairs_are_on_curve(input, n_pairs, bls.CURVE_ID);
assert all_on_curve = TRUE;

let (m) = multi_miller_loop_2P(cast(input, felt*));

let (f) = final_exponentiation(m);

return (res=f);
}

func multi_miller_loop_2P{
range_check_ptr,
poseidon_ptr: PoseidonBuiltin*,
range_check96_ptr: felt*,
add_mod_ptr: ModBuiltin*,
mul_mod_ptr: ModBuiltin*,
}(input: felt*) -> (res: E12D*) {
alloc_locals;
let (__fp__, _) = get_fp_and_pc();
let (circuit) = get_BLS12_381_MULTI_MILLER_LOOP_2_circuit();
let (output: felt*, _) = run_extension_field_modulo_circuit(circuit, input);
return (res=cast(output, E12D*));
}
47 changes: 47 additions & 0 deletions src/bls12_381/multi_pairing_3.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from starkware.cairo.common.registers import get_fp_and_pc
from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin, UInt384
from definitions import E12D, E6D, G1G2Pair, TRUE, bls

from precompiled_circuits.multi_miller_loop_bls12_381_3 import (
get_BLS12_381_MULTI_MILLER_LOOP_3_circuit,
)

from modulo_circuit import (
run_extension_field_modulo_circuit,
run_extension_field_modulo_circuit_continuation,
)
from bls12_381.final_exp import final_exponentiation
from ec_ops import all_g1_g2_pairs_are_on_curve

func multi_pairing_3P{
range_check_ptr,
poseidon_ptr: PoseidonBuiltin*,
range_check96_ptr: felt*,
add_mod_ptr: ModBuiltin*,
mul_mod_ptr: ModBuiltin*,
}(input: G1G2Pair*) -> (res: E12D) {
alloc_locals;
let n_pairs = 3;
let (all_on_curve) = all_g1_g2_pairs_are_on_curve(input, n_pairs, bls.CURVE_ID);
assert all_on_curve = TRUE;

let (m) = multi_miller_loop_3P(cast(input, felt*));

let (f) = final_exponentiation(m);

return (res=f);
}

func multi_miller_loop_3P{
range_check_ptr,
poseidon_ptr: PoseidonBuiltin*,
range_check96_ptr: felt*,
add_mod_ptr: ModBuiltin*,
mul_mod_ptr: ModBuiltin*,
}(input: felt*) -> (res: E12D*) {
alloc_locals;
let (__fp__, _) = get_fp_and_pc();
let (circuit) = get_BLS12_381_MULTI_MILLER_LOOP_3_circuit();
let (output: felt*, _) = run_extension_field_modulo_circuit(circuit, input);
return (res=cast(output, E12D*));
}
Loading

0 comments on commit 8e10d78

Please sign in to comment.