Skip to content

feat(apollo_infra_utils): add function to format Cairo0 source files #6269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion crates/apollo_infra_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license-file.workspace = true
description = "Infrastructure utility."

[features]
testing = ["cached", "colored", "dep:assert-json-diff", "socket2", "toml"]
testing = ["cached", "colored", "dep:assert-json-diff", "socket2", "tempfile", "toml"]

[lints]
workspace = true
Expand All @@ -20,6 +20,7 @@ num_enum.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
socket2 = { workspace = true, optional = true }
tempfile = { workspace = true, optional = true }
thiserror.workspace = true
tokio = { workspace = true, features = ["process", "rt", "time"] }
toml = { workspace = true, optional = true }
Expand All @@ -33,6 +34,7 @@ nix.workspace = true
pretty_assertions.workspace = true
rstest.workspace = true
socket2.workspace = true
tempfile.workspace = true
tokio = { workspace = true, features = ["macros", "rt", "signal", "sync"] }
toml.workspace = true
tracing-subscriber = { workspace = true, features = ["env-filter"] }
57 changes: 40 additions & 17 deletions crates/apollo_infra_utils/src/cairo0_compiler.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(any(test, feature = "testing"))]
use std::io::Write;
use std::path::PathBuf;
use std::process::Command;
use std::sync::LazyLock;
Expand All @@ -10,6 +12,7 @@ pub mod test;

pub const STARKNET_COMPILE_DEPRECATED: &str = "starknet-compile-deprecated";
pub const CAIRO0_COMPILE: &str = "cairo-compile";
pub const CAIRO0_FORMAT: &str = "cairo-format";
pub const EXPECTED_CAIRO0_VERSION: &str = "0.14.0a1";

/// The local python requirements used to determine the cairo0 compiler version.
Expand All @@ -27,12 +30,12 @@ pip install -r {:#?}"#,
});

#[derive(thiserror::Error, Debug)]
pub enum Cairo0CompilerVersionError {
pub enum Cairo0ScriptVersionError {
#[error(
"{compiler} version is not correct: required {required}, got {existing}. Are you in the \
"{script} version is not correct: required {required}, got {existing}. Are you in the \
venv? If not, run the following commands:\n{}", *ENTER_VENV_INSTRUCTIONS
)]
IncorrectVersion { compiler: String, existing: String, required: String },
IncorrectVersion { script: String, existing: String, required: String },
#[error(
"{0}. Are you in the venv? If not, run the following commands:\n{}",
*ENTER_VENV_INSTRUCTIONS
Expand All @@ -43,7 +46,7 @@ pub enum Cairo0CompilerVersionError {
#[derive(thiserror::Error, Debug)]
pub enum Cairo0CompilerError {
#[error(transparent)]
Cairo0CompilerVersion(#[from] Cairo0CompilerVersionError),
Cairo0CompilerVersion(#[from] Cairo0ScriptVersionError),
#[error("Cairo root path not found at {0:?}.")]
CairoRootNotFound(PathBuf),
#[error("Failed to compile the program. Error: {0}.")]
Expand All @@ -56,22 +59,22 @@ pub enum Cairo0CompilerError {
SourceFileNotFound(PathBuf),
}

pub fn cairo0_compilers_correct_version() -> Result<(), Cairo0CompilerVersionError> {
for compiler in [CAIRO0_COMPILE, STARKNET_COMPILE_DEPRECATED] {
let version = match Command::new(compiler).arg("--version").output() {
pub fn cairo0_scripts_correct_version() -> Result<(), Cairo0ScriptVersionError> {
for script in [CAIRO0_COMPILE, CAIRO0_FORMAT, STARKNET_COMPILE_DEPRECATED] {
let version = match Command::new(script).arg("--version").output() {
Ok(output) => String::from_utf8_lossy(&output.stdout).to_string(),
Err(error) => {
return Err(Cairo0CompilerVersionError::CompilerNotFound(format!(
"Failed to get {compiler} version: {error}."
return Err(Cairo0ScriptVersionError::CompilerNotFound(format!(
"Failed to get {script} version: {error}."
)));
}
};
if version.trim().replace("==", " ").split(" ").nth(1).ok_or(
Cairo0CompilerVersionError::CompilerNotFound("No compiler version found.".to_string()),
Cairo0ScriptVersionError::CompilerNotFound("No script version found.".to_string()),
)? != EXPECTED_CAIRO0_VERSION
{
return Err(Cairo0CompilerVersionError::IncorrectVersion {
compiler: compiler.to_string(),
return Err(Cairo0ScriptVersionError::IncorrectVersion {
script: script.to_string(),
existing: version,
required: EXPECTED_CAIRO0_VERSION.to_string(),
});
Expand All @@ -86,7 +89,7 @@ pub fn compile_cairo0_program(
path_to_main: PathBuf,
cairo_root_path: PathBuf,
) -> Result<Vec<u8>, Cairo0CompilerError> {
cairo0_compilers_correct_version()?;
cairo0_scripts_correct_version()?;
if !path_to_main.exists() {
return Err(Cairo0CompilerError::SourceFileNotFound(path_to_main));
}
Expand Down Expand Up @@ -116,17 +119,17 @@ pub fn compile_cairo0_program(

/// Verifies that the required Cairo0 compiler is available; panics if unavailable.
/// For use in tests only. If cairo0 compiler verification is required in business logic, use
/// `crate::cairo0_compiler::cairo0_compilers_correct_version` instead.
/// `crate::cairo0_compiler::cairo0_scripts_correct_version` instead.
#[cfg(any(test, feature = "testing"))]
pub fn verify_cairo0_compiler_deps() {
let specific_error = match cairo0_compilers_correct_version() {
let specific_error = match cairo0_scripts_correct_version() {
Ok(_) => {
return;
}
Err(Cairo0CompilerVersionError::CompilerNotFound(_)) => {
Err(Cairo0ScriptVersionError::CompilerNotFound(_)) => {
"no installed cairo-lang found".to_string()
}
Err(Cairo0CompilerVersionError::IncorrectVersion { existing, .. }) => {
Err(Cairo0ScriptVersionError::IncorrectVersion { existing, .. }) => {
format!("installed version: {existing}")
}
};
Expand All @@ -137,3 +140,23 @@ pub fn verify_cairo0_compiler_deps() {
*ENTER_VENV_INSTRUCTIONS
);
}

/// Runs the Cairo0 formatter on the input source code.
#[cfg(any(test, feature = "testing"))]
pub fn cairo0_format(unformatted: &String) -> String {
verify_cairo0_compiler_deps();

// Dump string to temporary file.
let mut temp_file = tempfile::NamedTempFile::new().unwrap();
temp_file.write_all(unformatted.as_bytes()).unwrap();

// Run formatter.
let mut command = Command::new("cairo-format");
command.arg(temp_file.path().to_str().unwrap());
let format_output = command.output().unwrap();
let stderr_output = String::from_utf8(format_output.stderr).unwrap();
assert!(format_output.status.success(), "{stderr_output}");

// Return formatted file.
String::from_utf8_lossy(format_output.stdout.as_slice()).to_string()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Autogenerated file.

// Transaction hash prefixes.
const DECLARE_HASH_PREFIX = 'declare';
const DEPLOY_HASH_PREFIX = 'deploy';
const DEPLOY_ACCOUNT_HASH_PREFIX = 'deploy_account';
const INVOKE_HASH_PREFIX = 'invoke';
const L1_HANDLER_HASH_PREFIX = 'l1_handler';

// An entry point offset that indicates that nothing needs to be done.
// Used to implement an empty constructor.
const NOP_ENTRY_POINT_OFFSET = {NOP_ENTRY_POINT_OFFSET};

const ENTRY_POINT_TYPE_EXTERNAL = {ENTRY_POINT_TYPE_EXTERNAL};
const ENTRY_POINT_TYPE_L1_HANDLER = {ENTRY_POINT_TYPE_L1_HANDLER};
const ENTRY_POINT_TYPE_CONSTRUCTOR = {ENTRY_POINT_TYPE_CONSTRUCTOR};

const L1_HANDLER_VERSION = {L1_HANDLER_VERSION};
const L1_HANDLER_L2_GAS_MAX_AMOUNT = {L1_HANDLER_L2_GAS_MAX_AMOUNT};

// Upper bound on the number of elements in a Sierra array.
const SIERRA_ARRAY_LEN_BOUND = {SIERRA_ARRAY_LEN_BOUND}; // 2^32

// get_selector_from_name('constructor').
const CONSTRUCTOR_ENTRY_POINT_SELECTOR = ({CONSTRUCTOR_ENTRY_POINT_SELECTOR});

// get_selector_from_name('__execute__').
const EXECUTE_ENTRY_POINT_SELECTOR = ({EXECUTE_ENTRY_POINT_SELECTOR});

// get_selector_from_name('__validate__').
const VALIDATE_ENTRY_POINT_SELECTOR = ({VALIDATE_ENTRY_POINT_SELECTOR});

// get_selector_from_name('__validate_declare__').
const VALIDATE_DECLARE_ENTRY_POINT_SELECTOR = ({VALIDATE_DECLARE_ENTRY_POINT_SELECTOR});

// get_selector_from_name('__validate_deploy__').
const VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR = ({VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR});

// get_selector_from_name('transfer').
const TRANSFER_ENTRY_POINT_SELECTOR = ({TRANSFER_ENTRY_POINT_SELECTOR});

const DEFAULT_ENTRY_POINT_SELECTOR = {DEFAULT_ENTRY_POINT_SELECTOR};

// OS reserved contract addresses.

// This contract stores the block number -> block hash mapping.
const BLOCK_HASH_CONTRACT_ADDRESS = {BLOCK_HASH_CONTRACT_ADDRESS};
// This contract stores the aliases mapping used for stateful compression.
const ALIAS_CONTRACT_ADDRESS = {ALIAS_CONTRACT_ADDRESS};
// Future reserved contract address.
const RESERVED_CONTRACT_ADDRESS = {RESERVED_CONTRACT_ADDRESS};
// The block number -> block hash mapping is written for the current block number minus this number.
const STORED_BLOCK_HASH_BUFFER = {STORED_BLOCK_HASH_BUFFER};

// Gas constants.

const STEP_GAS_COST = {STEP_GAS_COST};
const RANGE_CHECK_GAS_COST = {RANGE_CHECK_GAS_COST};
const RANGE_CHECK96_GAS_COST = {RANGE_CHECK96_GAS_COST};
const KECCAK_BUILTIN_GAS_COST = {KECCAK_BUILTIN_GAS_COST};
const PEDERSEN_GAS_COST = {PEDERSEN_GAS_COST};
const BITWISE_BUILTIN_GAS_COST = {BITWISE_BUILTIN_GAS_COST};
const ECOP_GAS_COST = {ECOP_GAS_COST};
const POSEIDON_GAS_COST = {POSEIDON_GAS_COST};
const ADD_MOD_GAS_COST = {ADD_MOD_GAS_COST};
const MUL_MOD_GAS_COST = {MUL_MOD_GAS_COST};
const ECDSA_GAS_COST = {ECDSA_GAS_COST};
const MEMORY_HOLE_GAS_COST = {MEMORY_HOLE_GAS_COST};

const DEFAULT_INITIAL_GAS_COST = {DEFAULT_INITIAL_GAS_COST};
const VALIDATE_MAX_SIERRA_GAS = {VALIDATE_MAX_SIERRA_GAS};
const EXECUTE_MAX_SIERRA_GAS = {EXECUTE_MAX_SIERRA_GAS};
const DEFAULT_INITIAL_GAS_COST_NO_L2 = VALIDATE_MAX_SIERRA_GAS + EXECUTE_MAX_SIERRA_GAS;

// Compiler gas costs.

// The initial budget at an entry point. This needs to be high enough to cover the initial get_gas.
// The entry point may refund whatever remains from the initial budget.
const ENTRY_POINT_INITIAL_BUDGET = {ENTRY_POINT_INITIAL_BUDGET};
// The gas cost of each syscall libfunc (this value is hard-coded by the compiler).
// This needs to be high enough to cover OS costs in the case of failure due to out of gas.
const SYSCALL_BASE_GAS_COST = {SYSCALL_BASE_GAS_COST};

// Syscall gas costs.
const CALL_CONTRACT_GAS_COST = {CALL_CONTRACT_GAS_COST};
const DEPLOY_GAS_COST = {DEPLOY_GAS_COST};
const DEPLOY_CALLDATA_FACTOR_GAS_COST = {DEPLOY_CALLDATA_FACTOR_GAS_COST};
const GET_BLOCK_HASH_GAS_COST = {GET_BLOCK_HASH_GAS_COST};
const GET_CLASS_HASH_AT_GAS_COST = {GET_CLASS_HASH_AT_GAS_COST};
const GET_EXECUTION_INFO_GAS_COST = {GET_EXECUTION_INFO_GAS_COST};
const LIBRARY_CALL_GAS_COST = {LIBRARY_CALL_GAS_COST};
const REPLACE_CLASS_GAS_COST = {REPLACE_CLASS_GAS_COST};
// TODO(Yoni, 1/1/2026): take into account Patricia updates and dict squash.
const STORAGE_READ_GAS_COST = {STORAGE_READ_GAS_COST};
const STORAGE_WRITE_GAS_COST = {STORAGE_WRITE_GAS_COST};
const EMIT_EVENT_GAS_COST = {EMIT_EVENT_GAS_COST};
const SEND_MESSAGE_TO_L1_GAS_COST = {SEND_MESSAGE_TO_L1_GAS_COST};
const META_TX_V0_GAS_COST = {META_TX_V0_GAS_COST};
const META_TX_V0_CALLDATA_FACTOR_GAS_COST = {META_TX_V0_CALLDATA_FACTOR_GAS_COST};

// Note the the following costs include `SYSCALL_BASE_GAS_COST` implicitly.
const SECP256K1_ADD_GAS_COST = {SECP256K1_ADD_GAS_COST};
const SECP256K1_GET_POINT_FROM_X_GAS_COST = {SECP256K1_GET_POINT_FROM_X_GAS_COST};
const SECP256K1_GET_XY_GAS_COST = {SECP256K1_GET_XY_GAS_COST};
const SECP256K1_MUL_GAS_COST = {SECP256K1_MUL_GAS_COST};
const SECP256K1_NEW_GAS_COST = {SECP256K1_NEW_GAS_COST};
const SECP256R1_ADD_GAS_COST = {SECP256R1_ADD_GAS_COST};
const SECP256R1_GET_POINT_FROM_X_GAS_COST = {SECP256R1_GET_POINT_FROM_X_GAS_COST};
const SECP256R1_GET_XY_GAS_COST = {SECP256R1_GET_XY_GAS_COST};
const SECP256R1_MUL_GAS_COST = {SECP256R1_MUL_GAS_COST};
const SECP256R1_NEW_GAS_COST = {SECP256R1_NEW_GAS_COST};

const KECCAK_GAS_COST = {KECCAK_GAS_COST};
const KECCAK_ROUND_COST_GAS_COST = {KECCAK_ROUND_COST_GAS_COST};
const SHA256_PROCESS_BLOCK_GAS_COST = {SHA256_PROCESS_BLOCK_GAS_COST};

// Cairo 1.0 error codes.
const ERROR_BLOCK_NUMBER_OUT_OF_RANGE = {ERROR_BLOCK_NUMBER_OUT_OF_RANGE};
const ERROR_OUT_OF_GAS = {ERROR_OUT_OF_GAS};
const ERROR_ENTRY_POINT_FAILED = {ERROR_ENTRY_POINT_FAILED};
const ERROR_ENTRY_POINT_NOT_FOUND = {ERROR_ENTRY_POINT_NOT_FOUND};
const ERROR_INVALID_INPUT_LEN = {ERROR_INVALID_INPUT_LEN};
const ERROR_INVALID_ARGUMENT = {ERROR_INVALID_ARGUMENT};

// The expected return value of the `__validate*__` functions of a Cairo 1.0 account contract.
const VALIDATED = {VALIDATED};

// Resources
const L1_GAS = {L1_GAS};
const L2_GAS = {L2_GAS};
const L1_DATA_GAS = {L1_DATA_GAS};
const L1_GAS_INDEX = {L1_GAS_INDEX};
const L2_GAS_INDEX = {L2_GAS_INDEX};
const L1_DATA_GAS_INDEX = {L1_DATA_GAS_INDEX};

// Round down the block number and timestamp when queried inside `__validate__`.
const VALIDATE_BLOCK_NUMBER_ROUNDING = {VALIDATE_BLOCK_NUMBER_ROUNDING};
const VALIDATE_TIMESTAMP_ROUNDING = {VALIDATE_TIMESTAMP_ROUNDING};

// List of CairoZero account contracts that require the transaction version to be 1.
{V1_BOUND_ACCOUNTS_CAIRO0}

// List of Cairo1 account contracts that require the transaction version to be 1.
{V1_BOUND_ACCOUNTS_CAIRO1}

// Max transaction tip for which a v3 transaction can be replaced with a v1 transaction.
const V1_BOUND_ACCOUNTS_MAX_TIP = {V1_BOUND_ACCOUNTS_MAX_TIP};

// List of Cairo1 account contracts that require the resource bounds to exclude data gas.
{DATA_GAS_ACCOUNTS}
Loading
Loading